In [None]:
"""
Runs ablation on MMLU subjects.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import gc

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

import pickle

main_device = 'cuda:0'
seed = 123
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
- Qwen3MoE architecture, includes Qwen3-30B-A3B, Qwen3-235B-A22B
"""
selected_model_index = 4

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3'),
        ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 'qwen3moe')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_with_ablation = getattr(model_module, f"run_{model_architecture}_with_ablation")

test_ablate_dict = {
    1: {
        (0, ): 0
    }
}

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_with_ablation(model, inputs['input_ids'], inputs['attention_mask'], test_ablate_dict)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Load dataset

In [None]:
sample_df = pd.read_csv(f'{model_prefix}-samples.csv')
topk_df = pd.read_csv(f'{model_prefix}-topks.csv')

In [None]:
sample_df\
    .groupby(['q_ix', 'lang_ix', 'domain_ix', 'source_id', 'domain', 'question_output_token', 'answer_char'], as_index = False)\
    .agg(n_tokens = ('q_ix', 'count'))\
    .assign(is_correct = lambda df: np.where(df['question_output_token'].str.strip() == df['answer_char'], 1, 0))\
    .groupby('domain', as_index = False)\
    .agg(n_accurate = ('is_correct', 'sum'), n_total = ('q_ix', 'count'))\
    .assign(accuracy = lambda df: df['n_accurate']/df['n_total'])

In [None]:
sample_df\
    [['q_ix', 'domain']]

In [None]:
topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .drop_duplicates('expert')\
    .sort_values('expert')