In [1]:
# Dependencies
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter



In [2]:
from functools import wraps
from time import time

def timing(f):
    @wraps(f)
    def wrap(*args, **kw):
        ts = time()
        result = f(*args, **kw)
        te = time()
        print("func:%r args:[%r, %r] took: %2.4f sec" % (f.__name__, args, kw, te-ts))
        return result
    return wrap

In [3]:
# MoE monitoring hook. Attaches a callback to routing modules which computes routing metrics

class MoEProbe:
    """
    Hooks into the model to capture router internals without 
    changing the model architecture.
    """
    def __init__(self, top_k=2):
        self.top_k = top_k
        self.logs = [] # Stores per-step data
        self.layer_names = {}
    
    def clear(self):
        self.logs = []

    def register(self, model):
        """Finds all Gate/Router layers and attaches the hook."""
        print(f"Scanning model for routers...")
        count = 0
        for name, module in model.named_modules():
            # In Qwen1.5-MoE, the router is usually a Linear layer named 'gate'
            # inside the MoE block.
            if name.endswith(".gate"): 
                self.layer_names[module] = name
                module.register_forward_hook(self.hook_fn)
                count += 1
        print(f"Attached probes to {count} router layers.")

    def hook_fn(self, module, inputs, outputs):
        """
        Captured during Forward Pass.
        Input: Hidden states entering the router
        Output: Logits (Raw scores for experts)
        """
        # outputs are the raw logits [batch, seq_len, num_experts]
        router_logits = outputs
        
        # Calculate probabilities
        probs = F.softmax(router_logits, dim=-1)
        
        # Metric: Router Entropy (Uncertainty)
        # High entropy = Router is unsure (or load balancing is forcing uniformity)
        # Low entropy = Strong specialization
        entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1).mean()
        
        # Metric: Expert Activation (Load)
        # Manually recalculate Top-K to see which experts won
        topk_weights, topk_indices = torch.topk(probs, self.top_k, dim=-1)
        
        # Store lightweight statistics (move to CPU to save VRAM)
        step_data = {
            "layer": self.layer_names[module],
            "entropy": entropy.item(),
            "active_experts": topk_indices.flatten().cpu().numpy().tolist()
        }
        self.logs.append(step_data)

In [4]:
# Experimental definition

@timing
def run_experiment(model, tokenizer):
    
    # Add profiling probe
    probe = MoEProbe(top_k=4) # Qwen A2.7B uses Top-4 routing usually (check config)
    # Note: Qwen1.5-MoE-A2.7B config: num_experts=60, num_experts_per_tok=4
    probe.register(model)
    
    prompts = {
        "Python Code": "def fibonacci(n):",
        "Creative": "The fog rolled into the ancient harbor, smelling of salt and decay."
    }
    
    results = {}
    
    for domain, text in prompts.items():
        print(f"\nTesting Domain: {domain}")
        probe.clear() # Reset logs
        
        inputs = tokenizer(text, return_tensors="pt").to(model.device)
        
        # Generate tokens
        with torch.no_grad():
            _ = model.generate(**inputs, max_new_tokens=20)
            
        # Analyze captured data
        print(f"   captured {len(probe.logs)} routing events.")
        
        # Aggregate expert usage for this domain
        all_indices = []
        avg_entropy = []
        for log in probe.logs:
            all_indices.extend(log['active_experts'])
            avg_entropy.append(log['entropy'])
            
        results[domain] = {
            "counts": Counter(all_indices),
            "entropy": np.mean(avg_entropy)
        }

    return results

In [5]:
# Visualization

def plot_results(results):
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    
    # Expert usage histogram
    domain_idx = 0
    for domain, data in results.items():
        
        expert_ids = np.array(sorted(data['counts'].keys()))
        counts = np.array([data['counts'][i] for i in expert_ids])
        tokens = data['counts'].total()
        freqs = counts / tokens
        entropy = data['entropy']
        bar_width = 0.25
        bar_offset = bar_width * domain_idx
        ax.bar(expert_ids + bar_offset, freqs, width=bar_width, \
                label='%s, tokens: %d entropy: %2.4f' % (domain, tokens, entropy), \
                alpha=0.7)
        domain_idx += 1
    
    ax.set_title("Expert Activation Frequency (Load)")
    ax.set_xlabel("Expert Index (0-59)")
    ax.set_ylabel("Activation Count")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Entropy comparison
    #ax = axes[1]
    #domains = list(results.keys())
    #entropies = [results[d]['entropy'] for d in domains]
    #ax.bar(domains, entropies, color=['blue', 'orange'])
    #ax.set_title("Router Entropy (Uncertainty)")
    #ax.set_ylabel("Mean Entropy")
    
    plt.tight_layout()
    plt.show()

In [6]:
@timing
def load_model_qwen():
    
    model_id = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
    print(f"Loading {model_id}...")
    
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        #device_map="auto",
        trust_remote_code=True,
        #torch_dtype=torch.float16 
    )
    print("Model loaded.")
    
    return model, tokenizer

In [None]:
#########
# Main
#########

# Load model
# We select Qwen1.5-MoE-A2.7B-Chat
model, tokenizer = load_model_qwen()

# Run experiment, repeating n times
n = 100
results = []
for i in range(n):
    stats = run_experiment(model, tokenizer)
    results.append(stats)

Loading Qwen/Qwen1.5-MoE-A2.7B-Chat...


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Model loaded.
func:'load_model_qwen' args:[(), {}] took: 8.6077 sec
Scanning model for routers...
Attached probes to 24 router layers.

Testing Domain: Python Code
   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
          

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

   captured 480 routing events.

Testing Domain: Creative
   captured 480 routing events.
func:'run_experiment' args:[(Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out

In [None]:
# Stack results
results_aggregate = {}
for domain in results[0]:
    results_aggregate[domain] = {
        "counts": np.sum([results[i][domain]["counts"] for i in range(n)]),
        "entropy": np.mean([results[i][domain]["entropy"] for i in range(n)])
    }

# Plot
plot_results(results_aggregate)