In [1]:
# ===== Tiny Vision Transformer (CIFAR-10) + Activation-EMP pruning =====
import math, copy, time, random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# --------- Reuse pruning utils (paste from the first block or import them) ----------

@torch.no_grad()
def model_sparsity(model: nn.Module) -> float:
    total, zeros = 0, 0
    for n, p in model.named_parameters():
        if p.dim() >= 2 and 'weight' in n:
            total += p.numel()
            zeros += (p == 0).sum().item()
    return zeros / max(total, 1)

# ---- Step 1: collect E[|x|] (per in-feature) for each Linear via forward hooks
@torch.no_grad()
def collect_input_magnitudes(model: nn.Module,
                             data_loader,
                             device,
                             num_batches: int = 10):
    model.eval()
    # list Linear modules in traversal order
    linear_list = [m for m in model.modules() if isinstance(m, nn.Linear)]
    sums = []
    counts = []
    handles = []

    for m in linear_list:
        sums.append(torch.zeros(m.in_features, device=device))
        counts.append(torch.tensor(0, device=device))

    index_of = {id(m): i for i, m in enumerate(linear_list)}

    def hook_fn(module, inputs, output):
        idx = index_of[id(module)]
        x = inputs[0].detach()
        # flatten all leading dims except last: [..., in_features]
        x2d = x.flatten(0, -2)  # (B*..., in_features)
        sums[idx] += x2d.abs().sum(dim=0)
        counts[idx] += x2d.shape[0]

    for m in linear_list:
        handles.append(m.register_forward_hook(hook_fn))

    seen = 0
    for data, target in data_loader:
        data = data.to(device)
        _ = model(data)
        seen += 1
        if seen >= num_batches:
            break

    for h in handles:
        h.remove()

    mags = [s / torch.clamp(c.float(), min=1.0) for s, c in zip(sums, counts)]
    # Return in the same order as linear_list
    return linear_list, mags

