In [None]:
import torch
from helpers.memory import check_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from helpers.expert_specialization import get_context_labeled_data, get_js_distance

In [None]:
hf_model_id = 'moonshotai/Moonlight-16B-A3B'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda()

check_memory()

In [None]:
# Test forward pass

# Hooks needed: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

def attach_moe_gate_hooks(model):
    """
    Registers forward-hooks on each MoE gating module in 'model'
    so that after a forward pass, we can retrieve the BN x topk 'topk_idx'
    from each layer that uses MoE.

    Returns:
        all_expert_ids: A list that will be appended to at runtime
                        with the topk_idx Tensor from each MoEGate.
        handles:        A dict {layer_index: hook_handle}, so you can remove them if desired.
    """
    all_expert_ids = []
    handles = {}

    def gate_forward_hook(module, inputs, output):
        """
        This hook is triggered after MoEGate.forward(...).
        'output' should be the tuple: (topk_idx, topk_weight).
        We only need topk_idx here.
        """
        (topk_idx, topk_weight) = output
        all_expert_ids.append(topk_idx.detach())  # shape [B*N, top_k]

    # The DeepseekV3ForCausalLM has `self.model` -> DeepseekV3Model -> .layers
    # Each layer has `self.mlp`, which might be DeepseekV3MoE or DeepseekV3MLP
    for layer_idx, layer in enumerate(model.model.layers):
        # Check if this layer is indeed MoE (DeepseekV3MoE)
        if hasattr(layer.mlp, 'gate'):
            # layer.mlp.gate is the actual MoEGate object
            gate_module = layer.mlp.gate
            gate_module._layer_id = layer_idx  # optional attribute for debugging
            hook_handle = gate_module.register_forward_hook(gate_forward_hook)
            handles[layer_idx] = hook_handle

    return all_expert_ids, handles


all_expert_ids, hook_handles = attach_moe_gate_hooks(model)

with torch.no_grad():
    outputs = model(**inputs)

for topk_idx_tensor in all_expert_ids:
    print(f"topk_idx shape = {topk_idx_tensor.shape}")

for layer_idx, h in hook_handles.items():
    h.remove()

In [None]:
# Test forward pass

# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)


with torch.no_grad():
    all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
    outputs = model(**inputs)
    for layer_idx, h in hook_handles.items():
        h.remove()

all_topk_experts[0]

In [None]:
model.config

