In [None]:
"""
Runs forward passes with ablation! Run after `calculate-mmlu-domains.ipynb`.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import gc

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

import pickle

main_device = 'cuda:0'
seed = 123
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently for path/expert ablation:
- Qwen3MoE architecture, includes Qwen3-30B-A3B, Qwen3-235B-A22B
"""
selected_model_index = 0

def get_model(index):
    model = [
        ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 'qwen3moe')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load path & expert ablation functions - these run forward passes - this verifies these are equivalent to a standard model() call output when no ablation
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_with_path_ablation = getattr(model_module, f"run_{model_architecture}_with_path_ablation")
run_model_with_expert_ablation = getattr(model_module, f"run_{model_architecture}_with_expert_ablation")

test_path_ablate_dict = {
    1: [
        ((0, ), 0)
    ]
}

test_expert_ablate_dict = {
    1: [64]
}

@torch.no_grad()
def test_custom_forward_pass(model, pad_token_id, run_fn, ablate_dict):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_fn(model, inputs['input_ids'], inputs['attention_mask'], ablate_dict, ablate_if_in_topk = False)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id, run_model_with_path_ablation, test_path_ablate_dict)
test_custom_forward_pass(model, tokenizer.pad_token_id, run_model_with_expert_ablation, test_expert_ablate_dict)

## Load dataset

In [None]:
"""
Which domain and language to ablate, as well as which k? These should be derived from `calculate-mmlu-domains.ipynb`, and should correspond to
 the output JSONs from that file.
"""
test_domain = 'biology'
test_lang = 'en'
test_k = 2

