## Init

In [None]:
"""
Imports
"""
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import pandas as pd
import numpy as np
from scipy import stats
import plotly.express as px 
from tqdm import tqdm
from utils.store_topk import convert_topk_to_df

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base model
"""
hf_model_id = 'deepseek-ai/DeepSeek-V2-Lite'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

check_memory()

In [None]:
"""
Test a forward pass with hooks needed to extract topk weights
"""
# Hooks needed: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py

def attach_moe_gate_hooks(model):
    """
    Registers forward-hooks on each MoE gating module in 'model' so that after a forward pass,
    we can retrieve the BN x topk 'topk_ix', 'topk_weight', from each layer.
    
    Returns:
        all_expert_ids: A list that will be filled at runtime withtuples of (layer_index, topk_idx_tensor).
        handles: A dictionary of {layer_index: hook_handle}, so you can remove them if desired.
    """
    all_expert_ids = []
    all_expert_weights = []
    handles = {}

    def gate_forward_hook(module, input, output):
        """
        This hook is triggered after MoEGate.forward(...).
        'output' should be the tuple: (topk_ix, topk_weight, aux_loss).
        topk_ix and topk_weight are both BN x topk
        """
        topk_ix, topk_weight, _ = output 
        # Sort by descending weight
        sorted_w, sorted_idx = topk_weight.sort(dim=-1, descending=True)
        sorted_ix = torch.gather(topk_ix, dim=-1, index=sorted_idx)

        all_expert_ids.append(sorted_ix.detach().cpu())
        all_expert_weights.append(sorted_w.detach().cpu())

    for layer_ix, layer in enumerate(model.model.layers):
        # Layer 0 is not moe
        if 'DeepseekV2MLP' not in str(type(layer.mlp)):
            # attach an attribute so we know which layer this gating belongs to
            layer.mlp.gate._layer_id = layer_ix
            hook_handle = layer.mlp.gate.register_forward_hook(gate_forward_hook)
            handles[layer_ix] = hook_handle

    return all_expert_ids, all_expert_weights, handles

def test_forward_pass_with_hooks():

    all_expert_ids, all_expert_weights, hook_handles = attach_moe_gate_hooks(model)

    inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    with torch.no_grad():
        _ = model(**inputs)

    print(f"all_expert_weights is of length {len(all_expert_ids)}")
    print(f"all_expert_ids is of length {len(all_expert_weights)}")

    print(f"all_expert_ids[0] shape = {all_expert_ids[0].shape}") # should be BN x topk
    print(f"all_expert_weights[0] shape = {all_expert_weights[0].shape}") # should be BN x topk
    
    print(f"all_expert_ids[0]: {all_expert_ids[0]}")
    print(f"all_expert_weights[0]: {all_expert_weights[0]}")

    for _, h in hook_handles.items():
        h.remove()

test_forward_pass_with_hooks()

## Get MMLU

In [None]:
""""
Get MMLU data and domains to test
"""
raw_mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
print(raw_mmlu_ds[0])

# all_domains = list(set(mmlu_ds['subject']))
# for domain in all_domains:
#     dom_length = len([q for q in mmlu_ds if q['subject'] == domain])
#     print(f"{domain}: {dom_length}")

# Dict containing {final_domains: [source1, source2], ...}
mmlu_domain_mappings = {
    'math': ['elementary_mathematics'], 
    'cs': ['high_school_computer_science', 'college_computer_science'], # 100 each
    'history': ['high_school_world_history'],
    # 'psych': ['high_school_psychology'],
    'chemistry': ['high_school_chemistry'],
    'biology': ['high_school_biology']
}

# Now let's put the MMLU questions into a list grouped by domain
def group_mmlu_ds(raw_ds, domain_map, max_questions_per_domain):
    source_to_domain_map = {source: domain for domain, sources in mmlu_domain_mappings.items() for source in sources} # Map each source => domain
    final_ds = {domain: {'sources': sources, 'questions': []} for domain, sources in domain_map.items()} # Create empty dict to fill
    for q in raw_ds: 
        if q['subject'] in source_to_domain_map.keys():
            if (len(final_ds[source_to_domain_map[q['subject']]]['questions']) >= max_questions_per_domain):
                continue
            final_ds[source_to_domain_map[q['subject']]]['questions'].append({
                'question': q['question'],
                'choices': q['choices'],
                'answer_index': q['answer'],
                'answer_char': chr(65 + q['answer'])
            })
    return [{'domain': domain, **values} for domain, values in final_ds.items()] # Convert back to list of dicts

mmlu_ds = group_mmlu_ds(raw_mmlu_ds, mmlu_domain_mappings, 200)
mmlu_ds[0]


In [None]:
"""
Create function to map MMLU data into string
"""
def prep_question(question, choices):
    prompt = f"Question: {question}\nChoices:\n"
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"
    return prompt

fs_ex = [
    [q for q in raw_mmlu_ds if q['subject'] == 'anatomy'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'machine_learning'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'astronomy'][0]
]

base_prompt = [
    {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
    {'role': 'user', 'content': prep_question(fs_ex[0]['question'], fs_ex[0]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[0]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[1]['question'], fs_ex[1]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[1]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[2]['question'], fs_ex[2]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[2]['answer'])}
]

print(tokenizer.apply_chat_template(base_prompt, tokenize = False, add_generation_prompt = False, continue_final_message = True))

## Save topk for later analysis

In [None]:
@torch.no_grad()
def run_model_and_return_topk(mmlu_ds):

    final_results = []

    for this_domain in mmlu_ds:

        domain_questions = this_domain['questions']
        domain_results = []

        for question_ix, q in enumerate(domain_questions):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [{'role': 'user', 'content': prep_question(q['question'], q['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            all_expert_ids, all_expert_weights, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"])
            for _, h in hook_handles.items():
                h.remove()

            topk_df = convert_topk_to_df(inputs["input_ids"], all_expert_ids, all_expert_weights).assign(question_ix = question_ix).drop(columns = 'sequence_ix')
            topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask
            
            predicted_text = tokenizer.decode([torch.argmax(outputs['logits'][0, -1, :]).item()]).strip()
            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            domain_results.append({
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0,
                'is_valid': 1 if predicted_letter is not None else 0,
                'topk_df': topk_df
            })

        n_total = len(domain_results)
        n_correct = len([x for x in domain_results if x['is_correct'] == 1])
        n_invalid = len([x for x in domain_results if x['is_valid'] == 0])
        print(f'{this_domain["domain"]} | Correct: {str(n_correct)} | Incorrect: {str(n_total - n_correct)} | Invalid: {str(n_invalid)} | Accuracy: {(n_correct / (n_total)) * 100:.1f}%')
        
        final_results.append({
            'domain': this_domain['domain'],
            'question_df': pd.DataFrame([{k: v for k, v in x.items() if k != 'topk_df'} for x in domain_results]).assign(domain = this_domain['domain']),
            'topk_df': pd.concat([x['topk_df'] for x in domain_results]).assign(domain = this_domain['domain']),
            'n_correct': n_correct,
            'n_total': n_total,
            'n_invalid': n_invalid,
        })


    return {
        'question_df': pd.concat([d['question_df'] for d in final_results]),
        'topk_df': pd.concat([d['topk_df'] for d in final_results]),
        'n_correct': sum([d['n_correct'] for d in final_results]),
        'n_total': sum([d['n_total'] for d in final_results]),
        'n_invalid': sum([d['n_invalid'] for d in final_results]),
        'accuracy': sum([d['n_correct'] for d in final_results])/sum([d['n_total'] for d in final_results])
        }

baseline = run_model_and_return_topk(mmlu_ds)
display(baseline['topk_df'])
print(baseline['accuracy'])

In [None]:
vocab_map =\
    pd.DataFrame([{"token": token.replace('Ġ', ' '), "token_id": token_id} for token, token_id in tokenizer.get_vocab().items()])\
    .sort_values(by = 'token_id')\
    .reset_index()\
    .drop(columns = 'index')

display(vocab_map)

all_answers = baseline['question_df']
display(all_answers)

all_topks = baseline['topk_df'].assign(weight = lambda df: np.around(df['weight'], 4))
display(all_topks)

In [None]:
vocab_map.to_csv('vocab.csv', index = False)
all_answers.to_csv('all_answers.csv', index = False)
all_topks.to_csv('all_topks.csv', index = False)

## Ablation prep

In [None]:
input_prompt = tokenizer.apply_chat_template(
    base_prompt + [
        {'role': 'user', 'content': prep_question(mmlu_ds[0]['questions'][0]['question'], mmlu_ds[0]['questions'][0]['choices'])},
        {'role': 'assistant', 'content': 'The correct answer is'}
    ],
    tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
)
inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

with torch.no_grad():
    outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    print(outputs['logits'])

with torch.no_grad():
    outputs = model.model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    hidden_state = outputs[0]
    logits = model.lm_head(hidden_state)
    print(logits)

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

with torch.no_grad():
    B, N = inputs["input_ids"].shape[:2]
    position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

    inputs_embeds = model.model.embed_tokens(inputs["input_ids"])
    attention_mask = _prepare_4d_causal_attention_mask(inputs['attention_mask'], (B, N), inputs_embeds, 0,)

    hidden_state = inputs_embeds
    all_topk_experts = []
    all_topk_weights = []

    for decoder_layer in model.model.layers:
        # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
        residual = hidden_state
        hidden_state = decoder_layer.input_layernorm(hidden_state)
        # Self Attention
        hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
        hidden_state = residual + hidden_state
        # Fully Connected
        residual = hidden_state
        hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
        ## MLP
        if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
            hidden_state = decoder_layer.mlp(hidden_state)
        else:
            identity = hidden_state
            orig_shape = hidden_state.shape
            ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
            bsz, seq_len, h = hidden_state.shape
            moe_hidden_state = hidden_state.view(-1, h)
            logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
            scores = logits.softmax(dim=-1, dtype=torch.float32)
            topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
            topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
            ####
            hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
            flat_topk_idx = topk_idx.view(-1)
            ### moe infer
            x = hidden_state
            topk_ids = topk_idx
            cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
            cnts.scatter_(1, topk_ids, 1)
            tokens_per_expert = cnts.sum(dim=0)
            idxs = topk_ids.view(-1).argsort()
            sorted_tokens = x[idxs // topk_ids.shape[1]]
            tokens_per_expert = tokens_per_expert.cpu().numpy()
            outputs = []
            start_idx = 0
            for i, num_tokens in enumerate(tokens_per_expert):
                end_idx = start_idx + num_tokens
                if num_tokens == 0:
                    continue
                expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                expert_out = expert(tokens_for_this_expert)
                outputs.append(expert_out)
                start_idx = end_idx
            outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
            new_x = torch.empty_like(outs)
            new_x[idxs] = outs
            final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
            ###
            y = final_out.view(*orig_shape)
            if decoder_layer.mlp.config.n_shared_experts is not None:
                y = y + decoder_layer.mlp.shared_experts(identity)
            hidden_state = y

            all_topk_experts.append(topk_ids)
            all_topk_weights.append(topk_weight)

        hidden_state = residual + hidden_state

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    print(logits)
    print(len(all_topk_experts))
    print(all_topk_experts[0].shape)
    print(all_topk_weights[0].shape)

In [None]:
@torch.no_grad()
def run_model_no_ablation(input_ids, attention_mask):
    B, N = input_ids.shape[:2]
    position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

    inputs_embeds = model.model.embed_tokens(input_ids)
    attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

    hidden_state = inputs_embeds
    all_topk_experts = []
    all_topk_weights = []
    for layer_ix, decoder_layer in enumerate(model.model.layers):
        # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
        residual = hidden_state
        hidden_state = decoder_layer.input_layernorm(hidden_state)
        # Self Attention
        hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
        hidden_state = residual + hidden_state
        # Fully Connected
        residual = hidden_state
        hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
        ## MLP
        if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
            hidden_state = decoder_layer.mlp(hidden_state)
        else:
            identity = hidden_state
            orig_shape = hidden_state.shape
            ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
            bsz, seq_len, h = hidden_state.shape
            moe_hidden_state = hidden_state.view(-1, h)
            logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
            scores = logits.softmax(dim=-1, dtype=torch.float32)
            topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
            topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
            ####
            hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
            ### moe infer
            x = hidden_state
            topk_ids = topk_idx
            cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
            cnts.scatter_(1, topk_ids, 1)
            tokens_per_expert = cnts.sum(dim=0)
            idxs = topk_ids.view(-1).argsort()
            sorted_tokens = x[idxs // topk_ids.shape[1]]
            tokens_per_expert = tokens_per_expert.cpu().numpy()
            outputs = []
            start_idx = 0
            for i, num_tokens in enumerate(tokens_per_expert):
                end_idx = start_idx + num_tokens
                if num_tokens == 0:
                    continue
                expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                expert_out = expert(tokens_for_this_expert)
                outputs.append(expert_out)
                start_idx = end_idx
            outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
            new_x = torch.empty_like(outs)
            new_x[idxs] = outs
            final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
            ###
            y = final_out.view(*orig_shape)
            if decoder_layer.mlp.config.n_shared_experts is not None:
                y = y + decoder_layer.mlp.shared_experts(identity)
            hidden_state = y

            all_topk_experts.append(topk_ids)
            all_topk_weights.append(topk_weight)

        hidden_state = residual + hidden_state

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    return {'logits': logits, 'all_topk_experts': all_topk_experts, 'all_topk_weights': all_topk_weights}
    

@torch.no_grad()
def evaluate_with_ablation(run_forward_fn, return_topk, *args, **kwargs):
    """
    Evaluate modified model

    Params:
        @run_forward_fn: A function that returns a model forward pass with inputs input_ids, attention_mask, and optional *args/**kwargs. 
          The function must return a dict with key `logits`.
        @return_topk: Whether to return the expert IDs and weights as well. If True, `run_forward_fn` must also return keys
         `all_topk_experts` and `all_topk_weights`.
        @*args/**kwargs: Additional arguments to pass to `run_forward_fn`.
    """
    final_results = []

    for this_domain in mmlu_ds:

        domain_questions = this_domain['questions']
        domain_results = []

        for question_ix, q in enumerate(domain_questions):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [{'role': 'user', 'content': prep_question(q['question'], q['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            outputs = run_forward_fn(inputs['input_ids'], inputs['attention_mask'], *args, **kwargs)
            
            if return_topk == True:
                topk_df = convert_topk_to_df(inputs["input_ids"], outputs['all_topk_experts'], outputs['all_topk_weights']).assign(question_ix = question_ix).drop(columns = 'sequence_ix')
                topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask

            predicted_text = tokenizer.decode([torch.argmax(outputs['logits'][0, -1, :]).item()]).strip()
            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            domain_results.append({
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0,
                'is_valid': 1 if predicted_letter is not None else 0,
                'topk_df': topk_df if return_topk == True else None
            })

        n_total = len(domain_results)
        n_correct = len([x for x in domain_results if x['is_correct'] == 1])
        n_invalid = len([x for x in domain_results if x['is_valid'] == 0])
        print(f'{this_domain["domain"]} | Correct: {str(n_correct)} | Incorrect: {str(n_total - n_correct)} | Invalid: {str(n_invalid)} | Accuracy: {(n_correct / (n_total)) * 100:.1f}%')
        
        final_results.append({
            'domain': this_domain['domain'],
            'question_df': pd.DataFrame([{k: v for k, v in x.items() if k != 'topk_df'} for x in domain_results]).assign(domain = this_domain['domain']),
            'topk_df': pd.concat([x['topk_df'] for x in domain_results]).assign(domain = this_domain['domain']) if return_topk == True else None,
            'n_correct': n_correct,
            'n_total': n_total,
            'n_invalid': n_invalid,
        })

    return {
        'question_df': pd.concat([d['question_df'] for d in final_results]),
        'topk_df': pd.concat([d['topk_df'] for d in final_results]) if return_topk == True else None,
        'n_correct': sum([d['n_correct'] for d in final_results]),
        'n_total': sum([d['n_total'] for d in final_results]),
        'n_invalid': sum([d['n_invalid'] for d in final_results]),
        'accuracy': sum([d['n_correct'] for d in final_results])/sum([d['n_total'] for d in final_results])
        }

evaluate_with_ablation(run_model_no_ablation, False)

## Ablation testing

In [None]:
@torch.no_grad()
def run_model_with_ablation(input_ids, attention_mask, layers_to_ablate = list(range(1, 27)), topk_to_ablate = [0], renorm = False):
    """
    Ablates certain rank-ordered top-k columns for the specified layers.
    
    Params:
        @layers_to_ablate: Which layer indices (0-based) should we ablate?
        @renorm: Whether to rescale the remaining weights to keep the sum unchanged
        @topk_to_ablate: A list of ranks to ablate, e.g. [0] = top-1, [2,3] = 3rd & 4th largest, etc.
    """
    B, N = input_ids.shape[:2]
    position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

    inputs_embeds = model.model.embed_tokens(input_ids)
    attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

    hidden_state = inputs_embeds
    all_topk_experts = []
    all_topk_weights = []
    for layer_ix, decoder_layer in enumerate(model.model.layers):
        # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
        residual = hidden_state
        hidden_state = decoder_layer.input_layernorm(hidden_state)
        # Self Attention
        hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
        hidden_state = residual + hidden_state
        # Fully Connected
        residual = hidden_state
        hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
        ## MLP
        if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
            hidden_state = decoder_layer.mlp(hidden_state)
        else:
            identity = hidden_state
            orig_shape = hidden_state.shape
            ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
            bsz, seq_len, h = hidden_state.shape
            moe_hidden_state = hidden_state.view(-1, h)
            logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
            scores = logits.softmax(dim=-1, dtype=torch.float32)
            topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=True)
            topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
            ####
            ######################## ABLATION
            # shape: topk_weight is [B*N, top_k]
            if layer_ix in layers_to_ablate:
                # (A) Sort the topk dimension locally to find which columns correspond to the rank-ordered experts (note shape of topk_weight: [BN, k])
                sorted_w, sorted_idx = topk_weight.sort(dim=-1, descending=True)
                # sorted_w[:,0] is the largest weight in each row, sorted_idx[:,0] gives the original column index for that largest weight.
                row_sum_before = topk_weight.sum(dim=-1, keepdim=True)

                # (B) For each rank in topk_to_ablate, zero out that column in topk_weight
                for rank in topk_to_ablate:
                    columns_to_ablate = sorted_idx[:, rank]  # columns_to_ablate is [BN], each entry is the "original column" that corresponds to rank `rank` in sorted order
                    # Now zero out topk_weight[row, col]
                    for row_i in range(topk_weight.shape[0]):
                        col_j = columns_to_ablate[row_i].item()
                        topk_weight[row_i, col_j] = 0.0

                # Re-scale the remaining top-k weights to keep sum the same
                if renorm:
                    row_sum_after = topk_weight.sum(dim=-1, keepdim=True)
                    scale_factor = row_sum_before / (row_sum_after + 1e-9)
                    topk_weight *= scale_factor
            ######################## The rest is unchanged
            hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
            ### moe infer
            x = hidden_state
            topk_ids = topk_idx
            cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
            cnts.scatter_(1, topk_ids, 1)
            tokens_per_expert = cnts.sum(dim=0)
            idxs = topk_ids.view(-1).argsort()
            sorted_tokens = x[idxs // topk_ids.shape[1]]
            tokens_per_expert = tokens_per_expert.cpu().numpy()
            outputs = []
            start_idx = 0
            for i, num_tokens in enumerate(tokens_per_expert):
                end_idx = start_idx + num_tokens
                if num_tokens == 0:
                    continue
                expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                expert_out = expert(tokens_for_this_expert)
                outputs.append(expert_out)
                start_idx = end_idx
            outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
            new_x = torch.empty_like(outs)
            new_x[idxs] = outs
            final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
            ###
            y = final_out.view(*orig_shape)
            if decoder_layer.mlp.config.n_shared_experts is not None:
                y = y + decoder_layer.mlp.shared_experts(identity)
            hidden_state = y

            all_topk_experts.append(topk_ids)
            all_topk_weights.append(topk_weight)

        hidden_state = residual + hidden_state

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    return {'logits': logits, 'all_topk_experts': all_topk_experts, 'all_topk_weights': all_topk_weights}
    
# evaluate_with_ablation(run_model_with_ablation, return_topk = False, layers_to_ablate = list(range(1, 6)), topk_to_ablate = [0], renorm = False)

In [None]:
layers_to_ablate_list = [list(range(1, 27)), list(range(22, 27)), list(range(17, 27)), list(range(7, 27)), list(range(0, 4))]
topk_to_ablate_list = [list(range(0, 1)), list(range(1, 6)), list(range(0, 6))]

all_res = []
for layers_to_ablate in layers_to_ablate_list:
    for topk_to_ablate in topk_to_ablate_list:
        print('\n\n---------------')
        print(f"Layers ablated: {','.join([str(x) for x in layers_to_ablate])}")
        print(f"Topk ablated: {','.join([str(x) for x in topk_to_ablate])}")
        res = evaluate_with_ablation(run_model_with_ablation, return_topk = False, layers_to_ablate = layers_to_ablate, topk_to_ablate = topk_to_ablate, renorm = False)
        print(res['accuracy'])
        all_res.append({
            'layers_ablated': layers_to_ablate,
            'topk_ablated': topk_to_ablate,
            'results': res
        })