# setup

In [2]:
import os, math, time, random
from dataclasses import dataclass, replace
from typing import Tuple
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR, ConstantLR
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets, models
from torchvision.datasets.utils import download_and_extract_archive

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [3]:
AUTO_DOWNLOAD_TINYIN = True

@dataclass
class TrainCfg:
    name: str
    dataset: str
    num_classes: int
    input_size: int
    batch_size: int = 128
    epochs: int = 200
    optimizer: str = "SGD"
    lr: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.0
    cosine: bool = True
    warmup_epochs: int = 5
    cifar_style_resnet: bool = False

cfgs = {
    # MNIST (1×28×28), FC baselines
    "FC2-M":  TrainCfg("FC2",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-M":  TrainCfg("FC5",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-M": TrainCfg("FC12", "MNIST",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # Fashion-MNIST (1×28×28), FC baselines
    "FC2-FM":  TrainCfg("FC2",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-FM":  TrainCfg("FC5",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-FM": TrainCfg("FC12", "FashionMNIST", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # CIFAR-10
    "FC5":     TrainCfg("FC5",    "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12":    TrainCfg("FC12",   "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "VGG16":   TrainCfg("VGG16",  "CIFAR10", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),
    "AlexNet": TrainCfg("AlexNet","CIFAR10", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224

    # CIFAR-100
    "ResNet18_C100": TrainCfg("ResNet18_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),
    "ResNet50_C100": TrainCfg("ResNet50_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),

    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)
    "ResNet18_TinyIN": TrainCfg("ResNet18_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
    "ResNet50_TinyIN": TrainCfg("ResNet50_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
}


In [4]:
class FCNet(nn.Module):
    def __init__(self, in_dim: int, widths):
        super().__init__()
        dims = [in_dim] + list(widths)
        layers = []
        for i in range(len(dims) - 2):
            layers += [nn.Linear(dims[i], dims[i + 1]), nn.ReLU(inplace=True)]
        layers += [nn.Linear(dims[-2], dims[-1])]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = torch.flatten(x, 1)
        return self.net(x)

def _infer_in_dim(cfg: TrainCfg) -> int:
    c = 1 if cfg.dataset in ("MNIST", "FashionMNIST") else 3
    return c * cfg.input_size * cfg.input_size

def build_model(cfg: TrainCfg) -> nn.Module:
    in_dim = _infer_in_dim(cfg)

    if cfg.name == "FC2":
        model = FCNet(in_dim, [100, cfg.num_classes])
    elif cfg.name == "FC5":
        model = FCNet(in_dim, [1000, 600, 300, 100, cfg.num_classes])
    elif cfg.name == "FC12":
        model = FCNet(in_dim, [1000, 900, 800, 750, 700, 650, 600, 500, 400, 200, 100, cfg.num_classes])

    elif cfg.name == "AlexNet":
        model = models.alexnet(weights=None)
        if isinstance(model.classifier, nn.Sequential):
            in_features = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_features, cfg.num_classes)

    elif cfg.name == "VGG16":
        model = models.vgg16(weights=None)
        model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        model.classifier = nn.Linear(512, cfg.num_classes)

    elif "ResNet18" in cfg.name:
        model = models.resnet18(weights=None, num_classes=cfg.num_classes)
        if cfg.cifar_style_resnet:
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            model.maxpool = nn.Identity()

    elif "ResNet50" in cfg.name:
        model = models.resnet50(weights=None, num_classes=cfg.num_classes)
        if cfg.cifar_style_resnet:
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            model.maxpool = nn.Identity()
    else:
        raise ValueError(f"Unknown model name {cfg.name}")

    return model.to(device)

def count_params(model: nn.Module):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


In [5]:
# Dataset statistics
MNIST_MEAN,   MNIST_STD   = (0.1307,), (0.3081,)
FASHION_MEAN, FASHION_STD = (0.2860,), (0.3530,)
CIFAR10_MEAN, CIFAR10_STD   = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
CIFAR100_MEAN, CIFAR100_STD = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406),    (0.229, 0.224, 0.225)

def _cifar_transforms(size: int, mean, std, resize_to_224: bool = False):
    if size == 32 and not resize_to_224:
        train_tfms = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        train_tfms = transforms.Compose([
            transforms.Resize(size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        test_tfms = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    return train_tfms, test_tfms

def _gray_transforms(size: int, mean, std):
    train_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    return train_tfms, test_tfms

def get_dataloaders(cfg: TrainCfg, data_root: str = "./data") -> Tuple[DataLoader, DataLoader]:
    if cfg.dataset == "MNIST":
        train_tfms, test_tfms = _gray_transforms(28, MNIST_MEAN, MNIST_STD)
        train_set = datasets.MNIST(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.MNIST(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "FashionMNIST":
        train_tfms, test_tfms = _gray_transforms(28, FASHION_MEAN, FASHION_STD)
        train_set = datasets.FashionMNIST(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "CIFAR10":
        resize_to_224 = (cfg.input_size == 224)
        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR10_MEAN, CIFAR10_STD, resize_to_224=resize_to_224)
        train_set = datasets.CIFAR10(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "CIFAR100":
        train_tfms, test_tfms = _cifar_transforms(cfg.input_size, CIFAR100_MEAN, CIFAR100_STD, resize_to_224=False)
        train_set = datasets.CIFAR100(root=data_root, train=True,  download=True, transform=train_tfms)
        test_set  = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tfms)

    elif cfg.dataset == "TinyImageNet":
        train_dir = os.path.join(data_root, "tiny-imagenet-200", "train")
        val_dir   = os.path.join(data_root, "tiny-imagenet-200", "val")
        train_tfms = transforms.Compose([
            transforms.RandomResizedCrop(cfg.input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
        test_tfms = transforms.Compose([
            transforms.Resize(cfg.input_size + 8),
            transforms.CenterCrop(cfg.input_size),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
        train_set = datasets.ImageFolder(train_dir, transform=train_tfms)
        test_set  = datasets.ImageFolder(val_dir,   transform=test_tfms)

    else:
        raise ValueError(f"Unknown dataset {cfg.dataset}")

    pin = torch.cuda.is_available()
    train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,  num_workers=4, pin_memory=pin)
    test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=pin)
    return train_loader, test_loader


In [6]:
AUTO_DOWNLOAD_TINYIN = True

@dataclass
class TrainCfg:
    name: str
    dataset: str
    num_classes: int
    input_size: int
    batch_size: int = 128
    epochs: int = 200
    optimizer: str = "SGD"
    lr: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.0
    cosine: bool = True
    warmup_epochs: int = 5
    cifar_style_resnet: bool = False

cfgs = {
    # MNIST (1×28×28), FC baselines
    "FC2-M":  TrainCfg("FC2",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-M":  TrainCfg("FC5",  "MNIST",         10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-M": TrainCfg("FC12", "MNIST",         10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # Fashion-MNIST (1×28×28), FC baselines
    "FC2-FM":  TrainCfg("FC2",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC5-FM":  TrainCfg("FC5",  "FashionMNIST", 10, input_size=28, epochs=5, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12-FM": TrainCfg("FC12", "FashionMNIST", 10, input_size=28, epochs=10, optimizer='ADAM', lr=1e-4, cosine=False, warmup_epochs=0, weight_decay=0.0),

    # CIFAR-10
    "FC5":     TrainCfg("FC5",    "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "FC12":    TrainCfg("FC12",   "CIFAR10", 10, input_size=32, lr=0.01, cosine=False, warmup_epochs=0, weight_decay=0.0),
    "VGG16":   TrainCfg("VGG16",  "CIFAR10", 10, input_size=32, lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),
    "AlexNet": TrainCfg("AlexNet","CIFAR10", 10, input_size=224,lr=0.01, cosine=True,  warmup_epochs=5, weight_decay=5e-4),  # upsample to 224

    # CIFAR-100
    "ResNet18_C100": TrainCfg("ResNet18_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),
    "ResNet50_C100": TrainCfg("ResNet50_C100","CIFAR100",100,input_size=32,lr=0.1, cosine=True, warmup_epochs=5, weight_decay=5e-4, cifar_style_resnet=True),

    # TinyImageNet (optional; place dataset under ./data/tiny-imagenet-200)
    "ResNet18_TinyIN": TrainCfg("ResNet18_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
    "ResNet50_TinyIN": TrainCfg("ResNet50_TinyIN","TinyImageNet",200,input_size=64,lr=0.01,cosine=True,warmup_epochs=5,weight_decay=5e-4),
}


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion):
    model.eval()
    total_loss, total_correct, n = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * y.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        n += y.size(0)
    return total_loss / n, 100.0 * total_correct / n

In [7]:
def EMP_global_magnitude(model, beta=1.0):
    
    model1 = copy.deepcopy(model)
    params = []
    for m in model1.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            params.append(m.weight.data.view(-1))
    if not params:
        return model, 0.0
    
    s = torch.cat(params)
    x_norm = torch.abs(s) / torch.sum(torch.abs(s))
    neff = 1/torch.sum((x_norm ** 2))
    r_neff = torch.floor(beta * neff)
    r_neff = r_neff.clamp(min=1, max=len(s)-1)
    _, indices = torch.sort(torch.abs(s), descending=True)
    thresh = torch.abs(s)[indices[int(r_neff)]]
    
    total, kept = 0, 0
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            w = m.weight.data
            mask = (torch.abs(w) >= thresh)
            kept += mask.sum().item()
            total += mask.numel()
            w.mul_(mask)
    
    sparsity = 1.0 - (kept / total)
    return sparsity

def EMP_loss_change(model: nn.Module, loader: DataLoader, beta: float = 1.0):
    """Taylor-based importance scoring with predefined threshold"""
    model1 = copy.deepcopy(model)
    model1.train()
    criterion = nn.CrossEntropyLoss()
    
    # Compute gradients on a batch
    data_iter = iter(loader)
    x, y = next(data_iter)
    x, y = x.to(device), y.to(device)
    
    model1.zero_grad()
    outputs = model1(x)
    loss = criterion(outputs, y)
    loss.backward()
    
    # Calculate importance scores (gradient * weight)
    scores = []
    for m in model1.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if m.weight.grad is not None:
                importance = torch.abs(m.weight.data * m.weight.grad)
            else:
                importance = torch.abs(m.weight.data)
            scores.append(importance.view(-1))
    
    if not scores:
        return model, 0.0
    
    all_scores = torch.cat(scores)
    #print(len(all_scores))
    w = torch.abs(all_scores)/torch.sum(torch.abs(all_scores))
    neff = 1/torch.sum(w ** 2)
    r_neff = torch.clamp(torch.floor(beta * neff), 1, len(all_scores)-1)
    #print(r_neff)
    threshold = torch.quantile(torch.abs(all_scores), r_neff / len(all_scores))
    
    # Apply mask
    idx = 0
    total, kept = 0, 0
    for m in model1.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            numel = m.weight.numel()
            if m.weight.grad is not None:
                importance = torch.abs(m.weight.data * m.weight.grad)
            else:
                importance = torch.abs(m.weight.data)
            mask = importance > threshold
            m.weight.data.mul_(mask)
            kept += mask.sum().item()
            total += numel
    
    actual_sparsity = 1.0 - (kept / total)
    return model1, actual_sparsity


def EMP_saliency(model: nn.Module, loader: DataLoader, beta: float = 1.0):
    """Gradient-based saliency with predefined threshold"""
    model1 = copy.deepcopy(model)
    model1.train()
    criterion = nn.CrossEntropyLoss()
    
    # Accumulate gradients over multiple batches for stability
    num_batches = min(10, len(loader))
    data_iter = iter(loader)
    
    # Initialize gradient accumulator
    grad_accum = {}
    for name, m in model1.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            grad_accum[name] = torch.zeros_like(m.weight)
    
    # Accumulate gradients
    for _ in range(num_batches):
        x, y = next(data_iter)
        x, y = x.to(device), y.to(device)
        
        model.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        
        for name, m in model.named_modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)) and m.weight.grad is not None:
                grad_accum[name] += torch.abs(m.weight.grad)
    
    # Average gradients
    for name in grad_accum:
        grad_accum[name] /= num_batches
    
    # Collect all saliency scores
    scores = []
    for name, m in model1.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            scores.append(grad_accum[name].view(-1))
    
    if not scores:
        return model, 0.0
    
    all_scores = torch.cat(scores)
    w = torch.abs(all_scores)/torch.sum(torch.abs(all_scores))
    neff = 1/torch.sum(w ** 2)
    r_neff = torch.clamp(torch.floor(beta * neff), 1, len(all_scores)-1)
    threshold = torch.quantile(torch.abs(all_scores), r_neff / len(all_scores))

    # Apply mask
    total, kept = 0, 0
    for name, m in model1.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            mask = grad_accum[name] > threshold
            m.weight.data.mul_(mask)
            kept += mask.sum().item()
            total += mask.numel()
    
    actual_sparsity = 1.0 - (kept / total)
    return model1, actual_sparsity


# loss change, fc5, fc12, vgg, alexnet, res18, res50

In [7]:
keys = ['FC5', 'FC12', 'VGG16', 'AlexNet', 'ResNet18_C100', 'ResNet50_C100', 'ResNet18_TinyIN', 'ResNet50_TinyIN']
criterion = nn.CrossEntropyLoss()

In [8]:
for key in keys:
    model_config = cfgs[key]
    train_loader, test_loader = get_dataloaders(model_config)
    model = build_model(model_config)
    tag = f"{model_config.name}_{model_config.dataset}"
    ckpt_path = f'./checkpoints/{tag}_best.pth'
    if os.path.exists(ckpt_path):
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state["model"])
        print("Loaded", ckpt_path)
        
        pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)
    print(f"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%")

    sparsity = EMP_global_magnitude(model)
    print(f"Achieved global sparsity: {sparsity*100}%")

    criterion = nn.CrossEntropyLoss()
    val_loss, val_acc = evaluate(model, test_loader, criterion)
    print(f"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%")
    

Loaded ./checkpoints/FC5_CIFAR10_best.pth
before pruning: val_loss=1.287621498298645 | val_acc=57.34%
Achieved global sparsity: 48.594308524336846%
After pruning: val_loss=1.2694350761413575 | val_acc=57.52% | change of loss=0.018186422157287385  |  | change of Acc=-0.17999999999999972%
Loaded ./checkpoints/FC12_CIFAR10_best.pth
before pruning: val_loss=1.2540489635467529 | val_acc=58.26%
Achieved global sparsity: 39.328941208866866%
After pruning: val_loss=1.231849143409729 | val_acc=58.42% | change of loss=0.02219982013702393  |  | change of Acc=-0.1600000000000037%
Loaded ./checkpoints/VGG16_CIFAR10_best.pth
before pruning: val_loss=0.4234209942817688 | val_acc=91.12%
Achieved global sparsity: 59.47641629445355%
After pruning: val_loss=0.31842048754692076 | val_acc=90.98% | change of loss=0.10500050673484806  |  | change of Acc=0.14000000000000057%
Loaded ./checkpoints/AlexNet_CIFAR10_best.pth
before pruning: val_loss=0.48460997281074525 | val_acc=89.84%
Achieved global sparsity: 60

# swap beta

In [12]:
keys = ["FC2-M","FC5-M","FC12-M","FC2-FM","FC5-FM","FC12-FM"]
beta_list = [0.5, 0.75, 1.0, 1.25, 1.5, 2.0]

In [17]:
for key in keys:
    model_config = cfgs[key]
    train_loader, test_loader = get_dataloaders(model_config)
    model = build_model(model_config)
    tag = f"{model_config.name}_{model_config.dataset}"
    ckpt_path = f'./checkpoints/{tag}_best.pth'
    if os.path.exists(ckpt_path):
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state["model"])
        print("Loaded", ckpt_path)
        
        pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)
    print(f"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%")
    
    for beta in beta_list:
        print(f'\nbeta = {beta}')
        new_model = copy.deepcopy(model)
        sparsity = EMP_global_magnitude(new_model, beta)
        print(f"Achieved global sparsity: {sparsity*100}%")

        criterion = nn.CrossEntropyLoss()
        val_loss, val_acc = evaluate(new_model, test_loader, criterion)
        print(f"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  change of Acc={pre_val_acc-val_acc}%")

Loaded ./checkpoints/FC2_MNIST_best.pth
before pruning: val_loss=0.21534884063005447 | val_acc=93.8%

beta = 0.5
Achieved global sparsity: 68.48362720403023%
After pruning: val_loss=0.24116790586709977 | val_acc=93.57% | change of loss=-0.025819065237045302  |  | change of Acc=0.23000000000000398%

beta = 0.75
Achieved global sparsity: 52.72544080604534%
After pruning: val_loss=0.22307048230171203 | val_acc=93.81% | change of loss=-0.007721641671657564  |  | change of Acc=-0.010000000000005116%

beta = 1.0
Achieved global sparsity: 36.968513853904284%
After pruning: val_loss=0.21815758819580078 | val_acc=93.78% | change of loss=-0.0028087475657463112  |  | change of Acc=0.01999999999999602%

beta = 1.25
Achieved global sparsity: 21.210327455919398%
After pruning: val_loss=0.2155759802699089 | val_acc=93.81% | change of loss=-0.00022713963985443453  |  | change of Acc=-0.010000000000005116%

beta = 1.5
Achieved global sparsity: 5.4521410579345115%
After pruning: val_loss=0.2153797119855

# EMP loss change & saliency

In [8]:
model_config = cfgs['VGG16']
train_loader, test_loader = get_dataloaders(model_config)
model = build_model(model_config)
tag = f"{model_config.name}_{model_config.dataset}"
ckpt_path = f'./checkpoints/{tag}_best.pth'
if os.path.exists(ckpt_path):
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model"])
    print("Loaded", ckpt_path)
    criterion = nn.CrossEntropyLoss()
    pre_val_loss, pre_val_acc = evaluate(model, test_loader, criterion)
print(f"before pruning: val_loss={pre_val_loss} | val_acc={pre_val_acc}%")

Loaded ./checkpoints/VGG16_CIFAR10_best.pth
before pruning: val_loss=0.4234209942817688 | val_acc=91.12%


In [9]:
model1 = copy.deepcopy(model)
model1, sparsity = EMP_loss_change(model1, train_loader)
print(f"EMP loss change")
print(f"Achieved global sparsity: {sparsity*100}%")
criterion = nn.CrossEntropyLoss()
val_loss, val_acc = evaluate(model1, test_loader, criterion)
print(f"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%")

EMP loss change
Achieved global sparsity: 2.0700775450026354%
After pruning: val_loss=0.4233589924812317 | val_acc=91.11% | change of loss=6.200180053711479e-05  |  | change of Acc=0.010000000000005116%


In [10]:
model2 = copy.deepcopy(model)
model2, sparsity = EMP_saliency(model2, test_loader)
print(f"EMP loss change")
print(f"Achieved global sparsity: {sparsity*100}%")
criterion = nn.CrossEntropyLoss()
val_loss, val_acc = evaluate(model2, test_loader, criterion)
print(f"After pruning: val_loss={val_loss} | val_acc={val_acc}% | change of loss={pre_val_loss-val_loss}  |  | change of Acc={pre_val_acc-val_acc}%")

EMP loss change
Achieved global sparsity: 25.72811245547577%
After pruning: val_loss=0.3816164824008942 | val_acc=90.97% | change of loss=0.04180451188087464  |  | change of Acc=0.15000000000000568%