In [None]:
@torch.no_grad()
def get_ics(model, test_token_data_list, use_topk1 = False):
    """
    Modified from `helpers.expert_specialization.get_ics` for this model.
    """
    results = {}

    n_layers = model.config.num_hidden_layers - 1 # first layer has no experts
    n_experts = model.config.n_routed_experts

    for test_item in test_token_data_list:
        test_token = test_item["test_token"]
        test_token_id = test_item["test_token_id"]
        test_meanings = test_item["test_meanings"]
        dl = test_item["dl"]

        meaning_counts_per_layer = [
            [torch.zeros(n_experts, dtype = torch.long, device = 'cpu') for _ in range(len(test_meanings))]
            for _ in range(n_layers)
        ] # meaning_counts_per_layer[l][meaning_idx]"" torch.LongTensor of shape (n_experts,) expert-length count
        # total_counts_per_layer[l]: same shape (n_experts,), expert-length count for baseline usage
        total_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long, device = 'cpu') for _ in range(n_layers)]

        # Map each meaning string to an index
        meaning_to_idx = {m: i for i, m in enumerate(test_meanings)}

        for batch in dl:
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)
            batch_meanings = batch["test_meanings"]
            B, N = input_ids.shape

            all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids, attention_mask)
            for layer_idx, h in hook_handles.items():
                h.remove()

            if use_topk1 == True:
                all_topk_experts = tuple(x[..., :1] for x in all_topk_experts)

            flat_ids = input_ids.view(-1)  # shape (B*N, ) # Flatten the input IDs for indexing alignment with all_topk_experts
            
            # Convert each example's meaning label to an integer index shape (B,). e.g. meaning_id_array[b] = meaning_idx of that example
            meaning_id_array = []
            for b_idx in range(B):
                m_str = batch_meanings[b_idx]
                meaning_id_array.append(meaning_to_idx[m_str])
            meaning_id_array = torch.tensor(meaning_id_array, device = model.device)  # (B,)
            meaning_id_array = meaning_id_array.unsqueeze(1).repeat(1, N).view(-1)  # Expand to shape (B, N), so each token in that example has the same meaning


            for l in range(n_layers):
                layer_exps = all_topk_experts[l]  # shape (BN, top_k)

                # A) Baseline distribution for the token - we want all rows where (flat_ids == test_token_id)
                base_mask = (flat_ids == test_token_id)
                base_idx = base_mask.nonzero(as_tuple=True)[0]
                if len(base_idx) > 0:
                    base_exps = layer_exps[base_idx, :] # gather => shape (#rows, top_k)
                    base_exps = base_exps.view(-1) # flatten => (#rows * top_k,)
                    hist_base = torch.bincount(base_exps, minlength = n_experts) # bincount => shape (n_experts,)
                    hist_base = hist_base.cpu()
                    total_counts_per_layer[l] += hist_base

                # B) For each meaning m, gather usage mask: (flat_ids == test_token_id) & (meaning_id_array == m)
                for m_idx in range(len(test_meanings)):
                    meaning_mask = (flat_ids == test_token_id) & (meaning_id_array == m_idx)
                    mm_idx = meaning_mask.nonzero(as_tuple=True)[0]
                    if len(mm_idx) > 0:
                        m_exps = layer_exps[mm_idx, :]
                        m_exps = m_exps.view(-1)
                        hist_m = torch.bincount(m_exps, minlength = n_experts)
                        hist_m = hist_m.cpu()
                        meaning_counts_per_layer[l][m_idx] += hist_m

        # Now compute the average JS distance for each layer (comparing each meaning vs. the overall distribution) then averaging
        layer_js_distances = []
        for l in range(n_layers):
            meaning_dists = []
            
            # Convert total_counts -> python dict 
            base_array = total_counts_per_layer[l].numpy()
            dict_base = {}
            for ex_id, c_val in enumerate(base_array):
                if c_val > 0:
                    dict_base[ex_id] = int(c_val)
            
            # Each meaning
            for m_idx in range(len(test_meanings)):
                sense_arr = meaning_counts_per_layer[l][m_idx].numpy()
                dict_sense = {}
                for ex_id, c_val in enumerate(sense_arr):
                    if c_val > 0:
                        dict_sense[ex_id] = int(c_val)
                d_js = get_js_distance(dict_sense, dict_base)
                meaning_dists.append(d_js)

            avg_js = sum(meaning_dists) / len(meaning_dists)
            layer_js_distances.append(avg_js)

        results[test_token] = {layer_idx: val for layer_idx, val in enumerate(layer_js_distances)}

    return results


