In [None]:
"""
This runs forward passes on samples and stores: (1) pre-MLP activations; (2) expert outputs; 
 (3) top-k expert selections; (4) sample metadata.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset, Dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import pickle
import safetensors

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer, load_custom_forward_pass
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df_fast

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

# Below are for Mamba replicability - can remove if remove all SSMs
# os.environ['MAMBA_DISABLE_CUDA_KERNELS'] = '1'
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
# torch.use_deterministic_algorithms(True, warn_only = False)

ws = '/workspace/interpretable-moes-analysis'

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gpt-oss-20b'
tokenizer, model, model_architecture, model_n_moe_layers, _ = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

In [None]:
"""
Load custom forward pass and verify equality to base model forward pass
"""
run_forward_return_metadata = load_custom_forward_pass(model_architecture, model, tokenizer)

## Get dataset

In [None]:
"""
Load dataset - C4 + HLPT (en/zh/es)
"""
def load_raw_ds():
    CACHE_FILE = '/workspace/data/pretrain.jsonl'
    if os.path.exists(CACHE_FILE):
        print('Loading cached dataset...')
        return load_dataset('json', data_files = CACHE_FILE, split = 'train').to_list()

    rng = np.random.default_rng(seed = seed)

    def get_hlpt(lang): # eng_Latn/zho_Hans/spa_Latn
        return load_dataset('HPLT/HPLT2.0_cleaned', lang, split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_c4(lang): # en/zh/esp
        return load_dataset('allenai/c4', lang, split = 'validation', streaming = True).shuffle(seed = seed, buffer_size = 50_000)

    def get_data(ds, n_samples, data_source): # en/zh/es
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text'], 'source': data_source})
        return raw_data
    
    combined_ds =\
        get_data(get_c4('en'), 300, 'en') + get_data(get_hlpt('eng_Latn'), 600, 'en') +\
        get_data(get_c4('zh'), 100, 'zh') + get_data(get_hlpt('zho_Hans'), 200, 'zh') +\
        get_data(get_c4('es'), 100, 'es') + get_data(get_hlpt('spa_Latn'), 200, 'es')

    combined_ds = [combined_ds[i] for i in rng.permutation(len(combined_ds))]

    print('Caching dataset to disk â€¦')
    Dataset.from_list(combined_ds).to_json(CACHE_FILE, orient = 'records', lines = True, force_ascii = False)

    return combined_ds

raw_data = load_raw_ds()

In [None]:
"""
Create sample sequences
"""
input_df  = pd.DataFrame(raw_data).assign(prompt_ix = lambda df: list(range(len(df))))

input_df

In [None]:
""" 
Load dataset into a dataloader which returns original tokens - important for BPE tokenizers to reconstruct the correct string later
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

# Create and chunk into lists of size 250 each - these will be the export breaks
test_dls = [
    DataLoader(
        ReconstructableTextDataset(chunk_df['text'].tolist(), tokenizer, max_length = 512, prompt_ix = chunk_df['prompt_ix'].tolist()),
        batch_size = 32, shuffle = False, collate_fn = stack_collate
    )
    for chunk_df in tqdm(chunk_list(input_df, 250))
]

# Run exports