In [None]:
"""
Load dataset, dataloader
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

def load_metadata():
    with open(f'./data/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    return metadata['mmlu_df'], metadata['test_ds']

mmlu_df, test_ds = load_metadata()
test_dl = DataLoader(test_ds, batch_size = 8, shuffle = False, collate_fn = stack_collate)

In [None]:
"""
Load path & expert ablation targets
"""
import json 

def load_path_ablation_targets_from_json(filepath):
    """
    Parse JSON, converting it into a dictionary mapping layer integers to a list of tuples:
    {
        5: [
            ((0,), 2), # Ablate expert #2 in layer #5 if previous expert was expert #0
            ((1, 2), 10) # Ablate expert #10 in layer #5 if previous expert was expert #1 -> expert #2 (in previous 2 layers)
        ]
    }
    """
    with open(filepath, 'r') as f:
        raw_targets = json.load(f)

    ablation_targets_parsed = {}

    for layer_str, rules_list in raw_targets.items():
        layer_int = int(layer_str)
        layer_rules = []
        for rule_pair in rules_list:
            if len(rule_pair) == 2 and (isinstance(rule_pair[0], list) or isinstance(rule_pair[1], int)) and isinstance(rule_pair[1], int):
                prefix_list = rule_pair[0] if isinstance(rule_pair[0], list) else [rule_pair[0]] # If an int, cast to a singleton list
                target_expert = rule_pair[1]
                prefix_tuple = tuple(prefix_list)
                layer_rules.append((prefix_tuple, target_expert))
            else:
                print(f"Warning: Invalid rule format {rule_pair} in layer {layer_int}.")

        if layer_rules:
            ablation_targets_parsed[layer_int] = layer_rules

    return ablation_targets_parsed

def load_expert_ablation_targets_from_json(filepath):
    """
    Parse JSON, converting it into a dictionary mapping layer integers to a list of ints:
    {
        5: [0, 2] # Ablate experts 0 and 2 in layer 5
    }
    """
    with open(filepath, 'r') as f:
        raw_targets = json.load(f)

    ablation_targets_parsed = {}

    for layer_str, experts_list in raw_targets.items():
        layer_int = int(layer_str)
        layer_rules = []

        if isinstance(experts_list, int):
            layer_rules.append(experts_list)
        elif isinstance(experts_list, list):
            for expert_target in experts_list:
                layer_rules.append(int(expert_target))
        else:
            print(f"Warning: Invalid format.")

        if layer_rules:
            ablation_targets_parsed[layer_int] = layer_rules

    return ablation_targets_parsed


path_ablation_targets = load_path_ablation_targets_from_json(f'data/{model_prefix}/path_ablation_targets_{test_domain}_{test_lang}_{test_k}.json')
print(f'Paths to ablate: {sum([len(x) for _, x in path_ablation_targets.items()])}')

expert_ablation_targets = load_expert_ablation_targets_from_json(f'data/{model_prefix}/expert_ablation_targets_{test_domain}_{test_lang}_{test_k}.json')
print(f'Experts to ablate: {sum([len(x) for _, x in expert_ablation_targets.items()])}')

# multipath_ablation_targets = load_path_ablation_targets_from_json(f'data/{model_prefix}/multipath_ablation_targets_{test_domain}_{test_lang}_{test_k}.json')
# print(f'Multipaths to ablate: {sum([len(x) for _, x in multipath_ablation_targets.items()])}')

In [None]:
"""
Helper to check accuracy stats
"""
def check_accuracy(sample_level_df):
    """
    Check accuracy against base_accs.csv file
    
    Params:
        @sample_level_df: A token-level or question-level dataframe containing `q_ix` and `question_output_token`.
    """
    accuracy_df =\
        sample_level_df\
        .merge(mmlu_df, how = 'inner', on = 'q_ix')\
        .groupby(['q_ix', 'lang_ix', 'domain_ix', 'source_id', 'domain', 'lang', 'question_output_token', 'answer_char'], as_index = False)\
        .agg(n_tokens = ('q_ix', 'count'))\
        .assign(is_correct = lambda df: np.where(df['question_output_token'].str.strip() == df['answer_char'], 1, 0))\
        .groupby(['domain', 'lang'], as_index = False)\
        .agg(n_accurate = ('is_correct', 'sum'), n_total = ('q_ix', 'count'))\
        .assign(accuracy = lambda df: df['n_accurate']/df['n_total'])\
        .merge(pd.read_csv(f'data/{model_prefix}/test_accuracy.csv')[['domain', 'lang', 'base_accuracy']], on = ['domain', 'lang'], how = 'inner')\
        .assign(chg = lambda df: (df['accuracy'] - df['base_accuracy'])/df['base_accuracy'])
    
    return accuracy_df

## Run path ablation

In [None]:
"""
Run forward passes with path ablation
"""
path_sample_dfs = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):

    input_ids = batch['input_ids'].to(main_device)
    attention_mask = batch['attention_mask'].to(main_device)
    original_tokens = batch['original_tokens']
    q_indices = batch['q_indices']

    with torch.no_grad():
        output = run_model_with_path_ablation(model, input_ids, attention_mask, ablation_targets = path_ablation_targets, ablate_if_in_topk = True)
    
    # Check no bugs by validating output/perplexity
    if batch_ix == 0:
        loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
        for i in range(min(5, input_ids.size(0))):
            decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
            next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
            print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
        print(f"PPL:", torch.exp(torch.tensor(loss)).item())
    
    # Create mapping of seq_ix to q_ix, so we can drop batch_ix and seq_ix and return q_ix
    seq_to_q_map = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'q_ix': batch['q_indices']})

    # Convert original tokens back to (seq_ix, token_ix) level for later storage
    original_tokens_df = pd.DataFrame(
        [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
        columns = ['sequence_ix', 'token_ix', 'token']
    )

    # Final token outputs - get the decoded outputs for the final tokens only, since convert_outputs_to_df only returns output tokens as IDs
    question_token_outputs = tokenizer.batch_decode(torch.argmax(output['logits'][:, -1, :], dim = 1).tolist())
    question_token_outputs_df = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'question_output_token': question_token_outputs})

    # Create sample (token) level dataframe
    sample_df =\
        convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
        .merge(question_token_outputs_df, how = 'left', on = ['sequence_ix'])\
        .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
        .merge(seq_to_q_map, how = 'inner', on = ['sequence_ix'])\
        .drop(columns = ['sequence_ix'])
    
    path_sample_dfs.append(sample_df)

    if batch_ix > 0 and batch_ix % 20 == 0:
        print(output['num_ablations_applied'])
        display(check_accuracy(pd.concat(path_sample_dfs)))

In [None]:
"""
Export results
"""
path_sample_df = pd.concat(path_sample_dfs)
path_accuracy_df = check_accuracy(path_sample_df)
display(path_accuracy_df)

path_sample_df.to_csv(f"data/{model_prefix}/path_ablation_samples_{test_domain}_{test_lang}_{test_k}.csv", index = False)
path_accuracy_df.to_csv(f"data/{model_prefix}/path_ablation_accuracy_{test_domain}_{test_lang}_{test_k}.csv", index = False)

## Run expert ablation

In [None]:
"""
Run forward passes with expert ablation
"""
expert_sample_dfs = []

for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):

    input_ids = batch['input_ids'].to(main_device)
    attention_mask = batch['attention_mask'].to(main_device)
    original_tokens = batch['original_tokens']
    q_indices = batch['q_indices']

    with torch.no_grad():
        output = run_model_with_expert_ablation(model, input_ids, attention_mask, ablation_targets = expert_ablation_targets, ablate_if_in_topk = False)
    
    # Check no bugs by validating output/perplexity
    if batch_ix == 0:
        loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
        for i in range(min(5, input_ids.size(0))):
            decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
            next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
            print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
        print(f"PPL:", torch.exp(torch.tensor(loss)).item())
    
    # Create mapping of seq_ix to q_ix, so we can drop batch_ix and seq_ix and return q_ix
    seq_to_q_map = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'q_ix': batch['q_indices']})

    # Convert original tokens back to (seq_ix, token_ix) level for later storage
    original_tokens_df = pd.DataFrame(
        [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
        columns = ['sequence_ix', 'token_ix', 'token']
    )

    # Final token outputs - get the decoded outputs for the final tokens only, since convert_outputs_to_df only returns output tokens as IDs
    question_token_outputs = tokenizer.batch_decode(torch.argmax(output['logits'][:, -1, :], dim = 1).tolist())
    question_token_outputs_df = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'question_output_token': question_token_outputs})

    # Create sample (token) level dataframe
    sample_df =\
        convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
        .merge(question_token_outputs_df, how = 'left', on = ['sequence_ix'])\
        .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
        .merge(seq_to_q_map, how = 'inner', on = ['sequence_ix'])\
        .drop(columns = ['sequence_ix'])
    
    expert_sample_dfs.append(sample_df)

    if batch_ix > 0 and batch_ix % 20 == 0:
        print(output['num_ablations_applied'])
        display(check_accuracy(pd.concat(expert_sample_dfs)))

In [None]:
"""
Export results
"""
expert_sample_df = pd.concat(expert_sample_dfs)
expert_accuracy_df = check_accuracy(expert_sample_df)
display(expert_accuracy_df)

expert_sample_df.to_csv(f'data/{model_prefix}/expert_ablation_samples_{test_domain}_{test_lang}_{test_k}.csv', index = False)
expert_accuracy_df.to_csv(f'data/{model_prefix}/expert_ablation_accuracy_{test_domain}_{test_lang}_{test_k}.csv', index = False)

## Multipath testing

In [None]:
# """
# Run forward passes with path ablation
# """
# multipath_sample_dfs = []

