# Profiling Expert Activation in GPT-OSS (Base vs BlockFFN-LoRA)

This notebook profiles Mixture-of-Experts (MoE) router activations in:
- Base: `openai/gpt-oss-20b`
- Fine-tuned: `mcemri/gpt-oss-20b-blockffn-lora`

We run on a downstream task (GSM8K) and record expert activation patterns per layer, computing useful metrics and plots:
- Per-layer activation entropy
- Token-level sparsity (TLS) at threshold
- Locality metric across adjacent tokens (BlockFFN-style)
- Chunk union sparsity (BlockFFN-style)
- Top-1 expert usage heatmaps (layers × experts)

Reference model on the Hub: [mcemri/gpt-oss-20b-blockffn-lora](https://huggingface.co/mcemri/gpt-oss-20b-blockffn-lora)



In [1]:
import os
import re
import json
import math
import time
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Lazy installs if missing
import importlib, subprocess, sys

def ensure(pkg: str, pip_name: Optional[str] = None):
    if importlib.util.find_spec(pkg) is None:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pip_name or pkg])

for pkg, pipn in [("datasets", None), ("transformers", None), ("accelerate", None), ("seaborn", None)]:
    ensure(pkg, pipn)

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

# Plotting defaults
sns.set_context("talk")
plt.rcParams["figure.figsize"] = (12, 5)
plt.rcParams["figure.dpi"] = 120



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration
BASE_MODEL_ID = "openai/gpt-oss-20b"
FT_MODEL_ID = "mcemri/gpt-oss-20b-blockffn-lora"  # merged model on Hub

DATASET_NAME = "gsm8k"
DATASET_CONFIG = "main"
SPLIT = "test"  # or "train"
NUM_SAMPLES = 50  # adjust for runtime; large models are heavy
SEED = 42

# Generation / profiling
MAX_NEW_TOKENS = 0  # set >0 to also profile decode steps (slower)
TEMPERATURE = 0.0
TOP_P = 1.0

# Router capture
ROUTER_KEYS = [
    "router_logits",
    "router_probs",
    "router_probabilities",
    "gate_logits",
    "gating_logits",
]
CHUNK_LEN = 8
PROB_TEMPERATURE = 1.0
MIN_PROB_EPS = 1e-6
ENTROPY_EPS = 1e-8
TLS_THRESHOLD = 0.01
SIGMOID_ALPHA = 12.0

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)



<torch._C.Generator at 0x7f09900cb930>

In [None]:
def _flatten_tensors(value) -> List[torch.Tensor]:
    tensors: List[torch.Tensor] = []
    if isinstance(value, torch.Tensor):
        tensors.append(value)
    elif isinstance(value, (list, tuple)):
        for item in value:
            tensors.extend(_flatten_tensors(item))
    elif isinstance(value, dict):
        for item in value.values():
            tensors.extend(_flatten_tensors(item))
    return tensors


