# test for effective activation function

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from typing import List, Dict, Any, Tuple

import copy
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import os
import json

import matplotlib.pyplot as plt
import time
from collections import defaultdict
import tqdm

folder = "test_results"
os.makedirs(folder, exist_ok=True)

## well trained linear mlp model in MNIST dataset

In [2]:
class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size= [512, 512, 512]):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList()
        
        prev_size = input_size
        for size in hidden_size:
            self.layers.append(nn.Linear(prev_size, size))
            prev_size = size
            
        self.output = nn.Linear(prev_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        
        for layer in self.layers:
            x = F.relu(layer(x))        
        x = self.output(x)
        return F.log_softmax(x, dim=1)
        
def train(model, device, train_loader, optimizer, epoch):
    """Train for one epoch"""
    model.train()
    train_loss = 0
    correct = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy


def test(model, device, test_loader):
    """Evaluate model on test set"""
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    
    return test_loss, accuracy

In [3]:
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 3e-4

# MINIST-10 dataset
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)   
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearModel(input_size=28*28, output_size=10, hidden_size=[1024, 512, 512]).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

result = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': []
}


100%|██████████| 9.91M/9.91M [00:01<00:00, 9.60MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 923kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 6.95MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.29MB/s]


In [4]:
# initial test
test_loss, test_accuracy = test(model, device, test_loader)
result['test_loss'].append(test_loss)
result['test_accuracy'].append(test_accuracy)

# Training loop
for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, epoch)
    result['train_loss'].append(train_loss)
    result['train_accuracy'].append(train_accuracy)
    
    # Test after each epoch
    test_loss, test_accuracy = test(model, device, test_loader)
    result['test_loss'].append(test_loss)
    result['test_accuracy'].append(test_accuracy)
    
# Save the model

torch.save(model.state_dict(), folder + '/' +'linear_model.pth')



Test set: Average loss: 2.3008, Accuracy: 1003/10000 (10.03%)


Test set: Average loss: 0.1090, Accuracy: 9644/10000 (96.44%)


Test set: Average loss: 0.0869, Accuracy: 9721/10000 (97.21%)


Test set: Average loss: 0.0704, Accuracy: 9771/10000 (97.71%)


Test set: Average loss: 0.0797, Accuracy: 9756/10000 (97.56%)


Test set: Average loss: 0.0630, Accuracy: 9811/10000 (98.11%)


Test set: Average loss: 0.0664, Accuracy: 9815/10000 (98.15%)


Test set: Average loss: 0.0768, Accuracy: 9794/10000 (97.94%)


Test set: Average loss: 0.0720, Accuracy: 9814/10000 (98.14%)


Test set: Average loss: 0.0733, Accuracy: 9793/10000 (97.93%)


Test set: Average loss: 0.0751, Accuracy: 9815/10000 (98.15%)



In [5]:
model.load_state_dict(torch.load(folder + '/' +'linear_model.pth'))
model.to(device)

LinearModel(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (output): Linear(in_features=512, out_features=10, bias=True)
)

## per colum pruning which means Prune weights going into a neuron

In [5]:
def get_linear_mask_per_column(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=0, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=0, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=0, descending=True)
    range_tensor = torch.arange(output_size, device=x.device).unsqueeze(0).expand(input_size, -1).T
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    return mask

