In [None]:
# Imports
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load Model

In [None]:
hf_model_id = 'deepseek-ai/DeepSeek-V2-Lite'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda()

check_memory()

In [None]:
# Test forward pass

# Hooks needed: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

def attach_moe_gate_hooks(model):
    """
    Registers forward-hooks on each MoE gating module in 'model' so that after a forward pass,
    we can retrieve the BN x topk 'topk_idx' from each layer.
    
    Returns:
        all_expert_ids: A list that will be filled at runtime withtuples of (layer_index, topk_idx_tensor).
        handles: A dictionary of {layer_index: hook_handle}, so you can remove them if desired.
    """
    all_expert_ids = []
    handles = {}

    def gate_forward_hook(module, input, output):
        """
        This hook is triggered after MoEGate.forward(...).
        'output' should be the tuple: (topk_idx, topk_weight, aux_loss).
        We only need topk_idx here.
        """
        topk_idx, _, _ = output
        all_expert_ids.append(topk_idx.detach())

    for layer_idx, layer in enumerate(model.model.layers):
        # Layer 0 is not moe
        if layer_idx > 0:
            # attach an attribute so we know which layer this gating belongs to
            layer.mlp.gate._layer_id = layer_idx
            hook_handle = layer.mlp.gate.register_forward_hook(gate_forward_hook)
            handles[layer_idx] = hook_handle

    return all_expert_ids, handles


all_expert_ids, hook_handles = attach_moe_gate_hooks(model)

with torch.no_grad():
    outputs = model(**inputs)

for topk_idx_tensor in all_expert_ids:
    print(f"topk_idx shape = {topk_idx_tensor.shape}")
    # e.g. shape is [B*N, top_k]

for layer_idx, h in hook_handles.items():
    h.remove()

In [None]:
# Test forward pass

# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

with torch.no_grad():
    all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
    outputs = model(**inputs)
    for layer_idx, h in hook_handles.items():
        h.remove()

all_topk_experts[0]

## Get MMLU

In [None]:
""""
Get MMLU data and domains to test
"""
from datasets import load_dataset

mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
print(mmlu_ds[0])

# Only retain domains for high school subjects
# domains_to_keep = ['college_biology', 'college_medicine', 'college_computer_science', 'college_mathematics']
domains_to_keep = ['elementary_mathematics', 'high_school_computer_science', 'high_school_world_history', 'high_school_psychology', 'high_school_chemistry', 'high_school_biology']
domains = [x for x in list(set([x['subject'] for x in mmlu_ds])) if x in domains_to_keep]
print(domains)

# Now let's put the MMLU questions into a list gruoped by domain
all_domain_questions = [
    {
        'domain': domain,
        'questions': 
            [
                {'question': q['question'], 'choices': q['choices'], 'answer_index': q['answer'], 'answer_char': chr(65 + q['answer'])}
                for q in mmlu_ds
                if q['subject'] == domain 
            ]
    }
    for domain in tqdm(domains)
]

all_domain_questions[0]

In [None]:
# all_domains = list(set(mmlu_ds['subject']))
# for domain in all_domains:
#     dom_length = len([q for q in mmlu_ds if q['subject'] == domain])
#     print(f"{domain}: {dom_length}")

In [None]:
"""
Create function to map MMLU data into questions
"""
def prep_question(question, choices):
    
    prompt = f"Question: {question}\nChoices:\n"
    
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"

    return prompt

# print(prep_question(mmlu_ds[0]['question'], mmlu_ds[0]['choices']))
# print(set(mmlu_ds['subject']))

mmlu_system_prompt = 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'
fs_ex = [
    [q for q in mmlu_ds if q['subject'] == 'anatomy'][0],
    [q for q in mmlu_ds if q['subject'] == 'machine_learning'][0],
    [q for q in mmlu_ds if q['subject'] == 'astronomy'][0]
]

