In [1]:
import torch
from helpers.memory import check_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from helpers.expert_specialization import get_context_aware_test_data, get_js_distance

In [2]:
hf_model_id = 'Qwen/Qwen1.5-MoE-A2.7B-Chat'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16,trust_remote_code = True).cuda()

check_memory()

In [None]:
# Test forward pass

# For Qwen2 MoE, just use output_router_logits.
# No need for hooks! https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
inputs = tokenizer(['Hi this is dog', 'Where is the beef'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)

with torch.no_grad():
    outputs = model(**inputs, output_router_logits = True)

all_topk_experts = ()
for l, layer_router_logits in enumerate(outputs.router_logits):
    # layer_router_logits is shape [B*N, num_experts]
    gating_probs = torch.softmax(layer_router_logits, dim = -1)
    _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
    all_topk_experts += (topk_experts,) 


In [4]:
@torch.no_grad()
def get_context_awareness_metric(model, test_token_data_list):
    """
    Compute the CA metric for each test token x layer, using JS distance to compare each meaning-specific expert distribution vs. 
     the overall distribution for that token.
    
    Params:
        @model: The model, must return `all_topk_experts` which is a tuple of length equal to # layers, where each element of
          the tuple is a BN x topk tensor of selected expert IDs.
        @test_token_data_list: A list of dictionaries of the exact format returned by `get_context_aware_test_data`.

    Returns:
        A dict of format:
            {
                test_token1: {0: .52, 1: .34, ...},
                test_token2: {0: .55, 1: .62, ...},
                ...
            },
          where the keys represent layer indices and the values represent context-awareness scores between 0 - 1
    """
    results = {}

    n_layers = model.config.num_hidden_layers
    n_experts = model.config.num_experts

    for test_item in test_token_data_list:
        test_token = test_item["test_token"]
        test_token_id = test_item["test_token_id"]
        test_meanings = test_item["test_meanings"]
        dl = test_item["dl"]

        meaning_counts_per_layer = [
            [torch.zeros(n_experts, dtype = torch.long, device = 'cpu') for _ in range(len(test_meanings))]
            for _ in range(n_layers)
        ] # meaning_counts_per_layer[l][meaning_idx]"" torch.LongTensor of shape (n_experts,) expert-length count
        # total_counts_per_layer[l]: same shape (n_experts,), expert-length count for baseline usage
        total_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long, device = 'cpu') for _ in range(n_layers)]

        # Map each meaning string to an index
        meaning_to_idx = {m: i for i, m in enumerate(test_meanings)}

        for batch in dl:
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)
            batch_meanings = batch["test_meanings"]
            B, N = input_ids.shape

            outputs = model(input_ids, attention_mask, output_router_logits = True)
            all_topk_experts = ()
            for l, layer_router_logits in enumerate(outputs.router_logits):
                # layer_router_logits is shape [B*N, num_experts]
                _, topk_experts = torch.topk(layer_router_logits, k = model.config.num_experts_per_tok, dim = -1)
                all_topk_experts += (topk_experts,) 

            flat_ids = input_ids.view(-1)  # shape (B*N, ) # Flatten the input IDs for indexing alignment with all_topk_experts
            
            # Convert each example's meaning label to an integer index shape (B,). e.g. meaning_id_array[b] = meaning_idx of that example
            meaning_id_array = []
            for b_idx in range(B):
                m_str = batch_meanings[b_idx]
                meaning_id_array.append(meaning_to_idx[m_str])
            meaning_id_array = torch.tensor(meaning_id_array, device = model.device)  # (B,)
            meaning_id_array = meaning_id_array.unsqueeze(1).repeat(1, N).view(-1)  # Expand to shape (B, N), so each token in that example has the same meaning


            for l in range(n_layers):
                layer_exps = all_topk_experts[l]  # shape (BN, top_k)

                # A) Baseline distribution for the token - we want all rows where (flat_ids == test_token_id)
                base_mask = (flat_ids == test_token_id)
                base_idx = base_mask.nonzero(as_tuple=True)[0]
                if len(base_idx) > 0:
                    base_exps = layer_exps[base_idx, :] # gather => shape (#rows, top_k)
                    base_exps = base_exps.view(-1) # flatten => (#rows * top_k,)
                    hist_base = torch.bincount(base_exps, minlength = n_experts) # bincount => shape (n_experts,)
                    hist_base = hist_base.cpu()
                    total_counts_per_layer[l] += hist_base

                # B) For each meaning m, gather usage mask: (flat_ids == test_token_id) & (meaning_id_array == m)
                for m_idx in range(len(test_meanings)):
                    meaning_mask = (flat_ids == test_token_id) & (meaning_id_array == m_idx)
                    mm_idx = meaning_mask.nonzero(as_tuple=True)[0]
                    if len(mm_idx) > 0:
                        m_exps = layer_exps[mm_idx, :]
                        m_exps = m_exps.view(-1)
                        hist_m = torch.bincount(m_exps, minlength = n_experts)
                        hist_m = hist_m.cpu()
                        meaning_counts_per_layer[l][m_idx] += hist_m

        # Now compute the average JS distance for each layer (comparing each meaning vs. the overall distribution) then averaging
        layer_js_distances = []
        for l in range(n_layers):
            meaning_dists = []
            
            # Convert total_counts -> python dict 
            base_array = total_counts_per_layer[l].numpy()
            dict_base = {}
            for ex_id, c_val in enumerate(base_array):
                if c_val > 0:
                    dict_base[ex_id] = int(c_val)
            
            # Each meaning
            for m_idx in range(len(test_meanings)):
                sense_arr = meaning_counts_per_layer[l][m_idx].numpy()
                dict_sense = {}
                for ex_id, c_val in enumerate(sense_arr):
                    if c_val > 0:
                        dict_sense[ex_id] = int(c_val)
                d_js = get_js_distance(dict_sense, dict_base)
                meaning_dists.append(d_js)

            avg_js = sum(meaning_dists) / len(meaning_dists)
            layer_js_distances.append(avg_js)

        results[test_token] = {layer_idx: val for layer_idx, val in enumerate(layer_js_distances)}

    return results