# ---- Step 2: build the activation-aware mask from N_eff on |W| * E|x|
@torch.no_grad()
def get_linear_mask_emp(module: nn.Linear,
                        in_mag: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Args:
        module: nn.Linear with weight shape [out, in]
        in_mag: tensor [in] = E[|x|] for this module's input
    Returns:
        mask (bool) with same shape as weight
        neff_row (long) length = out_features
    """
    W = module.weight.data  # [out, in]
    # contributions per input to each neuron:
    contrib = W.abs() * in_mag.unsqueeze(0)  # [out, in]
    row_sum = contrib.sum(dim=1, keepdim=True).clamp(min=1e-12)
    norm = contrib / row_sum                 # \hat c_ji

    neff = torch.floor(1.0 / norm.pow(2).sum(dim=1)).clamp(min=1, max=W.shape[1]).long()  # [out]

    # sort each row by importance and keep top neff[j]
    _, idx = torch.sort(norm, dim=1, descending=True)
    out, in_ = W.shape
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand(out, in_)
    keep_sorted = ranks < neff.unsqueeze(1)   # [out, in] (sorted order)

    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, idx, keep_sorted)
    return mask, neff

# ---- Step 3: prune with EMP (optional L1 row re-normalization)
@torch.no_grad()
def prune_model_emp_activation(model: nn.Module,
                               calib_loader,
                               device,
                               num_calib_batches: int = 10,
                               renormalize: bool = False):
    pruned = copy.deepcopy(model).to(device)
    linear_list, mags = collect_input_magnitudes(pruned, calib_loader, device, num_batches=num_calib_batches)

    layer_neff = {}
    for lin, mu in zip(linear_list, mags):
        W = lin.weight.data
        old_row_l1 = W.abs().sum(dim=1, keepdim=True)
        mask, neff = get_linear_mask_emp(lin, mu.to(W.device))
        # apply mask
        W.mul_(mask)
        if renormalize:
            new_row_l1 = W.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
            scale = old_row_l1 / new_row_l1
            W.mul_(scale)
        layer_neff[id(lin)] = {
            "name": getattr(lin, "_emp_name", None),
            "neff_row": neff.detach().cpu(),
            "avg_neff": neff.float().mean().item(),
            "in_features": W.shape[1],
            "out_features": W.shape[0],
            "layer_sparsity": float((~mask).sum().item() / mask.numel())
        }

    return pruned, layer_neff

# ---- Helper: pretty summary
def summarize_emp(layer_neff_dict):
    lines = []
    for k, v in layer_neff_dict.items():
        name = v.get("name") or f"Linear(id={k})"
        lines.append(
            f"{name:30s} | out={v['out_features']:4d} in={v['in_features']:4d} "
            f"| avg N_eff={v['avg_neff']:.1f} | sparsity={v['layer_sparsity']*100:5.1f}%"
        )
    return "\n".join(lines)

# ---- (Optional) attach names to ease reading
def tag_linear_names(model: nn.Module):
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m._emp_name = name


# --------- Model: Patch embedding, Encoder block, and ViT head ----------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):                   # x: [B,3,H,W]
        x = self.proj(x)                    # [B,embed,H/P,W/P]
        x = x.flatten(2).transpose(1, 2)    # [B, N, embed]
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, p=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=4, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):         # x: [B, N, C]
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)     # each: [B, N, H, D]
        q = q.transpose(1, 2)            # [B, H, N, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale        # [B,H,N,N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v                                       # [B,H,N,D]
        out = out.transpose(1, 2).reshape(B, N, C)           # [B,N,C]
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class EncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=2.0, drop=0.0, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SelfAttention(dim, num_heads, attn_drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp  = MLP(dim, int(dim * mlp_ratio), p=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class TinyViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3,
                 num_classes=10, embed_dim=128, depth=4, num_heads=4, mlp_ratio=2.0, drop=0.0):
        super().__init__()
        self.patch = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos = nn.Parameter(torch.zeros(1, self.patch.num_patches, embed_dim))
        self.blocks = nn.ModuleList([
            EncoderBlock(embed_dim, num_heads, mlp_ratio, drop, drop) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch(x) + self.pos
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = x.mean(dim=1)         # global average over patches (no class token)
        x = self.head(x)
        return F.log_softmax(x, dim=-1)

# --------- Data: CIFAR-10 (32x32) ----------
def get_cifar10_loaders(batch_size=128, num_workers=2):
    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tf)
    test  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_tf)
    return (DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True),
            DataLoader(test,  batch_size=256,       shuffle=False, num_workers=num_workers, pin_memory=True))

# --------- Train / Test ----------
def train_one_epoch(model, loader, opt, device):
    model.train()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = F.nll_loss(out, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss/total, 100*correct/total

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = F.nll_loss(out, y, reduction='sum')
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss/total, 100*correct/total

# --------- Run everything ----------
if __name__ == "__main__":
    torch.manual_seed(0); random.seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, test_loader = get_cifar10_loaders()
    model = TinyViT(embed_dim=128, depth=4, num_heads=4, mlp_ratio=2.0, drop=0.1).to(device)

    print("Params:", sum(p.numel() for p in model.parameters())/1e6, "M")
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

    # Train a few epochs (tune as desired)
    epochs = 10
    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, opt, device)
        te_loss, te_acc = evaluate(model, test_loader, device)
        print(f"Epoch {ep:02d}: train acc {tr_acc:5.2f}%  test acc {te_acc:5.2f}%")

    base_sparsity = model_sparsity(model)
    print(f"Base model sparsity: {base_sparsity:.4f}")

    # Activation-EMP prune using a small calibration slice of the train set
    pruned, layer_neff = prune_model_emp_activation(
        model, calib_loader=train_loader, device=device, num_calib_batches=8, renormalize=False
    )
    print("Layer-wise EMP summary:\n", summarize_emp(layer_neff))
    print(f"EMP sparsity: {model_sparsity(pruned):.4f}")

    # Evaluate pruned model
    te_loss, te_acc = evaluate(pruned, test_loader, device)
    print(f"EMP-pruned ViT — test acc {te_acc:5.2f}%  (loss {te_loss:.4f})")


Params: 0.54593 M




Epoch 01: train acc 33.46%  test acc 42.11%
Epoch 02: train acc 46.32%  test acc 49.63%
Epoch 03: train acc 52.18%  test acc 51.40%
Epoch 04: train acc 55.91%  test acc 58.05%
Epoch 05: train acc 58.67%  test acc 60.36%
Epoch 06: train acc 60.84%  test acc 62.72%
Epoch 07: train acc 62.39%  test acc 63.98%
Epoch 08: train acc 64.09%  test acc 66.32%
Epoch 09: train acc 65.35%  test acc 66.15%
Epoch 10: train acc 66.41%  test acc 65.86%
Base model sparsity: 0.0000
Layer-wise EMP summary:
 Linear(id=6060751744)          | out= 384 in= 128 | avg N_eff=86.7 | sparsity= 32.2%
Linear(id=6060746848)          | out= 128 in= 128 | avg N_eff=81.6 | sparsity= 36.3%
Linear(id=6060751216)          | out= 256 in= 128 | avg N_eff=86.8 | sparsity= 32.2%
Linear(id=6060747376)          | out= 128 in= 256 | avg N_eff=165.7 | sparsity= 35.3%
Linear(id=6060752320)          | out= 384 in= 128 | avg N_eff=88.9 | sparsity= 30.6%
Linear(id=6060750976)          | out= 128 in= 128 | avg N_eff=87.1 | sparsity= 32

In [1]:
# EMP vs Row-NEFF pruning on a Vision Transformer (CIFAR-10)
# -----------------------------------------------------------
# Requirements: torch, torchvision, numpy
# Optional: tqdm (for pretty progress bars)

import math, os, copy, random, time
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms

try:
    from tqdm import tqdm
    TQDM = True
except Exception:
    TQDM = False

# -----------------------------
# Utils
# -----------------------------
def set_seed(seed: int = 1337):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def accuracy(logits, targets):
    pred = logits.argmax(dim=1)
    return (pred == targets).float().mean().item()

def count_zeros(t: torch.Tensor) -> int:
    return (t == 0).sum().item()

def model_sparsity(model: nn.Module) -> float:
    total = 0; zeros = 0
    for n, p in model.named_parameters():
        if p.ndim >= 2 and "weight" in n:
            total += p.numel()
            zeros += count_zeros(p)
    return zeros / max(total, 1)

# -----------------------------
# Data
# -----------------------------
def get_cifar10(batch_size=128, num_workers=2):
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4,0.4,0.4,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

    train = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
    test  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=256, shuffle=False,
                                              num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

# -----------------------------
# ViT components
# -----------------------------
class DropPath(nn.Module):
    """Stochastic depth (per sample)"""
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training: return x
        keep = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rnd = x.new_empty(shape).bernoulli_(keep)
        return x * rnd / keep

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2
    def forward(self, x):
        # B, C, H, W -> B, N, D
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)                         # B, N, 3C
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]          # each: B, heads, N, head_dim
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path1 = DropPath(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
        self.drop_path2 = DropPath(drop_path)
    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch=4, in_chans=3, num_classes=10,
                 emb_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, drop=0.1, attn_drop=0.1, drop_path=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch, in_chans, emb_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, emb_dim))
        self.pos_drop = nn.Dropout(drop)

        dpr = torch.linspace(0, drop_path, steps=depth).tolist()
        self.blocks = nn.ModuleList([
            Block(emb_dim, num_heads, mlp_ratio, drop, attn_drop, dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.head = nn.Linear(emb_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)[:, 0]      # CLS
        x = self.head(x)
        return x

# -----------------------------
# Train / Eval
# -----------------------------
@dataclass
class TrainCfg:
    epochs: int = 20
    lr: float = 3e-4
    warmup_epochs: int = 2
    weight_decay: float = 0.05
    label_smoothing: float = 0.1
    grad_clip: float = 1.0

def train_one_epoch(model, loader, opt, device, loss_fn, scheduler=None):
    model.train()
    total_acc, total = 0.0, 0
    iterator = tqdm(loader, leave=False) if TQDM else loader
    for x, y in iterator:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        if scheduler is not None and hasattr(scheduler, "optimizer"):  # for some schedulers
            pass
        if math.isfinite(cfg.grad_clip): nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        total_acc += accuracy(logits.detach(), y) * x.size(0)
        total += x.size(0)
    if scheduler is not None:
        scheduler.step()
    return total_acc / max(total,1)

@torch.no_grad()
def evaluate(model, loader, device, loss_fn=None):
    model.eval()
    total_acc, total, total_loss = 0.0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        total_acc += accuracy(logits, y) * x.size(0)
        if loss_fn is not None:
            total_loss += F.cross_entropy(logits, y, reduction="sum").item()
        total += x.size(0)
    return total_acc / max(total,1), (total_loss / max(total,1) if loss_fn is not None else None)

# -----------------------------
# NEFF helpers (weights-only)
# -----------------------------
def per_row_neff_from_weight(W: torch.Tensor) -> torch.Tensor:
    # W: [out, in]
    absW = W.abs()
    denom = absW.sum(dim=1, keepdim=True).clamp_min(1e-12)
    p = absW / denom
    neff = torch.floor(1.0 / p.pow(2).sum(dim=1)).clamp(min=1, max=W.size(1))
    return neff  # [out]

def get_linear_mask_per_row(module: nn.Linear) -> torch.Tensor:
    W = module.weight.data
    out, in_ = W.shape
    absW = W.abs()
    p = absW / absW.sum(dim=1, keepdim=True).clamp_min(1e-12)
    neff = torch.floor(1.0 / (p.pow(2).sum(dim=1))).clamp_(min=1, max=in_).long()
    contrib = absW  # weight-only
    _, idx = torch.sort(contrib, dim=1, descending=True)
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand_as(idx)
    mask_sorted = ranks < neff.unsqueeze(1)
    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, idx, mask_sorted)
    return mask

def prune_model_neff_per_row(model: nn.Module, renorm=False) -> Tuple[nn.Module, List[Tuple[str,float,float]]]:
    pruned = copy.deepcopy(model)
    layer_info = []
    with torch.no_grad():
        for name, m in pruned.named_modules():
            if isinstance(m, nn.Linear):
                mask = get_linear_mask_per_row(m).to(m.weight.device)
                m.weight.mul_(mask)
                if renorm:
                    s = m.weight.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
                    m.weight.div_(s)  # optional
                neff = per_row_neff_from_weight(m.weight).float().mean().item()
                layer_spars = count_zeros(m.weight) / m.weight.numel()
                layer_info.append((f"{name}", neff, layer_spars))
    return pruned, layer_info

# -----------------------------
# Activation-aware EMP (W*x)
# -----------------------------
@torch.no_grad()
def collect_activation_means(model: nn.Module, loader, device, num_batches=8) -> Dict[nn.Linear, torch.Tensor]:
    """
    For each Linear layer, estimate E[|x|] per input channel (last dim).
    """
    stats_sum: Dict[nn.Linear, torch.Tensor] = {}
    stats_count: Dict[nn.Linear, int] = {}
    handles = []

    def register(m: nn.Linear):
        stats_sum[m] = torch.zeros(m.in_features, device=device)
        stats_count[m] = 0
        def hook(mod, inp, out):
            x = inp[0].detach()
            # reduce over all dims except the last (feature dim)
            red_dims = tuple(range(x.dim() - 1))
            s = x.abs().sum(dim=red_dims)
            stats_sum[mod] += s
            stats_count[mod] += (x.numel() // mod.in_features)
        return m.register_forward_hook(hook)

    for m in model.modules():
        if isinstance(m, nn.Linear):
            handles.append(register(m))

    model.eval()
    it = iter(loader)
    for b in range(num_batches):
        try:
            x, _ = next(it)
        except StopIteration:
            break
        x = x.to(device, non_blocking=True)
        model(x)

    for h in handles: h.remove()
    means = {m: (stats_sum[m] / max(stats_count[m],1)).clamp_min(1e-12) for m in stats_sum.keys()}
    return means

def prune_model_emp_activation(model: nn.Module, calib_loader, device, num_calib_batches=8, renorm=False):
    """
    Activation-aware EMP pruning: keep, in each row, the top N_eff elements using
    p_ij ∝ |w_ij| * E[|x_j|]
    """
    pruned = copy.deepcopy(model).to(device)
    act_means = collect_activation_means(pruned, calib_loader, device, num_batches=num_calib_batches)

    layer_info = []
    with torch.no_grad():
        for name, m in pruned.named_modules():
            if isinstance(m, nn.Linear):
                W = m.weight.data
                out, in_ = W.shape
                mean_abs_x = act_means[m]  # [in]
                contrib = W.abs() * mean_abs_x.unsqueeze(0)  # [out, in]
                denom = contrib.sum(dim=1, keepdim=True).clamp_min(1e-12)
                p = contrib / denom
                neff = torch.floor(1.0 / p.pow(2).sum(dim=1)).clamp_(min=1, max=in_).long()

                # build mask row-wise
                _, idx = torch.sort(contrib, dim=1, descending=True)
                ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand_as(idx)
                mask_sorted = ranks < neff.unsqueeze(1)
                mask = torch.zeros_like(W, dtype=torch.bool)
                mask.scatter_(1, idx, mask_sorted)
                m.weight.mul_(mask)
                if renorm:
                    s = m.weight.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
                    m.weight.div_(s)

                # report
                avg_neff = neff.float().mean().item()
                layer_spars = count_zeros(m.weight) / m.weight.numel()
                layer_info.append((f"{name}", avg_neff, layer_spars))

    return pruned, layer_info

# -----------------------------
# Reporting helpers
# -----------------------------
def print_layerwise_report(tag: str, layer_info: List[Tuple[str,float,float]]):
    print(f"Layer-wise {tag} summary:")
    for (name, neff_avg, spars) in layer_info:
        # try to parse layer shape for nicer printing if possible
        print(f" {name:<30s} | avg N_eff={neff_avg:5.1f} | sparsity={100*spars:5.1f}%")
    print()

def eval_and_report(model, test_loader, device, tag="Model"):
    acc, loss = evaluate(model, test_loader, device, loss_fn=True)
    print(f"{tag:>20s} — test acc {acc*100:5.2f}%  (loss {loss:.4f})")
    return acc, loss

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    set_seed(123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = get_cifar10(batch_size=128, num_workers=2)

    # Build a stronger baseline than a tiny ViT, but still lightweight
    model = ViT(
        img_size=32, patch=4, emb_dim=256, depth=6, num_heads=8,
        mlp_ratio=4.0, drop=0.1, attn_drop=0.1, drop_path=0.1, num_classes=10
    ).to(device)

    cfg = TrainCfg(epochs=20, lr=3e-4, warmup_epochs=2, weight_decay=0.05, label_smoothing=0.1, grad_clip=1.0)

    # Optimizer + cosine schedule with warmup
    opt = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9, 0.999))
    sched = CosineAnnealingLR(opt, T_max=cfg.epochs - cfg.warmup_epochs)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)

    # Warmup loop
    warmup_steps = cfg.warmup_epochs * len(train_loader)
    if warmup_steps > 0:
        warmup_factor = cfg.lr / warmup_steps
        cur_lr = 0.0

    print("Training baseline ViT…")
    for epoch in range(1, cfg.epochs + 1):
        if epoch <= cfg.warmup_epochs:
            # linear warmup
            for pgroup in opt.param_groups:
                pgroup["lr"] = min(cfg.lr, (epoch-1)*len(train_loader)*warmup_factor + warmup_factor*1)
        else:
            # cosine decay after warmup
            pass

        train_acc = train_one_epoch(model, train_loader, opt, device, loss_fn,
                                    scheduler=None if epoch <= cfg.warmup_epochs else sched)
        test_acc, _ = evaluate(model, test_loader, device, loss_fn=True)
        print(f"Epoch {epoch:02d}: train acc {train_acc*100:5.2f}%  test acc {test_acc*100:5.2f}%")

    print(f"Base model sparsity: {model_sparsity(model):.4f}")
    base_acc, base_loss = eval_and_report(model, test_loader, device, tag="Baseline ViT")

    # ---------------- Row-NEFF pruning ----------------
    row_model, row_info = prune_model_neff_per_row(model, renorm=False)
    print_layerwise_report("Row-NEFF", row_info)
    print(f"Row-NEFF sparsity: {model_sparsity(row_model):.4f}")
    eval_and_report(row_model, test_loader, device, tag="Row-NEFF pruned ViT")

    # --------------- Activation EMP (W*x) -------------
    # Use a small calibration subset (first few batches of training loader)
    emp_model, emp_info = prune_model_emp_activation(model, calib_loader=train_loader,
                                                     device=device, num_calib_batches=8, renorm=False)
    print_layerwise_report("EMP (activation)", emp_info)
    print(f"EMP sparsity: {model_sparsity(emp_model):.4f}")
    eval_and_report(emp_model, test_loader, device, tag="EMP-pruned ViT")


Training baseline ViT…


                                                 

Epoch 01: train acc 14.28%  test acc 23.67%


                                                 

Epoch 02: train acc 28.55%  test acc 36.53%


                                                 

Epoch 03: train acc 38.99%  test acc 47.52%


                                                 

Epoch 04: train acc 45.62%  test acc 52.58%


                                                 

Epoch 05: train acc 49.84%  test acc 53.18%


                                                 

Epoch 06: train acc 52.63%  test acc 58.29%


                                                 

Epoch 07: train acc 54.49%  test acc 59.13%


                                                 

Epoch 08: train acc 55.90%  test acc 61.06%


                                                 

Epoch 09: train acc 57.65%  test acc 62.42%


                                                 

Epoch 10: train acc 58.92%  test acc 63.18%


                                                 

Epoch 11: train acc 59.85%  test acc 64.63%


                                                 

Epoch 12: train acc 61.29%  test acc 64.75%


                                                 

Epoch 13: train acc 62.51%  test acc 66.06%


                                                 

Epoch 14: train acc 63.31%  test acc 66.06%


                                                 

Epoch 15: train acc 64.43%  test acc 66.31%


                                                 

Epoch 16: train acc 64.64%  test acc 67.69%


                                                 

Epoch 17: train acc 65.40%  test acc 68.43%


                                                 

Epoch 18: train acc 65.76%  test acc 68.71%


                                                 

Epoch 19: train acc 66.04%  test acc 69.22%


                                                 

Epoch 20: train acc 66.18%  test acc 69.18%
Base model sparsity: 0.0000
        Baseline ViT — test acc 69.18%  (loss 0.8982)
Layer-wise Row-NEFF summary:
 blocks.0.attn.qkv              | avg N_eff=133.5 | sparsity= 36.4%
 blocks.0.attn.proj             | avg N_eff=134.1 | sparsity= 36.2%
 blocks.0.mlp.fc1               | avg N_eff=133.5 | sparsity= 36.4%
 blocks.0.mlp.fc2               | avg N_eff=534.7 | sparsity= 36.4%
 blocks.1.attn.qkv              | avg N_eff=133.7 | sparsity= 36.3%
 blocks.1.attn.proj             | avg N_eff=133.1 | sparsity= 36.6%
 blocks.1.mlp.fc1               | avg N_eff=133.6 | sparsity= 36.4%
 blocks.1.mlp.fc2               | avg N_eff=534.3 | sparsity= 36.4%
 blocks.2.attn.qkv              | avg N_eff=133.8 | sparsity= 36.3%
 blocks.2.attn.proj             | avg N_eff=133.2 | sparsity= 36.5%
 blocks.2.mlp.fc1               | avg N_eff=133.4 | sparsity= 36.4%
 blocks.2.mlp.fc2               | avg N_eff=534.7 | sparsity= 36.4%
 blocks.3.attn.qkv           

In [None]:
"""
NEFF Pruning for Language Models: Complete Implementation
Includes perplexity evaluation for LLMs and full BERT experiments
For ICLR 2026 submission
"""

import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    BertForSequenceClassification, BertTokenizer,
    AutoModelForMaskedLM,
    get_linear_schedule_with_warmup
)
from datasets import load_dataset
from torch.optim import AdamW
import warnings
warnings.filterwarnings('ignore')

# =====================================
# Core NEFF Pruning Functions
# =====================================

def compute_neff(weights: torch.Tensor, dim: int = 1) -> torch.Tensor:
    """
    Compute N_eff = floor(1 / sum(p_i^2)) where p_i = |w_i| / sum(|w_j|)
    
    Mathematical foundation:
    - Effective number of parameters based on weight distribution entropy
    - Captures the "effective dimensionality" of each neuron
    """
    abs_weights = weights.abs()
    weight_sum = abs_weights.sum(dim=dim, keepdim=True).clamp_min(1e-12)
    p = abs_weights / weight_sum
    neff = torch.floor(1.0 / (p.pow(2).sum(dim=dim))).clamp(min=1, max=weights.shape[dim])
    return neff

def get_neff_mask_per_row(weight: torch.Tensor) -> torch.Tensor:
    """Generate pruning mask based on per-row NEFF"""
    out_dim, in_dim = weight.shape
    abs_weight = weight.abs()
    
    # Compute NEFF for each row
    neff = compute_neff(weight, dim=1).long()
    
    # Sort by magnitude and keep top NEFF elements per row
    _, indices = torch.sort(abs_weight, dim=1, descending=True)
    ranks = torch.arange(in_dim, device=weight.device).unsqueeze(0).expand_as(indices)
    mask_sorted = ranks < neff.unsqueeze(1)
    
    # Scatter back to original positions
    mask = torch.zeros_like(weight, dtype=torch.bool)
    mask.scatter_(1, indices, mask_sorted)
    return mask

@torch.no_grad()
def collect_activation_statistics(
    model: nn.Module,
    dataloader: DataLoader,
    num_batches: int = 32,
    model_type: str = "causal_lm"
) -> Dict[nn.Linear, torch.Tensor]:
    """Collect activation statistics for EMP pruning"""
    stats_sum = {}
    stats_count = {}
    handles = []
    
    def register_hook(module: nn.Linear):
        stats_sum[module] = torch.zeros(module.in_features, device=next(module.parameters()).device)
        stats_count[module] = 0
        
        def hook(mod, inp, out):
            x = inp[0].detach()
            if x.dim() == 3:  # [batch, seq, features]
                x = x.reshape(-1, x.size(-1))
            elif x.dim() == 2:  # [batch, features]
                pass
            else:
                return
            
            stats_sum[mod] += x.abs().sum(dim=0)
            stats_count[mod] += x.size(0)
        
        return module.register_forward_hook(hook)
    
    # Register hooks for all Linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and 'lm_head' not in name:
            handles.append(register_hook(module))
    
    # Run calibration
    model.eval()
    for i, batch in enumerate(tqdm(dataloader, desc="Collecting activations", total=min(num_batches, len(dataloader)))):
        if i >= num_batches:
            break
        
        if model_type == "causal_lm":
            inputs = batch['input_ids'].to(next(model.parameters()).device)
            with torch.no_grad():
                model(inputs)
        elif model_type == "bert_classification":
            inputs = {k: v.to(next(model.parameters()).device) 
                     for k, v in batch.items() if k in ['input_ids', 'attention_mask']}
            with torch.no_grad():
                model(**inputs)
    
    # Clean up hooks
    for handle in handles:
        handle.remove()
    
    # Compute means
    activation_means = {
        module: (stats_sum[module] / max(stats_count[module], 1)).clamp_min(1e-12)
        for module in stats_sum
    }
    
    return activation_means

def prune_model_neff(
    model: nn.Module,
    method: str = "weight_only",
    activation_stats: Optional[Dict] = None,
    keep_modules: List[str] = ['lm_head', 'classifier', 'embeddings']
) -> Tuple[nn.Module, Dict]:
    """
    Apply NEFF pruning to model
    
    Args:
        model: Model to prune
        method: "weight_only" for standard NEFF, "emp" for activation-aware
        activation_stats: Required for EMP method
        keep_modules: Module names to skip pruning
    """
    pruned_model = copy.deepcopy(model)
    pruning_info = {
        'layer_sparsity': {},
        'layer_neff': {},
        'total_params': 0,
        'pruned_params': 0
    }
    
    with torch.no_grad():
        for name, module in pruned_model.named_modules():
            # Skip non-Linear and protected modules
            if not isinstance(module, nn.Linear):
                continue
            if any(keep_name in name for keep_name in keep_modules):
                continue
            
            weight = module.weight.data
            original_weight = weight.clone()
            
            if method == "weight_only":
                mask = get_neff_mask_per_row(weight)
            elif method == "emp" and activation_stats is not None:
                # Activation-aware EMP
                if module in activation_stats:
                    act_mean = activation_stats[module]
                    effective_weight = weight.abs() * act_mean.unsqueeze(0)
                    
                    # Compute NEFF with activation weighting
                    weight_sum = effective_weight.sum(dim=1, keepdim=True).clamp_min(1e-12)
                    p = effective_weight / weight_sum
                    neff = torch.floor(1.0 / (p.pow(2).sum(dim=1))).clamp(min=1, max=weight.size(1)).long()
                    
                    # Create mask based on effective contributions
                    _, indices = torch.sort(effective_weight, dim=1, descending=True)
                    ranks = torch.arange(weight.size(1), device=weight.device).unsqueeze(0).expand_as(indices)
                    mask_sorted = ranks < neff.unsqueeze(1)
                    mask = torch.zeros_like(weight, dtype=torch.bool)
                    mask.scatter_(1, indices, mask_sorted)
                else:
                    mask = get_neff_mask_per_row(weight)
            else:
                raise ValueError(f"Unknown method: {method}")
            
            # Apply mask
            module.weight.data = weight * mask.float()
            
            # Calculate statistics
            num_zeros = (module.weight.data == 0).sum().item()
            num_params = weight.numel()
            sparsity = num_zeros / num_params
            avg_neff = compute_neff(module.weight.data, dim=1).float().mean().item()
            
            pruning_info['layer_sparsity'][name] = sparsity
            pruning_info['layer_neff'][name] = avg_neff
            pruning_info['total_params'] += num_params
            pruning_info['pruned_params'] += num_zeros
    
    pruning_info['overall_sparsity'] = pruning_info['pruned_params'] / pruning_info['total_params']
    
    return pruned_model, pruning_info

# =====================================
# Perplexity Evaluation for LLMs
# =====================================

@torch.no_grad()
def evaluate_perplexity(
    model: nn.Module,
    tokenizer,
    dataset_name: str = "wikitext",
    dataset_config: str = "wikitext-2-raw-v1",
    split: str = "test",
    max_length: int = 1024,
    stride: int = 512,
    batch_size: int = 1,
    max_samples: Optional[int] = None
) -> float:
    """
    Evaluate perplexity on a dataset
    
    Perplexity = exp(average_negative_log_likelihood)
    Lower is better!
    """
    model.eval()
    device = next(model.parameters()).device
    
    # Load dataset
    dataset = load_dataset(dataset_name, dataset_config, split=split)
    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))
    
    encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt")
    
    total_loss = 0
    total_tokens = 0
    
    # Sliding window approach
    for i in tqdm(range(0, encodings.input_ids.size(1), stride), desc="Computing perplexity"):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i
        
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100  # Mask non-target tokens
        
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            loss = outputs.loss * trg_len
            
        total_loss += loss.item()
        total_tokens += trg_len
        
        if end_loc == encodings.input_ids.size(1):
            break
    
    perplexity = math.exp(total_loss / total_tokens)
    return perplexity

# =====================================
# BERT Experiments
# =====================================

class BertExperiments:
    def __init__(self, model_name: str = "bert-base-uncased", task: str = "sst2"):
        self.model_name = model_name
        self.task = task
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load model and tokenizer
        if task in ["sst2", "mrpc", "cola", "qqp"]:
            self.model = BertForSequenceClassification.from_pretrained(
                model_name, num_labels=2
            ).to(self.device)
        else:
            raise ValueError(f"Unsupported task: {task}")
        
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        
        # Load dataset
        self.dataset = self._load_dataset()
    
    def _load_dataset(self):
        """Load GLUE dataset"""
        task_to_dataset = {
            "sst2": ("glue", "sst2"),
            "mrpc": ("glue", "mrpc"),
            "cola": ("glue", "cola"),
            "qqp": ("glue", "qqp")
        }
        
        dataset_name, config = task_to_dataset[self.task]
        dataset = load_dataset(dataset_name, config)
        return dataset
    
    def preprocess_function(self, examples):
        """Tokenize dataset"""
        if self.task == "sst2":
            return self.tokenizer(
                examples["sentence"],
                padding="max_length",
                truncation=True,
                max_length=128
            )
        elif self.task in ["mrpc", "qqp"]:
            return self.tokenizer(
                examples["sentence1" if self.task == "mrpc" else "question1"],
                examples["sentence2" if self.task == "mrpc" else "question2"],
                padding="max_length",
                truncation=True,
                max_length=128
            )
        elif self.task == "cola":
            return self.tokenizer(
                examples["sentence"],
                padding="max_length",
                truncation=True,
                max_length=128
            )
    
    def create_dataloader(self, split: str = "validation", batch_size: int = 32):
        """Create DataLoader for evaluation"""
        dataset = self.dataset[split]
        
        # Tokenize
        dataset = dataset.map(self.preprocess_function, batched=True)
        dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
        
        return DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    @torch.no_grad()
    def evaluate_accuracy(self, model: nn.Module, dataloader: DataLoader) -> Dict:
        """Evaluate model accuracy"""
        model.eval()
        correct = 0
        total = 0
        total_loss = 0
        
        for batch in tqdm(dataloader, desc="Evaluating"):
            inputs = {
                "input_ids": batch["input_ids"].to(self.device),
                "attention_mask": batch["attention_mask"].to(self.device),
                "labels": batch["label"].to(self.device)
            }
            
            outputs = model(**inputs)
            logits = outputs.logits
            loss = outputs.loss
            
            predictions = torch.argmax(logits, dim=-1)
            correct += (predictions == batch["label"].to(self.device)).sum().item()
            total += len(batch["label"])
            total_loss += loss.item() * len(batch["label"])
        
        accuracy = correct / total
        avg_loss = total_loss / total
        
        return {
            "accuracy": accuracy,
            "loss": avg_loss,
            "correct": correct,
            "total": total
        }
    
    @torch.no_grad()
    def evaluate_mlm_perplexity(self, model: nn.Module, max_samples: int = 1000) -> float:
        """Evaluate masked language model perplexity for BERT"""
        # Use BERT for MLM
        mlm_model = AutoModelForMaskedLM.from_pretrained(self.model_name).to(self.device)
        
        # Copy pruned weights to MLM model (excluding pooler which MLM doesn't have)
        if hasattr(model, 'bert'):
            source_state = model.bert.state_dict()
            target_state = mlm_model.bert.state_dict()
            
            # Only copy weights that exist in both models
            filtered_state = {k: v for k, v in source_state.items() 
                            if k in target_state and 'pooler' not in k}
            
            mlm_model.bert.load_state_dict(filtered_state, strict=False)
        
        mlm_model.eval()
        
        # Load WikiText for perplexity
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        if max_samples:
            dataset = dataset.select(range(min(max_samples, len(dataset))))
        
        total_loss = 0
        total_predictions = 0
        
        for idx, text in enumerate(tqdm(dataset["text"], desc="Computing MLM perplexity", total=min(max_samples, len(dataset["text"])))):
            if idx >= max_samples:
                break
            if not text.strip():
                continue
            
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                max_length=512,
                truncation=True,
                padding=True
            ).to(self.device)
            
            # Create masked inputs (15% masking)
            input_ids = inputs.input_ids.clone()
            labels = inputs.input_ids.clone()
            
            # Random masking
            rand = torch.rand(input_ids.shape).to(self.device)
            mask_arr = (rand < 0.15) * (input_ids != self.tokenizer.pad_token_id)
            
            for i in range(input_ids.shape[0]):
                selection = torch.where(mask_arr[i])[0]
                if len(selection) > 0:
                    input_ids[i, selection] = self.tokenizer.mask_token_id
                    # Only compute loss on masked tokens
                    labels[i, ~mask_arr[i]] = -100
            
            outputs = mlm_model(input_ids=input_ids, labels=labels)
            
            if outputs.loss is not None:
                total_loss += outputs.loss.item() * mask_arr.sum().item()
                total_predictions += mask_arr.sum().item()
        
        if total_predictions > 0:
            perplexity = math.exp(total_loss / total_predictions)
        else:
            perplexity = float('inf')
        
        return perplexity
    
    def run_full_evaluation(self, calibration_samples: int = 128):
        """Run complete evaluation pipeline"""
        print(f"\n{'='*50}")
        print(f"BERT Experiments on {self.task.upper()}")
        print(f"{'='*50}\n")
        
        # Create evaluation dataloader
        eval_dataloader = self.create_dataloader("validation", batch_size=32)
        calib_dataloader = self.create_dataloader("train", batch_size=8)
        
        # 1. Evaluate original model
        print("1. Evaluating Original BERT Model")
        print("-" * 30)
        orig_results = self.evaluate_accuracy(self.model, eval_dataloader)
        orig_perplexity = self.evaluate_mlm_perplexity(self.model, max_samples=100)
        
        print(f"Accuracy: {orig_results['accuracy']:.4f}")
        print(f"Loss: {orig_results['loss']:.4f}")
        print(f"MLM Perplexity: {orig_perplexity:.2f}")
        print(f"Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        # 2. NEFF Weight-only Pruning
        print("\n2. NEFF Weight-only Pruning")
        print("-" * 30)
        neff_model, neff_info = prune_model_neff(
            self.model,
            method="weight_only",
            keep_modules=['embeddings', 'classifier', 'pooler']
        )
        
        neff_results = self.evaluate_accuracy(neff_model, eval_dataloader)
        neff_perplexity = self.evaluate_mlm_perplexity(neff_model, max_samples=100)
        
        print(f"Accuracy: {neff_results['accuracy']:.4f} (Δ: {(neff_results['accuracy'] - orig_results['accuracy'])*100:.2f}%)")
        print(f"Loss: {neff_results['loss']:.4f}")
        print(f"MLM Perplexity: {neff_perplexity:.2f} (Δ: {neff_perplexity - orig_perplexity:.2f})")
        print(f"Overall Sparsity: {neff_info['overall_sparsity']:.2%}")
        
        # 3. EMP Activation-aware Pruning
        print("\n3. EMP Activation-aware Pruning")
        print("-" * 30)
        
        # Collect activation statistics
        activation_stats = collect_activation_statistics(
            self.model,
            calib_dataloader,
            num_batches=calibration_samples // 8,
            model_type="bert_classification"
        )
        
        emp_model, emp_info = prune_model_neff(
            self.model,
            method="emp",
            activation_stats=activation_stats,
            keep_modules=['embeddings', 'classifier', 'pooler']
        )
        
        emp_results = self.evaluate_accuracy(emp_model, eval_dataloader)
        emp_perplexity = self.evaluate_mlm_perplexity(emp_model, max_samples=100)
        
        print(f"Accuracy: {emp_results['accuracy']:.4f} (Δ: {(emp_results['accuracy'] - orig_results['accuracy'])*100:.2f}%)")
        print(f"Loss: {emp_results['loss']:.4f}")
        print(f"MLM Perplexity: {emp_perplexity:.2f} (Δ: {emp_perplexity - orig_perplexity:.2f})")
        print(f"Overall Sparsity: {emp_info['overall_sparsity']:.2%}")
        
        # Summary comparison
        print(f"\n{'='*50}")
        print("SUMMARY COMPARISON")
        print(f"{'='*50}")
        print(f"{'Method':<20} {'Accuracy':<12} {'MLM PPL':<12} {'Sparsity':<12}")
        print("-" * 56)
        print(f"{'Original':<20} {orig_results['accuracy']:.4f} {orig_perplexity:>11.2f} {'0.00%':>11}")
        print(f"{'NEFF (weight)':<20} {neff_results['accuracy']:.4f} {neff_perplexity:>11.2f} {neff_info['overall_sparsity']:>11.2%}")
        print(f"{'EMP (activation)':<20} {emp_results['accuracy']:.4f} {emp_perplexity:>11.2f} {emp_info['overall_sparsity']:>11.2%}")
        
        return {
            "original": {"accuracy": orig_results['accuracy'], "perplexity": orig_perplexity},
            "neff": {"accuracy": neff_results['accuracy'], "perplexity": neff_perplexity, "sparsity": neff_info['overall_sparsity']},
            "emp": {"accuracy": emp_results['accuracy'], "perplexity": emp_perplexity, "sparsity": emp_info['overall_sparsity']}
        }

# =====================================
# LLM Perplexity Testing
# =====================================

def test_llm_perplexity(
    model_name: str = "facebook/opt-125m",
    calibration_samples: int = 128,
    eval_samples: int = 100
):
    """Test perplexity on a small LLM with different pruning methods"""
    print(f"\n{'='*50}")
    print(f"LLM Perplexity Evaluation: {model_name}")
    print(f"{'='*50}\n")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model and tokenizer
    print("Loading model...")
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # 1. Original model perplexity
    print("\n1. Original Model")
    print("-" * 30)
    orig_ppl = evaluate_perplexity(
        model, tokenizer,
        dataset_name="wikitext",
        dataset_config="wikitext-2-raw-v1",
        max_samples=eval_samples
    )
    print(f"Perplexity: {orig_ppl:.2f}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Prepare calibration data
    calib_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    calib_texts = [text for text in calib_dataset["text"][:calibration_samples] if text.strip()]
    
    calib_data = []
    for text in calib_texts[:calibration_samples]:
        tokens = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True)
        calib_data.append({"input_ids": tokens["input_ids"]})
    
    calib_loader = DataLoader(calib_data, batch_size=1, shuffle=False)
    
    # 2. NEFF weight-only pruning
    print("\n2. NEFF Weight-only Pruning")
    print("-" * 30)
    neff_model, neff_info = prune_model_neff(
        model,
        method="weight_only",
        keep_modules=['lm_head', 'embed_tokens', 'embed_positions']
    )
    
    neff_ppl = evaluate_perplexity(
        neff_model, tokenizer,
        dataset_name="wikitext",
        dataset_config="wikitext-2-raw-v1",
        max_samples=eval_samples
    )
    print(f"Perplexity: {neff_ppl:.2f} (Δ: +{neff_ppl - orig_ppl:.2f})")
    print(f"Sparsity: {neff_info['overall_sparsity']:.2%}")
    
    # 3. EMP activation-aware pruning
    print("\n3. EMP Activation-aware Pruning")
    print("-" * 30)
    
    activation_stats = collect_activation_statistics(
        model,
        calib_loader,
        num_batches=min(32, len(calib_loader)),
        model_type="causal_lm"
    )
    
    emp_model, emp_info = prune_model_neff(
        model,
        method="emp",
        activation_stats=activation_stats,
        keep_modules=['lm_head', 'embed_tokens', 'embed_positions']
    )
    
    emp_ppl = evaluate_perplexity(
        emp_model, tokenizer,
        dataset_name="wikitext",
        dataset_config="wikitext-2-raw-v1",
        max_samples=eval_samples
    )
    print(f"Perplexity: {emp_ppl:.2f} (Δ: +{emp_ppl - orig_ppl:.2f})")
    print(f"Sparsity: {emp_info['overall_sparsity']:.2%}")
    
    # Summary
    print(f"\n{'='*50}")
    print("SUMMARY - LLM Perplexity Comparison")
    print(f"{'='*50}")
    print(f"{'Method':<20} {'Perplexity':<15} {'Sparsity':<15}")
    print("-" * 50)
    print(f"{'Original':<20} {orig_ppl:<15.2f} {'0.00%':<15}")
    print(f"{'NEFF (weight)':<20} {neff_ppl:<15.2f} {neff_info['overall_sparsity']:<15.2%}")
    print(f"{'EMP (activation)':<20} {emp_ppl:<15.2f} {emp_info['overall_sparsity']:<15.2%}")
    
    return {
        "original": orig_ppl,
        "neff": {"perplexity": neff_ppl, "sparsity": neff_info['overall_sparsity']},
        "emp": {"perplexity": emp_ppl, "sparsity": emp_info['overall_sparsity']}
    }

# =====================================
# Main Execution
# =====================================

if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Test 1: LLM Perplexity
    print("\n[TEST 1: Language Model Perplexity]")
    llm_results = test_llm_perplexity(
        model_name="facebook/opt-125m",  # Small model for demo
        calibration_samples=128,
        eval_samples=100  # Increase for paper
    )
    
    # Test 2: BERT on GLUE tasks
    print("\n[TEST 2: BERT Classification Tasks]")
    
    # Run on multiple tasks for comprehensive evaluation
    tasks = ["sst2", "mrpc"]  # Add more: "cola", "qqp" for paper
    all_bert_results = {}
    
    for task in tasks:
        bert_exp = BertExperiments(
            model_name="bert-base-uncased",
            task=task
        )
        results = bert_exp.run_full_evaluation(calibration_samples=128)
        all_bert_results[task] = results
    
    # Final summary
    print("\n" + "=" * 60)
    print("FINAL RESULTS SUMMARY")
    print("=" * 60)
    
    print("\nLLM Perplexity Results:")
    print(f"  NEFF achieves {llm_results['neff']['sparsity']:.1%} sparsity with {(llm_results['neff']['perplexity']/llm_results['original'] - 1)*100:.1f}% perplexity increase")
    print(f"  EMP achieves {llm_results['emp']['sparsity']:.1%} sparsity with {(llm_results['emp']['perplexity']/llm_results['original'] - 1)*100:.1f}% perplexity increase")
    
    print("\nBERT Results Summary:")
    for task, results in all_bert_results.items():
        print(f"\n  {task.upper()}:")
        orig_acc = results['original']['accuracy']
        neff_acc = results['neff']['accuracy']
        emp_acc = results['emp']['accuracy']
        print(f"    NEFF: {results['neff']['sparsity']:.1%} sparsity, {(neff_acc/orig_acc - 1)*100:.2f}% accuracy change")
        print(f"    EMP:  {results['emp']['sparsity']:.1%} sparsity, {(emp_acc/orig_acc - 1)*100:.2f}% accuracy change")
    

NEFF PRUNING: Complete Evaluation Suite
For ICLR 2026 Submission

[TEST 1: Language Model Perplexity]

LLM Perplexity Evaluation: facebook/opt-125m

Loading model...

1. Original Model
------------------------------


Computing perplexity:  91%|█████████ | 10/11 [00:00<00:00, 70.34it/s]


Perplexity: 28.11
Parameters: 125,239,296

2. NEFF Weight-only Pruning
------------------------------


Computing perplexity:  91%|█████████ | 10/11 [00:00<00:00, 68.53it/s]


Perplexity: 55.25 (Δ: +27.14)
Sparsity: 40.18%

3. EMP Activation-aware Pruning
------------------------------


Collecting activations: 100%|██████████| 32/32 [00:00<00:00, 165.00it/s]
Computing perplexity:  91%|█████████ | 10/11 [00:00<00:00, 68.89it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Perplexity: 55.25 (Δ: +27.14)
Sparsity: 40.18%

SUMMARY - LLM Perplexity Comparison
Method               Perplexity      Sparsity       
--------------------------------------------------
Original             28.11           0.00%          
NEFF (weight)        55.25           40.18%         
EMP (activation)     55.25           40.18%         

[TEST 2: BERT Classification Tasks]

BERT Experiments on SST2



Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

1. Evaluating Original BERT Model
------------------------------


Evaluating: 100%|██████████| 28/28 [00:00<00:00, 30.05it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Computing MLM perplexity: 100%|██████████| 100/100 [00:00<00:00, 266.05it/s]


Accuracy: 0.4908
Loss: 0.7275
MLM Perplexity: 15.50
Parameters: 109,483,778

2. NEFF Weight-only Pruning
------------------------------


Evaluating: 100%|██████████| 28/28 [00:00<00:00, 29.64it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Computing MLM perplexity: 100%|██████████| 100/100 [00:00<00:00, 275.74it/s]


Accuracy: 0.4908 (Δ: 0.00%)
Loss: 0.7263
MLM Perplexity: 20.07 (Δ: 4.57)
Overall Sparsity: 37.35%

3. EMP Activation-aware Pruning
------------------------------


Collecting activations: 100%|██████████| 16/16 [00:00<00:00, 93.25it/s]
Evaluating: 100%|██████████| 28/28 [00:00<00:00, 29.97it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Computing MLM perplexity: 100%|██████████| 100/100 [00:00<00:00, 292.62it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-bas

Accuracy: 0.4908 (Δ: 0.00%)
Loss: 0.7263
MLM Perplexity: 14.23 (Δ: -1.28)
Overall Sparsity: 37.35%

SUMMARY COMPARISON
Method               Accuracy     MLM PPL      Sparsity    
--------------------------------------------------------
Original             0.4908       15.50       0.00%
NEFF (weight)        0.4908       20.07      37.35%
EMP (activation)     0.4908       14.23      37.35%


mrpc/train-00000-of-00001.parquet:   0%|          | 0.00/649k [00:00<?, ?B/s]

mrpc/validation-00000-of-00001.parquet:   0%|          | 0.00/75.7k [00:00<?, ?B/s]

mrpc/test-00000-of-00001.parquet:   0%|          | 0.00/308k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]


BERT Experiments on MRPC



Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

1. Evaluating Original BERT Model
------------------------------


Evaluating: 100%|██████████| 13/13 [00:00<00:00, 28.32it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Computing MLM perplexity: 100%|██████████| 100/100 [00:00<00:00, 288.73it/s]


Accuracy: 0.6838
Loss: 0.6257
MLM Perplexity: 18.34
Parameters: 109,483,778

2. NEFF Weight-only Pruning
------------------------------


Evaluating: 100%|██████████| 13/13 [00:00<00:00, 29.49it/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Computing MLM perplexity: 100%|██████████| 100/100 [00:00<00:00, 259.23it/s]


Accuracy: 0.6838 (Δ: 0.00%)
Loss: 0.6228
MLM Perplexity: 18.44 (Δ: 0.10)
Overall Sparsity: 37.35%

3. EMP Activation-aware Pruning
------------------------------


Collecting activations: 100%|██████████| 16/16 [00:00<00:00, 93.49it/s]
Evaluating: 100%|██████████| 13/13 [00:00<00:00, 30.21it/s]
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/config.json
Retrying in 1s [Retry 1/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/config.json
Retrying in 2s [Retry 2/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/config.json
Retrying in 4s [Retry 3/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/config.json
Retrying in 8s [Retry 4/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/bert-base-uncased/resolve/main/config.json
Retrying in 8s [Retry 5/5].
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relatio

Accuracy: 0.6838 (Δ: 0.00%)
Loss: 0.6228
MLM Perplexity: 20.32 (Δ: 1.98)
Overall Sparsity: 37.35%

SUMMARY COMPARISON
Method               Accuracy     MLM PPL      Sparsity    
--------------------------------------------------------
Original             0.6838       18.34       0.00%
NEFF (weight)        0.6838       18.44      37.35%
EMP (activation)     0.6838       20.32      37.35%

FINAL RESULTS SUMMARY

LLM Perplexity Results:
  NEFF achieves 40.2% sparsity with 96.5% perplexity increase
  EMP achieves 40.2% sparsity with 96.5% perplexity increase

BERT Results Summary:

  SST2:
    NEFF: 37.3% sparsity, 0.00% accuracy change
    EMP:  37.3% sparsity, 0.00% accuracy change

  MRPC:
    NEFF: 37.3% sparsity, 0.00% accuracy change
    EMP:  37.3% sparsity, 0.00% accuracy change

Ready for ICLR 2026! 🚀



