## Init

In [None]:
"""
Imports
"""
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import pandas as pd
import numpy as np
from scipy import stats
import plotly
import plotly.express as px 
import plotly.graph_objects as go
from tqdm import tqdm
from utils.store_topk import convert_topk_to_df

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base model
"""
hf_model_id = 'Qwen/Qwen1.5-MoE-A2.7B-Chat'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

check_memory()

In [None]:
"""
Test a forward pass with topk experts and weights extracted
"""
# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py

def test_forward_pass():

    inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    with torch.no_grad():
        outputs = model(**inputs, output_router_logits = True)

    all_topk_experts = []
    all_topk_weights = []
    for l, layer_router_logits in enumerate(outputs.router_logits):
        # layer_router_logits is shape [B*N, num_experts]
        gating_probs = torch.softmax(layer_router_logits, dim = -1)
        routing_weights, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1, sorted = True)
        
        all_topk_experts.append(topk_experts.detach().cpu()) 
        all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

    print(f"all_expert_weights is of length {len(all_topk_experts)}")
    print(f"all_topk_weights is of length {len(all_topk_weights)}")

    print(f"all_topk_experts[0] shape = {all_topk_experts[0].shape}") # should be BN x topk
    print(f"all_topk_weights[0] shape = {all_topk_weights[0].shape}") # should be BN x topk
    
    print(f"all_topk_experts[0]: {all_topk_experts[0]}")
    print(f"all_topk_weights[0]: {all_topk_weights[0]}")

test_forward_pass()

## Get MMLU

In [None]:
""""
Get MMLU data and domains to test
"""
raw_mmlu_ds = load_dataset("cais/mmlu", 'all', split = 'test')
print(raw_mmlu_ds[0])

# all_domains = list(set(mmlu_ds['subject']))
# for domain in all_domains:
#     dom_length = len([q for q in mmlu_ds if q['subject'] == domain])
#     print(f"{domain}: {dom_length}")

# Dict containing {final_domains: [source1, source2], ...}
mmlu_domain_mappings = {
    'math': ['elementary_mathematics'], 
    'cs': ['high_school_computer_science', 'college_computer_science'], # 100 each
    'history': ['high_school_world_history'],
    # 'psych': ['high_school_psychology'],
    'chemistry': ['high_school_chemistry'],
    'biology': ['high_school_biology']
}

# Now let's put the MMLU questions into a list grouped by domain
def group_mmlu_ds(raw_ds, domain_map, max_questions_per_domain):
    source_to_domain_map = {source: domain for domain, sources in mmlu_domain_mappings.items() for source in sources} # Map each source => domain
    final_ds = {domain: {'sources': sources, 'questions': []} for domain, sources in domain_map.items()} # Create empty dict to fill
    for q in raw_ds: 
        if q['subject'] in source_to_domain_map.keys():
            if (len(final_ds[source_to_domain_map[q['subject']]]['questions']) >= max_questions_per_domain):
                continue
            final_ds[source_to_domain_map[q['subject']]]['questions'].append({
                'question': q['question'],
                'choices': q['choices'],
                'answer_index': q['answer'],
                'answer_char': chr(65 + q['answer'])
            })
    return [{'domain': domain, **values} for domain, values in final_ds.items()] # Convert back to list of dicts

mmlu_ds = group_mmlu_ds(raw_mmlu_ds, mmlu_domain_mappings, 200)
mmlu_ds[0]


In [None]:
"""
Create function to map MMLU data into string
"""
def prep_question(question, choices):
    prompt = f"Question: {question}\nChoices:\n"
    for i, option in enumerate(choices):
        letter = chr(65 + i)
        prompt += f"({letter}) {option}\n"
    return prompt

fs_ex = [
    [q for q in raw_mmlu_ds if q['subject'] == 'anatomy'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'machine_learning'][0],
    [q for q in raw_mmlu_ds if q['subject'] == 'astronomy'][0]
]

base_prompt = [
    {'role': 'system', 'content': 'You will be provided with a multiple-choice question, as well as a list of possible answer choices. Respond exactly with: "The correct answer is {X}", substituting in X with the code for the correct choice.'},
    {'role': 'user', 'content': prep_question(fs_ex[0]['question'], fs_ex[0]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[0]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[1]['question'], fs_ex[1]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[1]['answer'])},
    {'role': 'user', 'content': prep_question(fs_ex[2]['question'], fs_ex[2]['choices'])},
    {'role': 'assistant', 'content': 'The correct answer is ' + chr(65 + fs_ex[2]['answer'])}
]

print(tokenizer.apply_chat_template(base_prompt, tokenize = False, add_generation_prompt = False, continue_final_message = True))

## Run logit lens

In [None]:
mmlu_ds[0]['questions'][0]

In [None]:

test_prompt = tokenizer.apply_chat_template(
    base_prompt + [{'role': 'user', 'content': prep_question(mmlu_ds[0]['questions'][0]['question'], mmlu_ds[0]['questions'][0]['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
    tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
)
inputs = tokenizer(test_prompt, return_tensors = 'pt').to(main_device)

with torch.no_grad():
    input_embeds = model.model.embed_tokens(inputs['input_ids'])
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(inputs['attention_mask'], input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    layer_outputs = []
    for layer in model.model.layers:
        hidden_state = layer(hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)[0]
        layer_outputs.append(hidden_state.detach())

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)

top_k_probs = 5
with torch.no_grad():
    logit_lens_outputs = []

    for layer_ix, layer_output in enumerate(layer_outputs):
        layer_output = model.model.norm(layer_output)
        lm_output = model.lm_head(layer_output).float()

        last_token_logits = lm_output[0, -1, :]
        probabilities = torch.nn.functional.softmax(last_token_logits, dim = 0)

        top_probabilities, top_indices = torch.topk(probabilities, top_k_probs, dim = 0)  # top_k
        for rank in range(top_k_probs):
            token_idx = top_indices[rank].item()
            prob = round(top_probabilities[rank].item(), 6)
            token = tokenizer.decode([token_idx])
            logit_lens_outputs.append({'layer_ix': layer_ix, 'prob': prob, 'token_rank': rank + 1, 'token': token})
            
            
logit_lens_df = pd.DataFrame(logit_lens_outputs)
logit_lens_df

In [None]:
def plot_logit_lens(logit_lens_df):

    logit_lens_plot_df = logit_lens_df\
        .assign(token_str = lambda df: np.where(
            df['token'].isin([' A', ' B', ' C', ' D']),
            '<span style="font-weight:bold;font-size:13px">' + df['token'] + '</span>',
            '<span style="font-size:11px">' + df['token'] + '</span>')
        )\
        .assign(text = lambda df: df.apply(lambda row: f"{row['token_str']} <span style=\"display:none;font-size:8px;color:lightgray;font-weight:bold\"> {(row['prob'] * 100):.1f}%</span>", axis=1))\
        .pipe(lambda df: df[df['layer_ix'] >= 12])

    neg_colors = plotly.colors.sequential.YlOrRd_r
    pos_colors = plotly.colors.sequential.YlGnBu

    neg_breaks = [-1.0, -0.6, -0.4, -0.2, -0.05, -0.01, 0]  # Define breakpoints for negative values
    pos_breaks = [0, 0.01, 0.05, 0.1, 0.4, 0.6, 1.0] # Define breakpoints for positive values
    all_breaks = neg_breaks + pos_breaks[1:]  # Skip duplicate 0

    # Calculate the normalized positions for each breakpoint
    min_val = all_breaks[0]
    max_val = all_breaks[-1]
    norm_positions = [(val - min_val) / (max_val - min_val) for val in all_breaks]

    # Create hybrid colorscale with Plasma for negative and Viridis for positive
    hybrid_scale = []

    # Add colors from Plasma for the negative part
    for i in range(len(neg_breaks)):
        norm_pos = i / (len(neg_breaks) - 1)
        plasma_idx = min(int(norm_pos * len(neg_colors)), len(neg_colors) - 1)
        hybrid_scale.append([norm_positions[i], neg_colors[plasma_idx]])

    # Add colors from Viridis for the positive part
    for i in range(1, len(pos_breaks)):
        viridis_idx = int((i / (len(pos_breaks) - 1)) * (len(pos_colors) - 1))
        hybrid_scale.append([norm_positions[len(neg_breaks) + i - 1], pos_colors[viridis_idx]])

    # Create pivot tables for the heatmap and text
    pivot_prob = logit_lens_plot_df.pivot(index='token_rank', columns='layer_ix', values='prob')
    pivot_text = logit_lens_plot_df.pivot(index='token_rank', columns='layer_ix', values='text')

    # Create the heatmap with go.Heatmap
    fig =\
        go.Figure(data = go.Heatmap(
            z = pivot_prob.values,
            x = pivot_prob.columns,
            y = pivot_prob.index,
            colorscale = hybrid_scale,
            text = pivot_text.values,
            texttemplate = "%{text}",
            zmin = min_val,
            zmax = max_val,
            colorbar = dict(
                title = "Probability",
                tickvals = all_breaks,
                ticktext = [str(v) for v in all_breaks]
            )
        ))\
        .update_yaxes(autorange="reversed")\
        .update_layout(
            title = "Logit Lens",
            xaxis_title = "Layer Index",
            yaxis_title = "Token Rank",
            width = 1400,
            height = 600
        )

    return fig

plot_logit_lens(logit_lens_df)

In [None]:
"""
Ablated lens
"""
def get_logit_lens_probs(input_ids, attention_mask, layers_to_ablate = list(range(0, 24)), topk_to_ablate = [0], renorm = False):
    """
    Ablates certain rank-ordered top-k columns for the specified layers.
    
    Params:
        @layers_to_ablate: Which layer indices (0-based) should we ablate?
        @renorm: Whether to rescale the remaining weights to keep the sum unchanged
        @topk_to_ablate: A list of ranks to ablate, e.g. [0] = top-1, [2,3] = 3rd & 4th largest, etc.
    """
    input_embeds = model.model.embed_tokens(input_ids)
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(attention_mask, input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    layer_outputs = []
    for layer_ix, layer in enumerate(model.model.layers):
        # SA
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        hidden_state, _, _ = layer.self_attn(hidden_states = hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + hidden_state
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        # MoE
        ####### Qwen2MoeSparseMoeBlock - below code replaces hidden_state = layer.mlp(hidden_state)
        batch_size, sequence_length, hidden_dim = hidden_state.shape
        moe_hidden_state = hidden_state.view(-1, hidden_dim)
        router_logits = layer.mlp.gate(moe_hidden_state) # Size (BN, n_experts)

        routing_weights = torch.nn.functional.softmax(router_logits, dim = 1, dtype = torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, layer.mlp.top_k, dim = -1, sorted = True)
        routing_weights = routing_weights.to(moe_hidden_state.dtype)

        #### ABLATION
        if layer_ix in layers_to_ablate:
            row_sum_before = routing_weights.sum(dim = -1, keepdim = True) # Shaype (BN, 1)            
            # For each rank in topk_to_ablate, zero out that column
            for rank in topk_to_ablate:
                routing_weights[:, rank] = 0.0
            if renorm:
                row_sum_after = routing_weights.sum(dim = -1, keepdim = True)
                scale_factor = row_sum_before / (row_sum_after + 1e-9)
                routing_weights *= scale_factor
        ####
        final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype = moe_hidden_state.dtype, device = moe_hidden_state.device)

        # One hot encode the selected experts to create an expert mask 
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = layer.mlp.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(layer.mlp.num_experts):
            expert_layer = layer.mlp.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            # Index the correct hidden states and compute the expert hidden state for the current expert.
            current_state = moe_hidden_state[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(moe_hidden_state.dtype))

        shared_expert_output = layer.mlp.shared_expert(moe_hidden_state)
        shared_expert_output = torch.nn.functional.sigmoid(layer.mlp.shared_expert_gate(moe_hidden_state)) * shared_expert_output

        final_hidden_states = (final_hidden_states + shared_expert_output).reshape(batch_size, sequence_length, hidden_dim)
        #######
        hidden_state = final_hidden_states
        hidden_state = residual + hidden_state
        
        layer_outputs.append(hidden_state.detach())

    layer_probs = []
    for layer_ix, layer_output in enumerate(layer_outputs):
        layer_output = model.model.norm(layer_output)
        lm_output = model.lm_head(layer_output).float()
        last_token_logits = lm_output[0, -1, :].detach().cpu()
        layer_probs.append(torch.nn.functional.softmax(last_token_logits, dim = 0))

    return layer_probs


def draw_logit_lens_topk(logit_lens_probs, top_k_probs = 8):
    logit_lens_outputs = []
    for layer_ix, layer_probs in enumerate(logit_lens_probs):
        _, top_indices = torch.topk(layer_probs.abs(), top_k_probs, dim = 0)  # top_k  # Get abs value
        top_probabilities = layer_probs[top_indices]
        for rank in range(top_k_probs):
            token_idx = top_indices[rank].item()
            prob = round(top_probabilities[rank].item(), 6)
            token = tokenizer.decode([token_idx])
            logit_lens_outputs.append({'layer_ix': layer_ix, 'prob': prob, 'token_rank': rank + 1, 'token': token})

    logit_lens_df = pd.DataFrame(logit_lens_outputs)
    return plot_logit_lens(logit_lens_df)

logit_lens_probs = get_logit_lens_probs(inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(0, 1)), topk_to_ablate = [0, 1, 2, 3], renorm = False)
draw_logit_lens_topk(logit_lens_probs)

In [None]:
baseline_logit_lens_probs = get_logit_lens_probs(inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(0, 1)), topk_to_ablate = [], renorm = False)
logit_lens_probs = get_logit_lens_probs(inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(0, 1)), topk_to_ablate = [0, 1, 2, 3], renorm = False)
probdiffs = [logit_lens_probs[ix] - baseline_logit_lens_probs[ix] for ix in range(0, len(logit_lens_probs))]
draw_logit_lens_topk(probdiffs)

In [None]:
no_expert_probs = get_logit_lens_probs(inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(6, 8)), topk_to_ablate = [0, 1, 2, 3], renorm = False)
logit_lens_probs = get_logit_lens_probs(inputs['input_ids'], inputs['attention_mask'], layers_to_ablate = list(range(6, 8)), topk_to_ablate = [1, 2, 3], renorm = False)
probdiffs = [logit_lens_probs[ix] - baseline_logit_lens_probs[ix] for ix in range(0, len(logit_lens_probs))]
draw_logit_lens_topk(probdiffs)

## Save topk for later analysis

In [None]:
@torch.no_grad()
def run_model_and_return_topk(mmlu_ds):

    final_results = []

    for this_domain in mmlu_ds:

        domain_questions = this_domain['questions']
        domain_results = []

        for question_ix, q in enumerate(domain_questions):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [{'role': 'user', 'content': prep_question(q['question'], q['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs["attention_mask"], output_router_logits = True)
            all_topk_experts = []
            all_topk_weights = []
            for l, layer_router_logits in enumerate(outputs.router_logits):
                # layer_router_logits is shape [B*N, num_experts]
                gating_probs = torch.softmax(layer_router_logits, dim = -1)
                routing_weights, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1, sorted = True)
                
                all_topk_experts.append(topk_experts.detach().cpu()) 
                all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

            topk_df = convert_topk_to_df(inputs["input_ids"], all_topk_experts, all_topk_weights).assign(question_ix = question_ix).drop(columns = 'sequence_ix')
            topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask
            
            predicted_text = tokenizer.decode([torch.argmax(outputs['logits'][0, -1, :]).item()]).strip()
            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            domain_results.append({
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0,
                'is_valid': 1 if predicted_letter is not None else 0,
                'topk_df': topk_df
            })

        n_total = len(domain_results)
        n_correct = len([x for x in domain_results if x['is_correct'] == 1])
        n_invalid = len([x for x in domain_results if x['is_valid'] == 0])
        print(f'{this_domain["domain"]} | Correct: {str(n_correct)} | Incorrect: {str(n_total - n_correct)} | Invalid: {str(n_invalid)} | Accuracy: {(n_correct / (n_total)) * 100:.1f}%')
        
        final_results.append({
            'domain': this_domain['domain'],
            'question_df': pd.DataFrame([{k: v for k, v in x.items() if k != 'topk_df'} for x in domain_results]).assign(domain = this_domain['domain']),
            'topk_df': pd.concat([x['topk_df'] for x in domain_results]).assign(domain = this_domain['domain']),
            'n_correct': n_correct,
            'n_total': n_total,
            'n_invalid': n_invalid,
        })


    return {
        'question_df': pd.concat([d['question_df'] for d in final_results]),
        'topk_df': pd.concat([d['topk_df'] for d in final_results]),
        'n_correct': sum([d['n_correct'] for d in final_results]),
        'n_total': sum([d['n_total'] for d in final_results]),
        'n_invalid': sum([d['n_invalid'] for d in final_results]),
        'accuracy': sum([d['n_correct'] for d in final_results])/sum([d['n_total'] for d in final_results])
        }

baseline = run_model_and_return_topk(mmlu_ds)
display(baseline['topk_df'])
print(baseline['accuracy'])

In [None]:
vocab_map =\
    pd.DataFrame([{"token": token.replace('Ġ', ' '), "token_id": token_id} for token, token_id in tokenizer.get_vocab().items()])\
    .sort_values(by = 'token_id')\
    .reset_index()\
    .drop(columns = 'index')

display(vocab_map)

all_answers = baseline['question_df']
display(all_answers)

all_topks = baseline['topk_df'].assign(weight = lambda df: np.around(df['weight'], 4))
display(all_topks)

In [None]:
vocab_map.to_csv('vocab.csv', index = False)
all_answers.to_csv('all_answers.csv', index = False)
all_topks.to_csv('all_topks.csv', index = False)

## Ablation prep

In [None]:
input_prompt = tokenizer.apply_chat_template(
    base_prompt + [
        {'role': 'user', 'content': prep_question(mmlu_ds[0]['questions'][0]['question'], mmlu_ds[0]['questions'][0]['choices'])},
        {'role': 'assistant', 'content': 'The correct answer is'}
    ],
    tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
)
inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

with torch.no_grad():
    outputs = model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    print(outputs['logits'])

with torch.no_grad():
    outputs = model.model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
    hidden_state = outputs[0]
    logits = model.lm_head(hidden_state)
    print(logits)

# with torch.no_grad():
#     hidden_state = model.model(input_ids = inputs["input_ids"], attention_mask = inputs['attention_mask'])
#     logits = model.lm_head(hidden_state[0])
#     print(logits)

with torch.no_grad():
    input_embeds = model.model.embed_tokens(inputs['input_ids'])
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(inputs['attention_mask'], input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    all_topk_experts = []
    all_topk_weights = []
    for layer in model.model.layers:
        # SA
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        hidden_state, _, _ = layer.self_attn(hidden_states = hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + hidden_state
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        # MoE
        ####### Qwen2MoeSparseMoeBlock - below code replaces hidden_state = layer.mlp(hidden_state)
        batch_size, sequence_length, hidden_dim = hidden_state.shape
        moe_hidden_state = hidden_state.view(-1, hidden_dim)
        router_logits = layer.mlp.gate(moe_hidden_state) # Size (BN, n_experts)

        routing_weights = torch.nn.functional.softmax(router_logits, dim = 1, dtype = torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, layer.mlp.top_k, dim = -1, sorted = True)
        routing_weights = routing_weights.to(moe_hidden_state.dtype)

        final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype = moe_hidden_state.dtype, device = moe_hidden_state.device)

        # One hot encode the selected experts to create an expert mask 
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = layer.mlp.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(layer.mlp.num_experts):
            expert_layer = layer.mlp.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            # Index the correct hidden states and compute the expert hidden state for the current expert.
            current_state = moe_hidden_state[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(moe_hidden_state.dtype))

        shared_expert_output = layer.mlp.shared_expert(moe_hidden_state)
        shared_expert_output = torch.nn.functional.sigmoid(layer.mlp.shared_expert_gate(moe_hidden_state)) * shared_expert_output

        final_hidden_states = (final_hidden_states + shared_expert_output).reshape(batch_size, sequence_length, hidden_dim)
        #######
        hidden_state = final_hidden_states
        hidden_state = residual + hidden_state

        all_topk_experts.append(selected_experts.detach().cpu())
        all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    print(logits)
    print(len(all_topk_experts))
    print(all_topk_experts[0].shape)
    print(all_topk_weights[0].shape)

In [None]:
@torch.no_grad()
def run_model_no_ablation(input_ids, attention_mask):
    input_embeds = model.model.embed_tokens(input_ids)
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(attention_mask, input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    all_topk_experts = []
    all_topk_weights = []
    for layer in model.model.layers:
        # SA
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        hidden_state, _, _ = layer.self_attn(hidden_states = hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + hidden_state
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        # MoE
        ####### Qwen2MoeSparseMoeBlock - below code replaces hidden_state = layer.mlp(hidden_state)
        batch_size, sequence_length, hidden_dim = hidden_state.shape
        moe_hidden_state = hidden_state.view(-1, hidden_dim)
        router_logits = layer.mlp.gate(moe_hidden_state) # Size (BN, n_experts)

        routing_weights = torch.nn.functional.softmax(router_logits, dim = 1, dtype = torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, layer.mlp.top_k, dim = -1, sorted = True)
        routing_weights = routing_weights.to(moe_hidden_state.dtype)

        final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype = moe_hidden_state.dtype, device = moe_hidden_state.device)

        # One hot encode the selected experts to create an expert mask 
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = layer.mlp.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(layer.mlp.num_experts):
            expert_layer = layer.mlp.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            # Index the correct hidden states and compute the expert hidden state for the current expert.
            current_state = moe_hidden_state[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(moe_hidden_state.dtype))

        shared_expert_output = layer.mlp.shared_expert(moe_hidden_state)
        shared_expert_output = torch.nn.functional.sigmoid(layer.mlp.shared_expert_gate(moe_hidden_state)) * shared_expert_output

        final_hidden_states = (final_hidden_states + shared_expert_output).reshape(batch_size, sequence_length, hidden_dim)
        #######
        hidden_state = final_hidden_states
        hidden_state = residual + hidden_state

        all_topk_experts.append(selected_experts.detach().cpu())
        all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    return {'logits': logits, 'all_topk_experts': all_topk_experts, 'all_topk_weights': all_topk_weights}
    

@torch.no_grad()
def evaluate_with_ablation(run_forward_fn, return_topk, *args, **kwargs):
    """
    Evaluate modified model

    Params:
        @run_forward_fn: A function that returns a model forward pass with inputs input_ids, attention_mask, and optional *args/**kwargs. 
          The function must return a dict with key `logits`.
        @return_topk: Whether to return the expert IDs and weights as well. If True, `run_forward_fn` must also return keys
         `all_topk_experts` and `all_topk_weights`.
        @*args/**kwargs: Additional arguments to pass to `run_forward_fn`.
    """
    final_results = []

    for this_domain in mmlu_ds:

        domain_questions = this_domain['questions']
        domain_results = []

        for question_ix, q in enumerate(domain_questions):

            input_prompt = tokenizer.apply_chat_template(
                base_prompt + [{'role': 'user', 'content': prep_question(q['question'], q['choices'])}, {'role': 'assistant', 'content': 'The correct answer is'}],
                tokenize = False, add_generation_prompt = False, continue_final_message = True # Otherwise appends eos token
            )
            inputs = tokenizer(input_prompt, return_tensors = 'pt').to(main_device)

            outputs = run_forward_fn(inputs['input_ids'], inputs['attention_mask'], *args, **kwargs)
            
            if return_topk == True:
                topk_df = convert_topk_to_df(inputs["input_ids"], outputs['all_topk_experts'], outputs['all_topk_weights']).assign(question_ix = question_ix).drop(columns = 'sequence_ix')
                topk_df = topk_df[topk_df['token_id'] != tokenizer.pad_token_id] # Filter out rows with attention_mask

            predicted_text = tokenizer.decode([torch.argmax(outputs['logits'][0, -1, :]).item()]).strip()
            predicted_letter = None
            for c in predicted_text:
                if c.upper() in [chr(65 + i) for i in range(len(q['choices']))]:
                    predicted_letter = c.upper()
                    break

            domain_results.append({
                'question_ix': question_ix, 
                'model_output': predicted_text,
                'model_choice': predicted_letter,
                'correct_choice': q['answer_char'],
                'is_correct': 1 if predicted_letter == q['answer_char'] else 0,
                'is_valid': 1 if predicted_letter is not None else 0,
                'topk_df': topk_df if return_topk == True else None
            })

        n_total = len(domain_results)
        n_correct = len([x for x in domain_results if x['is_correct'] == 1])
        n_invalid = len([x for x in domain_results if x['is_valid'] == 0])
        print(f'{this_domain["domain"]} | Correct: {str(n_correct)} | Incorrect: {str(n_total - n_correct)} | Invalid: {str(n_invalid)} | Accuracy: {(n_correct / (n_total)) * 100:.1f}%')
        
        final_results.append({
            'domain': this_domain['domain'],
            'question_df': pd.DataFrame([{k: v for k, v in x.items() if k != 'topk_df'} for x in domain_results]).assign(domain = this_domain['domain']),
            'topk_df': pd.concat([x['topk_df'] for x in domain_results]).assign(domain = this_domain['domain']) if return_topk == True else None,
            'n_correct': n_correct,
            'n_total': n_total,
            'n_invalid': n_invalid,
        })

    return {
        'question_df': pd.concat([d['question_df'] for d in final_results]),
        'topk_df': pd.concat([d['topk_df'] for d in final_results]) if return_topk == True else None,
        'n_correct': sum([d['n_correct'] for d in final_results]),
        'n_total': sum([d['n_total'] for d in final_results]),
        'n_invalid': sum([d['n_invalid'] for d in final_results]),
        'accuracy': sum([d['n_correct'] for d in final_results])/sum([d['n_total'] for d in final_results])
        }

evaluate_with_ablation(run_model_no_ablation, False)

## Ablation testing

In [None]:
@torch.no_grad()
def run_model_with_ablation(input_ids, attention_mask, layers_to_ablate = list(range(0, 24)), topk_to_ablate = [0], renorm = False):
    """
    Ablates certain rank-ordered top-k columns for the specified layers.
    
    Params:
        @layers_to_ablate: Which layer indices (0-based) should we ablate?
        @renorm: Whether to rescale the remaining weights to keep the sum unchanged
        @topk_to_ablate: A list of ranks to ablate, e.g. [0] = top-1, [2,3] = 3rd & 4th largest, etc.
    """
    input_embeds = model.model.embed_tokens(input_ids)
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(attention_mask, input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    all_topk_experts = []
    all_topk_weights = []
    for layer_ix, layer in enumerate(model.model.layers):
        # SA
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        hidden_state, _, _ = layer.self_attn(hidden_states = hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + hidden_state
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        # MoE
        ####### Qwen2MoeSparseMoeBlock - below code replaces hidden_state = layer.mlp(hidden_state)
        batch_size, sequence_length, hidden_dim = hidden_state.shape
        moe_hidden_state = hidden_state.view(-1, hidden_dim)
        router_logits = layer.mlp.gate(moe_hidden_state) # Size (BN, n_experts)

        routing_weights = torch.nn.functional.softmax(router_logits, dim = 1, dtype = torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, layer.mlp.top_k, dim = -1, sorted = True)
        routing_weights = routing_weights.to(moe_hidden_state.dtype)

        #### ABLATION
        if layer_ix in layers_to_ablate:
            row_sum_before = routing_weights.sum(dim = -1, keepdim = True) # Shaype (BN, 1)            
            # For each rank in topk_to_ablate, zero out that column
            for rank in topk_to_ablate:
                routing_weights[:, rank] = 0.0
            if renorm:
                row_sum_after = routing_weights.sum(dim = -1, keepdim = True)
                scale_factor = row_sum_before / (row_sum_after + 1e-9)
                routing_weights *= scale_factor
        ####
        final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype = moe_hidden_state.dtype, device = moe_hidden_state.device)

        # One hot encode the selected experts to create an expert mask 
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = layer.mlp.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(layer.mlp.num_experts):
            expert_layer = layer.mlp.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            # Index the correct hidden states and compute the expert hidden state for the current expert.
            current_state = moe_hidden_state[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(moe_hidden_state.dtype))

        shared_expert_output = layer.mlp.shared_expert(moe_hidden_state)
        shared_expert_output = torch.nn.functional.sigmoid(layer.mlp.shared_expert_gate(moe_hidden_state)) * shared_expert_output

        final_hidden_states = (final_hidden_states + shared_expert_output).reshape(batch_size, sequence_length, hidden_dim)
        #######
        hidden_state = final_hidden_states
        hidden_state = residual + hidden_state

        all_topk_experts.append(selected_experts.detach().cpu())
        all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    return {'logits': logits, 'all_topk_experts': all_topk_experts, 'all_topk_weights': all_topk_weights}
    
evaluate_with_ablation(run_model_with_ablation, return_topk = False, layers_to_ablate = list(range(6, 7)), topk_to_ablate = [0, 1, 2, 3], renorm = False)

In [None]:
list(range(0, 24))