In [1]:
# ===== Tiny Vision Transformer (CIFAR-10) + Activation-EMP pruning =====
import math, copy, time, random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

# --------- Reuse pruning utils (paste from the first block or import them) ----------

@torch.no_grad()
def model_sparsity(model: nn.Module) -> float:
    total, zeros = 0, 0
    for n, p in model.named_parameters():
        if p.dim() >= 2 and 'weight' in n:
            total += p.numel()
            zeros += (p == 0).sum().item()
    return zeros / max(total, 1)

# ---- Step 1: collect E[|x|] (per in-feature) for each Linear via forward hooks
@torch.no_grad()
def collect_input_magnitudes(model: nn.Module,
                             data_loader,
                             device,
                             num_batches: int = 10):
    model.eval()
    # list Linear modules in traversal order
    linear_list = [m for m in model.modules() if isinstance(m, nn.Linear)]
    sums = []
    counts = []
    handles = []

    for m in linear_list:
        sums.append(torch.zeros(m.in_features, device=device))
        counts.append(torch.tensor(0, device=device))

    index_of = {id(m): i for i, m in enumerate(linear_list)}

    def hook_fn(module, inputs, output):
        idx = index_of[id(module)]
        x = inputs[0].detach()
        # flatten all leading dims except last: [..., in_features]
        x2d = x.flatten(0, -2)  # (B*..., in_features)
        sums[idx] += x2d.abs().sum(dim=0)
        counts[idx] += x2d.shape[0]

    for m in linear_list:
        handles.append(m.register_forward_hook(hook_fn))

    seen = 0
    for data, target in data_loader:
        data = data.to(device)
        _ = model(data)
        seen += 1
        if seen >= num_batches:
            break

    for h in handles:
        h.remove()

    mags = [s / torch.clamp(c.float(), min=1.0) for s, c in zip(sums, counts)]
    # Return in the same order as linear_list
    return linear_list, mags