@torch.no_grad()
def get_tis(model, test_token_data_list, pad_token_id, use_topk1 = False):
    """
    Modified from `helpers.expert_specialization.get_ics` for this model.
    """
    results = {}

    n_layers = model.config.num_hidden_layers - 1 # first layer has no experts
    n_experts = model.config.n_routed_experts

    for token_item in test_token_data_list:
        token_str = token_item["test_token"]
        token_id = token_item["test_token_id"]
        dl = token_item["dl"]

        # For each layer, we'll track:
        global_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long) for _ in range(n_layers)]
        token_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long) for _ in range(n_layers)]        

        for i, batch in enumerate(dl):
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)

            all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids, attention_mask)
            for layer_idx, h in hook_handles.items():
                h.remove()

            if use_topk1 == True:
                all_topk_experts = tuple(x[..., :1] for x in all_topk_experts)


            flat_ids = input_ids.view(-1)  # shape B*N

            nonpad_mask = (flat_ids != pad_token_id) # (A) Non-pad
            token_mask = (flat_ids == token_id) # (B) This specific token

            # For each layer, accumulate counts via bincount
            for l in range(n_layers):
                layer_exps = all_topk_experts[l]  # (B*N, topk)
                
                # A) Global usage (non-pad) gather only the rows where nonpad_mask is True
                nonpad_indices = nonpad_mask.nonzero(as_tuple = True)[0]
                if len(nonpad_indices) > 0:
                    nonpad_rows = layer_exps[nonpad_indices, :] # shape (#nonpad, top_k)
                    nonpad_rows = nonpad_rows.view(-1) # flatten => shape (#nonpad*top_k,)
                    hist_global = torch.bincount(nonpad_rows, minlength = n_experts)
                    global_counts_per_layer[l] += hist_global.cpu()

                # B) Token usage (token_mask)
                token_indices = token_mask.nonzero(as_tuple = True)[0]
                if len(token_indices) > 0:
                    token_rows = layer_exps[token_indices, :] # shape (#token, top_k)
                    token_rows = token_rows.view(-1)
                    hist_token = torch.bincount(token_rows, minlength = n_experts)
                    token_counts_per_layer[l] += hist_token.cpu()
                    
        # Now compute the JS distance for each layer, comparing token_expert_counts vs. global_expert_counts
        layer_js_list = []
        for l in range(n_layers):
            dict_global = {}
            dict_token = {}
            global_arr = global_counts_per_layer[l].numpy()
            token_arr = token_counts_per_layer[l].numpy()
            # Build dictionaries for get_js_distance
            for ex_id, count_val in enumerate(global_arr):
                dict_global[ex_id] = int(count_val)
            for ex_id, count_val in enumerate(token_arr):
                dict_token[ex_id] = int(count_val)
            d_js = get_js_distance(dict_token, dict_global)
            layer_js_list.append(d_js)

        results[token_str] = {i: val for i, val in enumerate(layer_js_list)}
        
    return results

@torch.no_grad()
def get_ec(model, test_token_data_list, use_topk1 = False):
    """
    """
    model.eval()
    n_layers = model.config.num_hidden_layers - 1 # first layer has no experts

    results = {}

    for token_item in test_token_data_list:
        token_str = token_item["test_token"]
        token_id = token_item["test_token_id"]
        dl = token_item["dl"]

        # overlap_count[l] = number of occurrences that keep at least one expert from layer l->l+1
        # total_count[l]   = total number of token occurrences we see for layer l
        overlap_count = [0 for _ in range(n_layers-1)]
        total_count   = [0 for _ in range(n_layers-1)]

        for batch in dl:
            input_ids = batch["input_ids"].to(model.device)       # (B,N)
            attention_mask = batch["attention_mask"].to(model.device)

            all_topk_experts, hook_handles = attach_moe_gate_hooks(model)
            outputs = model(input_ids, attention_mask)
            for layer_idx, h in hook_handles.items():
                h.remove()

            if use_topk1 == True:
                all_topk_experts = tuple(x[..., :1] for x in all_topk_experts)


            flat_ids = input_ids.view(-1)

            valid_mask = (flat_ids == token_id)
            valid_indices = valid_mask.nonzero(as_tuple = True)[0]  # 1D array of row indices

            # 3) For each layer up to n_layers - 2, check overlap with layer + 1
            # We'll gather the topk experts for those valid_indices in layer l and l+1
            for l in range(n_layers - 1):
                exps_l     = all_topk_experts[l]    # (BN, top_k)
                exps_next  = all_topk_experts[l+1]  # (BN, top_k)

                # Gather relevant rows => shape (#valid, top_k)
                exps_l_valid    = exps_l[valid_indices, :]
                exps_next_valid = exps_next[valid_indices, :]

                total_count[l] += exps_l_valid.size(0) # total_count[l] += #valid

                # Now check overlap row by row in a vectorized manner.
                # If top_k = 1, simpler check: just eq
                if exps_l_valid.size(1) == 1:
                    same = (exps_l_valid[:, 0] == exps_next_valid[:, 0]) # shape (#valid,)
                    overlap_count[l] += same.sum().item()
                else:
                    # exps_l_valid, exps_next_valid: both shape (#valid, top_k)
                    # We'll do a broadcast eq => shape (#valid, top_k, top_k).
                    # If ANY of [top_k x top_k] is True => there's intersection.
                    # Then we reduce "any" across dims 1 & 2 => shape (#valid,).
                    # We'll sum up how many are True => that many have overlap.
                    exps_l_3d = exps_l_valid.unsqueeze(2) # shape (#valid, top_k, 1)
                    exps_next_3d = exps_next_valid.unsqueeze(1) # shape (#valid, 1, top_k)
                    eq_matrix = (exps_l_3d == exps_next_3d) # eq => shape (#valid, top_k, top_k)
                    overlap_bool = eq_matrix.any(dim=(1,2))  # overlap => shape (#valid,)
                    overlap_count[l] += overlap_bool.sum().item()

        continuity_dict = {}
        for l in range(n_layers - 1):
            continuity_dict[l] = overlap_count[l] / total_count[l]
            
        continuity_dict[n_layers-1] = 0.0 # define last layer's continuity=0.0

        results[token_str] = continuity_dict
    return results

