In [None]:
"""
This runs forward passes on samples and stores: (1) pre-MLP activations; (2) expert outputs; 
 (3) top-k expert selections; (4) sample metadata.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset, Dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import gc
import pickle
import datetime

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

# Below are for Mamba replicability - can remove if remove all SSMs
# os.environ['MAMBA_DISABLE_CUDA_KERNELS'] = '1'
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
# torch.use_deterministic_algorithms(True, warn_only = False)

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- 0. OlMoE architecture: OLMoE-1B-7B-* (1B/7B)
- 1. Qwen2MoE architecture: Qwen1.5-MoE-A2.7B-* (2.7B/14.3B), Qwen2-57B-A14B-* (14B/57B)
- 2. Deepseek v2 architecture: Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B) -> use trust_remote_code = False
- 3. Deepseek v3 architecture: Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B) -> use trust_remote_code = False
- 4. Qwen3MoE architecture: Qwen3-30B-A3B (3B/30B), Qwen3-235B-A22B (22B/235B), Qwen3-Coder (35B/480B)
- 5. KimiVL architecture: Kimi-VL-A3B-* (3B/16B)
- 6. Granite architecture: Granite-4.0-Tiny-* (1B/7B)
- 7: GLM4MoE architecture: GLM-4.5 (32B/355B), GLM-4.5 Air (12B/106B) * Supports multi-GPU
- 8: GTP-OSS architecture: GPT-OSS-120B (5B/117B), GPT-OSS-20B (4B/21B)
"""
selected_model_index = 8

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe', None, True),
        1: ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe', None, True),
        2: ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2', None, True),
        3: ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3', None, True),
        4: ('Qwen/Qwen3-30B-A3B-Instruct-2507', 'qwen3moe', 'qwen3moe', None, True),
        5: ('moonshotai/Kimi-VL-A3B-Instruct', 'kimivl', 'kimivl', None, False),
        6: ('ibm-granite/granite-4.0-tiny-preview', 'granite', 'granite', None, True),
        7: ('zai-org/GLM-4.5-Air-FP8', 'glm4moe', 'glm4moe', None, True),
        8: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True), # Will load experts in MXFP4 if triton kernels installed
        9: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True)
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, torch_dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

In [None]:
"""
Checks
"""
# # Quants
# print(model.lm_head.weight)
# print(model.model.embed_tokens.weight)

# # Test model() call
# inputs = tokenizer(['Test string'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 12).to(model.device)
# with torch.no_grad():
#     original_results = model(**inputs, use_cache = False)
# print(original_results['logits'][0, :].detach().float().cpu().numpy())

# # Test custom loader
# import importlib
# import utils.pretrained_models.dsv2 as test_mod   # the module object
# test_mod = importlib.reload(test_mod)
# run_model_return_topk = test_mod.run_dsv2_return_topk
# custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
# print(custom_results['logits'][0, :].detach().float().cpu().numpy())

# check_memory()

In [None]:
# model.model.layers[0].mlp.experts.down_proj

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")
    print(f"Expert outputs : {(custom_results['all_expert_outputs'][0].shape)}")
    print(f"Router logits : {(custom_results['all_router_logits'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get dataset

In [None]:
"""
Load dataset - C4 + HLPT (en/zh/es)
"""
def load_raw_ds():
    CACHE_FILE = '/workspace/data/c4_hlpt.jsonl'

    if os.path.exists(CACHE_FILE):
        print('Loading cached dataset...')
        return load_dataset('json', data_files = CACHE_FILE, split = 'train').to_list()

    rng = np.random.default_rng(seed = seed)

    def get_hlpt(lang): # eng_Latn/zho_Hans/spa_Latn
        return load_dataset('HPLT/HPLT2.0_cleaned', lang, split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_c4(lang): # en/zh/esp
        return load_dataset('allenai/c4', lang, split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples, data_source): # en/zh/es
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})

        return raw_data
    
    combined_ds =\
        get_data(get_c4('en'), 750, 'en') + get_data(get_hlpt('eng_Latn'), 750, 'en') +\
        get_data(get_c4('zh'), 250, 'zh') + get_data(get_hlpt('zho_Hans'), 250, 'zh') +\
        get_data(get_c4('es'), 250, 'es') + get_data(get_hlpt('spa_Latn'), 250, 'es')

    combined_ds = [combined_ds[i] for i in rng.permutation(len(combined_ds))]

    print('Caching dataset to disk …')
    Dataset.from_list(combined_ds).to_json(CACHE_FILE, orient = 'records', lines = True, force_ascii = False)

    return combined_ds