# for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl)):

#     input_ids = batch['input_ids'].to(main_device)
#     attention_mask = batch['attention_mask'].to(main_device)
#     original_tokens = batch['original_tokens']
#     q_indices = batch['q_indices']

#     with torch.no_grad():
#         output = run_model_with_path_ablation(model, input_ids, attention_mask, ablation_targets = multipath_ablation_targets, ablate_if_in_topk = False)
    
#     # Check no bugs by validating output/perplexity
#     if batch_ix == 0:
#         loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
#         for i in range(min(5, input_ids.size(0))):
#             decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
#             next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
#             print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
#         print(f"PPL:", torch.exp(torch.tensor(loss)).item())
    
#     # Create mapping of seq_ix to q_ix, so we can drop batch_ix and seq_ix and return q_ix
#     seq_to_q_map = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'q_ix': batch['q_indices']})

#     # Convert original tokens back to (seq_ix, token_ix) level for later storage
#     original_tokens_df = pd.DataFrame(
#         [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
#         columns = ['sequence_ix', 'token_ix', 'token']
#     )

#     # Final token outputs - get the decoded outputs for the final tokens only, since convert_outputs_to_df only returns output tokens as IDs
#     question_token_outputs = tokenizer.batch_decode(torch.argmax(output['logits'][:, -1, :], dim = 1).tolist())
#     question_token_outputs_df = pd.DataFrame({'sequence_ix': list(range(0, input_ids.shape[0])), 'question_output_token': question_token_outputs})

#     # Create sample (token) level dataframe
#     sample_df =\
#         convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
#         .merge(question_token_outputs_df, how = 'left', on = ['sequence_ix'])\
#         .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
#         .merge(seq_to_q_map, how = 'inner', on = ['sequence_ix'])\
#         .drop(columns = ['sequence_ix'])
    
#     multipath_sample_dfs.append(sample_df)

#     if batch_ix > 0 and batch_ix % 20 == 0:
#         print(output['num_ablations_applied'])
#         display(check_accuracy(pd.concat(multipath_sample_dfs)))