@torch.no_grad()
def get_token_specialization_metric(model, test_token_data_list, pad_token_id):
    """
    Computes a token specialization metric for each test token x layer, using the JS distance b/t: (a) the distribution of 
      experts used for that token and (b) the distribution of experts used for *all* tokens (excluding padding).
    
    Params:
        @model: The model, must return `all_topk_experts` which is a tuple of length equal to # layers, where each element of
          the tuple is a BN x topk tensor of selected expert IDs.
        @test_token_data_list: A list of dictionaries of the exact format returned by `get_context_aware_test_data`.
        @pad_token_id: The ID used for padding, which we should exclude from the "global usage" distribution.

    Returns:
        A dict of format:
            {
                test_token1: {0: .52, 1: .34, ...},
                test_token2: {0: .55, 1: .62, ...},
                ...
            },
          where the keys represent layer indices and the values represent token-specialization scores between 0 - 1
    """
    results = {}

    n_layers = model.config.num_hidden_layers
    n_experts = model.config.num_experts

    for token_item in test_token_data_list:
        token_str = token_item["test_token"]
        token_id = token_item["test_token_id"]
        dl = token_item["dl"]

        # For each layer, we'll track:
        global_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long) for _ in range(n_layers)]
        token_counts_per_layer = [torch.zeros(n_experts, dtype = torch.long) for _ in range(n_layers)]        

        for i, batch in enumerate(dl):
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)

            outputs = model(input_ids, attention_mask, output_router_logits = True)

            all_topk_experts = ()
            for l, layer_router_logits in enumerate(outputs.router_logits): # layer_router_logits is shape [B*N, num_experts]
                gating_probs = torch.softmax(layer_router_logits, dim = -1)
                _, topk_experts = torch.topk(gating_probs, k = model.config.num_experts_per_tok, dim = -1)
                all_topk_experts += (topk_experts,) 

            flat_ids = input_ids.view(-1)  # shape B*N

            nonpad_mask = (flat_ids != pad_token_id) # (A) Non-pad
            token_mask = (flat_ids == token_id) # (B) This specific token

            # For each layer, accumulate counts via bincount
            for l in range(n_layers):
                layer_exps = all_topk_experts[l]  # (B*N, topk)
                
                # A) Global usage (non-pad) gather only the rows where nonpad_mask is True
                nonpad_indices = nonpad_mask.nonzero(as_tuple = True)[0]
                if len(nonpad_indices) > 0:
                    nonpad_rows = layer_exps[nonpad_indices, :] # shape (#nonpad, top_k)
                    nonpad_rows = nonpad_rows.view(-1) # flatten => shape (#nonpad*top_k,)
                    hist_global = torch.bincount(nonpad_rows, minlength = n_experts)
                    global_counts_per_layer[l] += hist_global.cpu()

                # B) Token usage (token_mask)
                token_indices = token_mask.nonzero(as_tuple = True)[0]
                if len(token_indices) > 0:
                    token_rows = layer_exps[token_indices, :] # shape (#token, top_k)
                    token_rows = token_rows.view(-1)
                    hist_token = torch.bincount(token_rows, minlength = n_experts)
                    token_counts_per_layer[l] += hist_token.cpu()
                    
        # Now compute the JS distance for each layer, comparing token_expert_counts vs. global_expert_counts
        layer_js_list = []
        for l in range(n_layers):
            dict_global = {}
            dict_token = {}
            global_arr = global_counts_per_layer[l].numpy()
            token_arr = token_counts_per_layer[l].numpy()
            # Build dictionaries for get_js_distance
            for ex_id, count_val in enumerate(global_arr):
                dict_global[ex_id] = int(count_val)
            for ex_id, count_val in enumerate(token_arr):
                dict_token[ex_id] = int(count_val)
            d_js = get_js_distance(dict_token, dict_global)
            layer_js_list.append(d_js)

        results[token_str] = {i: val for i, val in enumerate(layer_js_list)}
        
    return results

