In [1]:
"""
Runs ablation on MMLU subjects.
"""
None

In [2]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import gc

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

import pickle

main_device = 'cuda:0'
seed = 123
clear_all_cuda_memory()
check_memory()

All CUDA memory cleared on all devices.
Device 0: NVIDIA H100 PCIe
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 79.10 GB



## Load base model

In [3]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
- Qwen3MoE architecture, includes Qwen3-30B-A3B, Qwen3-235B-A22B
"""
selected_model_index = 4

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3'),
        ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 'qwen3moe')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

Loading checkpoint shards:   0%|          | 0/16 [00:00<?, ?it/s]

In [4]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = False)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

Length of topk: 48
Topk size: torch.Size([1024, 8])
First token topk IDs: tensor([ 84,  53,   1, 119,  68, 114,  16,  31])
First token topk weights: tensor([0.2197, 0.1924, 0.1465, 0.1357, 0.0820, 0.0796, 0.0762, 0.0679])
LM loss: 3.733102798461914


## Get dataset

In [5]:
""""
Get MMLU data and domains to test
"""
mmlu_ds = {
    lang: load_dataset("CohereLabs/Global-MMLU", lang, split = 'test')
    for lang in ['en']
}

display(pd.DataFrame(mmlu_ds['en']).groupby('subject', as_index = False).agg(count = ('sample_id', 'count')))

Unnamed: 0,subject,count
0,abstract_algebra,100
1,anatomy,135
2,astronomy,152
3,business_ethics,100
4,clinical_knowledge,265
5,college_biology,144
6,college_chemistry,100
7,college_computer_science,100
8,college_mathematics,100
9,college_medicine,173


In [9]:
# Dict containing {final_domains: [source1, source2], ...}
mmlu_domain_mappings = {
    'math': ['elementary_mathematics'], 
    # 'statistics': ['high_school_statistics'],
    'compsci': ['high_school_computer_science', 'college_computer_science'],
    'chemistry': ['high_school_chemistry'],
    'biology': ['high_school_biology']
}

def get_mmlu_df(raw_ds, domain_map, max_questions_per_domain):
    """
    Clean + prep MMLU dataset
    """
    source_to_domain_map = {source: domain for domain, sources in domain_map.items() for source in sources} # Map each source => domain
    final_ds = []
    lang_ix = 0
    for q in raw_ds:
        if q['subject'] not in source_to_domain_map.keys():
            continue
        domain_count = len([x for x in final_ds if x['domain'] == source_to_domain_map[q['subject']]])
        if domain_count >= max_questions_per_domain:
            continue
        final_ds.append({
            'lang_ix': lang_ix,
            'domain_ix': domain_count,
            'source_id': q['sample_id'],
            'stem': q['question'],
            'choices': [q['option_a'], q['option_b'], q['option_c'], q['option_d']],
            'domain': source_to_domain_map[q['subject']],
            'answer_char': q['answer']
        })
        lang_ix += 1

    return pd.DataFrame(final_ds)

mmlu_raw_dfs_by_lang = {
    lang: get_mmlu_df(ds, mmlu_domain_mappings, 200).assign(lang = lang)
    for lang, ds in mmlu_ds.items()
}

# assert all(mmlu_raw_dfs_by_lang['en']['source_id'] == mmlu_raw_dfs_by_lang['es']['source_id'])
# assert all(mmlu_raw_dfs_by_lang['en']['source_id'] == mmlu_raw_dfs_by_lang['zh']['source_id'])

mmlu_raw_df = pd.concat([df for _, df in mmlu_raw_dfs_by_lang.items()]).pipe(lambda df: df.assign(q_ix = list(range(0, len(df)))))
mmlu_raw_df

Unnamed: 0,lang_ix,domain_ix,source_id,stem,choices,domain,answer_char,lang,q_ix
0,0,0,college_computer_science/test/0,The access matrix approach to protection has t...,"[the matrix, if stored directly, is large and ...",compsci,A,en,0
1,1,1,college_computer_science/test/1,An integer c is a common divisor of two intege...,"[{-6,-2, -1, 1, 2, 6}, {-6, -2, -1, 0, 1, 2, 6...",compsci,C,en,1
2,2,2,college_computer_science/test/2,"In the NoNicks operating system, the time requ...","[1:4, 1:3.5, 1:1, 1.1:1]",compsci,B,en,2
3,3,3,college_computer_science/test/3,You want to cluster 7 points into 3 clusters u...,"[C1: (3,3), C2: (4,4), C3: (6,6), C1: (3,3), C...",compsci,A,en,3
4,4,4,college_computer_science/test/4,Any set of Boolean operators that is sufficien...,"[{AND, NOT}, {NOT, OR}, {AND, OR}, {NAND}]",compsci,C,en,4
...,...,...,...,...,...,...,...,...,...
795,795,195,high_school_computer_science/test/95,"Consider the following code segment, which use...","[1 1, 1 2, 2 3, 3 2]",compsci,C,en,795
796,796,196,high_school_computer_science/test/96,A digital photo file contains data representin...,[Determining the likelihood that the photo is ...,compsci,B,en,796
797,797,197,high_school_computer_science/test/97,"In Python 3, what is ['a', 'Chemistry', 0, 1][...","[a, Chemistry, 0, 1]",compsci,B,en,797
798,798,198,high_school_computer_science/test/98,Two computers are built by different manufactu...,[The computers cannot communicate because diff...,compsci,D,en,798


In [10]:
"""
Create function to map MMLU data into an instruct-formatted string
"""
append_think = '\n<think>\n\n</think>\n\n' if model_prefix == 'qwen3moe' else ''

def prep_question(question, choices):
    prompt = f"Question: {question}\nChoices:\n"
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"
    return prompt

fs_ex = [
    [q for q in mmlu_ds['en'] if q['subject'] == 'anatomy'][0]
]

base_prompt = [
    {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
    {'role': 'user', 'content': prep_question(fs_ex[0]['question'], [fs_ex[0]['option_a'], fs_ex[0]['option_b'], fs_ex[0]['option_c'], fs_ex[0]['option_d']])},
    {'role': 'assistant', 'content': 'The correct answer is ' + fs_ex[0]['answer']}
]

mmlu_df = \
    mmlu_raw_df\
    .assign(
        input_prompt = lambda df: df.apply(
            lambda q: tokenizer.apply_chat_template(
                base_prompt + [{'role': 'user', 'content': prep_question(q['stem'], q['choices'])}, {'role': 'assistant', 'content': f'{append_think}The correct answer is'}],
                tokenize = False,
                add_generation_prompt = False,
                continue_final_message = True # Otherwise appends eos token
            ), 
            axis = 1
        )
    )

display(mmlu_df)
print(mmlu_df['input_prompt'].tolist()[0])

Unnamed: 0,lang_ix,domain_ix,source_id,stem,choices,domain,answer_char,lang,q_ix,input_prompt
0,0,0,college_computer_science/test/0,The access matrix approach to protection has t...,"[the matrix, if stored directly, is large and ...",compsci,A,en,0,<|im_start|>system\nYou will be provided with ...
1,1,1,college_computer_science/test/1,An integer c is a common divisor of two intege...,"[{-6,-2, -1, 1, 2, 6}, {-6, -2, -1, 0, 1, 2, 6...",compsci,C,en,1,<|im_start|>system\nYou will be provided with ...
2,2,2,college_computer_science/test/2,"In the NoNicks operating system, the time requ...","[1:4, 1:3.5, 1:1, 1.1:1]",compsci,B,en,2,<|im_start|>system\nYou will be provided with ...
3,3,3,college_computer_science/test/3,You want to cluster 7 points into 3 clusters u...,"[C1: (3,3), C2: (4,4), C3: (6,6), C1: (3,3), C...",compsci,A,en,3,<|im_start|>system\nYou will be provided with ...
4,4,4,college_computer_science/test/4,Any set of Boolean operators that is sufficien...,"[{AND, NOT}, {NOT, OR}, {AND, OR}, {NAND}]",compsci,C,en,4,<|im_start|>system\nYou will be provided with ...
...,...,...,...,...,...,...,...,...,...,...
795,795,195,high_school_computer_science/test/95,"Consider the following code segment, which use...","[1 1, 1 2, 2 3, 3 2]",compsci,C,en,795,<|im_start|>system\nYou will be provided with ...
796,796,196,high_school_computer_science/test/96,A digital photo file contains data representin...,[Determining the likelihood that the photo is ...,compsci,B,en,796,<|im_start|>system\nYou will be provided with ...
797,797,197,high_school_computer_science/test/97,"In Python 3, what is ['a', 'Chemistry', 0, 1][...","[a, Chemistry, 0, 1]",compsci,B,en,797,<|im_start|>system\nYou will be provided with ...
798,798,198,high_school_computer_science/test/98,Two computers are built by different manufactu...,[The computers cannot communicate because diff...,compsci,D,en,798,<|im_start|>system\nYou will be provided with ...


<|im_start|>system
You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.<|im_end|>
<|im_start|>user
Question: A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral
Choices:
(A) paralysis of the facial muscles.
(B) paralysis of the facial muscles and loss of taste.
(C) paralysis of the facial muscles, loss of taste and lacrimation.
(D) paralysis of the facial muscles, loss of taste, lacrimation and decreased salivation.
<|im_end|>
<|im_start|>assistant
The correct answer is A<|im_end|>
<|im_start|>user
Question: The access matrix approach to protection has the difficulty that
Choices:
(A) the matrix, if stored directly, is large and can be clumsy to manage
(B) it is not capable of expressing complex protection requirements
(C) deciding whether a process has access to a resource is undecidabl

In [11]:
""" 
Load dataset into a dataloader
"""
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, q_indices, lang_indices, domain_indices, domains, stems, choices, answer_chars, tokenized_prompts):
        self.q_indices = q_indices
        self.lang_indices = lang_indices
        self.domain_indices = domain_indices
        self.domains = domains
        self.stems = stems
        self.choices = choices
        self.answer_chars = answer_chars
        self.input_ids = tokenized_prompts['input_ids']
        self.attention_mask = tokenized_prompts['attention_mask']

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'q_indices': self.q_indices[idx],
            'lang_indices': self.lang_indices[idx],
            'domain_indices': self.domain_indices[idx],
            'domains': self.domains[idx],
            'stems': self.stems[idx],
            'choices': self.choices[idx],
            'answer_chars': self.answer_chars[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

tokenized_prompts = tokenizer(mmlu_df['input_prompt'].tolist(), add_special_tokens = False, max_length = 1024, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated

mmlu_dl = DataLoader(TextDataset(
    mmlu_df['q_ix'].tolist(),
    mmlu_df['lang_ix'].tolist(),
    mmlu_df['domain_ix'].tolist(),
    mmlu_df['domain'].tolist(),
    mmlu_df['stem'].tolist(),
    mmlu_df['choices'].tolist(),
    mmlu_df['answer_char'].tolist(),
    tokenized_prompts
), batch_size = 4, shuffle = False)

print(next(iter(mmlu_dl)))

tensor(554)
{'q_indices': tensor([0, 1, 2, 3]), 'lang_indices': tensor([0, 1, 2, 3]), 'domain_indices': tensor([0, 1, 2, 3]), 'domains': ['compsci', 'compsci', 'compsci', 'compsci'], 'stems': ['The access matrix approach to protection has the difficulty that', 'An integer c is a common divisor of two integers x and y if and only if c is a divisor of x and c is a divisor of y. Which of the following sets of integers could possibly be the set of all common divisors of two integers?', 'In the NoNicks operating system, the time required by a single file-read operation has four nonoverlapping components:\ndisk seek time-25 msec\ndisk latency time-8 msec\ndisk transfer time- 1 msec per 1,000 bytes\noperating system overhead-1 msec per 1,000 bytes + 10 msec\nIn version 1 of the system, the file read retrieved blocks of 1,000 bytes. In version 2, the file read (along with the underlying layout on disk) was modified to retrieve blocks of 4,000 bytes. The ratio of-the time required to read a lar

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from torch.utils.data import Dataset, DataLoader

class ReconstructableTextDataset(Dataset):

    def __init__(self, raw_texts, q_indices, tokenizer, max_length):
        """
        Creates a dataset object that also returns a B x N list of the original tokens in the same position as the input ids.

        Params:
            @raw_texts: A list of samples of text dataset.
            @q_indices: A list of question indices attached to the raw texts.
            @tokenizer: A HF tokenizer object.
        """
        tokenized = tokenizer(raw_texts, add_special_tokens = False, max_length = max_length, padding = 'max_length', truncation = True, return_offsets_mapping = True, return_tensors = 'pt')

        self.input_ids = tokenized['input_ids']
        self.attention_mask = tokenized['attention_mask']
        self.offset_mapping = tokenized['offset_mapping']
        self.q_indices = q_indices
        self.original_tokens = self.get_original_tokens(raw_texts)

    def get_original_tokens(self, texts):
        """
        Return the original tokens associated with each B x N position. This is important for reconstructing the original text when BPE tokenizers are used.
        
        Params:
            @input_ids: A B x N tensor of input ids.
            @offset_mapping: A B x N x 2 tensor of offset mappings. Get from `tokenizer(..., return_offsets_mapping = True)`.

        Returns:
            A list of length B, each with length N, containing the corresponding original tokens corresponding to the token ID at the same position of input_ids.
        """
        all_token_substrings = []
        for i in range(0, self.input_ids.shape[0]):
            token_substrings = []
            for j in range(self.input_ids.shape[1]): 
                start_char, end_char = self.offset_mapping[i][j].tolist()
                if start_char == 0 and end_char == 0: # When pads, offset_mapping might be [0, 0], so let's store an empty string for those positions.
                    token_substrings.append("")
                else:
                    original_substring = texts[i][start_char:end_char]
                    token_substrings.append(original_substring)
            
            all_token_substrings.append(token_substrings)

        return all_token_substrings

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'original_tokens': self.original_tokens[idx], 'q_indices': self.q_indices[idx]}
    
def collate_fn(batch):
    """
    Custom collate function; necessary to return original_tokens in the correct shape 
    """
    input_ids = torch.stack([b['input_ids'] for b in batch], dim = 0)
    attention_mask = torch.stack([b['attention_mask'] for b in batch], dim = 0)        
    original_tokens = [b['original_tokens'] for b in batch]
    q_indices = [b['q_indices'] for b in batch]
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'original_tokens': original_tokens, 'q_indices': q_indices}

# Must be under max length to confirm nothing was truncated
print(
    tokenizer(mmlu_df['input_prompt'].tolist(), add_special_tokens = False, max_length = 1024, padding = 'max_length', truncation = True, return_tensors = 'pt')\
        ['attention_mask']\
        .sum(dim = 1)\
        .max()
)

test_dl = DataLoader(
    ReconstructableTextDataset(mmlu_df['input_prompt'].tolist(), mmlu_df['q_ix'].tolist(), tokenizer, max_length = 768),
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_fn
)

tensor(554)


## Run first pass

In [15]:
"""
Run forward passes - store token level dataframe and topk-level dataframe
"""
sample_dfs = []
topk_dfs = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):

    input_ids = batch['input_ids'].to(main_device)
    attention_mask = batch['attention_mask'].to(main_device)
    original_tokens = batch['original_tokens']
    q_indices = batch['q_indices']

    with torch.no_grad():
        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

    # Check no bugs by validating output/perplexity
    if batch_ix == 0:
        loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
        for i in range(min(5, input_ids.size(0))):
            decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
            next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
            print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
        print(f"PPL:", torch.exp(torch.tensor(loss)).item())
    
    # Create mapping of seq_ix to q_ix, so we can drop batch_ix and seq_ix and return q_ix
    seq_to_q_map = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'q_ix': batch['q_indices']})

    # Convert original tokens back to (seq_ix, token_ix) level for later storage
    original_tokens_df = pd.DataFrame(
        [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
        columns = ['sequence_ix', 'token_ix', 'token']
    )

    # Final token outputs - get the decoded outputs for the final tokens only, since convert_outputs_to_df only returns output tokens as IDs
    question_token_outputs = tokenizer.batch_decode(torch.argmax(output['logits'][:, -1, :], dim = 1).tolist())
    question_token_outputs_df = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'question_output_token': question_token_outputs})

    # Create sample (token) level dataframe
    sample_df =\
        convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
        .merge(question_token_outputs_df, how = 'left', on = ['sequence_ix'])\
        .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
        .merge(seq_to_q_map, how = 'inner', on = ['sequence_ix'])\
        .drop(columns = ['sequence_ix'])
    
    # Create topk x layer_ix x sample level dataframe
    topk_df =\
        convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
        .merge(seq_to_q_map, how = 'inner')\
        .drop(columns = ['sequence_ix', 'token_id'])

    sample_dfs.append(sample_df)
    topk_dfs.append(topk_df)


  0%|          | 0/100 [00:00<?, ?it/s]

---------
system
You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.
user
Question: A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral
Choices:
(A) paralysis of the facial muscles.
(B) paralysis of the facial muscles and loss of taste.
(C) paralysis of the facial muscles, loss of taste and lacrimation.
(D) paralysis of the facial muscles, loss of taste, lacrimation and decreased salivation.

assistant
The correct answer is A
user
Question: The access matrix approach to protection has the difficulty that
Choices:
(A) the matrix, if stored directly, is large and can be clumsy to manage
(B) it is not capable of expressing complex protection requirements
(C) deciding whether a process has access to a resource is undecidable
(D) there is no way to express who has rights to change the access

100%|██████████| 100/100 [26:15<00:00, 15.75s/it]


In [18]:
"""
Save data
"""
sample_df = pd.concat(sample_dfs).merge(mmlu_df, how = 'inner', on = 'q_ix').drop(columns = 'input_prompt')
topk_df = pd.concat(topk_dfs)

sample_df.to_csv(f'{model_prefix}-samples.csv', mode = 'w', index = False)
topk_df.to_csv(f'{model_prefix}-topks.csv', mode = 'w', index = False)

In [None]:
"""
Quick check for accuracy
"""
pd.concat(sample_dfs).merge(mmlu_df, how = 'inner', on = 'q_ix')\
    .groupby(['q_ix', 'lang_ix', 'domain_ix', 'source_id', 'domain', 'question_output_token', 'answer_char'], as_index = False)\
    .agg(n_tokens = ('q_ix', 'count'))\
    .assign(is_correct = lambda df: np.where(df['question_output_token'].str.strip() == df['answer_char'], 1, 0))