In [None]:
import torch
from helpers.memory import check_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
from helpers.expert_specialization import get_context_aware_test_data, get_js_distance

In [None]:
hf_model_id = 'deepseek-ai/DeepSeek-V2-Lite'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16,trust_remote_code = True).cuda()

check_memory()

In [None]:
# Test forward pass

# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

with torch.no_grad():
    outputs = model(**inputs)

def attach_moe_gate_hooks(model):
    """
    Registers forward-hooks on each MoE gating module in 'model' so that after a forward pass,
    we can retrieve the BN x topk 'topk_idx' from each layer.
    
    Returns:
        all_expert_ids: A list that will be filled at runtime with
                        tuples of (layer_index, topk_idx_tensor).
        handles: A dictionary of {layer_index: hook_handle}, so you can remove them if desired.
    """
    all_expert_ids = []
    handles = {}

    def gate_forward_hook(module, input, output):
        """
        This hook is triggered after MoEGate.forward(...).
        'output' should be the tuple: (topk_idx, topk_weight, aux_loss).
        We only need topk_idx here.
        """
        topk_idx, topk_weight, aux_loss = output
        # You might want to store 'topk_idx' along with which layer it came from
        layer_id = getattr(module, "_layer_id", None)
        all_expert_ids.append((layer_id, topk_idx.detach().cpu()))

        # Optionally store them in a more structured way:
        # e.g. all_expert_ids[layer_id] = topk_idx.detach().cpu()

    for layer_idx, layer in enumerate(model.model.layers):
        # Layer 0 is not moe
        if layer_idx > 0:
            # attach an attribute so we know which layer this gating belongs to
            layer.mlp.gate._layer_id = layer_idx
            hook_handle = layer.mlp.gate.register_forward_hook(gate_forward_hook)
            handles[layer_idx] = hook_handle

    return all_expert_ids, handles

all_expert_ids, hook_handles = attach_moe_gate_hooks(model)

with torch.no_grad():
    outputs = model(**inputs)

for (layer_idx, topk_idx_tensor) in all_expert_ids:
    print(f"Layer {layer_idx}: topk_idx shape = {topk_idx_tensor.shape}")
    # e.g. shape is [B*N, top_k]

# 4) (Optional) remove hooks if you no longer need them
for layer_idx, h in hook_handles.items():
    h.remove()

In [None]:
model.model.layers[2].mlp.last_topk_idx