In [5]:
context_aware_test_dataset = get_context_aware_test_data("./../../data/contextual-tokens/samples_*.yaml", tokenizer, 512, 16)

In [6]:
context_awareness = get_context_awareness_metric(model, context_aware_test_dataset[0:20])
token_specialization = get_token_specialization_metric(model, context_aware_test_dataset[0:20], tokenizer.pad_token_id)

In [7]:
import pandas as pd 
import plotly.express as px 

def plot_ca(ca_data):
    n_layers = len(next(iter(ca_data.values()))) 
    layer_vals = []
    
    # Create a list of (layer, avg_ca) pairs
    for layer_idx in range(n_layers):
        sum_val = 0.0
        count = 0
        for _, layer_dict in ca_data.items():
            sum_val += layer_dict[layer_idx]
            count += 1
        avg_ca = sum_val / count
        layer_vals.append((layer_idx, avg_ca))
        
    df = pd.DataFrame(layer_vals, columns=["layer", "avg_ca"])
    
    fig = px.line(df, x = 'layer', y = 'avg_ca', title = "Average Context-Awareness by Layer", range_y = (0, 1), markers = True)
    
    fig.update_layout(xaxis_title = 'Layer', yaxis_title = 'Average CA(l)', yaxis = dict(tickformat = '.2f'), width = 600, height = 400)
    fig.add_hline(y = 0.5, line_dash = 'dash', line_color = 'red')
    fig.show()

def plot_ts(ts_data):
    n_layers = len(next(iter(ts_data.values())))
    layer_vals = []
    
    # Create a list of (layer, avg_ts) pairs
    for layer_idx in range(n_layers):
        sum_val = 0.0
        count = 0
        for _, layer_dict in ts_data.items():
            sum_val += layer_dict[layer_idx]
            count += 1
        avg_ts = sum_val / count
        layer_vals.append((layer_idx, avg_ts))
        
    df = pd.DataFrame(layer_vals, columns=["layer", "avg_ts"])
    
    fig = px.line(df, x = 'layer', y = 'avg_ts', title = "Average Token Specialization by Layer", range_y = (0, 1), markers = True)
    
    fig.update_layout(xaxis_title = 'Layer', yaxis_title = 'Average TS(l)', yaxis = dict(tickformat = '.2f'), width = 600, height = 400)
    fig.add_hline(y = 0.5, line_dash = 'dash', line_color = 'red')
    fig.show()


plot_ca(context_awareness)
plot_ts(token_specialization)

In [9]:
def plot_ca2(ca_data):
    
    df_all = pd.DataFrame([
        {"token": token, "layer": layer_idx, "ca_value": val} 
        for token, layer_dict in ca_data.items()
        for layer_idx, val in layer_dict.items() 
    ])

    df_avg = df_all.groupby("layer", as_index = False)["ca_value"].mean().assign(token  = "__AVG__")
    df_plot = pd.concat([df_all, df_avg], ignore_index = True)

    fig = px.line(df_plot, x = "layer", y = "ca_value", color = "token", title = "CA Per Layer", range_y = (0,1), markers = True)

    for trace in fig.data:
        if trace.name == "__AVG__":
            trace.update(line = dict(color = "firebrick", width = 2), opacity = 1.0, marker = dict(size = 6), name = "Average CA")
        else:
            trace.update(line = dict(color = "firebrick", width = 1), opacity = 0.2, marker = dict(size = 1))
    
    fig.add_shape(type="line", x0=0, x1=1, xref="paper", y0=0.5, y1=0.5, yref="y", line=dict(color="gray", dash="dash"))
    fig.update_layout(width=600, height=400, xaxis_title="Layer", yaxis_title="CA(l)")
    fig.show()

def plot_ts2(ts_data):
    
    df_all = pd.DataFrame([
        {"token": token, "layer": layer_idx, "ca_value": val} 
        for token, layer_dict in ts_data.items()
        for layer_idx, val in layer_dict.items() 
    ])

    df_avg = df_all.groupby("layer", as_index = False)["ca_value"].mean().assign(token  = "__AVG__")
    df_plot = pd.concat([df_all, df_avg], ignore_index = True)

    fig = px.line(df_plot, x = "layer", y = "ca_value", color = "token", title = "TS Per Layer", range_y = (0,1), markers = True)

    for trace in fig.data:
        if trace.name == "__AVG__":
            trace.update(line = dict(color = "cornflowerblue", width = 2), opacity = 1.0, marker = dict(size = 6), name = "Average TS")
        else:
            trace.update(line = dict(color = "cornflowerblue", width = 1), opacity = 0.2, marker = dict(size = 1))
    
    fig.add_shape(type="line", x0=0, x1=1, xref="paper", y0=0.5, y1=0.5, yref="y", line=dict(color="gray", dash="dash"))
    fig.update_layout(width=600, height=400, xaxis_title="Layer", yaxis_title="TS(l)")
    fig.show()

plot_ca2(context_awareness)
plot_ts2(token_specialization)