# ---- Step 2: build the activation-aware mask from N_eff on |W| * E|x|
@torch.no_grad()
def get_linear_mask_emp(module: nn.Linear,
                        in_mag: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Args:
        module: nn.Linear with weight shape [out, in]
        in_mag: tensor [in] = E[|x|] for this module's input
    Returns:
        mask (bool) with same shape as weight
        neff_row (long) length = out_features
    """
    W = module.weight.data  # [out, in]
    # contributions per input to each neuron:
    contrib = W.abs() * in_mag.unsqueeze(0)  # [out, in]
    row_sum = contrib.sum(dim=1, keepdim=True).clamp(min=1e-12)
    norm = contrib / row_sum                 # \hat c_ji

    neff = torch.floor(1.0 / norm.pow(2).sum(dim=1)).clamp(min=1, max=W.shape[1]).long()  # [out]

    # sort each row by importance and keep top neff[j]
    _, idx = torch.sort(norm, dim=1, descending=True)
    out, in_ = W.shape
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand(out, in_)
    keep_sorted = ranks < neff.unsqueeze(1)   # [out, in] (sorted order)

    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, idx, keep_sorted)
    return mask, neff

# ---- Step 3: prune with EMP (optional L1 row re-normalization)
@torch.no_grad()
def prune_model_emp_activation(model: nn.Module,
                               calib_loader,
                               device,
                               num_calib_batches: int = 10,
                               renormalize: bool = False):
    pruned = copy.deepcopy(model).to(device)
    linear_list, mags = collect_input_magnitudes(pruned, calib_loader, device, num_batches=num_calib_batches)

    layer_neff = {}
    for lin, mu in zip(linear_list, mags):
        W = lin.weight.data
        old_row_l1 = W.abs().sum(dim=1, keepdim=True)
        mask, neff = get_linear_mask_emp(lin, mu.to(W.device))
        # apply mask
        W.mul_(mask)
        if renormalize:
            new_row_l1 = W.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
            scale = old_row_l1 / new_row_l1
            W.mul_(scale)
        layer_neff[id(lin)] = {
            "name": getattr(lin, "_emp_name", None),
            "neff_row": neff.detach().cpu(),
            "avg_neff": neff.float().mean().item(),
            "in_features": W.shape[1],
            "out_features": W.shape[0],
            "layer_sparsity": float((~mask).sum().item() / mask.numel())
        }

    return pruned, layer_neff

# ---- Helper: pretty summary
def summarize_emp(layer_neff_dict):
    lines = []
    for k, v in layer_neff_dict.items():
        name = v.get("name") or f"Linear(id={k})"
        lines.append(
            f"{name:30s} | out={v['out_features']:4d} in={v['in_features']:4d} "
            f"| avg N_eff={v['avg_neff']:.1f} | sparsity={v['layer_sparsity']*100:5.1f}%"
        )
    return "\n".join(lines)

# ---- (Optional) attach names to ease reading
def tag_linear_names(model: nn.Module):
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m._emp_name = name


# --------- Model: Patch embedding, Encoder block, and ViT head ----------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):                   # x: [B,3,H,W]
        x = self.proj(x)                    # [B,embed,H/P,W/P]
        x = x.flatten(2).transpose(1, 2)    # [B, N, embed]
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, p=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=4, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):         # x: [B, N, C]
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)     # each: [B, N, H, D]
        q = q.transpose(1, 2)            # [B, H, N, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale        # [B,H,N,N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v                                       # [B,H,N,D]
        out = out.transpose(1, 2).reshape(B, N, C)           # [B,N,C]
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class EncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=2.0, drop=0.0, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SelfAttention(dim, num_heads, attn_drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp  = MLP(dim, int(dim * mlp_ratio), p=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class TinyViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3,
                 num_classes=10, embed_dim=128, depth=4, num_heads=4, mlp_ratio=2.0, drop=0.0):
        super().__init__()
        self.patch = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.pos = nn.Parameter(torch.zeros(1, self.patch.num_patches, embed_dim))
        self.blocks = nn.ModuleList([
            EncoderBlock(embed_dim, num_heads, mlp_ratio, drop, drop) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch(x) + self.pos
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = x.mean(dim=1)         # global average over patches (no class token)
        x = self.head(x)
        return F.log_softmax(x, dim=-1)

# --------- Data: CIFAR-10 (32x32) ----------
def get_cifar10_loaders(batch_size=128, num_workers=2):
    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tf)
    test  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_tf)
    return (DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True),
            DataLoader(test,  batch_size=256,       shuffle=False, num_workers=num_workers, pin_memory=True))

# --------- Train / Test ----------
def train_one_epoch(model, loader, opt, device):
    model.train()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = F.nll_loss(out, y)
        loss.backward()
        opt.step()
        total_loss += loss.item() * x.size(0)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss/total, 100*correct/total

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct, total_loss = 0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = F.nll_loss(out, y, reduction='sum')
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)
    return total_loss/total, 100*correct/total

# --------- Run everything ----------
if __name__ == "__main__":
    torch.manual_seed(0); random.seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, test_loader = get_cifar10_loaders()
    model = TinyViT(embed_dim=128, depth=4, num_heads=4, mlp_ratio=2.0, drop=0.1).to(device)

    print("Params:", sum(p.numel() for p in model.parameters())/1e6, "M")
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

    # Train a few epochs (tune as desired)
    epochs = 10
    for ep in range(1, epochs+1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, opt, device)
        te_loss, te_acc = evaluate(model, test_loader, device)
        print(f"Epoch {ep:02d}: train acc {tr_acc:5.2f}%  test acc {te_acc:5.2f}%")

    base_sparsity = model_sparsity(model)
    print(f"Base model sparsity: {base_sparsity:.4f}")

    # Activation-EMP prune using a small calibration slice of the train set
    pruned, layer_neff = prune_model_emp_activation(
        model, calib_loader=train_loader, device=device, num_calib_batches=8, renormalize=False
    )
    print("Layer-wise EMP summary:\n", summarize_emp(layer_neff))
    print(f"EMP sparsity: {model_sparsity(pruned):.4f}")

    # Evaluate pruned model
    te_loss, te_acc = evaluate(pruned, test_loader, device)
    print(f"EMP-pruned ViT — test acc {te_acc:5.2f}%  (loss {te_loss:.4f})")


Params: 0.54593 M




Epoch 01: train acc 33.46%  test acc 42.11%
Epoch 02: train acc 46.32%  test acc 49.63%
Epoch 03: train acc 52.18%  test acc 51.40%
Epoch 04: train acc 55.91%  test acc 58.05%
Epoch 05: train acc 58.67%  test acc 60.36%
Epoch 06: train acc 60.84%  test acc 62.72%
Epoch 07: train acc 62.39%  test acc 63.98%
Epoch 08: train acc 64.09%  test acc 66.32%
Epoch 09: train acc 65.35%  test acc 66.15%
Epoch 10: train acc 66.41%  test acc 65.86%
Base model sparsity: 0.0000
Layer-wise EMP summary:
 Linear(id=6060751744)          | out= 384 in= 128 | avg N_eff=86.7 | sparsity= 32.2%
Linear(id=6060746848)          | out= 128 in= 128 | avg N_eff=81.6 | sparsity= 36.3%
Linear(id=6060751216)          | out= 256 in= 128 | avg N_eff=86.8 | sparsity= 32.2%
Linear(id=6060747376)          | out= 128 in= 256 | avg N_eff=165.7 | sparsity= 35.3%
Linear(id=6060752320)          | out= 384 in= 128 | avg N_eff=88.9 | sparsity= 30.6%
Linear(id=6060750976)          | out= 128 in= 128 | avg N_eff=87.1 | sparsity= 32

In [2]:
# EMP vs Row-NEFF pruning on a Vision Transformer (CIFAR-10)
# -----------------------------------------------------------
# Requirements: torch, torchvision, numpy
# Optional: tqdm (for pretty progress bars)

import math, os, copy, random, time
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms

try:
    from tqdm import tqdm
    TQDM = True
except Exception:
    TQDM = False

# -----------------------------
# Utils
# -----------------------------
def set_seed(seed: int = 1337):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def accuracy(logits, targets):
    pred = logits.argmax(dim=1)
    return (pred == targets).float().mean().item()

def count_zeros(t: torch.Tensor) -> int:
    return (t == 0).sum().item()

def model_sparsity(model: nn.Module) -> float:
    total = 0; zeros = 0
    for n, p in model.named_parameters():
        if p.ndim >= 2 and "weight" in n:
            total += p.numel()
            zeros += count_zeros(p)
    return zeros / max(total, 1)

# -----------------------------
# Data
# -----------------------------
def get_cifar10(batch_size=128, num_workers=2):
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4,0.4,0.4,0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

    train = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
    test  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True,
                                               num_workers=num_workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=256, shuffle=False,
                                              num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

# -----------------------------
# ViT components
# -----------------------------
class DropPath(nn.Module):
    """Stochastic depth (per sample)"""
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training: return x
        keep = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rnd = x.new_empty(shape).bernoulli_(keep)
        return x * rnd / keep

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2
    def forward(self, x):
        # B, C, H, W -> B, N, D
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, drop=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x)                         # B, N, 3C
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]          # each: B, heads, N, head_dim
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B,N,C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path1 = DropPath(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
        self.drop_path2 = DropPath(drop_path)
    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        x = x + self.drop_path2(self.mlp(self.norm2(x)))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch=4, in_chans=3, num_classes=10,
                 emb_dim=256, depth=6, num_heads=8, mlp_ratio=4.0, drop=0.1, attn_drop=0.1, drop_path=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch, in_chans, emb_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, emb_dim))
        self.pos_drop = nn.Dropout(drop)

        dpr = torch.linspace(0, drop_path, steps=depth).tolist()
        self.blocks = nn.ModuleList([
            Block(emb_dim, num_heads, mlp_ratio, drop, attn_drop, dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.head = nn.Linear(emb_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)[:, 0]      # CLS
        x = self.head(x)
        return x

# -----------------------------
# Train / Eval
# -----------------------------
@dataclass
class TrainCfg:
    epochs: int = 20
    lr: float = 3e-4
    warmup_epochs: int = 2
    weight_decay: float = 0.05
    label_smoothing: float = 0.1
    grad_clip: float = 1.0

def train_one_epoch(model, loader, opt, device, loss_fn, scheduler=None):
    model.train()
    total_acc, total = 0.0, 0
    iterator = tqdm(loader, leave=False) if TQDM else loader
    for x, y in iterator:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad()
        logits = model(x)
        loss = loss_fn(logits, y)
        loss.backward()
        if scheduler is not None and hasattr(scheduler, "optimizer"):  # for some schedulers
            pass
        if math.isfinite(cfg.grad_clip): nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        total_acc += accuracy(logits.detach(), y) * x.size(0)
        total += x.size(0)
    if scheduler is not None:
        scheduler.step()
    return total_acc / max(total,1)

@torch.no_grad()
def evaluate(model, loader, device, loss_fn=None):
    model.eval()
    total_acc, total, total_loss = 0.0, 0, 0.0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        total_acc += accuracy(logits, y) * x.size(0)
        if loss_fn is not None:
            total_loss += F.cross_entropy(logits, y, reduction="sum").item()
        total += x.size(0)
    return total_acc / max(total,1), (total_loss / max(total,1) if loss_fn is not None else None)

# -----------------------------
# NEFF helpers (weights-only)
# -----------------------------
def per_row_neff_from_weight(W: torch.Tensor) -> torch.Tensor:
    # W: [out, in]
    absW = W.abs()
    denom = absW.sum(dim=1, keepdim=True).clamp_min(1e-12)
    p = absW / denom
    neff = torch.floor(1.0 / p.pow(2).sum(dim=1)).clamp(min=1, max=W.size(1))
    return neff  # [out]

def get_linear_mask_per_row(module: nn.Linear) -> torch.Tensor:
    W = module.weight.data
    out, in_ = W.shape
    absW = W.abs()
    p = absW / absW.sum(dim=1, keepdim=True).clamp_min(1e-12)
    neff = torch.floor(1.0 / (p.pow(2).sum(dim=1))).clamp_(min=1, max=in_).long()
    contrib = absW  # weight-only
    _, idx = torch.sort(contrib, dim=1, descending=True)
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand_as(idx)
    mask_sorted = ranks < neff.unsqueeze(1)
    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, idx, mask_sorted)
    return mask

def prune_model_neff_per_row(model: nn.Module, renorm=False) -> Tuple[nn.Module, List[Tuple[str,float,float]]]:
    pruned = copy.deepcopy(model)
    layer_info = []
    with torch.no_grad():
        for name, m in pruned.named_modules():
            if isinstance(m, nn.Linear):
                mask = get_linear_mask_per_row(m).to(m.weight.device)
                m.weight.mul_(mask)
                if renorm:
                    s = m.weight.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
                    m.weight.div_(s)  # optional
                neff = per_row_neff_from_weight(m.weight).float().mean().item()
                layer_spars = count_zeros(m.weight) / m.weight.numel()
                layer_info.append((f"{name}", neff, layer_spars))
    return pruned, layer_info

# -----------------------------
# Activation-aware EMP (W*x)
# -----------------------------
@torch.no_grad()
def collect_activation_means(model: nn.Module, loader, device, num_batches=8) -> Dict[nn.Linear, torch.Tensor]:
    """
    For each Linear layer, estimate E[|x|] per input channel (last dim).
    """
    stats_sum: Dict[nn.Linear, torch.Tensor] = {}
    stats_count: Dict[nn.Linear, int] = {}
    handles = []

    def register(m: nn.Linear):
        stats_sum[m] = torch.zeros(m.in_features, device=device)
        stats_count[m] = 0
        def hook(mod, inp, out):
            x = inp[0].detach()
            # reduce over all dims except the last (feature dim)
            red_dims = tuple(range(x.dim() - 1))
            s = x.abs().sum(dim=red_dims)
            stats_sum[mod] += s
            stats_count[mod] += (x.numel() // mod.in_features)
        return m.register_forward_hook(hook)

    for m in model.modules():
        if isinstance(m, nn.Linear):
            handles.append(register(m))

    model.eval()
    it = iter(loader)
    for b in range(num_batches):
        try:
            x, _ = next(it)
        except StopIteration:
            break
        x = x.to(device, non_blocking=True)
        model(x)

    for h in handles: h.remove()
    means = {m: (stats_sum[m] / max(stats_count[m],1)).clamp_min(1e-12) for m in stats_sum.keys()}
    return means

def prune_model_emp_activation(model: nn.Module, calib_loader, device, num_calib_batches=8, renorm=False):
    """
    Activation-aware EMP pruning: keep, in each row, the top N_eff elements using
    p_ij ∝ |w_ij| * E[|x_j|]
    """
    pruned = copy.deepcopy(model).to(device)
    act_means = collect_activation_means(pruned, calib_loader, device, num_batches=num_calib_batches)

    layer_info = []
    with torch.no_grad():
        for name, m in pruned.named_modules():
            if isinstance(m, nn.Linear):
                W = m.weight.data
                out, in_ = W.shape
                mean_abs_x = act_means[m]  # [in]
                contrib = W.abs() * mean_abs_x.unsqueeze(0)  # [out, in]
                denom = contrib.sum(dim=1, keepdim=True).clamp_min(1e-12)
                p = contrib / denom
                neff = torch.floor(1.0 / p.pow(2).sum(dim=1)).clamp_(min=1, max=in_).long()

                # build mask row-wise
                _, idx = torch.sort(contrib, dim=1, descending=True)
                ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand_as(idx)
                mask_sorted = ranks < neff.unsqueeze(1)
                mask = torch.zeros_like(W, dtype=torch.bool)
                mask.scatter_(1, idx, mask_sorted)
                m.weight.mul_(mask)
                if renorm:
                    s = m.weight.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
                    m.weight.div_(s)

                # report
                avg_neff = neff.float().mean().item()
                layer_spars = count_zeros(m.weight) / m.weight.numel()
                layer_info.append((f"{name}", avg_neff, layer_spars))

    return pruned, layer_info

# -----------------------------
# Reporting helpers
# -----------------------------
def print_layerwise_report(tag: str, layer_info: List[Tuple[str,float,float]]):
    print(f"Layer-wise {tag} summary:")
    for (name, neff_avg, spars) in layer_info:
        # try to parse layer shape for nicer printing if possible
        print(f" {name:<30s} | avg N_eff={neff_avg:5.1f} | sparsity={100*spars:5.1f}%")
    print()

def eval_and_report(model, test_loader, device, tag="Model"):
    acc, loss = evaluate(model, test_loader, device, loss_fn=True)
    print(f"{tag:>20s} — test acc {acc*100:5.2f}%  (loss {loss:.4f})")
    return acc, loss

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    set_seed(123)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = get_cifar10(batch_size=128, num_workers=2)

    # Build a stronger baseline than a tiny ViT, but still lightweight
    model = ViT(
        img_size=32, patch=4, emb_dim=256, depth=6, num_heads=8,
        mlp_ratio=4.0, drop=0.1, attn_drop=0.1, drop_path=0.1, num_classes=10
    ).to(device)

    cfg = TrainCfg(epochs=20, lr=3e-4, warmup_epochs=2, weight_decay=0.05, label_smoothing=0.1, grad_clip=1.0)

    # Optimizer + cosine schedule with warmup
    opt = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, betas=(0.9, 0.999))
    sched = CosineAnnealingLR(opt, T_max=cfg.epochs - cfg.warmup_epochs)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)

    # Warmup loop
    warmup_steps = cfg.warmup_epochs * len(train_loader)
    if warmup_steps > 0:
        warmup_factor = cfg.lr / warmup_steps
        cur_lr = 0.0

    print("Training baseline ViT…")
    for epoch in range(1, cfg.epochs + 1):
        if epoch <= cfg.warmup_epochs:
            # linear warmup
            for pgroup in opt.param_groups:
                pgroup["lr"] = min(cfg.lr, (epoch-1)*len(train_loader)*warmup_factor + warmup_factor*1)
        else:
            # cosine decay after warmup
            pass

        train_acc = train_one_epoch(model, train_loader, opt, device, loss_fn,
                                    scheduler=None if epoch <= cfg.warmup_epochs else sched)
        test_acc, _ = evaluate(model, test_loader, device, loss_fn=True)
        print(f"Epoch {epoch:02d}: train acc {train_acc*100:5.2f}%  test acc {test_acc*100:5.2f}%")

    print(f"Base model sparsity: {model_sparsity(model):.4f}")
    base_acc, base_loss = eval_and_report(model, test_loader, device, tag="Baseline ViT")

    # ---------------- Row-NEFF pruning ----------------
    row_model, row_info = prune_model_neff_per_row(model, renorm=False)
    print_layerwise_report("Row-NEFF", row_info)
    print(f"Row-NEFF sparsity: {model_sparsity(row_model):.4f}")
    eval_and_report(row_model, test_loader, device, tag="Row-NEFF pruned ViT")

    # --------------- Activation EMP (W*x) -------------
    # Use a small calibration subset (first few batches of training loader)
    emp_model, emp_info = prune_model_emp_activation(model, calib_loader=train_loader,
                                                     device=device, num_calib_batches=8, renorm=False)
    print_layerwise_report("EMP (activation)", emp_info)
    print(f"EMP sparsity: {model_sparsity(emp_model):.4f}")
    eval_and_report(emp_model, test_loader, device, tag="EMP-pruned ViT")


Training baseline ViT…


                                                 

Epoch 01: train acc 14.53%  test acc 23.80%


                                                 

Epoch 02: train acc 28.50%  test acc 37.76%


                                                 

Epoch 03: train acc 39.15%  test acc 45.97%


                                                 

Epoch 04: train acc 45.27%  test acc 52.45%


                                                 

Epoch 05: train acc 49.40%  test acc 55.83%


                                                 

Epoch 06: train acc 51.95%  test acc 58.17%


                                                 

Epoch 07: train acc 53.92%  test acc 58.78%


                                                 

Epoch 08: train acc 55.99%  test acc 61.19%


                                                 

Epoch 09: train acc 57.34%  test acc 62.27%


                                                 

Epoch 10: train acc 58.90%  test acc 62.02%


                                                 

Epoch 11: train acc 60.26%  test acc 64.30%


                                                 

Epoch 12: train acc 61.53%  test acc 64.69%


                                                 

Epoch 13: train acc 62.29%  test acc 66.02%


                                                 

Epoch 14: train acc 63.49%  test acc 67.71%


                                                 

Epoch 15: train acc 64.27%  test acc 67.25%


                                                 

Epoch 16: train acc 64.81%  test acc 67.19%


                                                 

Epoch 17: train acc 65.64%  test acc 69.34%


                                                 

Epoch 18: train acc 65.95%  test acc 69.02%


                                                 

Epoch 19: train acc 66.15%  test acc 69.39%


                                                 

Epoch 20: train acc 66.50%  test acc 69.45%
Base model sparsity: 0.0000
        Baseline ViT — test acc 69.45%  (loss 0.9024)
Layer-wise Row-NEFF summary:
 blocks.0.attn.qkv              | avg N_eff=133.3 | sparsity= 36.5%
 blocks.0.attn.proj             | avg N_eff=134.0 | sparsity= 36.2%
 blocks.0.mlp.fc1               | avg N_eff=133.7 | sparsity= 36.4%
 blocks.0.mlp.fc2               | avg N_eff=534.3 | sparsity= 36.4%
 blocks.1.attn.qkv              | avg N_eff=133.4 | sparsity= 36.4%
 blocks.1.attn.proj             | avg N_eff=133.4 | sparsity= 36.5%
 blocks.1.mlp.fc1               | avg N_eff=133.4 | sparsity= 36.4%
 blocks.1.mlp.fc2               | avg N_eff=535.4 | sparsity= 36.3%
 blocks.2.attn.qkv              | avg N_eff=133.5 | sparsity= 36.4%
 blocks.2.attn.proj             | avg N_eff=133.7 | sparsity= 36.4%
 blocks.2.mlp.fc1               | avg N_eff=133.5 | sparsity= 36.4%
 blocks.2.mlp.fc2               | avg N_eff=534.7 | sparsity= 36.4%
 blocks.3.attn.qkv           