raw_data = load_raw_ds()

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 250 each - these will be the export breaks
test_dls = [
    DataLoader(
        ReconstructableTextDataset([x['text'] for x in data_chunk], tokenizer, max_length = 512, sources = [x['source'] for x in data_chunk]),
        batch_size = 64,
        shuffle = False,
        collate_fn = stack_collate
    )
    for data_chunk in tqdm(chunk_list(raw_data, 250))
]

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data
"""
@torch.no_grad()
def run_and_export_topk(model, model_prefix: str, dls: list[ReconstructableTextDataset], layers_to_keep_acts: list[int], layers_to_keep_experts = list[int], topk_to_keep: int = 0, max_batches: None | int = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @model_prefix: The model prefix - used for file saving.
        @dls: A list of dataloaders, each a ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).
        @layers_to_keep_experts: A list of layer indices (0-indexed) for which to return topk-indices (see returned object description).
        @topk_to_keep: How many of the topk-indices for which to return in `all_expert_outputs` (see returned object description).
        @max_batches: The max number of batches to run.

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `topk_df`: A sample (token) x layer_ix x topk_ix level dataframe that gives the expert ID selected at each sample-layer-topk (removes masked_tokens)
        - `all_router_logits`: A tensor of size n_samples x layers_to_keep_acts x n_experts with the pre-softmax router logits.
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
        - `all_expert_outputs` A tensor of size n_samples x layers_to_keep_experts x topk_to_keep x D to return the MLP output for all retained topk indices for each each retained layer.
    """
    cross_dl_batch_ix = 0
    output_dir = f'activations-sm/{model_prefix}'
    os.makedirs(output_dir, exist_ok = True)
    
    # Save metadata
    with open(f'{output_dir}/metadata.pkl', 'wb') as f:
        pickle.dump(
            {'all_pre_mlp_hidden_states_layers': layers_to_keep_acts, 'all_expert_outputs_layers': layers_to_keep_experts},
            f
        )

    # Iterate through dataloaders
    for dl_ix, dl in enumerate(dls):
        print(f"Processing {str(dl_ix)} of {len(dls)}...")   
        dl_dir = f"{output_dir}/{dl_ix:02d}"
        os.makedirs(dl_dir, exist_ok = True)

        all_router_logits = []
        all_pre_mlp_hidden_states = []
        all_expert_outputs = []
        sample_dfs = []
        topk_dfs = []

        for _, batch in tqdm(enumerate(dl), total = len(dl)):

            input_ids = batch['input_ids'].to(main_device)
            attention_mask = batch['attention_mask'].to(main_device)
            original_tokens = batch['original_tokens']
            sources = batch['sources']

            output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

            # Check no bugs by validating output/perplexity
            if cross_dl_batch_ix == 0:
                loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
                for i in range(min(10, input_ids.size(0))):
                    decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
                    next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                    print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
                print(f"PPL:", torch.exp(torch.tensor(loss)).item())
            
            original_tokens_df = pd.DataFrame(
                [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
                columns = ['sequence_ix', 'token_ix', 'token']
            )
                    
            sources_df = pd.DataFrame(
                [(seq_i, seq_source) for seq_i, seq_source in enumerate(sources)], 
                columns = ['sequence_ix', 'source']
            )

            # Create sample (token) level dataframe
            sample_df =\
                convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
                .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
                .merge(sources_df, how = 'left', on = ['sequence_ix'])\
                .assign(batch_ix = cross_dl_batch_ix)

            # Create topk x layer_ix x sample level dataframe
            topk_df =\
                convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
                .assign(batch_ix = cross_dl_batch_ix, weight = lambda df: df['weight'])\
                .drop(columns = 'token_id')
            
            sample_dfs.append(sample_df)
            topk_dfs.append(topk_df)

            # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
            valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
            all_router_logits.append(torch.stack(output['all_router_logits'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])
            all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])
            all_expert_outputs.append(torch.stack(output['all_expert_outputs'], dim = 1)[valid_pos][:, layers_to_keep_experts, 0:topk_to_keep, :]) # This is BN x n_layers x topk x D - keep only top1 + top2

            cross_dl_batch_ix += 1
            if max_batches is not None and cross_dl_batch_ix >= max_batches:
                pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
                pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
                torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
                torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')
                torch.save(torch.cat(all_expert_outputs, dim = 0), f'{dl_dir}/all-expert-outputs.pt')
                return True

        pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
        pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
        torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
        torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')
        torch.save(torch.cat(all_expert_outputs, dim = 0), f'{dl_dir}/all-expert-outputs.pt')

    return True

if model_prefix == 'olmoe':
    layers_to_keep_acts = list(range(16))
elif model_prefix == 'qwen1.5moe':
    layers_to_keep_acts = list(range(24))
elif model_prefix == 'dsv2':
    layers_to_keep_acts = list(range(26))
elif model_prefix == 'moonlight':
    layers_to_keep_acts = list(range(26))
elif model_prefix == 'qwen3moe':
    layers_to_keep_acts = list(range(48))
elif model_prefix == 'kimivl':
    layers_to_keep_acts = list(range(26))
elif model_prefix == 'granite':
    layers_to_keep_acts = list(range(40))
elif model_prefix == 'glm4moe':
    layers_to_keep_acts = list(range(45))

res = run_and_export_topk(
    model,
    model_prefix,
    test_dls,
    layers_to_keep_acts = layers_to_keep_acts,
    layers_to_keep_experts = layers_to_keep_acts,
    topk_to_keep = 1,
    max_batches = None
)