## Init

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib 

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils.vocab import export_vocab_as_csv
from utils import pretrained_models

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
"""
selected_model_index = 1

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'])
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get MMLU

In [None]:
""""
Get MMLU data and domains to test
"""
raw_mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
[print(f"{domain}: {len([q for q in raw_mmlu_ds if q['subject'] == domain])}") for domain in set(raw_mmlu_ds['subject'])]

# Dict containing {final_domains: [source1, source2], ...}
mmlu_domain_mappings = {
    'math': ['elementary_mathematics'], 
    'statistics': ['high_school_statistics'],
    'cs': ['high_school_computer_science', 'college_computer_science'], # 100 each
    'chemistry': ['high_school_chemistry'],
    'biology': ['high_school_biology'],
    'nutrition': ['nutrition'],
    'psych': ['high_school_psychology']
}

# Now let's put the MMLU questions into a list grouped by domain
def group_mmlu_ds(raw_ds, domain_map, max_questions_per_domain):
    source_to_domain_map = {source: domain for domain, sources in mmlu_domain_mappings.items() for source in sources} # Map each source => domain
    final_ds = {domain: {'sources': sources, 'questions': []} for domain, sources in domain_map.items()} # Create empty dict to fill
    for q in raw_ds: 
        if q['subject'] in source_to_domain_map.keys():
            if (len(final_ds[source_to_domain_map[q['subject']]]['questions']) >= max_questions_per_domain):
                continue
            final_ds[source_to_domain_map[q['subject']]]['questions'].append({
                'question': q['question'],
                'choices': q['choices'],
                'n_choices': len(q['choices']),
                'domain': source_to_domain_map[q['subject']],
                'source': q['subject'],
                'answer_index': q['answer'],
                'answer_char': chr(65 + q['answer'])
            })
    return [{'domain': domain, **values} for domain, values in final_ds.items()] # Convert back to list of dicts

mmlu_ds_grouped = group_mmlu_ds(raw_mmlu_ds, mmlu_domain_mappings, 200)
print([len(domain['questions']) for domain in mmlu_ds_grouped])

In [None]:
"""
Create function to map MMLU data into an instruct-formatted string
"""
def prep_question(question, choices):
    prompt = f"Question: {question}\nChoices:\n"
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"
    return prompt

fs_ex = [
    [q for q in raw_mmlu_ds if q['subject'] == 'anatomy'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'machine_learning'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'astronomy'][0]
]

base_prompt = [
    {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
    {'role': 'user', 'content': prep_question(fs_ex[0]['question'], fs_ex[0]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[0]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[1]['question'], fs_ex[1]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[1]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[2]['question'], fs_ex[2]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[2]['answer'])}
]

mmlu_df = pd.DataFrame([
    {
        **q,
        'q_ix': q_ix,
        'input_prompt': tokenizer.apply_chat_template(
            base_prompt + [{'role': 'user', 'content': prep_question(q['question'], q['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
            tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
        )
    }
    for q_ix, q in enumerate(sum([domain['questions'] for domain in mmlu_ds_grouped], []))
])

display(mmlu_df)
print(mmlu_df['input_prompt'][0])

In [None]:
""" 
Load dataset into a dataloader
"""
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, q_indices, questions, choices, n_choices, domains, answer_chars, tokenized_prompts):
        self.q_indices = q_indices
        self.questions = questions
        self.choices = choices
        self.n_choices = n_choices
        self.domains = domains
        self.answer_chars = answer_chars
        self.input_ids = tokenized_prompts['input_ids']
        self.attention_mask = tokenized_prompts['attention_mask']

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'q_indices': self.q_indices[idx],
            'questions': self.questions[idx],
            'choices': self.choices[idx],
            'n_choices': self.n_choices[idx],
            'domains': self.domains[idx],
            'answer_chars': self.answer_chars[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

tokenized_prompts = tokenizer(mmlu_df['input_prompt'].tolist(), add_special_tokens = False, max_length = 1200, padding = 'max_length', truncation = True, return_tensors = 'pt')
print(tokenized_prompts['attention_mask'].sum(dim = 1).max()) # Must be under max length to confirm nothing was truncated

mmlu_dl = DataLoader(TextDataset(
    mmlu_df['q_ix'].tolist(),
    mmlu_df['question'].tolist(),
    mmlu_df['choices'].tolist(),
    mmlu_df['n_choices'].tolist(),
    mmlu_df['domain'].tolist(),
    mmlu_df['answer_char'].tolist(),
    tokenized_prompts
), batch_size = 8, shuffle = False)

print(next(iter(mmlu_dl)))

## Run model on MMLU & store topks

In [None]:
"""
Run forward passes & evaluate correctness
"""
output_dfs = []
topk_dfs = []

for batch_ix, batch in tqdm(enumerate(mmlu_dl), total = len(mmlu_dl)):

    input_ids = batch['input_ids'].to(main_device)
    attention_mask = batch['attention_mask'].to(main_device)

    output = run_model_return_topk(model, input_ids, attention_mask.to(main_device))

    # Check no bugs by validating output/perplexity
    if batch_ix == 0:
        loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
        for i in range(min(2, input_ids.size(0))):
            decoded_input = tokenizer.decode(input_ids[i, :attention_mask[i].sum()], skip_special_tokens = True)
            next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
            print(decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = True), 'green'))
        print(f"PPL:", torch.exp(torch.tensor(loss)).item())
        
    # Decode answers and correctness
    predicted_texts = tokenizer.batch_decode(torch.argmax(output['logits'][:, -1, :], dim = 1).tolist())
    predicted_choices = []

    for c_ix, c in enumerate(predicted_texts):
        predicted_choice = None
        if c.upper() in [' ' + chr(65 + i) for i in range(batch['n_choices'][c_ix])]:
            predicted_choice = c.upper().strip()
        predicted_choices.append(predicted_choice)

    q_df = pd.DataFrame({
        'q_ix': batch['q_indices'],
        'predicted_texts': predicted_texts,
        'predicted_choice': predicted_choices,
        'correct_choice': batch['answer_chars'],
        'is_correct': [1 if predicted_choices[c_ix] == batch['answer_chars'][c_ix] else 0 for c_ix in range(0, len(predicted_texts))],
        'is_valid': [1 if predicted_choices[c_ix] is not None else 0 for c_ix in range(0, len(predicted_texts))]
    })

    # Convert to df and map sequence indices batch to question indices
    seq_to_q_map = pd.DataFrame({'sequence_ix': list(range(0, len(batch['questions']))), 'q_ix': batch['q_indices']})

    output_df =\
        convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
        .pipe(lambda df: df[df['token_ix'] == df.groupby('sequence_ix')['token_ix'].transform('max')])\
        .merge(seq_to_q_map, how = 'inner').drop(columns = 'sequence_ix')\
        .merge(q_df, how = 'inner', on = 'q_ix')\
        .drop(columns = ['token_ix', 'token_id'])

    topk_df =\
        convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
        .pipe(lambda df: df[df['token_ix'] == df.groupby('sequence_ix')['token_ix'].transform('max')])\
        .merge(seq_to_q_map, how = 'inner').drop(columns = 'sequence_ix')\
        .assign(weight = lambda df: df['weight'].round(3))\
        .drop(columns = ['token_ix', 'token_id'])

    output_dfs.append(output_df)
    topk_dfs.append(topk_df)

full_output_df = pd.concat(output_dfs).merge(mmlu_df[['q_ix', 'question', 'domain', 'choices', 'source']], how = 'inner', on = 'q_ix')
display(full_output_df)
full_topk_df = pd.concat(topk_dfs)
display(full_topk_df)

In [None]:
"""
Check model accuracy
"""
full_output_df = pd.concat(output_dfs).merge(mmlu_df[['q_ix', 'question', 'domain', 'choices', 'source']], how = 'inner', on = 'q_ix')

print(f'Overall accuracy: {sum(full_output_df['is_correct'].tolist())/len(full_output_df) * 100:.1f}%')
print(f'Overall validity: {sum(full_output_df['is_valid'].tolist())/len(full_output_df) * 100:.1f}%')

display(full_output_df\
    .groupby('domain')\
    .agg(
        n_accurate = ('is_correct', 'sum'),
        n_valid = ('is_valid', 'sum'), 
        n_total = ('q_ix', 'count')
    ).reset_index(drop = False)\
    .assign(
        accuracy = lambda df: df['n_accurate']/df['n_total'],
        validity = lambda df: df['n_valid']/df['n_total']
    ))

In [None]:
"""
Export
"""
export_vocab_as_csv(tokenizer, f'{model_prefix}-vocab.csv')
full_output_df.to_csv(f'{model_prefix}-mmlu-questions.csv', mode = 'w', index = False)
full_topk_df.to_csv(f'{model_prefix}-mmlu-topk.csv', mode = 'w', index = False)