In [17]:
# import & set-up
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [18]:
# pixel normalization
transform = transforms.Compose([
    transforms.ToTensor()
])

# download dataset
full_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_data, val_data = random_split(full_train_dataset, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1000, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"Train set size: {len(train_dataset)}, Test set size: {len(test_dataset)}")

Train set size: 60000, Test set size: 10000


In [19]:
# LeNet Model
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [20]:
# train, validation, test
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def validate(model, device, val_loader, criterion):
    model.eval()
    correct = 0
    loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    return loss / len(val_loader), 100. * correct / len(val_loader.dataset)

def test(model, device, test_loader, criterion):
    model.eval()
    correct = 0
    loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    return loss / len(test_loader), 100. * correct / len(test_loader.dataset)


In [21]:
# prunning
def prune_by_percentile(model, prune_percent, current_mask):
    all_weights = []
    for name, param in model.named_parameters():
        if name in current_mask:
            masked_weights = param.data[current_mask[name] == 1]
            all_weights += list(masked_weights.abs().cpu().detach().numpy())
    threshold = np.percentile(np.array(all_weights), prune_percent)

    new_mask = {}
    for name, param in model.named_parameters():
        if name in current_mask:
            new_mask[name] = (param.data.abs() > threshold).float() * current_mask[name]
    return new_mask

def apply_mask_and_reset(model, initial_weights, mask_dict):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in mask_dict:
                param.copy_(initial_weights[name] * mask_dict[name])

In [22]:
# iterative prunning
def run_experiment(prune_percent=20, max_iterations=5, early_stop_patience=3):
    all_final_accuracies = []

    for trial in range(5):
        print(f"\n Trial {trial + 1}")
        model = LeNet().to(device)
        initial_weights = {k: v.clone() for k, v in model.state_dict().items()}
        current_mask = {k: torch.ones_like(v) for k, v in initial_weights.items() if 'weight' in k}

        for iteration in range(max_iterations):
            print(f"\n Iteration {iteration + 1}/{max_iterations}")
            model.load_state_dict(initial_weights)
            apply_mask_and_reset(model, initial_weights, current_mask)

            optimizer = optim.Adam(model.parameters(), lr=0.001)
            criterion = nn.CrossEntropyLoss()

            best_val_acc = 0
            patience = 0
            epoch = 0
            max_epoch = 50
            last_test_acc = 0

            while epoch < max_epoch and patience < early_stop_patience:
                train_loss = train(model, device, train_loader, optimizer, criterion, epoch)
                val_loss, val_acc = validate(model, device, val_loader, criterion)
                test_loss, test_acc = test(model, device, test_loader, criterion)

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    patience = 0
                else:
                    patience += 1

                epoch += 1
                print(f"[Epoch {epoch}] Val Acc: {val_acc:.2f}%, Test Acc: {test_acc:.2f}%")
                last_test_acc = test_acc

            all_final_accuracies.append(last_test_acc)
            current_mask = prune_by_percentile(model, prune_percent, current_mask)

    mean_acc = np.mean(all_final_accuracies)
    std_acc = np.std(all_final_accuracies)
    print(f"\n Final Test Accuracy over {5 * max_iterations} runs: {mean_acc:.2f}% ± {std_acc:.2f}%")
    return all_final_accuracies


In [None]:
result_accs = run_experiment(prune_percent=20, max_iterations=5, early_stop_patience=3)


 Trial 1

 Iteration 1/5
[Epoch 1] Val Acc: 96.88%, Test Acc: 97.06%
[Epoch 2] Val Acc: 97.68%, Test Acc: 98.19%
[Epoch 3] Val Acc: 98.35%, Test Acc: 98.59%
[Epoch 4] Val Acc: 98.60%, Test Acc: 98.48%
[Epoch 5] Val Acc: 98.58%, Test Acc: 98.54%
[Epoch 6] Val Acc: 98.58%, Test Acc: 98.53%
[Epoch 7] Val Acc: 98.67%, Test Acc: 98.85%
[Epoch 8] Val Acc: 98.67%, Test Acc: 98.92%
[Epoch 9] Val Acc: 98.63%, Test Acc: 98.63%
[Epoch 10] Val Acc: 98.43%, Test Acc: 98.61%

 Iteration 2/5
[Epoch 1] Val Acc: 96.87%, Test Acc: 97.44%
[Epoch 2] Val Acc: 97.85%, Test Acc: 98.20%
[Epoch 3] Val Acc: 98.47%, Test Acc: 98.72%
[Epoch 4] Val Acc: 98.60%, Test Acc: 98.67%
[Epoch 5] Val Acc: 98.58%, Test Acc: 98.77%
[Epoch 6] Val Acc: 98.67%, Test Acc: 98.85%
[Epoch 7] Val Acc: 98.65%, Test Acc: 98.87%
[Epoch 8] Val Acc: 98.78%, Test Acc: 98.84%
[Epoch 9] Val Acc: 98.70%, Test Acc: 99.00%
[Epoch 10] Val Acc: 99.05%, Test Acc: 98.98%
[Epoch 11] Val Acc: 99.05%, Test Acc: 99.03%
[Epoch 12] Val Acc: 98.38%, Tes

In [13]:
# final plot

plt.plot(result_accs, marker='o')
plt.xlabel('Iteration')
plt.ylabel('Test Accuracy (%)')
plt.title('Test Accuracy across Iterative Pruning (5 Trials)')
plt.grid(True)
plt.show()