In [None]:
"""
Runs forward passes and stores hidden states for later use
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

import pickle

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""

model_id = 'Qwen/Qwen2.5-7B-Instruct'
model_prefix = 'qwen2.5-7b'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Run forward pass
"""
def test_forward_pass(model):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    output = model(**inputs, output_hidden_states = True)
    all_hidden_states = [h.view(-1, h.shape[2]).detach().cpu() for h in output['hidden_states']]
    print(f'Hidden state shape: {all_hidden_states[0].shape}')

test_forward_pass(model)

## Get dataset

In [None]:
"""
Load dataset - C4 mix (en/zh/es)
"""
def load_raw_ds():
   
    ds_en = load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    ds_zh = load_dataset('allenai/c4', 'zh', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    ds_es = load_dataset('allenai/c4', 'es', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    
    def get_data(ds, n_samples):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(0, n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append(sample['text'])
        
        return raw_data
    
    return get_data(ds_en, 1200) + get_data(ds_zh, 400) + get_data(ds_es, 400)


raw_data = load_raw_ds()

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from torch.utils.data import Dataset, DataLoader

class ReconstructableTextDataset(Dataset):

    def __init__(self, text_dataset, tokenizer, max_length):
        """
        Creates a dataset object that also returns a B x N list of the original tokens in the same position as the input ids.

        Params:
            @text_dataset: A list of B samples of text dataset.
            @tokenizer: A HF tokenizer object.
        """
        tokenized = tokenizer(text_dataset, add_special_tokens = False, max_length = max_length, padding = 'max_length', truncation = True, return_offsets_mapping = True, return_tensors = 'pt')

        self.input_ids = tokenized['input_ids']
        self.attention_mask = tokenized['attention_mask']
        self.offset_mapping = tokenized['offset_mapping']
        self.original_tokens = self.get_original_tokens(text_dataset)

    def get_original_tokens(self, text_dataset):
        """
        Return the original tokens associated with each B x N position. This is important for reconstructing the original text when BPE tokenizers are used.
        
        Params:
            @input_ids: A B x N tensor of input ids.
            @offset_mapping: A B x N x 2 tensor of offset mappings. Get from `tokenizer(..., return_offsets_mapping = True)`.

        Returns:
            A list of length B, each with length N, containing the corresponding original tokens corresponding to the token ID at the same position of input_ids.
        """
        all_token_substrings = []
        for i in range(0, self.input_ids.shape[0]):
            token_substrings = []
            for j in range(self.input_ids.shape[1]): 
                start_char, end_char = self.offset_mapping[i][j].tolist()
                if start_char == 0 and end_char == 0: # When pads, offset_mapping might be [0, 0], so let's store an empty string for those positions.
                    token_substrings.append("")
                else:
                    original_substring = text_dataset[i][start_char:end_char]
                    token_substrings.append(original_substring)
            
            all_token_substrings.append(token_substrings)

        return all_token_substrings

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'original_tokens': self.original_tokens[idx]}
    
def collate_fn(batch):
    """
    Custom collate function; necessary to return original_tokens in the correct shape 
    """
    input_ids = torch.stack([b['input_ids'] for b in batch], dim = 0)
    attention_mask = torch.stack([b['attention_mask'] for b in batch], dim = 0)        
    original_tokens = [b['original_tokens'] for b in batch]
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'original_tokens': original_tokens}

test_dl = DataLoader(
    ReconstructableTextDataset(raw_data, tokenizer, max_length = 1024),
    batch_size = 4,
    shuffle = False,
    collate_fn = collate_fn
)

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data
"""

@torch.no_grad()
def run_and_export_activations(model, dl: ReconstructableTextDataset, layers_to_keep: list[int], max_batches: None | int = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`, and
          `all_post_layer_hidden_states`.
        @dl: The dataloader which returns `input_Ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep: A list of layers for which to filter `topk_df` and `all_post_layer_hidden_states` (see returned object description).
        @max_batches: The max number of batches to run.

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `topk_df`: A sample (token) x layer_ix x topk_ix level dataframe that gives the expert ID selected at each sample-layer-topk (removes masked_tokens)
        - `all_post_layer_hidden_states`: A tensor of size n_samples x layers_to_keep x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    b_count = 0
    all_post_layer_hidden_states = []
    sample_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']

        output = model(input_ids, attention_mask, output_hidden_states = True)
        post_layer_hidden_states = [h.view(-1, h.shape[2]).detach().cpu() for h in output['hidden_states']]

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(2, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :attention_mask[i].sum()], skip_special_tokens = True)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print(decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = True), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
        
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )

        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .assign(batch_ix = batch_ix)
        
        sample_dfs.append(sample_df)
        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_post_layer_hidden_states.append(torch.stack(post_layer_hidden_states, dim = 1)[valid_pos][:, layers_to_keep, :])

        b_count += 1
        if max_batches is not None and b_count >= max_batches:
            break

    return {'sample_df': pd.concat(sample_dfs),'all_post_layer_hidden_states': torch.cat(all_post_layer_hidden_states, dim = 0)}

res = run_and_export_activations(model, test_dl, layers_to_keep = list(range(0, 8)), max_batches = None)

In [None]:
# Note, this only exports the first 8 layers to save space
torch.save(res['all_post_layer_hidden_states'], f'data/{model_prefix}-all-post-layer-hidden-states.pt')

with open(f'data/{model_prefix}-metadata.pkl', 'wb') as f:
    pickle.dump({'sample_df': res['sample_df']}, f)

In [None]:
model.config