In [6]:
def prune_model_neff_per_column(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask_per_column(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.sum(dim=0, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model

def model_sparsity(model):
    """Calculate the sparsity of the model"""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += torch.sum(param == 0).item()
    
    sparsity = zero_params / total_params
    return sparsity

In [7]:
prune_model = prune_model_neff_per_column(model, renormalize=False)
prune_model.to(device)
# Test the pruned model without renormalization
test_loss, test_accuracy = test(prune_model, device, test_loader)
print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model without renormalization
torch.save(prune_model.state_dict(), folder + '/' + 'pruned_linear_model.pth')


Test set: Average loss: 0.0685, Accuracy: 9803/10000 (98.03%)

Pruned Model without Renormalization - Test Loss: 0.0685, Test Accuracy: 98.03%


In [8]:
sparsity = model_sparsity(prune_model)
print(f'Sparsity of the pruned model without renormalization: {sparsity:.4f}')
sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

Sparsity of the pruned model without renormalization: 0.3420
Sparsity of the original model: 0.0000


## per row pruning which means Prune weights going into a neuron

In [9]:
def get_linear_mask_per_row(module:nn.Module) -> torch.Tensor:
    x = module.weight.data
    output_size, input_size = x.shape
    x_norm = torch.abs(x) / torch.sum(torch.abs(x), dim=1, keepdim=True)
    neff = torch.floor(1/torch.sum((x_norm ** 2), dim=1, keepdim=True).squeeze(0))
    
    _, indices = torch.sort(x_norm, dim=1, descending=True)
    range_tensor = torch.arange(input_size, device=x.device).unsqueeze(0).expand(output_size, -1)
    sorted_mask = range_tensor < neff
    
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(1, indices, sorted_mask)
    return mask

In [10]:
def prune_model_neff_per_row(model, renormalize=False):
    model = copy.deepcopy(model)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            mask = get_linear_mask_per_row(module).to(module.weight.device)
            with torch.no_grad():
                module.weight *= mask
                if renormalize:
                    row_sum = module.weight.sum(dim=0, keepdim=True).clamp(min=1e-8)
                    module.weight.div_(row_sum)
    return model


In [11]:
prune_model_row = prune_model_neff_per_row(model, renormalize=False)
prune_model_row.to(device)
# Test the pruned model without renormalization
test_loss, test_accuracy = test(prune_model_row, device, test_loader)
print(f'Pruned Model without Renormalization - Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# Save the pruned model without renormalization
torch.save(prune_model_row.state_dict(), folder + '/' + 'pruned_linear_model_row.pth')


Test set: Average loss: 0.0707, Accuracy: 9803/10000 (98.03%)

Pruned Model without Renormalization - Test Loss: 0.0707, Test Accuracy: 98.03%


In [12]:
sparsity = model_sparsity(prune_model_row)
print(f'Sparsity of the pruned model without renormalization: {sparsity:.4f}')
sparsity = model_sparsity(model)
print(f'Sparsity of the original model: {sparsity:.4f}')

Sparsity of the pruned model without renormalization: 0.3646
Sparsity of the original model: 0.0000


## EMP ACT

In [13]:
# ===== Activation-EMP (W·x) pruning utilities for nn.Linear =====
import copy
import torch
import torch.nn as nn

@torch.no_grad()
def model_sparsity(model: nn.Module) -> float:
    total, zeros = 0, 0
    for n, p in model.named_parameters():
        if p.dim() >= 2 and 'weight' in n:
            total += p.numel()
            zeros += (p == 0).sum().item()
    return zeros / max(total, 1)

# ---- Step 1: collect E[|x|] (per in-feature) for each Linear via forward hooks
@torch.no_grad()
def collect_input_magnitudes(model: nn.Module,
                             data_loader,
                             device,
                             num_batches: int = 10):
    model.eval()
    # list Linear modules in traversal order
    linear_list = [m for m in model.modules() if isinstance(m, nn.Linear)]
    sums = []
    counts = []
    handles = []

    for m in linear_list:
        sums.append(torch.zeros(m.in_features, device=device))
        counts.append(torch.tensor(0, device=device))

    index_of = {id(m): i for i, m in enumerate(linear_list)}

    def hook_fn(module, inputs, output):
        idx = index_of[id(module)]
        x = inputs[0].detach()
        # flatten all leading dims except last: [..., in_features]
        x2d = x.flatten(0, -2)  # (B*..., in_features)
        sums[idx] += x2d.abs().sum(dim=0)
        counts[idx] += x2d.shape[0]

    for m in linear_list:
        handles.append(m.register_forward_hook(hook_fn))

    seen = 0
    for data, target in data_loader:
        data = data.to(device)
        _ = model(data)
        seen += 1
        if seen >= num_batches:
            break

    for h in handles:
        h.remove()

    mags = [s / torch.clamp(c.float(), min=1.0) for s, c in zip(sums, counts)]
    # Return in the same order as linear_list
    return linear_list, mags

# ---- Step 2: build the activation-aware mask from N_eff on |W| * E|x|
@torch.no_grad()
def get_linear_mask_emp(module: nn.Linear,
                        in_mag: torch.Tensor) -> (torch.Tensor, torch.Tensor):
    """
    Args:
        module: nn.Linear with weight shape [out, in]
        in_mag: tensor [in] = E[|x|] for this module's input
    Returns:
        mask (bool) with same shape as weight
        neff_row (long) length = out_features
    """
    W = module.weight.data  # [out, in]
    # contributions per input to each neuron:
    contrib = W.abs() * in_mag.unsqueeze(0)  # [out, in]
    row_sum = contrib.sum(dim=1, keepdim=True).clamp(min=1e-12)
    norm = contrib / row_sum                 # \hat c_ji

    neff = torch.floor(1.0 / norm.pow(2).sum(dim=1)).clamp(min=1, max=W.shape[1]).long()  # [out]

    # sort each row by importance and keep top neff[j]
    _, idx = torch.sort(norm, dim=1, descending=True)
    out, in_ = W.shape
    ranks = torch.arange(in_, device=W.device).unsqueeze(0).expand(out, in_)
    keep_sorted = ranks < neff.unsqueeze(1)   # [out, in] (sorted order)

    mask = torch.zeros_like(W, dtype=torch.bool)
    mask.scatter_(1, idx, keep_sorted)
    return mask, neff

# ---- Step 3: prune with EMP (optional L1 row re-normalization)
@torch.no_grad()
def prune_model_emp_activation(model: nn.Module,
                               calib_loader,
                               device,
                               num_calib_batches: int = 10,
                               renormalize: bool = False):
    pruned = copy.deepcopy(model).to(device)
    linear_list, mags = collect_input_magnitudes(pruned, calib_loader, device, num_batches=num_calib_batches)

    layer_neff = {}
    for lin, mu in zip(linear_list, mags):
        W = lin.weight.data
        old_row_l1 = W.abs().sum(dim=1, keepdim=True)
        mask, neff = get_linear_mask_emp(lin, mu.to(W.device))
        # apply mask
        W.mul_(mask)
        if renormalize:
            new_row_l1 = W.abs().sum(dim=1, keepdim=True).clamp(min=1e-8)
            scale = old_row_l1 / new_row_l1
            W.mul_(scale)
        layer_neff[id(lin)] = {
            "name": getattr(lin, "_emp_name", None),
            "neff_row": neff.detach().cpu(),
            "avg_neff": neff.float().mean().item(),
            "in_features": W.shape[1],
            "out_features": W.shape[0],
            "layer_sparsity": float((~mask).sum().item() / mask.numel())
        }

    return pruned, layer_neff

# ---- Helper: pretty summary
def summarize_emp(layer_neff_dict):
    lines = []
    for k, v in layer_neff_dict.items():
        name = v.get("name") or f"Linear(id={k})"
        lines.append(
            f"{name:30s} | out={v['out_features']:4d} in={v['in_features']:4d} "
            f"| avg N_eff={v['avg_neff']:.1f} | sparsity={v['layer_sparsity']*100:5.1f}%"
        )
    return "\n".join(lines)

# ---- (Optional) attach names to ease reading
def tag_linear_names(model: nn.Module):
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            m._emp_name = name


In [14]:
# After you finish training your model `model` and have loaders:
#   train_loader, test_loader, device, and your `test(...)` function defined

# 1) Build a pruned copy using a few calibration batches
pruned_emp, layer_neff = prune_model_emp_activation(
    model, calib_loader=train_loader, device=device, num_calib_batches=10, renormalize=False
)
print("Layer-wise EMP summary:\n", summarize_emp(layer_neff))

# 2) Evaluate
test_loss, test_acc = test(pruned_emp, device, test_loader)
print(f"EMP-pruned (W·x) - Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
print(f"EMP-pruned model sparsity: {model_sparsity(pruned_emp):.4f}")

# (Optional) compare to your row/column-only weight pruning results.


Layer-wise EMP summary:
 Linear(id=13268523312)         | out=1024 in= 784 | avg N_eff=355.7 | sparsity= 54.6%
Linear(id=13268522976)         | out= 512 in=1024 | avg N_eff=270.8 | sparsity= 73.6%
Linear(id=4891302336)          | out= 512 in= 512 | avg N_eff=250.4 | sparsity= 51.1%
Linear(id=4891300320)          | out=  10 in= 512 | avg N_eff=253.3 | sparsity= 50.5%

Test set: Average loss: 0.0851, Accuracy: 9749/10000 (97.49%)

EMP-pruned (W·x) - Test Loss: 0.0851, Test Acc: 97.49%
EMP-pruned model sparsity: 0.6026