def softmax_router(t: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
    x = t.to(torch.float32)
    # Accept [T, E] by assuming batch=1
    if x.dim() == 2:
        x = x.unsqueeze(0)
    elif x.dim() < 2:
        return torch.empty(0, device=x.device)
    if temperature != 1.0:
        x = x / temperature
    probs = torch.softmax(x, dim=-1)
    return torch.clamp(probs, min=MIN_PROB_EPS)


def collect_from_output(outputs, router_keys=ROUTER_KEYS) -> List[torch.Tensor]:
    if isinstance(outputs, dict):
        lookup = outputs
    else:
        lookup = {k: getattr(outputs, k) for k in dir(outputs) if not k.startswith("_")}
    tensors: List[torch.Tensor] = []
    for key in router_keys:
        if key in lookup:
            tensors.extend(_flatten_tensors(lookup[key]))
    return tensors


_LAYER_RE = re.compile(r"layers\.(\d+)\.")


def _extract_layer_idx(name: str) -> Optional[int]:
    m = _LAYER_RE.search(name)
    return int(m.group(1)) if m else None


def register_router_hooks(model) -> Tuple[Dict[int, List[torch.Tensor]], List[torch.utils.hooks.RemovableHandle]]:
    captured: Dict[int, List[torch.Tensor]] = {}
    handles: List[torch.utils.hooks.RemovableHandle] = []

    def make_hook(layer_idx: int):
        def _hook(module, inputs, output):
            # We accept either logits or probs; convert to probs later
            out_tensors = []
            if isinstance(output, torch.Tensor):
                out_tensors.append(output.detach())
            else:
                out_tensors.extend(_flatten_tensors(output))
            if out_tensors:
                captured.setdefault(layer_idx, []).extend([t.detach() for t in out_tensors])
        return _hook

    for name, module in model.named_modules():
        lname = name.lower()
        should_hook = False
        if lname.endswith("mlp.router") or lname.endswith(".router"):
            should_hook = True
        elif ("router" in lname or "gating" in lname or "gate" in lname) and ("mlp" in lname or "expert" in lname):
            should_hook = True
        if should_hook:
            layer_idx = _extract_layer_idx(name)
            if layer_idx is None:
                continue
            handles.append(module.register_forward_hook(make_hook(layer_idx)))

    return captured, handles


def consolidate_captured(captured: Dict[int, List[torch.Tensor]], to_probs: bool = True) -> Dict[int, List[torch.Tensor]]:
    consolidated: Dict[int, List[torch.Tensor]] = {}
    for layer_idx, chunks in captured.items():
        if not chunks:
            continue
        cat = []
        for t in chunks:
            if t.numel() == 0:
                continue
            # Normalize shape to [B, T, E] when needed before converting
            x = t
            if x.dim() == 2:
                x = x.unsqueeze(0)
            if to_probs:
                cat.append(softmax_router(x, temperature=PROB_TEMPERATURE))
            else:
                cat.append(x)
        if cat:
            consolidated[layer_idx] = cat
    return consolidated



In [6]:
def entropy_from_probs(p: torch.Tensor, eps: float = ENTROPY_EPS) -> torch.Tensor:
    p = torch.clamp(p, min=eps)
    return -(p * torch.log(p)).sum(dim=-1)


def estimate_token_sparsity(probs: torch.Tensor, mask: torch.Tensor, threshold: float = TLS_THRESHOLD) -> Optional[torch.Tensor]:
    masked = mask.unsqueeze(-1).float()
    denom = masked.sum()
    if denom == 0:
        return None
    active = (probs > threshold).float()
    ratio = (active * masked).sum() / (denom * probs.size(-1))
    return 1.0 - ratio


def compute_activation_locality_metric(
    probs_list: List[torch.Tensor], token_mask: torch.Tensor, alpha: float = SIGMOID_ALPHA
) -> Tuple[Optional[torch.Tensor], dict]:
    if not probs_list:
        return None, {}
    losses, tls_vals = [], []
    for probs in probs_list:
        if probs.size(1) < 2:
            continue
        mask_forward = token_mask[:, :-1] & token_mask[:, 1:]
        if not mask_forward.any():
            continue
        sharpen_curr = torch.sigmoid(alpha * (probs[:, :-1, :] - 0.5))
        sharpen_next = torch.sigmoid(alpha * (probs[:, 1:, :] - 0.5))
        bce = torch.nn.functional.binary_cross_entropy(sharpen_next, sharpen_curr, reduction="none").mean(dim=-1)
        loss = (bce * mask_forward.float()).sum() / mask_forward.float().sum()
        losses.append(loss.detach())
        tls = estimate_token_sparsity(probs, token_mask)
        if tls is not None:
            tls_vals.append(tls.detach())
    if not losses:
        return None, {}
    metrics = {}
    if tls_vals:
        metrics["tls"] = torch.stack(tls_vals).mean().item()
    return torch.stack(losses).mean(), metrics


def compute_chunk_union_sparsity(
    probs_list: List[torch.Tensor], token_mask: torch.Tensor, chunk_len: int = CHUNK_LEN, eps: float = MIN_PROB_EPS
) -> Tuple[Optional[torch.Tensor], dict]:
    if chunk_len <= 0 or not probs_list:
        return None, {}
    losses, cls_vals = [], []
    for probs in probs_list:
        if probs.size(1) < chunk_len:
            continue
        unfolded_probs = probs.unfold(dimension=1, size=chunk_len, step=1)
        unfolded_mask = token_mask.float().unfold(dimension=1, size=chunk_len, step=1)
        valid_mask = (unfolded_mask.sum(-1) == float(chunk_len))
        if not valid_mask.any():
            continue
        complement = torch.clamp(1.0 - unfolded_probs, min=eps, max=1.0)
        union = 1.0 - torch.prod(complement, dim=-2)
        union_mean = union.mean(dim=-1)
        loss = (union_mean * valid_mask.float()).sum() / valid_mask.float().sum()
        losses.append(loss.detach())
        denom = valid_mask.float().sum()
        if denom > 0:
            sparsity = 1.0 - (union * valid_mask.unsqueeze(-1).float()).sum() / (denom * union.size(-1))
            cls_vals.append(sparsity.detach())
    if not losses:
        return None, {}
    metrics = {}
    if cls_vals:
        metrics["cls"] = torch.stack(cls_vals).mean().item()
    return torch.stack(losses).mean(), metrics


def top1_counts(probs_list: List[torch.Tensor], token_mask: torch.Tensor) -> Optional[torch.Tensor]:
    if not probs_list:
        return None
    counts = None
    for probs in probs_list:
        if probs.numel() == 0:
            continue
        top1 = probs.argmax(dim=-1)  # [B, T]
        mask = token_mask
        if counts is None:
            counts = torch.zeros(probs.size(-1), dtype=torch.long)
        counts.index_add_(0, top1[mask].view(-1).cpu(), torch.ones_like(top1[mask].view(-1), dtype=torch.long))
    return counts


def metrics_from_probs(probs_chunks: List[torch.Tensor], token_mask: torch.Tensor) -> Dict[str, float]:
    metrics: Dict[str, float] = {}
    if not probs_chunks:
        return metrics
    # Entropy
    ent_vals = []
    for probs in probs_chunks:
        ent = entropy_from_probs(probs).mean()  # average over tokens
        ent_vals.append(ent.detach().item())
    metrics["entropy_mean"] = float(np.mean(ent_vals)) if ent_vals else float("nan")
    # Locality and TLS
    loc_loss, loc_metrics = compute_activation_locality_metric(probs_chunks, token_mask)
    if loc_loss is not None:
        metrics["locality_loss"] = loc_loss.item()
    metrics.update({f"{k}": v for k, v in loc_metrics.items()})
    # Chunk union sparsity
    cls_loss, cls_metrics = compute_chunk_union_sparsity(probs_chunks, token_mask)
    if cls_loss is not None:
        metrics["chunk_union"] = cls_loss.item()
    metrics.update({f"{k}": v for k, v in cls_metrics.items()})
    return metrics



In [7]:
def load_model_and_tokenizer(model_id: str):
    tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=DTYPE,
        device_map="auto",
    )
    # Try to enable router outputs if supported by the model
    if hasattr(model, "config") and hasattr(model.config, "output_router_logits"):
        model.config.output_router_logits = True
    return model, tok



