In [1]:
import os
import json
import math
from pathlib import Path
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models import resnet50, resnet18, resnet101

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_loaders(dataset_name, batch_size=128, test_batch_size=1000, data_root='./data'):
    """
    Returns: train_loader, test_loader, input_size, num_classes, meta (dict)
    """
    name = dataset_name.lower()
    meta = {}

    # Generic normalizations (safe defaults). If you want canonical stats, compute them once.
    NORM_1C = transforms.Normalize((0.5,), (0.5,))
    NORM_3C = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if name == 'mnist':
        # (You already have this; included for completeness.)
        tfm = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
        train = datasets.MNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.MNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10

    elif name == 'fashionmnist':
        tfm = transforms.Compose([transforms.ToTensor(), NORM_1C])
        train = datasets.FashionMNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.FashionMNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10
        
    elif name == 'cifar10':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR10(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR10(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 10

    elif name == 'cifar100':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR100(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR100(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 100
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test,  batch_size=test_batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader, inp, ncls, meta


def train(model, device, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    avg_loss = train_loss / len(train_loader)
    accuracy = 100. * correct / len(train_loader.dataset)
    return avg_loss, accuracy

def test(model, device, test_loader, criterion, times=1):
    model.eval()
    accuracy_list = []
    loss_list = []
    for _ in range(times):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        accuracy_list.append(accuracy)
        loss_list.append(test_loss)
    if times == 1:
        return test_loss, accuracy
    else:
        return loss_list, accuracy_list, sum(accuracy_list) / times

In [3]:
datasets_name = 'cifar10'
train_loader, test_loader, input_size, num_classes, meta = get_loaders(datasets_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cifar_10_model_configs = {
    'resnet18': {
        'model': resnet18,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 0.01,
        'epochs': 10
    },
    'resnet50': {
        'model': resnet50,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 1e-3,
        'epochs': 20,
    },
    'resnet101': {
        'model': resnet101,
        'pretrained': False,
        'input_size': input_size,
        'num_classes': num_classes,
        'lr': 3e-4,
        'epochs': 40,
    }
}

criterion = nn.CrossEntropyLoss()

In [4]:
for model_name, config in cifar_10_model_configs.items():
    model = config['model'](pretrained=config['pretrained'], num_classes=config['num_classes']).to(device)
    optimizer = optim.Adam(model.parameters())
    
    for epoch in range(1, config['epochs'] + 1):
        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
        test_loss, test_acc = test(model, device, test_loader, criterion)
        print(f"Model: {model_name}, Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}")
        
    model_path = f"./models/{datasets_name}/{model_name}.pth"
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    torch.save(model.state_dict(), model_path)



Model: resnet18, Epoch: 1, Train Loss: 1.3452, Train Acc: 50.77, Test Loss: 0.0012, Test Acc: 58.70
Model: resnet18, Epoch: 2, Train Loss: 0.8120, Train Acc: 71.27, Test Loss: 0.0008, Test Acc: 73.11
Model: resnet18, Epoch: 3, Train Loss: 0.6069, Train Acc: 78.64, Test Loss: 0.0009, Test Acc: 72.68
Model: resnet18, Epoch: 4, Train Loss: 0.4740, Train Acc: 83.53, Test Loss: 0.0007, Test Acc: 76.88
Model: resnet18, Epoch: 5, Train Loss: 0.3797, Train Acc: 86.74, Test Loss: 0.0007, Test Acc: 78.60
Model: resnet18, Epoch: 6, Train Loss: 0.2926, Train Acc: 89.75, Test Loss: 0.0006, Test Acc: 78.75
Model: resnet18, Epoch: 7, Train Loss: 0.2165, Train Acc: 92.51, Test Loss: 0.0007, Test Acc: 80.76
Model: resnet18, Epoch: 8, Train Loss: 0.1561, Train Acc: 94.46, Test Loss: 0.0006, Test Acc: 82.07
Model: resnet18, Epoch: 9, Train Loss: 0.1081, Train Acc: 96.24, Test Loss: 0.0006, Test Acc: 82.80
Model: resnet18, Epoch: 10, Train Loss: 0.0863, Train Acc: 96.93, Test Loss: 0.0006, Test Acc: 83.85

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.optim.lr_scheduler import CosineAnnealingLR
import copy
import time
from tqdm import tqdm

def mask_block(module: nn.Module, beta=1.0, method='magnitude') -> tuple:
    """Apply neff-based masking to a single module's weights."""
    x = module.weight.data
    original_shape = x.shape
    x = x.view(-1)
    
    if method == 'mean':
        x = x - torch.mean(x)
    
    # L1 normalization
    x_abs = torch.abs(x)
    x_norm = x_abs / torch.sum(x_abs)
    
    # Calculate effective number of parameters
    neff = 1 / torch.sum(x_norm ** 2)
    
    # Determine how many weights to keep
    r_neff = torch.floor(beta * neff)
    r_neff = r_neff.clamp(min=1, max=len(x)-1).long()
    
    # Sort by magnitude (using absolute values)
    sorted_vals, indices = torch.sort(x_abs, descending=True)
    
    # Create mask for top r_neff weights
    range_tensor = torch.arange(len(x), device=x.device)
    sorted_mask = range_tensor < r_neff
    
    # Scatter back to original positions
    mask = torch.zeros_like(x, dtype=torch.bool)
    mask.scatter_(0, indices, sorted_mask)
    
    # Reshape mask to match original weight shape
    mask = mask.view(original_shape)
    
    return mask, neff.item()

def model_block(model, renormalize=False, beta=1.0, method='magnitude'):
    """Apply neff-based pruning to entire model."""
    model = copy.deepcopy(model)
    total_params = 0
    pruned_params = 0
    layer_stats = {}
    
    for name, module in model.named_modules():
        # Apply to both Linear and Conv2d layers
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            mask, neff = mask_block(module, beta=beta, method=method)
            mask = mask.to(module.weight.device)
            
            # Count parameters
            total_params += module.weight.numel()
            pruned_params += (~mask).sum().item()
            
            # Store stats
            layer_stats[name] = {
                'neff': neff,
                'total': module.weight.numel(),
                'kept': mask.sum().item(),
                'sparsity': 1 - (mask.sum().item() / module.weight.numel())
            }
            
            with torch.no_grad():
                if renormalize:
                    # For Conv2d, sum over input channels (dim 1)
                    # For Linear, sum over input features (dim 1)
                    pre = module.weight.abs().sum(dim=1 if isinstance(module, nn.Conv2d) else 0, keepdim=True)
                    module.weight *= mask
                    post = module.weight.abs().sum(dim=1 if isinstance(module, nn.Conv2d) else 0, keepdim=True)
                    # Avoid division by zero
                    scale = torch.where(post > 0, pre / post, torch.ones_like(pre))
                    module.weight *= scale
                else:
                    module.weight *= mask
    
    overall_sparsity = pruned_params / total_params if total_params > 0 else 0
    return model, layer_stats, overall_sparsity

def evaluate(model, loader, criterion, device):
    """Evaluate model accuracy and loss."""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item() * data.size(0)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
    
    return total_loss / total, 100. * correct / total

def fine_tune(model, train_loader, test_loader, epochs=10, lr=0.001, device='cuda'):
    """Fine-tune pruned model."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    
    model.to(device)
    best_acc = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += data.size(0)
            
            pbar.set_postfix({'loss': train_loss/total, 'acc': 100.*correct/total})
        
        # Evaluation
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)
        
        print(f'Epoch {epoch+1}: Train Acc: {100.*correct/total:.2f}%, '
              f'Test Acc: {test_acc:.2f}%, Test Loss: {test_loss:.4f}')
        
        if test_acc > best_acc:
            best_acc = test_acc
        
        scheduler.step()
    
    return model, best_acc

def experiment_pruning(model_name='vgg16', dataset='cifar10', beta_values=[0.3, 0.5, 0.7, 1.0], 
                       renormalize=True, fine_tune_epochs=10, device='cuda'):
    """Run pruning experiments with different beta values."""
    
    # Import dataloader function
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    
    # Get data loaders
    print(f"Loading {dataset} dataset...")
    train_loader, test_loader, _, num_classes, _ = get_loaders(dataset)
    
    # Load pretrained model
    print(f"Loading {model_name} model...")
    if model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
        # Adjust final layer for dataset
        model.classifier[-1] = nn.Linear(4096, num_classes)
    elif model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        # Adjust final layer for dataset
        model.fc = nn.Linear(2048, num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    model = model.to(device)
    
    # Evaluate original model
    print("Evaluating original model...")
    criterion = nn.CrossEntropyLoss()
    orig_loss, orig_acc = evaluate(model, test_loader, criterion, device)
    print(f"Original Model - Accuracy: {orig_acc:.2f}%, Loss: {orig_loss:.4f}")
    
    results = {'original': {'accuracy': orig_acc, 'loss': orig_loss, 'sparsity': 0}}
    
    # Test different beta values
    for beta in beta_values:
        print(f"\n{'='*50}")
        print(f"Testing beta = {beta}")
        print(f"{'='*50}")
        
        # Apply pruning
        print("Applying neff-based pruning...")
        pruned_model, layer_stats, overall_sparsity = model_block(
            model, renormalize=renormalize, beta=beta, method='magnitude'
        )
        
        # Print layer statistics
        print(f"\nLayer-wise pruning statistics (beta={beta}):")
        for name, stats in layer_stats.items():
            print(f"  {name}: neff={stats['neff']:.1f}, "
                  f"kept={stats['kept']}/{stats['total']} "
                  f"(sparsity={stats['sparsity']*100:.1f}%)")
        print(f"Overall sparsity: {overall_sparsity*100:.2f}%")
        
        # Evaluate pruned model before fine-tuning
        pruned_loss, pruned_acc = evaluate(pruned_model, test_loader, criterion, device)
        print(f"\nPruned Model (before fine-tuning) - Accuracy: {pruned_acc:.2f}%, Loss: {pruned_loss:.4f}")
        
        # Fine-tune if specified
        if fine_tune_epochs > 0:
            print(f"\nFine-tuning for {fine_tune_epochs} epochs...")
            pruned_model, best_acc = fine_tune(
                pruned_model, train_loader, test_loader, 
                epochs=fine_tune_epochs, lr=0.001, device=device
            )
            print(f"Best accuracy after fine-tuning: {best_acc:.2f}%")
            
            results[f'beta_{beta}'] = {
                'accuracy_before': pruned_acc,
                'accuracy_after': best_acc,
                'sparsity': overall_sparsity * 100,
                'layer_stats': layer_stats
            }
        else:
            results[f'beta_{beta}'] = {
                'accuracy': pruned_acc,
                'loss': pruned_loss,
                'sparsity': overall_sparsity * 100,
                'layer_stats': layer_stats
            }
    
    return results

# DataLoader function (from your code)
def get_loaders(dataset_name, batch_size=128, test_batch_size=1000, data_root='./data'):
    """Returns: train_loader, test_loader, input_size, num_classes, meta (dict)"""
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    
    name = dataset_name.lower()
    meta = {}

    # Generic normalizations
    NORM_1C = transforms.Normalize((0.5,), (0.5,))
    NORM_3C = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if name == 'mnist':
        tfm = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
        train = datasets.MNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.MNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10

    elif name == 'fashionmnist':
        tfm = transforms.Compose([transforms.ToTensor(), NORM_1C])
        train = datasets.FashionMNIST(data_root, train=True, download=True, transform=tfm)
        test  = datasets.FashionMNIST(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 28*28, 10
        
    elif name == 'cifar10':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR10(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR10(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 10

    elif name == 'cifar100':
        tfm = transforms.Compose([
                transforms.Resize(224),
                transforms.ToTensor(),
                NORM_3C
            ])
        train = datasets.CIFAR100(data_root, train=True,  download=True, transform=tfm)
        test  = datasets.CIFAR100(data_root, train=False, download=True, transform=tfm)
        inp, ncls = 224*224*3, 100
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test,  batch_size=test_batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader, inp, ncls, meta

if __name__ == "__main__":
    # Example usage
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Run experiments
    results = experiment_pruning(
        model_name='vgg16',  # or 'resnet50'
        dataset='cifar10',
        beta_values=[1.0],
        renormalize=True,
        fine_tune_epochs=5,  # Reduce for faster testing
        device=device
    )
    
    # Print summary
    print("\n" + "="*60)
    print("EXPERIMENT SUMMARY")
    print("="*60)
    for config, metrics in results.items():
        if config == 'original':
            print(f"{config}: Acc={metrics['accuracy']:.2f}%, Sparsity=0%")
        else:
            if 'accuracy_after' in metrics:
                print(f"{config}: Before={metrics['accuracy_before']:.2f}%, "
                      f"After={metrics['accuracy_after']:.2f}%, "
                      f"Sparsity={metrics['sparsity']:.1f}%")
            else:
                print(f"{config}: Acc={metrics['accuracy']:.2f}%, "
                      f"Sparsity={metrics['sparsity']:.1f}%")

Loading cifar10 dataset...
Loading vgg16 model...




Evaluating original model...
Original Model - Accuracy: 10.20%, Loss: 2.3202

Testing beta = 1.0
Applying neff-based pruning...

Layer-wise pruning statistics (beta=1.0):
  features.0: neff=920.7, kept=920/1728 (sparsity=46.8%)
  features.2: neff=17836.2, kept=17836/36864 (sparsity=51.6%)
  features.5: neff=38119.3, kept=38119/73728 (sparsity=48.3%)
  features.7: neff=82056.5, kept=82056/147456 (sparsity=44.4%)
  features.10: neff=166045.7, kept=166045/294912 (sparsity=43.7%)
  features.12: neff=349219.0, kept=349219/589824 (sparsity=40.8%)
  features.14: neff=342429.1, kept=342429/589824 (sparsity=41.9%)
  features.17: neff=679219.2, kept=679219/1179648 (sparsity=42.4%)
  features.19: neff=1421596.8, kept=1421596/2359296 (sparsity=39.7%)
  features.21: neff=1420363.6, kept=1420363/2359296 (sparsity=39.8%)
  features.24: neff=1438711.9, kept=1438711/2359296 (sparsity=39.0%)
  features.26: neff=1467173.1, kept=1467173/2359296 (sparsity=37.8%)
  features.28: neff=1423944.0, kept=1423944/

Epoch 1/5: 100%|██████████| 391/391 [01:34<00:00,  4.13it/s, loss=0.00433, acc=81]  


Epoch 1: Train Acc: 80.97%, Test Acc: 89.53%, Test Loss: 0.3062


Epoch 2/5: 100%|██████████| 391/391 [01:35<00:00,  4.11it/s, loss=0.00208, acc=91]  


Epoch 2: Train Acc: 90.97%, Test Acc: 90.68%, Test Loss: 0.2749


Epoch 3/5: 100%|██████████| 391/391 [01:33<00:00,  4.17it/s, loss=0.00142, acc=93.6]


Epoch 3: Train Acc: 93.62%, Test Acc: 91.59%, Test Loss: 0.2448


Epoch 4/5: 100%|██████████| 391/391 [01:33<00:00,  4.18it/s, loss=0.000986, acc=95.6]


Epoch 4: Train Acc: 95.62%, Test Acc: 92.20%, Test Loss: 0.2384


Epoch 5/5: 100%|██████████| 391/391 [01:34<00:00,  4.13it/s, loss=0.000789, acc=96.5]


Epoch 5: Train Acc: 96.55%, Test Acc: 92.40%, Test Loss: 0.2336
Best accuracy after fine-tuning: 92.40%

EXPERIMENT SUMMARY
original: Acc=10.20%, Sparsity=0%
beta_1.0: Before=11.01%, After=92.40%, Sparsity=38.6%
