In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from torchvision import datasets, transforms
import numpy as np
from thop import profile

In [None]:
# Define the LeNet-5 model for MNIST
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(50 * 4 * 4, 800)
        self.fc2 = nn.Linear(800, 500)
        self.fc3 = nn.Linear(500, num_classes)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 50 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_lenet5():
    return LeNet5()

# Dataset preparation (MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Adjust the paths to point to the correct dataset directory
trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

In [None]:
# Training function for baseline model
def train_baseline_model(net, trainloader, criterion, optimizer, epochs=40):
    net.train()
    for epoch in range(epochs):
        if epoch in [20, 30]:
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.cuda(), labels.cuda()

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}')

# Testing function for baseline model
def test_model(net, testloader):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

# Save model function
def save_model(net, path):
    torch.save(net.state_dict(), path)

# Load model function
def load_model(net, path):
    net.load_state_dict(torch.load(path))

In [None]:
# Pruning method definition
class PruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = 'structured'

    def __init__(self, filters_selected, dim=0):
        self.filters_selected = filters_selected
        self.dim = dim

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        if len(t.shape) == 4:
            mask[self.filters_selected, :, :, :] = 0
        elif len(t.shape) == 2:
            mask[self.filters_selected, :] = 0
        elif len(t.shape) == 1:
            mask[self.filters_selected] = 0
        return mask

def prune_filters_by_similarity(net, prune_ratio_conv1, prune_ratio_conv2):
    for name, module in net.named_modules():
        if isinstance(module, nn.Conv2d):
            filters = module.weight.data.view(module.out_channels, -1)
            cos_sim = F.cosine_similarity(filters.unsqueeze(1), filters.unsqueeze(0), dim=-1).cpu().numpy()

            np.fill_diagonal(cos_sim, 0)
            cos_sim_triu = np.triu(cos_sim)
            if 'conv1' in name:
                num_filters_to_prune = int(prune_ratio_conv1 * module.out_channels)
            elif 'conv2' in name:
                num_filters_to_prune = int(prune_ratio_conv2 * module.out_channels)

            high_sim_indices = np.argsort(-cos_sim_triu, axis=None)[:num_filters_to_prune]
            high_sim_pairs = np.column_stack(np.unravel_index(high_sim_indices, cos_sim_triu.shape))

            selected_filters = high_sim_pairs[:, 1]
            PruningMethod.apply(module, 'weight', filters_selected=selected_filters)

            print(f'Pruned {len(selected_filters)} filters from {module}')
            print(f'Number of remaining filters: {module.out_channels - len(selected_filters)}')

def calculate_similarity(t):
    t = t.view(t.size(0), -1)
    sim_matrix = torch.zeros(t.size(0), t.size(0)).cuda()
    for i in range(t.size(0)):
        for j in range(i + 1, t.size(0)):
            sim_matrix[i, j] = F.cosine_similarity(t[i], t[j], dim=0)
    return sim_matrix

def calculate_flops_and_params(model, input_size):
    input = torch.randn(1, 1, *input_size).cuda()
    flops, params = profile(model, inputs=(input,))
    return flops, params

def iterative_pruning(net, trainloader, testloader, criterion, optimizer, prune_ratio_conv1, prune_ratio_conv2, prune_limit_conv1, prune_limit_conv2, alpha, beta, epochs=40):
    best_accuracy = 0.0
    best_model_path = 'best_model.pth'
    initial_accuracy = test_model(net, testloader)

    for prune_iter in range(int(max(prune_limit_conv1 / prune_ratio_conv1, prune_limit_conv2 / prune_ratio_conv2))):
        print(f'Pruning Iteration {prune_iter + 1}')

        # Load the best model before pruning
        if prune_iter > 0:
            load_model(net, best_model_path)
        
        prune_filters_by_similarity(net, prune_ratio_conv1, prune_ratio_conv2)

        flops, params = calculate_flops_and_params(net, (28, 28))
        print(f'FLOPs after pruning: {flops / 1e6:.2f}M')
        print(f'Parameters after pruning: {params / 1e6:.2f}M')

        for epoch in range(epochs):
            if epoch in [10, 20]:
                for param_group in optimizer.param_groups:
                    param_group['lr'] /= 10
            net.train()
            running_loss = 0.0
            old_running_loss = 0.0
            for inputs, labels in trainloader:
                inputs, labels = inputs.cuda(), labels.cuda()

                optimizer.zero_grad()
                outputs = net(inputs)
                old_loss = criterion(outputs, labels)
                old_running_loss += old_loss.item()

                regularization_term = 0

                for module in net.modules():
                    if isinstance(module, nn.Conv2d):
                        filters = module.weight.data.view(module.out_channels, -1)
                        sim_matrix = calculate_similarity(filters)

                        # Sort similarity values and select the top 2%
                        sim_values = sim_matrix.view(-1)
                        top_2_percent_idx = torch.topk(sim_values, int(0.02 * sim_values.numel()), largest=True).indices
                        regularization_term += torch.exp(-torch.sum(sim_values[top_2_percent_idx]))

                new_loss = old_loss + alpha * regularization_term
                new_loss.backward()
                optimizer.step()

                running_loss += new_loss.item()
            print(f'Pruning Iteration {prune_iter + 1}, Epoch [{epoch + 1}/{epochs}], Old Loss: {old_running_loss / len(trainloader):.4f}, New Loss: {running_loss / len(trainloader):.4f}')

        accuracy = test_model(net, testloader)

        # Save the best model based on accuracy
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            save_model(net, best_model_path)
            print(f'Best model saved with accuracy: {best_accuracy:.2f}%')

        # Stop if the accuracy drop is less than 2%
        if abs(initial_accuracy - best_accuracy) < 2:
            print(f'Pruning stopped as the accuracy drop is less than 2%. Final accuracy: {best_accuracy:.2f}%')
            break

    # Save the final pruned model
    save_model(net, 'pruned_model.pth')

In [None]:
# Main script
device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = get_lenet5().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001)

# Train baseline model
train_baseline_model(net, trainloader, criterion, optimizer, epochs=40)

# Evaluate baseline model
initial_accuracy = test_model(net, testloader)
save_model(net, 'baseline_model.pth')
save_model(net, 'best_model.pth')

# Pruning configurations
prune_ratio_conv1 = 0.04
prune_ratio_conv2 = 0.12
prune_limit_conv1 = 0.999
prune_limit_conv2 = 0.999
alpha = 0.01
beta = 0.02

# Perform iterative pruning with regularization
iterative_pruning(net, trainloader, testloader, criterion, optimizer, prune_ratio_conv1, prune_ratio_conv2, prune_limit_conv1, prune_limit_conv2, alpha, beta, epochs=100)