In [8]:
def build_prompt(question: str) -> str:
    return (
        "You are a helpful math assistant. Solve the following problem and give the final answer as a number.\n\n"
        f"Problem: {question}\n\nLet's think step by step."
    )


def load_gsm8k(num_samples: int = NUM_SAMPLES, split: str = SPLIT) -> List[dict]:
    ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=split)
    ds = ds.shuffle(seed=SEED).select(range(min(num_samples, len(ds))))
    samples = []
    for ex in ds:
        samples.append({
            "question": ex["question"],
            "answer": ex.get("answer", ""),
            "prompt": build_prompt(ex["question"]) ,
        })
    return samples



In [None]:
def run_single_forward(model, tok, prompt: str):
    enc = tok(prompt, return_tensors="pt")
    enc = {k: v.to(model.device) for k, v in enc.items()}
    token_mask = enc.get("attention_mask", (enc["input_ids"] != tok.pad_token_id).long()).bool()

    captured, handles = register_router_hooks(model)
    with torch.no_grad():
        try:
            outputs = model(**enc, use_cache=False, output_router_logits=True)
        except TypeError:
            outputs = model(**enc, use_cache=False)
    for h in handles:
        try:
            h.remove()
        except Exception:
            pass

    # Fallback: collect router tensors directly from model outputs if any
    if not captured:
        extra = collect_from_output(outputs, ROUTER_KEYS)
        if extra:
            captured[-1] = [t.detach() for t in extra if isinstance(t, torch.Tensor) and t.numel() > 0]

    probs_by_layer = consolidate_captured(captured, to_probs=True)
    # Ensure shapes are [B, T, E] and on CPU for metrics
    for k, chunks in list(probs_by_layer.items()):
        fixed = []
        for t in chunks:
            if t.numel() == 0:
                continue
            if t.dim() == 2:  # [T, E] -> [1, T, E]
                t = t.unsqueeze(0)
            fixed.append(t.cpu())
        probs_by_layer[k] = fixed

    token_mask_cpu = token_mask.cpu()

    metrics_layer: Dict[int, Dict[str, float]] = {}
    counts_layer: Dict[int, torch.Tensor] = {}

    for layer_idx, probs_chunks in probs_by_layer.items():
        m = metrics_from_probs(probs_chunks, token_mask_cpu)
        metrics_layer[layer_idx] = m
        c = top1_counts(probs_chunks, token_mask_cpu)
        if c is not None:
            counts_layer[layer_idx] = c

    return metrics_layer, counts_layer