In [None]:
context_aware_test_dataset = get_context_labeled_data("./../../data/contextual-tokens/samples_*.yaml", tokenizer, 512, 8)

In [None]:
ics = get_ics(model, context_aware_test_dataset[0:20])
tis = get_tis(model, context_aware_test_dataset[0:20], tokenizer.pad_token_id)
ec = get_ec(model, context_aware_test_dataset[0:20])

In [None]:
ics1 = get_ics(model, context_aware_test_dataset[0:20], use_topk1 = True)
tis1 = get_tis(model, context_aware_test_dataset[0:20], tokenizer.pad_token_id, use_topk1 = True)

In [None]:
import pandas as pd 
import plotly.express as px 

def plot_token_layer_metrics(input_data, title, yaxis_title, color, show_token_lines):
    
    df_all = pd.DataFrame([
        {"token": token, "layer": layer_idx, "avg_value": val} 
        for token, layer_dict in input_data.items()
        for layer_idx, val in layer_dict.items() 
    ])

    df_avg = df_all.groupby("layer", as_index = False)["avg_value"].mean().assign(token  = "__AVG__")
    if show_token_lines == True:
        df_plot = pd.concat([df_all, df_avg], ignore_index = True)
    else:
        df_plot = df_avg

    fig = px.line(df_plot, x = "layer", y = "avg_value", color = "token", title = title, range_y = (0,1), markers = True)

    for trace in fig.data:
        if trace.name == "__AVG__":
            trace.update(line = dict(color = color, width = 2), opacity = 1.0, marker = dict(size = 6), name = "Average")
        else:
            trace.update(line = dict(color = color, width = 1), opacity = 0.2, marker = dict(size = 1))
    
    fig.add_shape(type = "line", x0 = 0, x1 = 1, xref = "paper", y0 = 0.5, y1 = 0.5, yref = "y", line = dict(color = "gray", dash = "dash"))
    fig.update_layout(width = 600, height = 400, xaxis_title = "Layer", yaxis_title = yaxis_title, showlegend = False)
    fig.show()

plot_token_layer_metrics(ics, 'ICS by layer', 'ICS(l)', color = 'firebrick', show_token_lines = True)
plot_token_layer_metrics(tis, 'TIS by layer', 'TIS(l)', color = 'cornflowerblue', show_token_lines = True)
plot_token_layer_metrics(ec, 'EC by layer', 'EC(l)', color = 'forestgreen', show_token_lines = True)
plot_token_layer_metrics(ics1, 'ICS by layer for Topk = 1', 'ICS(l)', color = 'firebrick', show_token_lines = True)
plot_token_layer_metrics(tis1, 'TIS by layer for Topk = 1', 'TIS(l)', color = 'cornflowerblue', show_token_lines = True)