In [None]:
"""
This runs forward passes on samples and stores: (1) pre-MLP activations; (2) top-k expert selections; (3) sample metadata.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
from datasets import load_dataset, Dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
from termcolor import colored
import importlib
import random

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df_fast
from utils.misc import flatten_list
from utils.role_templates import render_single_message

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
selected_model_index = 1

def get_model(index):
    # HF model ID, model prefix, model architecture,  attn implementation, whether to use hf lib implementation
    models = {
        0: ('openai/gpt-oss-20b', 'gptoss20', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 24),
        1: ('openai/gpt-oss-120b', 'gptoss120', 'gptoss', 'kernels-community/vllm-flash-attn3', True, 36), # Will load experts in MXFP4 if triton kernels installed
        2: ('Qwen/Qwen3-30B-A3B-Thinking-2507', 'qwen3-30b-a3b', 'qwen3moe', None, True, 48),
        3: ('zai-org/GLM-4.5-Air-FP8', 'glm4moe', 'glm4moe', None, True, 45) # GLM-4.5 has one dense pre-layer, this is MoE layers
    }
    return models[index]

def load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf):
    """
    Load the model and tokenizer from HF, or from file if already downloaded.
    """
    cache_dir = '/workspace/hf'
    tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = cache_dir, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir = cache_dir, dtype = torch.bfloat16, trust_remote_code = not model_use_hf, device_map = 'auto', attn_implementation = model_attn).eval()
    return tokenizer, model

model_id, model_prefix, model_architecture, model_attn, model_use_hf, model_n_layers = get_model(selected_model_index)
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix, model_attn, model_use_hf)

In [None]:
"""
Some checks for special models (GPT-OSS - check attn + expert precision)
"""
if model_architecture == 'gptoss':
    print(model.model.layers[0].mlp.experts.down_proj)
    print(model.model.config._attn_implementation)

# Temp hotfix for qwen3 https://github.com/huggingface/accelerate/issues/3870

# # Test model() call
# inputs = tokenizer(['Test string'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 12).to(model.device)
# with torch.no_grad():
#     original_results = model(**inputs, use_cache = False)
# print(original_results['logits'][0, :].detach().float().cpu().numpy())

# # Test custom loader
# import importlib
# import utils.pretrained_models.ringmini2 as test_mod   # the module object
# test_mod = importlib.reload(test_mod)
# run_model_return_topk = test_mod.run_ringmini2_return_topk
# custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
# print(custom_results['logits'][0, :].detach().float().cpu().numpy())

check_memory()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs, use_cache = False)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp | post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp | post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")
    print(f"Router logits : {(custom_results['all_router_logits'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get dataset

In [None]:
# Test
s = tokenizer.apply_chat_template(
    [
        {'role': 'developer', 'content': 'Test.'},
        {'role': 'user', 'content': 'Hi! I am a dog and I like to bark'},
        {'role': 'assistant', 'content': 'I am a dog and I like to bark'}
    ],
    tokenize = False,
    padding = 'max_length',
    truncation = True,
    max_length = 512,
    add_generation_prompt = False
)
print(s)
# print(re.sub(r'(Current date:\s*)\d{4}-\d{2}-\d{2}', '2025-09-01', s))


In [None]:
print(render_single_message(model_architecture, role = 'assistant-cot',  content ='Hi'))

In [None]:
"""
Load dataset - C4 + HPLT2
"""
n_sample_size = 500

def load_raw_ds():

    def get_c4():
        return load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    # def get_hplt2():
    #     return load_dataset('HPLT/HPLT2.0_cleaned', 'eng_Latn', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples, data_source): # en
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    return get_data(get_c4(), n_sample_size, 'c4-en')    

raw_data = load_raw_ds()

In [None]:
"""
Function to get sample sequences
"""
truncated_texts = tokenizer.batch_decode(tokenizer([t['text'] for t in raw_data], padding = False, truncation = True, max_length = 768).input_ids)

def get_sample_seqs(sample_str):
    sample_seqs = {
        'system': ('system', sample_str, None),
        'user': ('user', sample_str, None),
        'assistant-cot': ('assistant-cot', sample_str, None),
        'assistant-final': ('assistant-final', sample_str, None),
        'tool': ('tool', sample_str, None)
    }
    return [
        {'role': k, 'prompt': render_single_message(model_architecture, v[0], v[1], v[2])}
        for k, v in sample_seqs.items()
    ]

input_list = flatten_list([
    [
        {'question_ix': i, 'question': x, **p}
        for p in get_sample_seqs(x)
    ]
    for i, x in enumerate(truncated_texts)
])

input_df =\
    pd.DataFrame(input_list)\
    .assign(prompt_ix = lambda df: list(range(len(df))))

input_df

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 500 each - these will be the export breaks
test_dls = [
    DataLoader(
        ReconstructableTextDataset([x['prompt'] for x in data_chunk], tokenizer, max_length = 896, prompt_ix = [x['prompt_ix'] for x in data_chunk]),
        batch_size = 16,
        shuffle = False,
        collate_fn = stack_collate
    )
    for data_chunk in tqdm(chunk_list(input_df.to_dict('records'), 500))
]

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data
"""
@torch.no_grad()
def run_and_export_topk(model, model_prefix: str, dls: list[ReconstructableTextDataset], layers_to_keep_acts: list[int], max_batches: None | int = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @model_prefix: The model prefix to save the activations under.
        @dls: A list of dataloaders, each a ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).
        @max_batches: The max number of batches to run.

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `topk_df`: A sample (token) x layer_ix x topk_ix level dataframe that gives the expert ID selected at each sample-layer-topk (removes masked_tokens)
        - `all_router_logits`: A tensor of size n_samples x layers_to_keep_acts x n_experts with the pre-softmax router logits.
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    cross_dl_batch_ix = 0
    output_dir = f'activations/{model_prefix}'
    os.makedirs(output_dir, exist_ok = True)
    
    # Save metadata
    with open(f'{output_dir}/metadata.pkl', 'wb') as f:
        pickle.dump(
            {'all_pre_mlp_hidden_states_layers': layers_to_keep_acts},
            f
        )

    # Iterate through dataloaders
    for dl_ix, dl in enumerate(dls):
        print(f"Processing {str(dl_ix)} of {len(dls)}...")   
        dl_dir = f"{output_dir}/{dl_ix:02d}"
        os.makedirs(dl_dir, exist_ok = True)

        all_router_logits = []
        all_pre_mlp_hidden_states = []
        sample_dfs = []
        topk_dfs = []

        for _, batch in tqdm(enumerate(dl), total = len(dl)):

            input_ids = batch['input_ids'].to(main_device)
            attention_mask = batch['attention_mask'].to(main_device)
            original_tokens = batch['original_tokens']
            prompt_indices = batch['prompt_ix']

            output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

            # Check no bugs by validating output/perplexity
            if cross_dl_batch_ix == 0:
                loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
                for i in range(min(20, input_ids.size(0))):
                    decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                    next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                    print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
                print(f"PPL:", torch.exp(torch.tensor(loss)).item())
            
            original_tokens_df = pd.DataFrame(
                [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
                columns = ['sequence_ix', 'token_ix', 'token']
            )
                    
            prompt_indices_df = pd.DataFrame(
                [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
                columns = ['sequence_ix', 'prompt_ix']
            )

            # Create sample (token) level dataframe
            sample_df =\
                convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
                .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
                .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
                .assign(batch_ix = cross_dl_batch_ix)

            # Create topk x layer_ix x sample level dataframe
            topk_df =\
                convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
                .assign(batch_ix = cross_dl_batch_ix, weight = lambda df: df['weight'])\
                .drop(columns = 'token_id')
            
            sample_dfs.append(sample_df)
            topk_dfs.append(topk_df)

            # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
            valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
            all_router_logits.append(torch.stack(output['all_router_logits'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])
            all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

            cross_dl_batch_ix += 1
            if max_batches is not None and cross_dl_batch_ix >= max_batches:
                pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
                pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
                torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
                torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')
                return True

        pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
        pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
        torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
        torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')

    return True

res = run_and_export_topk(
    model,
    model_prefix,
    test_dls,
    layers_to_keep_acts = list(range(0, model_n_layers, 2)),
    max_batches = None
)

In [None]:
"""
Input mappings to questions
"""
input_df.to_csv(f'./activations/{model_prefix}/input_mappings.csv', index = False)