In [None]:
"""
Runs forward passes with different top-k values ablated (set to zero).
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib
import os
import gc
import pickle

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
- Qwen3MoE architecture, includes Qwen3-30B-A3B (3B/30B), Qwen3-235B-A22B (22B/235B)
"""
selected_model_index = 1

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3'),
        ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 'qwen3moe')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Test ablation function and ensure that without ablated experts (topk_to_ablate = []), it returns the same response as the base model call
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_with_ablation = getattr(model_module, f"run_{model_architecture}_with_topk_ablation")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_with_ablation(model, inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(0, 100)), topk_to_ablate = [])
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get dataset

In [None]:
"""
Load dataset - FW edu
"""
def load_raw_ds():
    rng = np.random.default_rng(seed = seed)
    ds_en = load_dataset('HuggingFaceFW/fineweb-edu', 'CC-MAIN-2024-51', split = 'train', streaming = True).shuffle(seed = seed, buffer_size = 50_000)
    
    def get_data(ds, n_samples):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append({'text': sample['text']})
        
        return raw_data
    
    combined_ds = get_data(ds_en, 100)

    perm = rng.permutation(len(combined_ds))
    combined_ds = [combined_ds[i] for i in perm]

    return combined_ds

raw_data = load_raw_ds()

In [None]:
""" 
Load dataset into a dataloader.
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

test_dl = DataLoader(
    ReconstructableTextDataset([x['text'] for x in raw_data], tokenizer, max_length = 512),
    batch_size = 16,
    shuffle = False,
    collate_fn = stack_collate
)

## Ablation tests

In [None]:
@torch.no_grad()
def run_all_with_ablation(model, layers_to_ablate: list, topk_to_ablate: list, renorm: bool, verbose = False):
    """
    Run forward passes on given model with experts ablated (set to 0) with experts identified by top-k position

    Params:
        @model: The model to run forward passes on.
        @layers_to_ablate: The layers to ablate, 0-indexed.
        @topk: The topk to ablate, 0-indexed.
        @renorm: Whether to renormalize the expert weights after ablation.
    """
    total_loss = 0
    total_tokens = 0

    for batch_ix, batch in tqdm(enumerate(test_dl), total = len(test_dl), disable = not verbose):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)

        output = run_model_with_ablation(model, input_ids, attention_mask, layers_to_ablate = layers_to_ablate, topk_to_ablate = topk_to_ablate, renorm = renorm)

        labels =  torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
        batch_loss = ForCausalLMLoss(output['logits'], labels, model.config.vocab_size).detach().cpu().item()
        token_count = (labels != -100).sum().item()

        # Check no bugs by validating output/perplexity
        if batch_ix == 0 and verbose:
            for i in range(min(5, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = True)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
            print(f"Tokens: {token_count:d} | PPL: {torch.exp(torch.tensor(batch_loss)).item():.2f}")

        total_loss += batch_loss * token_count
        total_tokens += token_count
    
    avg_loss = total_loss/total_tokens
    avg_ppl = torch.exp(torch.tensor(avg_loss)).item()

    return {
        'total_loss': total_loss,
        'total_tokens': total_tokens,
        'avg_loss': avg_loss,
        'avg_ppl': avg_ppl
    }


if model_prefix == 'olmoe':
    all_layers = list(range(16))
    all_topk = list(range(8))
elif model_prefix == 'qwen1.5moe':
    all_layers = list(range(25))
    all_topk = list(range(4))
elif model_prefix == 'dsv2':
    all_layers = list(range(26))
    all_topk = list(range(6))
elif model_prefix == 'moonlight':
    all_layers = list(range(26))
    all_topk = list(range(6))
elif model_prefix == 'qwen3moe':
    all_layers = list(range(48))
    all_topk = list(range(8))

base_result = run_all_with_ablation(model, [], [], False)
base_result

## Basic topk ablations

In [None]:
"""
Ablating single topk's
"""
print('----- No Renorm -----')
for ablation_k in all_topk:
    ablate_res = run_all_with_ablation(model, all_layers, [ablation_k], False)
    print(f"Ablated topk={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

print('----- Renorm -----')
for ablation_k in all_topk:
    ablate_res = run_all_with_ablation(model, all_layers, [ablation_k], True)
    print(f"Ablated topk={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

In [None]:
"""
Ablating all smaller topks
"""
print('----- No Renorm -----')
for ablation_k in all_topk:
    ablate_res = run_all_with_ablation(model, all_layers, list(range(ablation_k, max(all_topk) + 1)), False)
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

print('----- Renorm -----')
for ablation_k in all_topk:
    ablate_res = run_all_with_ablation(model, all_layers, list(range(ablation_k, max(all_topk) + 1)), True)
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

## Layer-targeted

In [None]:
for ablation_k in all_topk:
    # int(np.floor(np.median(all_layers)))
    ablate_res = run_all_with_ablation(
        model,
        [4, 5, 6, 7, 8, 9, 10, 11, 12],
        list(range(ablation_k, max(all_topk) + 1)),
        False
    )
    
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

In [None]:
for ablation_k in all_topk:
    # int(np.floor(np.median(all_layers)))
    ablate_res = run_all_with_ablation(
        model,
        [0, 1, 2],
        list(range(ablation_k, max(all_topk) + 1)),
        False
    )
    
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

In [None]:
for ablation_k in all_topk:
    # int(np.floor(np.median(all_layers)))
    ablate_res = run_all_with_ablation(
        model,
        [13, 14, 15],
        list(range(ablation_k, max(all_topk) + 1)),
        False
    )
    
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

In [None]:
for ablation_k in all_topk:
    # int(np.floor(np.median(all_layers)))
    ablate_res = run_all_with_ablation(
        model,
        [10, 11, 12],
        list(range(ablation_k, max(all_topk) + 1)),
        False
    )
    
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")

In [None]:
for ablation_k in all_topk:
    # int(np.floor(np.median(all_layers)))
    ablate_res = run_all_with_ablation(
        model,
        [4, 5, 6],
        list(range(ablation_k, max(all_topk) + 1)),
        False
    )
    
    print(f"Ablated topk>={ablation_k} => PPL: {ablate_res['avg_ppl']:.2f}")