In [None]:
@torch.no_grad()
def get_context_awareness_metric(model, test_token_data_list):
    """
    Compute the CA metric for each test token x layer, using JS distance to compare each meaning-specific expert distribution vs. 
     the overall distribution for that token.
    
    Params:
        @model: The model, must return `all_topk_experts` which is a tuple of length equal to # layers, where each element of
          the tuple is a BN x topk tensor of selected expert IDs.
        @test_token_data_list: A list of dictionaries of the exact format returned by `get_context_aware_test_data`.

    Returns:
        A dict of format:
            {
                test_token1: {0: .52, 1: .34, ...},
                test_token2: {0: .55, 1: .62, ...},
                ...
            },
          where the keys represent layer indices and the values represent context-awareness scores between 0 - 1
    """
    results = {}
    model.eval()

    # Number of layers from the model config
    n_layers = model.config.num_hidden_layers

    for test_item in test_token_data_list:
        test_token = test_item["test_token"]
        test_token_id = test_item["test_token_id"]
        test_meanings = test_item["test_meanings"]
        dl = test_item["dl"]

        meaning_counts_per_layer = [
            [defaultdict(int) for _ in range(len(test_meanings))]  # one dict per meaning
            for _ in range(n_layers)
        ] # meaning_counts_per_layer[l][meaning_idx][expert_id] -> count per meaning
        total_counts_per_layer = [defaultdict(int) for _ in range(n_layers)] # total_counts_per_layer[l][expert_id]: count for the all meanings baseline

        # Map each meaning string to an index
        meaning_to_idx = {m: i for i, m in enumerate(test_meanings)}

        for batch in dl:
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)
            batch_meanings = batch["test_meanings"]
            B, N = input_ids.shape

            outputs = model(input_ids, attention_mask, output_router_logits = True)
            all_topk_experts = ()
            for l, layer_router_logits in enumerate(outputs.router_logits):
                # layer_router_logits is shape [B*N, num_experts]
                gating_probs = torch.softmax(layer_router_logits, dim = -1)
                _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
                all_topk_experts += (topk_experts,) 

            # Flatten the input IDs for indexing alignment with all_topk_experts
            flat_input_ids = input_ids.view(-1)  # shape (B*N, )
            
            for l in range(n_layers):
                layer_experts = all_topk_experts[l]  # shape (B*N, topk)

                for token_index in range(B * N):
                    if flat_input_ids[token_index].item() == test_token_id:
                        b_idx = token_index // N # Figure out which example in the batch we belong to
                        meaning_label = batch_meanings[b_idx]
                        meaning_idx = meaning_to_idx[meaning_label]

                        # Gather all top-k experts
                        topk_exs = layer_experts[token_index]  # shape (topk,)
                        for ex_val in topk_exs:
                            ex_id = int(ex_val.item())
                            meaning_counts_per_layer[l][meaning_idx][ex_id] += 1
                            total_counts_per_layer[l][ex_id] += 1


        # Now compute the average JS distance for each layer (comparing each meaning vs. the overall distribution) then averaging
        layer_js_distances = []

        for l in range(n_layers):
            layer_sense_dists = []
            for s_idx in range(len(test_meanings)):
                d_js = get_js_distance(meaning_counts_per_layer[l][s_idx], total_counts_per_layer[l])
                layer_sense_dists.append(d_js)

            if len(layer_sense_dists) > 0:
                avg_js = sum(layer_sense_dists) / len(layer_sense_dists)
            else:
                avg_js = 0.0

            layer_js_distances.append(avg_js)

        results[test_token] = {i: val for i, val in enumerate(layer_js_distances)}

    return results

@torch.no_grad()
def get_token_specialization_metric(model, test_token_data_list, pad_token_id):
    """
    Computes a token specialization metric for each test token x layer, using the JS distance b/t: (a) the distribution of 
      experts used for that token and (b) the distribution of experts used for *all* tokens (excluding padding).
    
    Params:
        @model: The model, must return `all_topk_experts` which is a tuple of length equal to # layers, where each element of
          the tuple is a BN x topk tensor of selected expert IDs.
        @test_token_data_list: A list of dictionaries of the exact format returned by `get_context_aware_test_data`.
        @pad_token_id: The ID used for padding, which we should exclude from the "global usage" distribution.

    Returns:
        A dict of format:
            {
                test_token1: {0: .52, 1: .34, ...},
                test_token2: {0: .55, 1: .62, ...},
                ...
            },
          where the keys represent layer indices and the values represent token-specialization scores between 0 - 1
    """

    model.eval()
    n_layers = model.config.num_hidden_layers

    results = {}

    for token_item in test_token_data_list:
        token_str = token_item["test_token"]
        token_id = token_item["test_token_id"]
        dl = token_item["dl"]

        # For each layer, we'll track:
        token_expert_counts = [defaultdict(int) for _ in range(n_layers)] # token_expert_counts[l][expert_id] = # of times `token_id` is assigned to expert_id
        global_expert_counts = [defaultdict(int) for _ in range(n_layers)] # global_expert_counts[l][expert_id] = # of times ANY non-pad token is assigned to expert_id
 
        for batch in dl:
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)
            B, N = input_ids.shape

            outputs = model(input_ids, attention_mask, output_router_logits = True)
            all_topk_experts = ()
            for l, layer_router_logits in enumerate(outputs.router_logits):
                # layer_router_logits is shape [B*N, num_experts]
                gating_probs = torch.softmax(layer_router_logits, dim = -1)
                _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
                all_topk_experts += (topk_experts,) 

            flat_ids = input_ids.view(-1)  # shape B*N
            # for each layer, shape (B*N, topk)
            for l in range(n_layers):
                layer_experts = all_topk_experts[l]  # (B*N, topk)
                for idx in range(B*N):
                    # skip if it's padded
                    if flat_ids[idx].item() == pad_token_id:
                        continue

                    # Add global usage counts
                    topk_exs = layer_experts[idx]
                    for ex_id_val in topk_exs:
                        ex_id = int(ex_id_val.item())
                        global_expert_counts[l][ex_id] += 1

                    # If it's our target token, also track token_expert_counts
                    if flat_ids[idx].item() == token_id:
                        for ex_id_val in topk_exs:
                            ex_id = int(ex_id_val.item())
                            token_expert_counts[l][ex_id] += 1

        # Now compute the JS distance for each layer, comparing token_expert_counts vs. global_expert_counts
        layer_js_list = []
        for l in range(n_layers):
            d_js = get_js_distance(token_expert_counts[l], global_expert_counts[l])
            layer_js_list.append(d_js)

        results[token_str] = {i: val for i, val in enumerate(layer_js_list)}

    return results