In [None]:
""" 
Run forward passes + export data
"""
@torch.no_grad()
def run_and_export_topk(model, tokenizer, *, model_prefix: str, run_model_return_states, dls: list[ReconstructableTextDataset], layer_indices: list[int], return_expert_outputs: bool = False, retain_topk: int|None = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run.
        @tokenizer: The tokenizer object corresponding to the model.
        @model_prefix: (str) Model prefix used for file saving.
        @run_model_return_states: (function) Fn returning dict with keys: `logits`, `all_topk_experts`, `all_topk_weights`, `all_pre_mlp_hidden_states`, `all_expert_outputs`.
        @dls: List[ReconstructableTextDataset] returning `prompt_ix`, `input_ids`, `attention_mask`, and `original_tokens`.
        @layer_indices: (list(int)) List of layer indices (0-indexed by MoE layers) for which to filter `all_pre_mlp_hidden_states`.
        @return_expert_outputs: (bool) Whether to return expert outputs.
        @retain_topk: (int|None) If expert_return_layers=True, the topk's for which to return experts outputs.

    Returns:
        A dict with keys:
        - `sample_df`: A token-level df with input/output ids + text (removes masked tokens).
        - `topk_df`: A sample token x layer_ix x topk_ix level df that giving the selected expert ID.
        - `all_router_logits`: Tensor (n_samples, layer_indices, n_experts) with the pre-softmax router logits.
        - `all_pre_mlp_hs`: Tensor (n_samples, layer_indices, D) with states. Each n_sample corresponds to a row of sample_df.
        - `all_expert_outputs` Tensor (n_samples, retain_layers, retain_topk, D) of MLP outputs.
    """
    cross_dl_batch_ix = 0

    # Save metadata
    output_dir = f'{ws}/experiments/geometry/activations/{model_prefix}'
    os.makedirs(output_dir, exist_ok = True)
    with open(f'{output_dir}/metadata.pkl', 'wb') as f:
        pickle.dump({'layer_mappings': layer_indices}, f)
    input_df.to_feather(f'{output_dir}/prompts.feather')

    # Iterate through dataloaders
    for dl_ix, dl in enumerate(dls):
        print(f"Processing {str(dl_ix)} of {len(dls)}...")   
        dl_dir = f"{output_dir}/{dl_ix:02d}"
        os.makedirs(dl_dir, exist_ok = True)

        all_router_logits = []
        all_pre_mlp_hs = []
        all_expert_outputs = []
        sample_dfs = []
        topk_dfs = []

        for _, batch in tqdm(enumerate(dl), total = len(dl)):
            input_ids = batch['input_ids'].to(main_device)
            attention_mask = batch['attention_mask'].to(main_device)
            original_tokens = batch['original_tokens']
            prompt_indices = batch['prompt_ix']

            output = run_model_return_states(model, input_ids, attention_mask, return_hidden_states = True)

            # Check no bugs by validating output/perplexity
            if cross_dl_batch_ix == 0:
                loss = ForCausalLMLoss(
                    output['logits'],
                    torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids),
                    output['logits'].size(-1)
                ).detach().cpu().item()
                for i in range(min(20, input_ids.size(0))):
                    decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                    next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                    print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
                print(f"PPL:", torch.exp(torch.tensor(loss)).item())

            original_tokens_df = pd.DataFrame(
                [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
                columns = ['sequence_ix', 'token_ix', 'token']
            )         

            prompt_indices_df = pd.DataFrame(
                [(seq_i, seq_source) for seq_i, seq_source in enumerate(prompt_indices)], 
                columns = ['sequence_ix', 'prompt_ix']
            )

            # Create sample (token) level dataframe
            sample_df =\
                convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
                .merge(original_tokens_df, how = 'inner', on = ['token_ix', 'sequence_ix'])\
                .merge(prompt_indices_df, how = 'inner', on = ['sequence_ix'])

            # Create topk x layer_ix x sample level dataframe
            topk_df =\
                convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
                .merge(sample_df[['sequence_ix', 'token_ix', 'prompt_ix']], how = 'inner', on = ['sequence_ix', 'token_ix'])
            
            sample_df = sample_df.drop(columns = ['sequence_ix']) 
            topk_df = topk_df.drop(columns = ['sequence_ix'])

            sample_dfs.append(sample_df)
            topk_dfs.append(topk_df)

            valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
            all_router_logits.append(torch.stack(output['all_router_logits'], dim = 1)[valid_pos][:, layer_indices, :])
            # Store pre-MLP hs - the fwd pass as n_layers list as (BN, D), collapse to (BN, n_layers, D), with BN filtering out masked items
            all_pre_mlp_hs.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layer_indices, :])
            if return_expert_outputs and all(x is not None for x in output['all_expert_outputs']): # Some models never return this regardless
                all_expert_outputs.append(torch.stack(output['all_expert_outputs'], dim = 1)[valid_pos][:, layer_indices, 0:retain_topk, :]) # (BN, n_layers, topk, D) - keep only top1 + top2

            cross_dl_batch_ix += 1

        sample_df = pd.concat(sample_dfs, ignore_index = True)
        topk_df = pd.concat(topk_dfs, ignore_index = True)

        sample_df.to_feather(f'{dl_dir}/samples.feather')
        topk_df.to_feather(f'{dl_dir}/topks.feather')

        tensors  = {
            'all_router_logits': torch.cat(all_router_logits, dim = 0).contiguous(),
            'all_pre_mlp_hs': torch.cat(all_pre_mlp_hs, dim = 0).contiguous()
        }
        if len(all_expert_outputs) > 0:
            tensors['all_expert_outputs'] = torch.cat(all_expert_outputs, dim = 0).contiguous()

        safetensors.torch.save_file(
            tensors,
            f"{dl_dir}/activations.safetensors",
            metadata = {'layer_indices': ','.join(map(str, layer_indices))}
        )

    return True

res = run_and_export_topk(
    model,
    tokenizer, 
    model_prefix = model_prefix,
    run_model_return_states = run_forward_return_metadata,
    dls = test_dls,
    layer_indices = list(range(model_n_moe_layers)),
    return_expert_outputs = False,
    retain_topk = 1
)