In [10]:
def profile_model(samples: List[dict], model_id: str):
    model, tok = load_model_and_tokenizer(model_id)
    try:
        per_layer_metrics: Dict[int, List[Dict[str, float]]] = {}
        per_layer_counts: Dict[int, torch.Tensor] = {}

        for ex in tqdm(samples, desc=f"Profiling {model_id}"):
            m, c = run_single_forward(model, tok, ex["prompt"])
            for layer_idx, md in m.items():
                per_layer_metrics.setdefault(layer_idx, []).append(md)
            for layer_idx, counts in c.items():
                if layer_idx not in per_layer_counts:
                    per_layer_counts[layer_idx] = counts.clone()
                else:
                    per_layer_counts[layer_idx] += counts

        # Aggregate metrics by mean across samples
        agg_metrics: Dict[int, Dict[str, float]] = {}
        all_keys = set()
        for arr in per_layer_metrics.values():
            if arr:
                all_keys.update(arr[0].keys())
        for layer_idx, arr in per_layer_metrics.items():
            if not arr:
                continue
            d = {}
            for k in all_keys:
                vals = [x.get(k) for x in arr if k in x and not (x.get(k) is None or np.isnan(x.get(k)))]
                d[k] = float(np.mean(vals)) if vals else float("nan")
            agg_metrics[layer_idx] = d

        # Normalize counts to fractions per layer
        frac_counts: Dict[int, np.ndarray] = {}
        for layer_idx, counts in per_layer_counts.items():
            total = counts.sum().item()
            if total > 0:
                frac_counts[layer_idx] = (counts.float() / total).numpy()
        return agg_metrics, frac_counts
    finally:
        try:
            del model
            torch.cuda.empty_cache()
        except Exception:
            pass



In [None]:
def plots_compare_metrics(base_metrics: Dict[int, Dict[str, float]], ft_metrics: Dict[int, Dict[str, float]]):
    layers = sorted(set(base_metrics.keys()) | set(ft_metrics.keys()))
    metric_names = set()
    for d in base_metrics.values():
        metric_names.update(d.keys())
    for d in ft_metrics.values():
        metric_names.update(d.keys())
    metric_names = [m for m in ["entropy_mean", "tls", "locality_loss", "chunk_union", "cls"] if m in metric_names]

    n = len(metric_names)
    if n == 0:
        print("No metrics to plot.")
        return None
    fig, axes = plt.subplots(n, 1, figsize=(12, 4*n), sharex=True)
    if n == 1:
        axes = [axes]

    for ax, m in zip(axes, metric_names):
        yb = [base_metrics.get(l, {}).get(m, np.nan) for l in layers]
        yf = [ft_metrics.get(l, {}).get(m, np.nan) for l in layers]
        ax.plot(layers, yb, label=f"Base {m}", marker="o")
        ax.plot(layers, yf, label=f"FT {m}", marker="o")
        ax.set_ylabel(m)
        ax.grid(True, alpha=0.3)
        ax.legend()
    axes[-1].set_xlabel("Layer")
    plt.tight_layout()
    return fig


