In [None]:
"""
Runs forward passes on samples and stores: (1) top-k expert selections; (2) sample metadata.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
- Qwen3MoE architecture, includes Qwen3-30B-A3B, Qwen3-235B-A22B
"""
selected_model_index = 4

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3'),
        ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 'qwen3moe')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'])
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get dataset

In [None]:
"""
Load dataset (c4)
"""
def load_raw_ds():
   
    ds_en = load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 100_000)
    ds_zh = load_dataset('allenai/c4', 'zh', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 100_000)
    ds_es = load_dataset('allenai/c4', 'es', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 100_000)
    # ds = load_dataset('HuggingFaceFW/fineweb-edu', 'CC-MAIN-2024-51', split = 'train', streaming = True).shuffle(seed = 123, buffer_size = 1_000_000)

    def get_data(ds, n_samples, data_source):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(0, n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        
        return raw_data
    
    combined_ds = get_data(ds_en, 25_000, 'en') + get_data(ds_zh, 10_000, 'zh') + get_data(ds_es, 10_000, 'es')
    
    return combined_ds

raw_data = load_raw_ds()

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from torch.utils.data import Dataset, DataLoader

class ReconstructableTextDataset(Dataset):

    def __init__(self, text_dataset, tokenizer, max_length):
        """
        Creates a dataset object that also returns a B x N list of the original tokens in the same position as the input ids.

        Params:
            @text_dataset: A list of B samples of text dataset. Each sample is a dict with keys `text` and `source`.
            @tokenizer: A HF tokenizer object.
        """
        texts = [x['text'] for x in text_dataset]
        sources = [x['source'] for x in text_dataset]
        tokenized = tokenizer(texts, add_special_tokens = False, max_length = max_length, padding = 'max_length', truncation = True, return_offsets_mapping = True, return_tensors = 'pt')

        self.input_ids = tokenized['input_ids']
        self.sources = sources
        self.attention_mask = tokenized['attention_mask']
        self.offset_mapping = tokenized['offset_mapping']
        self.original_tokens = self.get_original_tokens(texts)

    def get_original_tokens(self, texts):
        """
        Return the original tokens associated with each B x N position. This is important for reconstructing the original text when BPE tokenizers are used.
        
        Params:
            @input_ids: A B x N tensor of input ids.
            @offset_mapping: A B x N x 2 tensor of offset mappings. Get from `tokenizer(..., return_offsets_mapping = True)`.

        Returns:
            A list of length B, each with length N, containing the corresponding original tokens corresponding to the token ID at the same position of input_ids.
        """
        all_token_substrings = []
        for i in range(0, self.input_ids.shape[0]):
            token_substrings = []
            for j in range(self.input_ids.shape[1]): 
                start_char, end_char = self.offset_mapping[i][j].tolist()
                if start_char == 0 and end_char == 0: # When pads, offset_mapping might be [0, 0], so let's store an empty string for those positions.
                    token_substrings.append("")
                else:
                    original_substring = texts[i][start_char:end_char]
                    token_substrings.append(original_substring)
            
            all_token_substrings.append(token_substrings)

        return all_token_substrings

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'original_tokens': self.original_tokens[idx], 'sources': self.sources[idx]}
    
def collate_fn(batch):
    """
    Custom collate function; necessary to return original_tokens in the correct shape 
    """
    input_ids = torch.stack([b['input_ids'] for b in batch], dim = 0)
    attention_mask = torch.stack([b['attention_mask'] for b in batch], dim = 0)        
    original_tokens = [b['original_tokens'] for b in batch]
    sources = [b['sources'] for b in batch]
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'original_tokens': original_tokens, 'sources': sources}


def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 1000 each - these will be the export breaks
test_dls = [
    DataLoader(
        ReconstructableTextDataset(x, tokenizer, max_length = 1024),
        batch_size = 4,
        shuffle = False,
        collate_fn = collate_fn
    )
    for x in tqdm(chunk_list(raw_data, 5_000))
]

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data
"""

@torch.no_grad()
def run_and_export_topk(model, model_prefix, dls: list[ReconstructableTextDataset], max_batches: None | int = None):
    """
    Run forward passes on given model and saves the top-k expert ids and sample information after every dataloader

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`.
        @model_prefix: The model prefix - used for file saving.
        @dls: A list of dataloaders, each a ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @max_batches: The max number of batches to run.

    Returns:
        True on completion
    """
    cross_dl_batch_ix = 0
    output_dir = f'topks/{model_prefix}'
    os.makedirs(output_dir, exist_ok = True)

    for dl_ix, dl in enumerate(dls):

        sample_dfs = []
        topk_dfs = []

        for _, batch in tqdm(enumerate(dl), total = len(dl)):

            input_ids = batch['input_ids'].to(main_device)
            attention_mask = batch['attention_mask'].to(main_device)
            original_tokens = batch['original_tokens']
            sources = batch['sources']

            output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = False)

            # Check no bugs by validating output/perplexity
            if cross_dl_batch_ix == 0:
                loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
                for i in range(min(5, input_ids.size(0))):
                    decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
                    next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                    print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
                print(f"PPL:", torch.exp(torch.tensor(loss)).item())
            
            original_tokens_df = pd.DataFrame(
                [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
                columns = ['sequence_ix', 'token_ix', 'token']
            )
                    
            sources_df = pd.DataFrame(
                [(seq_i, seq_source) for seq_i, seq_source in enumerate(sources)], 
                columns = ['sequence_ix', 'source']
            )

            # Create sample (token) level dataframe
            sample_df =\
                convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
                .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
                .merge(sources_df, how = 'left', on = ['sequence_ix'])\
                .assign(batch_ix = cross_dl_batch_ix)

            # Create topk x layer_ix x sample level dataframe
            topk_df =\
                convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
                .assign(batch_ix = cross_dl_batch_ix, weight = lambda df: df['weight'])\
                .drop(columns = 'token_id')#
            
            sample_dfs.append(sample_df)
            topk_dfs.append(topk_df)

            cross_dl_batch_ix += 1
            if max_batches is not None and cross_dl_batch_ix >= max_batches:
                combined_sample_df = pd.concat(sample_dfs, ignore_index = True)
                combined_topk_df = pd.concat(topk_dfs, ignore_index = True)
                combined_sample_df.to_csv(f'{output_dir}/samples.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))
                combined_topk_df.to_csv(f'{output_dir}/topks.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))
                combined_topk_df[combined_topk_df['topk_ix'] == 1].to_csv(f'{output_dir}/topk1s.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))
                return True

        if sample_dfs:
            combined_sample_df = pd.concat(sample_dfs, ignore_index = True)
            combined_topk_df = pd.concat(topk_dfs, ignore_index = True)
            combined_sample_df.to_csv(f'{output_dir}/samples.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))
            combined_topk_df.to_csv(f'{output_dir}/topks.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))
            combined_topk_df[combined_topk_df['topk_ix'] == 1].to_csv(f'{output_dir}/topk1s.csv', mode = 'w' if dl_ix == 0 else 'a', index = False, header = (dl_ix == 0))

    return True


run_and_export_topk(model, model_prefix, test_dls, max_batches = None)