base_prompt = [
    {'role': 'system', 'content': mmlu_system_prompt},
    {'role': 'user', 'content': prep_question(fs_ex[0]['question'], fs_ex[0]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[0]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[1]['question'], fs_ex[1]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[1]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[2]['question'], fs_ex[2]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[2]['answer'])}
]

print(tokenizer.apply_chat_template(base_prompt, tokenize = False, add_generation_prompt = False, continue_final_message = True))

## Save topk for later analysis

In [None]:
import pandas as pd
from utils.store_topk import convert_topk_to_df

cat_results = []

for this_domain in all_domain_questions:

    domain_questions = this_domain['questions']

    count_correct = 0
    count_incorrect = 0
    topk_dfs = []
    results = []
    for question_ix, q in tqdm(enumerate(domain_questions[0:200])):

        input_prompt = tokenizer.apply_chat_template(
            base_prompt + [
                {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is'},
            ],
            tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
        )
        inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

        with torch.no_grad():
            all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"])
            for layer_idx, h in hook_handles.items():
                h.remove()

            topk_df = convert_topk_to_df(all_topk_experts, inputs["input_ids"]).assign(domain = this_domain['domain'], question_ix = question_ix).drop(columns = 'sequence_ix')
            topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask
            topk_dfs.append(topk_df)

            next_token_logits = outputs['logits'][0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

        predicted_letter = None
        for c in predicted_text:
            if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                predicted_letter = c.upper()
                break

        result = {
            'domain': this_domain['domain'],
            'question_ix': question_ix, 
            'model_output': predicted_text,
            'model_choice': predicted_letter,
            'correct_choice': q['answer_char'],
            'is_correct': 1 if predicted_letter == q['answer_char'] else 0
        }

        if result['is_correct'] == 1: count_correct += 1
        else: count_incorrect += 1

        results.append(result)
    

    cat_results.append({
        'answer_df': pd.DataFrame(results),
        'topks_df': pd.concat(topk_dfs)
    })

    print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

In [None]:
vocab_map =\
    pd.DataFrame([{"token": token.replace('Ġ', ' '), "token_id": token_id} for token, token_id in tokenizer.get_vocab().items()])\
    .sort_values(by = 'token_id')\
    .reset_index()

display(vocab_map)

all_answers = pd.concat([cat['answer_df'] for cat in cat_results])
display(all_answers)

all_topks = pd.concat([cat['topks_df'] for cat in cat_results]).merge(vocab_map, how = 'left', on = 'token_id')
display(all_topks)

In [None]:
all_answers.to_csv('all_answers.csv', index = False)
all_topks.to_csv('all_topks.csv', index = False)

## Ablation prep

In [None]:
with torch.no_grad():
    outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    print(outputs['logits'])

with torch.no_grad():
    outputs = model.model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    hidden_state = outputs[0]
    logits = model.lm_head(hidden_state)
    print(logits)

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

with torch.no_grad():
    B, N = inputs["input_ids"].shape[:2]
    position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

    inputs_embeds = model.model.embed_tokens(inputs["input_ids"])
    attention_mask = _prepare_4d_causal_attention_mask(inputs['attention_mask'], (B, N), inputs_embeds, 0,)

    hidden_state = inputs_embeds
    for decoder_layer in model.model.layers:
        # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
        residual = hidden_state
        hidden_state = decoder_layer.input_layernorm(hidden_state)
        # Self Attention
        hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
        hidden_state = residual + hidden_state
        # Fully Connected
        residual = hidden_state
        hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
        ## MLP
        if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
            hidden_state = decoder_layer.mlp(hidden_state)
        else:
            identity = hidden_state
            orig_shape = hidden_state.shape
            ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
            bsz, seq_len, h = hidden_state.shape
            moe_hidden_state = hidden_state.view(-1, h)
            logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
            scores = logits.softmax(dim=-1, dtype=torch.float32)
            topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
            topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
            ####
            hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
            flat_topk_idx = topk_idx.view(-1)
            ### moe infer
            x = hidden_state
            topk_ids = topk_idx
            cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
            cnts.scatter_(1, topk_ids, 1)
            tokens_per_expert = cnts.sum(dim=0)
            idxs = topk_ids.view(-1).argsort()
            sorted_tokens = x[idxs // topk_ids.shape[1]]
            tokens_per_expert = tokens_per_expert.cpu().numpy()
            outputs = []
            start_idx = 0
            for i, num_tokens in enumerate(tokens_per_expert):
                end_idx = start_idx + num_tokens
                if num_tokens == 0:
                    continue
                expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                expert_out = expert(tokens_for_this_expert)
                outputs.append(expert_out)
                start_idx = end_idx
            outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
            new_x = torch.empty_like(outs)
            new_x[idxs] = outs
            final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
            ###
            y = final_out.view(*orig_shape)
            if decoder_layer.mlp.config.n_shared_experts is not None:
                y = y + decoder_layer.mlp.shared_experts(identity)
            hidden_state = y

        hidden_state = residual + hidden_state

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    print(logits)

In [None]:
def run_model_no_ablation(input_ids, attention_mask):
    with torch.no_grad():
        return model(input_ids, attention_mask)['logits']

def evaluate_with_ablation(run_forward_fn, *args, **kwargs):
    
    cat_results = []

    for this_domain in all_domain_questions:

        domain_questions = this_domain['questions']

        count_correct = 0
        count_incorrect = 0
        results = []
        for question_ix, q in tqdm(enumerate(domain_questions[0:200])):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [
                    {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                    {'role': 'assistant', 'content': 'The correct answer is'},
                ],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            outputs = run_forward_fn(inputs['input_ids'], inputs['attention_mask'], *args, **kwargs)
            next_token_logits = outputs[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            result = {
                'domain': this_domain['domain'],
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0
            }

            if result['is_correct'] == 1: count_correct += 1
            else: count_incorrect += 1

            results.append(result)
        

        cat_results.append({
            'answer_df': pd.DataFrame(results)
        })

        print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

evaluate_with_ablation(run_model_no_ablation)

## Ablation 1 - ablate top1 expert

In [None]:
def run_model_ablate_top1(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        for layer_ix, decoder_layer in enumerate(model.model.layers):
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                bsz, seq_len, h = hidden_state.shape
                moe_hidden_state = hidden_state.view(-1, h)
                logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
                scores = logits.softmax(dim=-1, dtype=torch.float32)
                topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
                topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
                ####
                ######################## ABLATE TOP-1
                # For each token, find the single highest weight in topk_weight, set it to zero.
                # shape: topk_weight is [B*N, top_k]
                # if layer_ix in list(range(0, 27)):
                if layer_ix in list(range(1, 4)):
                    row_max_indices = topk_weight.argmax(dim=-1)  # which of the top_k is max
                    for row_i in range(topk_weight.shape[0]):
                        topk_weight[row_i, row_max_indices[row_i]] = 0.0
                ######################## The rest is unchanged
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

            hidden_state = residual + hidden_state

        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits
    
evaluate_with_ablation(run_model_ablate_top1)

In [None]:
list(range(1, 3))

## Ablation 2 - ablate top1

In [None]:
def run_model_ablate_top1_renorm(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                bsz, seq_len, h = hidden_state.shape
                moe_hidden_state = hidden_state.view(-1, h)
                logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
                scores = logits.softmax(dim=-1, dtype=torch.float32)
                topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
                topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
                ####
                ######################## ABLATE TOP-1
                # For each token, find the single highest weight in topk_weight, set it to zero.
                # shape: topk_weight is [B*N, top_k]
                row_sum_before = topk_weight.sum(dim=-1, keepdim=True)  # shape [B*N, 1]
                row_max_indices = topk_weight.argmax(dim=-1)
                for row_i in range(topk_weight.shape[0]):
                    topk_weight[row_i, row_max_indices[row_i]] = 0.0
                row_sum_after = topk_weight.sum(dim=-1, keepdim=True)

                # Avoid dividing by zero
                scale_factor = row_sum_before / (row_sum_after + 1e-9)
                topk_weight = topk_weight * scale_factor
                ######################## The rest is unchanged
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

            hidden_state = residual + hidden_state

        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        
        return logits

evaluate_with_ablation(run_model_ablate_top1_renorm)

## Ablation 3 - ablate everything other than topk=1

In [None]:
def run_model_ablate_only_top1_renorm(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                bsz, seq_len, h = hidden_state.shape
                moe_hidden_state = hidden_state.view(-1, h)
                logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
                scores = logits.softmax(dim=-1, dtype=torch.float32)
                topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
                topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
                ####
                ######################## ABLATE ALL EXCEPT TOP-1
                row_sum_before = topk_weight.sum(dim=-1, keepdim=True)  # sum of all k
                row_max_indices = topk_weight.argmax(dim=-1)            # which column is top1

                # We'll keep only that top1 column, zero out others
                # shape: topk_weight is [B*N, top_k]
                temp = topk_weight.clone()
                topk_weight.zero_()
                # For each row, restore only the largest weight
                for row_i in range(topk_weight.shape[0]):
                    col = row_max_indices[row_i]
                    topk_weight[row_i, col] = temp[row_i, col]

                # Now renormalize so the sum is unchanged
                # row_sum_after = topk_weight.sum(dim=-1, keepdim=True)
                # scale_factor = row_sum_before / (row_sum_after + 1e-9)
                # topk_weight = topk_weight * scale_factor
                ######################## The rest is unchanged
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

            hidden_state = residual + hidden_state

        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        
        return logits

evaluate_with_ablation(run_model_ablate_only_top1_renorm)

## Ablate top1, last expert only

## Check topk expert distribution

In [None]:
"""
First, let's check the distributions of the topks so we can see why topk=1 matters so much
"""
def run_model_return_topk_weights(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device).unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        topk_weights = []
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                bsz, seq_len, h = hidden_state.shape
                moe_hidden_state = hidden_state.view(-1, h)
                logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
                scores = logits.softmax(dim=-1, dtype=torch.float32)
                topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=True)
                topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

                topk_weights.append(topk_weight)

            hidden_state = residual + hidden_state

        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits, topk_weights # topk_weights is BN x k

def get_topk_weights():
    
    cat_results = []

    for this_domain in all_domain_questions:

        domain_questions = this_domain['questions']

        count_correct = 0
        count_incorrect = 0
        results = []
        topk_weight_dfs = []
        for question_ix, q in tqdm(enumerate(domain_questions[0:200])):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [
                    {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                    {'role': 'assistant', 'content': 'The correct answer is'},
                ],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            logits, topk_weights = run_model_return_topk_weights(inputs['input_ids'], inputs['attention_mask'])
            next_token_logits = logits[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            result = {
                'domain': this_domain['domain'],
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0
            }

            if result['is_correct'] == 1: count_correct += 1
            else: count_incorrect += 1

            results.append(result)

            ### Cleanup topk weights and return df ###
            np_probs = (torch.stack(topk_weights)/torch.stack(topk_weights).sum(dim = -1, keepdim = True)).cpu().numpy()
            flat_data = []
            dim_names = ['layer_ix', 'token_ix', 'topk_ix']
            dim_values = [list(range(s)) for s in np_probs.shape]
            for i in range(np_probs.shape[0]):
                for j in range(np_probs.shape[1]):
                    for k in range(np_probs.shape[2]):
                        flat_data.append({
                            dim_names[0]: dim_values[0][i],
                            dim_names[1]: dim_values[1][j],
                            dim_names[2]: dim_values[2][k],
                            'prob': np_probs[i, j, k]
                        })

            layer_x_token_x_topk = pd.DataFrame(flat_data).assign(domain = this_domain['domain'], question_ix = question_ix)
            topk_weight_dfs.append(layer_x_token_x_topk)

        cat_results.append({
            'answer_df': pd.DataFrame(results),
            'topk_weights': pd.concat(topk_weight_dfs)
        })

        print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')
    return cat_results

topk_res = get_topk_weights()

In [None]:
layer_x_token_x_topk = pd.concat([x['topk_weights'] for x in topk_res])
layer_x_token_x_topk

In [None]:
token_x_topk = layer_x_token_x_topk[layer_x_token_x_topk['layer_ix'] == 0].sample(n = 100)
token_x_topk

In [None]:
import plotly.express as px
token_x_topk = layer_x_token_x_topk[layer_x_token_x_topk['layer_ix'] == 0].sample(n = 10000)
display(token_x_topk)

fig_violin = px.violin(
    token_x_topk,
    x="topk_ix",
    y="prob",
    color="topk_ix",
    box=True,  # Include box plot inside violin
    points="all",  # Show all points
    title="Probability Distribution by Top-K Index"
)

fig_violin.show()

In [None]:
import plotly.graph_objects as go
import numpy as np
from scipy import stats

# Get unique topk_ix values
topk_indices = sorted(token_x_topk["topk_ix"].unique())

# Create figure
fig = go.Figure()

# Generate a nice color palette
colors = px.colors.qualitative.Plotly[:len(topk_indices)]

# Add a KDE plot for each topk_ix
for i, topk_ix in enumerate(topk_indices):
    # Filter data for this topk_ix
    subset = token_x_topk[token_x_topk["topk_ix"] == topk_ix]
    
    # Get probability values
    values = subset["prob"].values
    
    # Calculate KDE using scipy
    if len(values) > 1:  # Need at least 2 points for KDE
        bandwidth = 0.1 * stats.gaussian_kde(values).factor  # Lower factor = less smooth
        kde = stats.gaussian_kde(values, bw_method=bandwidth)
        
        kde_x = np.linspace(min(values), max(values), 1000)
        kde_y = kde(kde_x)
                
        # Add line
        fig.add_trace(go.Scatter(
            x=kde_x,
            y=kde_y,
            mode='lines',
            name=f'Top-K {topk_ix}',
            line=dict(width=2, color=colors[i]),
            fill='tozeroy',  # Fill to the x-axis
            fillcolor=f'rgba({int(int(colors[i][1:3], 16))}, {int(int(colors[i][3:5], 16))}, {int(int(colors[i][5:7], 16))}, 0.3)'  # Translucent fill
        ))

# Update layout
fig.update_layout(
    title="Probability Density Distribution by Top-K Index",
    xaxis_title="Probability",
    yaxis_title="Density",
    legend_title="Top-K Index",
    xaxis=dict(
        showgrid=True,
        zeroline=True,
    ),
    yaxis=dict(
        showgrid=True,
        zeroline=True,
    )
)

fig.show()

## Ablation 4 - ablate lower topks

## Ablation 5 -- ablate lower topks

In [None]:
topk_weight[0, :].mean()

In [None]:
np_probs = (torch.stack(topk_weights)/torch.stack(topk_weights).sum(dim = -1, keepdim = True)).cpu().numpy()

flat_data = []
dim_names = ['layer_ix', 'token_ix', 'topk_ix']
dim_values = [list(range(s)) for s in np_probs.shape]
for i in range(np_probs.shape[0]):
    for j in range(np_probs.shape[1]):
        for k in range(np_probs.shape[2]):
            flat_data.append({
                dim_names[0]: dim_values[0][i],
                dim_names[1]: dim_values[1][j],
                dim_names[2]: dim_values[2][k],
                'prob': np_probs[i, j, k]
            })

layer_x_token_x_topk = pd.DataFrame(flat_data)
token_x_topk = layer_x_token_x_topk[layer_x_token_x_topk['layer_ix'] == 0]

In [None]:
# Option 3: Violin plot
fig_violin = px.violin(
    token_x_topk,
    x="topk_ix",
    y="prob",
    color="topk_ix",
    box=True,  # Include box plot inside violin
    points="all",  # Show all points
    title="Probability Distribution by Top-K Index"
)

fig_violin.show()

In [None]:
topk_weights[-1]/topk_weights[-1].sum(dim = -1, keepdim = True)

In [None]:
topk_weight/topk_weight.sum(dim = -1, keepdim = True)

In [None]:
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

def run_model_with_ablation(input_ids, attention_mask, experts_to_ablate: list[int]):
    """
    Params:
        @input_ids
        @attention_mask
        @experts_to_ablate: A list equal to the length of MoE layers, which each element containing a single expert ID to ablate per layer
    """
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device)
        position_ids = position_ids.unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        moe_layer_ix = 0
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                expert_to_ablate = experts_to_ablate[moe_layer_ix]

                identity = hidden_state
                orig_shape = hidden_state.shape
                ### moegate - originally topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                bsz, seq_len, h = hidden_state.shape
                moe_hidden_state = hidden_state.view(-1, h)
                logits = torch.nn.functional.linear(moe_hidden_state.type(torch.float32), decoder_layer.mlp.gate.weight.type(torch.float32), None)
                scores = logits.softmax(dim=-1, dtype=torch.float32)
                topk_weight, topk_idx = torch.topk(scores, k=decoder_layer.mlp.gate.top_k, dim=-1, sorted=False)
                topk_weight = topk_weight * decoder_layer.mlp.gate.routed_scaling_factor
                ####
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    
                    # Figure out which "global" expert index this corresponds to
                    # (some code uses 'ep_rank' & 'experts_per_rank' to place experts across ranks)
                    expert_idx = i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]

                    if expert_idx == expert_to_ablate:
                        # -------------------------
                        # ABLATE THIS EXPERT
                        # -------------------------
                        expert_out = torch.zeros_like(tokens_for_this_expert)
                    else:
                        # Run the normal expert forward
                        expert = decoder_layer.mlp.experts[expert_idx]
                        expert_out = expert(tokens_for_this_expert)
                    
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

                moe_layer_ix += 1

            hidden_state = residual + hidden_state


        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits
    
run_model_with_ablation(inputs["input_ids"], inputs['attention_mask'], experts_to_ablate = [7] * 26)

In [None]:
model.model.config.norm_topk_prob

In [None]:
def evaluate_with_ablation(run_forward_fn, *args, **kwargs):
    
    cat_results = []

    for this_domain in all_domain_questions:

        domain_questions = this_domain['questions']

        count_correct = 0
        count_incorrect = 0
        results = []
        for question_ix, q in tqdm(enumerate(domain_questions[3:203])):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [
                    {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                    {'role': 'assistant', 'content': 'The correct answer is'},
                ],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            outputs = run_forward_fn(*args, **kwargs)
            next_token_logits = outputs[0, -1, :]
            next_token_id = torch.argmax(next_token_logits).item()
            predicted_text = tokenizer.decode([next_token_id]).strip()

            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            result = {
                'domain': this_domain['domain'],
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0
            }

            if result['is_correct'] == 1: count_correct += 1
            else: count_incorrect += 1

            results.append(result)
        

        cat_results.append({
            'answer_df': pd.DataFrame(results)
        })

        print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

evaluate_with_ablation(run_model_with_ablation, inputs["input_ids"], inputs['attention_mask'], experts_to_ablate = [7] * 26)

## Ablation (method 2 - ablate top1 expert by zeroing it)

In [None]:
def run_model_with_top1_zero(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device)
        position_ids = position_ids.unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)

            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                #####################################
                # 1) Zero out the top-1 slot for each token/shape of topk_weight: (BN, k)/set topk_weight[:, 0] = 0 for all tokens
                topk_weight[:, 0] = 0
                #####################################
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                sorted_tokens_shape = sorted_tokens.shape
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

            hidden_state = residual + hidden_state


        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits

evaluate_with_ablation(run_model_with_top1_zero, inputs["input_ids"], inputs['attention_mask'])

## Ablation 3: ablate top1, move in topk+1

In [None]:
def run_model_with_top1_shift(input_ids, attention_mask):
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device)
        position_ids = position_ids.unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state
            hidden_state = decoder_layer.input_layernorm(hidden_state)
            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(hidden_states = hidden_state, attention_mask = attention_mask, position_ids = position_ids)
            hidden_state = residual + hidden_state
            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)

            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                identity = hidden_state
                orig_shape = hidden_state.shape
                #####################################
                k = decoder_layer.mlp.num_experts_per_tok 
                decoder_layer.mlp.gate.top_k = k + 1 # set new topk
                topkplus1_idx, topkplus1_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                new_topk_idx = topkplus1_idx[:, 1:]    # shape (BN, k)
                new_topk_weight = topkplus1_weight[:, 1:]  # shape (BN, k) 
                #####################################
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                ### moe infer
                x = hidden_state
                topk_ids = new_topk_idx
                print(topk_ids.shape)

                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                sorted_tokens_shape = sorted_tokens.shape
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(new_topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

            hidden_state = residual + hidden_state

        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits

evaluate_with_ablation(run_model_with_top1_shift, inputs["input_ids"], inputs['attention_mask'])

In [None]:
model.model.config.

In [None]:
decoder_layer.mlp.num_experts_per_tok

## Ablation study (naive method, all layers)

In [None]:
import numpy as np

res =\
    all_topks\
    .assign(is_test_domain = lambda df: np.where(df['domain'] == 'high_school_biology', 1, 0))\
    .groupby(['is_test_domain', 'layer_ix', 'expert_1'])\
    .agg(count = ('token_ix', 'count'))\
    .reset_index()\
    .pivot(index = ['layer_ix', 'expert_1'], columns='is_test_domain', values = 'count')\
    .rename(columns = {0: 'other_domain_count', 1: 'test_domain_count'})\
    .fillna(0)\
    .pipe(lambda df: df.set_axis(df.columns, axis = 1))\
    .reset_index()

res

In [None]:
with torch.no_grad():
    outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    print(outputs['logits'])

with torch.no_grad():
    outputs = model.model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    hidden_state = outputs[0]
    logits = model.lm_head(hidden_state)
    print(logits)

In [None]:
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

def run_model_with_ablation(input_ids, attention_mask, experts_to_ablate: list[int]):
    """
    Params:
        @input_ids
        @attention_mask
        @experts_to_ablate: A list equal to the length of MoE layers, which each element containing a single expert ID to ablate per layer
    """
    with torch.no_grad():
        B, N = input_ids.shape[:2]
        position_ids = torch.arange(0, N, dtype=torch.long, device = main_device)
        position_ids = position_ids.unsqueeze(0)

        inputs_embeds = model.model.embed_tokens(input_ids)
        attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (B, N), inputs_embeds, 0,)

        hidden_state = inputs_embeds
        moe_layer_ix = 0
        for decoder_layer in model.model.layers:
            # layer_outputs = decoder_layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids,)
            residual = hidden_state

            hidden_state = decoder_layer.input_layernorm(hidden_state)

            # Self Attention
            hidden_state, self_attn_weights, present_key_value = decoder_layer.self_attn(
                hidden_states = hidden_state,
                attention_mask = attention_mask,
                position_ids = position_ids
            )
            hidden_state = residual + hidden_state

            # Fully Connected
            residual = hidden_state
            hidden_state = decoder_layer.post_attention_layernorm(hidden_state)
            ## MLP
            if 'DeepseekV2MLP' in str(type(decoder_layer.mlp)):
                hidden_state = decoder_layer.mlp(hidden_state)
            else:
                expert_to_ablate = experts_to_ablate[moe_layer_ix]

                identity = hidden_state
                orig_shape = hidden_state.shape
                topk_idx, topk_weight, aux_loss = decoder_layer.mlp.gate(hidden_state)
                hidden_state = hidden_state.view(-1, hidden_state.shape[-1])
                flat_topk_idx = topk_idx.view(-1)
                ### moe infer
                x = hidden_state
                topk_ids = topk_idx
                cnts = topk_ids.new_zeros((topk_ids.shape[0], len(decoder_layer.mlp.experts)))
                cnts.scatter_(1, topk_ids, 1)
                tokens_per_expert = cnts.sum(dim=0)
                idxs = topk_ids.view(-1).argsort()
                sorted_tokens = x[idxs // topk_ids.shape[1]]
                sorted_tokens_shape = sorted_tokens.shape
                tokens_per_expert = tokens_per_expert.cpu().numpy()
                outputs = []
                start_idx = 0
                for i, num_tokens in enumerate(tokens_per_expert):
                    end_idx = start_idx + num_tokens
                    if num_tokens == 0:
                        continue
                    expert = decoder_layer.mlp.experts[i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank]
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
                    expert_out = expert(tokens_for_this_expert)
                    
                    # Figure out which "global" expert index this corresponds to
                    # (some code uses 'ep_rank' & 'experts_per_rank' to place experts across ranks)
                    expert_idx = i + decoder_layer.mlp.ep_rank * decoder_layer.mlp.experts_per_rank
                    tokens_for_this_expert = sorted_tokens[start_idx:end_idx]

                    if expert_idx == expert_to_ablate:
                        # -------------------------
                        # ABLATE THIS EXPERT
                        # -------------------------
                        expert_out = torch.zeros_like(tokens_for_this_expert)
                    else:
                        # Run the normal expert forward
                        expert = decoder_layer.mlp.experts[expert_idx]
                        expert_out = expert(tokens_for_this_expert)
                    
                    outputs.append(expert_out)
                    start_idx = end_idx
                outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
                new_x = torch.empty_like(outs)
                new_x[idxs] = outs
                final_out = (new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype).mul_(topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
                ###
                y = final_out.view(*orig_shape)
                if decoder_layer.mlp.config.n_shared_experts is not None:
                    y = y + decoder_layer.mlp.shared_experts(identity)
                hidden_state = y

                moe_layer_ix += 1

            hidden_state = residual + hidden_state


        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)
        return logits
    
run_model_with_ablation(inputs["input_ids"], inputs['attention_mask'], experts_to_ablate = [7] * 26)

In [None]:
cat_results = []

for this_domain in all_domain_questions:

    domain_questions = this_domain['questions']

    count_correct = 0
    count_incorrect = 0
    results = []
    for question_ix, q in tqdm(enumerate(domain_questions[3:203])):

        input_prompt = tokenizer.apply_chat_template(
            base_prompt + [
                {'role': 'user', 'content': prep_question(q['question'], q['choices'])},
                {'role': 'assistant', 'content': 'The correct answer is'},
            ],
            tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
        )
        inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

        outputs = run_model_with_ablation(inputs["input_ids"], inputs["attention_mask"], experts_to_ablate = [7] * 26)
        next_token_logits = outputs[0, -1, :]
        next_token_id = torch.argmax(next_token_logits).item()
        predicted_text = tokenizer.decode([next_token_id]).strip()

        predicted_letter = None
        for c in predicted_text:
            if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                predicted_letter = c.upper()
                break

        result = {
            'domain': this_domain['domain'],
            'question_ix': question_ix, 
            'model_output': predicted_text,
            'model_choice': predicted_letter,
            'correct_choice': q['answer_char'],
            'is_correct': 1 if predicted_letter == q['answer_char'] else 0
        }

        if result['is_correct'] == 1: count_correct += 1
        else: count_incorrect += 1

        results.append(result)
    

    cat_results.append({
        'answer_df': pd.DataFrame(results)
    })

    print(f'{this_domain["domain"]} | Correct: {str(count_correct)} | Incorrect: {str(count_incorrect)} | Accuracy: {(count_correct / (count_correct + count_incorrect)) * 100:.1f}%')

In [None]:
'DeepseekV2MLP' in str(type(model.model.layers[0].mlp))

In [None]:
del hidden_states

In [None]:
# from datasets import load_dataset

# mmlu_ds = load_dataset("TIGER-Lab/MMLU-Pro", split = 'test')