def plots_compare_heatmaps(base_frac: Dict[int, np.ndarray], ft_frac: Dict[int, np.ndarray]):
    layers = sorted(set(base_frac.keys()) | set(ft_frac.keys()))
    if not layers:
        print("No top-1 usage captured.")
        return None
    # Determine expert count by the largest vector length
    max_e = max([arr.shape[0] for arr in list(base_frac.values()) + list(ft_frac.values())])

    def stack(frac_dict):
        mat = np.zeros((len(layers), max_e), dtype=np.float32)
        for i, l in enumerate(layers):
            arr = frac_dict.get(l)
            if arr is not None:
                mat[i, : arr.shape[0]] = arr
        return mat

    base_mat = stack(base_frac)
    ft_mat = stack(ft_frac)

    fig, axes = plt.subplots(1, 2, figsize=(18, 6), sharey=True)
    sns.heatmap(base_mat, ax=axes[0], cmap="magma", cbar=True)
    axes[0].set_title("Base: Top-1 Expert Fraction per Layer")
    axes[0].set_xlabel("Expert")
    axes[0].set_ylabel("Layer (0=bottom)")

    sns.heatmap(ft_mat, ax=axes[1], cmap="magma", cbar=True)
    axes[1].set_title("BlockFFN-LoRA: Top-1 Expert Fraction per Layer")
    axes[1].set_xlabel("Expert")

    plt.tight_tight_layout = plt.tight_layout
    plt.tight_tight_layout()
    return fig

In [None]:
def save_results(out_dir: str, base_metrics, ft_metrics, base_frac, ft_frac):
    os.makedirs(out_dir, exist_ok=True)
    with open(os.path.join(out_dir, "metrics_base.json"), "w") as f:
        json.dump(base_metrics, f, indent=2)
    with open(os.path.join(out_dir, "metrics_ft.json"), "w") as f:
        json.dump(ft_metrics, f, indent=2)
    # Save fraction matrices as CSVs per layer
    def save_frac(prefix, frac):
        for l, arr in frac.items():
            pd.Series(arr).to_csv(os.path.join(out_dir, f"{prefix}_layer{l}.csv"), index_label="expert", header=["fraction"])
    save_frac("base_top1", base_frac)
    save_frac("ft_top1", ft_frac)

    # Plots (save the figures created by the plotting functions)
    fig1 = plots_compare_metrics(base_metrics, ft_metrics)
    if fig1 is not None:
        fig1.savefig(os.path.join(out_dir, "metrics_compare.png"), bbox_inches="tight")
        plt.close(fig1)

    fig2 = plots_compare_heatmaps(base_frac, ft_frac)
    if fig2 is not None:
        fig2.savefig(os.path.join(out_dir, "top1_heatmaps.png"), bbox_inches="tight")
        plt.close(fig2)

    print(f"Saved results to: {out_dir}")



In [None]:
samples = load_gsm8k(NUM_SAMPLES, SPLIT)
print(f"Loaded {len(samples)} {DATASET_NAME}/{DATASET_CONFIG}:{SPLIT} samples for profiling.")

base_metrics, base_frac = profile_model(samples, BASE_MODEL_ID)
ft_metrics, ft_frac = profile_model(samples, FT_MODEL_ID)

# Display plots inline using returned figures
fig = plots_compare_metrics(base_metrics, ft_metrics)
if fig is not None:
    plt.show()
fig = plots_compare_heatmaps(base_frac, ft_frac)
if fig is not None:
    plt.show()

# Save
timestamp = time.strftime("%Y%m%d-%H%M%S")
out_dir = f"/data/mert_cemri/LLaMA-Factory/saves/profiling_gsm8k_{timestamp}"
save_results(out_dir, base_metrics, ft_metrics, base_frac, ft_frac)



Loaded 50 gsm8k/main:test samples for profiling.


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.06it/s]
Profiling openai/gpt-oss-20b: 100%|██████████| 50/50 [00:01<00:00, 28.98it/s]
Loading checkpoint shards: 100%|██████████| 9/9 [00:07<00:00,  1.20it/s]
Profiling mcemri/gpt-oss-20b-blockffn-lora: 100%|██████████| 50/50 [00:01<00:00, 29.23it/s]


No metrics to plot.
No top-1 usage captured.
No metrics to plot.
No top-1 usage captured.
Saved results to: /data/mert_cemri/LLaMA-Factory/saves/profiling_gsm8k_20251117-092331