In [None]:
context_aware_test_dataset = get_context_aware_test_data("./../../data/contextual-tokens/samples_*.yaml", tokenizer, 512, 128)

In [None]:
context_awareness = get_context_awareness_metric(model, context_aware_test_dataset[0:9])
token_specialization = get_token_specialization_metric(model, context_aware_test_dataset[0:9], tokenizer.pad_token_id)

In [None]:
import pandas as pd 
import plotly.express as px 

def plot_ca(ca_data):
    # 1) For each layer, average CA across all tokens
    n_layers = len(next(iter(ca_data.values())))  # e.g. 24
    layer_vals = []
    
    # Create a list of (layer, avg_ca) pairs
    for layer_idx in range(n_layers):
        sum_val = 0.0
        count = 0
        for _, layer_dict in ca_data.items():
            sum_val += layer_dict[layer_idx]
            count += 1
        avg_ca = sum_val / count
        layer_vals.append((layer_idx, avg_ca))
        
    df = pd.DataFrame(layer_vals, columns=["layer", "avg_ca"])
    
    fig = px.line(df, x = 'layer', y = 'avg_ca', title = "Average Context-Awareness by Layer", range_y = (0, 1), markers = True)
    
    fig.update_layout(xaxis_title = 'Layer', yaxis_title = 'Average CA(l)', yaxis = dict(tickformat = '.2f'), width = 600, height = 400)
    fig.add_hline(y = 0.5, line_dash = 'dash', line_color = 'red')
    fig.show()

def plot_ts(ts_data):
    # 1) For each layer, average CA across all tokens
    n_layers = len(next(iter(ts_data.values())))  # e.g. 24
    layer_vals = []
    
    # Create a list of (layer, avg_ts) pairs
    for layer_idx in range(n_layers):
        sum_val = 0.0
        count = 0
        for _, layer_dict in ts_data.items():
            sum_val += layer_dict[layer_idx]
            count += 1
        avg_ts = sum_val / count
        layer_vals.append((layer_idx, avg_ts))
        
    df = pd.DataFrame(layer_vals, columns=["layer", "avg_ts"])
    
    fig = px.line(df, x = 'layer', y = 'avg_ts', title = "Average Token Specialization by Layer", range_y = (0, 1), markers = True)
    
    fig.update_layout(xaxis_title = 'Layer', yaxis_title = 'Average TS(l)', yaxis = dict(tickformat = '.2f'), width = 600, height = 400)
    fig.add_hline(y = 0.5, line_dash = 'dash', line_color = 'red')
    fig.show()


plot_ca(context_awareness)
plot_ts(token_specialization)