In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
num_epochs = 1
batch_size = 64
learning_rate = 0.001
num_classes = 10

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class ModifiedAlexNet(nn.Module):
    def __init__(self, num_classes=10):  # Update to the number of classes in MNIST
        super(ModifiedAlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),  # Update input channels to 1
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)

In [None]:
model = ModifiedAlexNet(num_classes).to(device)
model.apply(weights_init)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def train_model(model, criterion, optimizer, train_loader, num_epochs):
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i+1) % 100 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                       .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

In [None]:
def test_model(model, test_loader):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total * 100
        return accuracy

In [None]:
train_model(model, criterion, optimizer, train_loader, num_epochs)

In [None]:
baseline_accuracy = test_model(model, test_loader)
print(f"Baseline Accuracy: {baseline_accuracy:.2f}%")

In [None]:
# Pruning rates for evaluation
pruning_rates = [0.2, 0.5, 0.8]

for prune_rate in pruning_rates:
        # Prune nodes using PCA
    def prune_nodes_PCA(model, prune_ratio):
        # Extracting weights
        weights = torch.cat([param.view(-1) for param in model.parameters()]).cpu().detach().numpy()

        # Print weights for debugging
        print("Weights before PCA pruning:", weights)

        # Applying PCA
        pca = PCA(n_components=1)
        pca.fit(weights.reshape(-1, 1))
        explained_variance_ratio = np.sum(pca.explained_variance_ratio_)

        # Print explained variance for debugging
        print("Explained Variance Ratio (PCA):", explained_variance_ratio)

        node_threshold = np.percentile(np.abs(weights), prune_ratio * 100)

        # Print node threshold for debugging
        print("Node Threshold (PCA):", node_threshold)

        # Prune nodes
        for param in model.parameters():
            param.data[torch.abs(param.data) < node_threshold] = 0

        return explained_variance_ratio

    explained_variance_ratio_nodes_pca = prune_nodes_PCA(model, prune_rate)
    accuracy_nodes_pruned_pca = test_model(model, test_loader)
    print(f"PCA Pruning Rate: {prune_rate}, Accuracy after PCA pruning : {accuracy_nodes_pruned_pca:.2f}%")
    print(f"Explained Variance Ratio (PCA): {explained_variance_ratio_nodes_pca:.2f}")

    # Prune connections using PCA
    def prune_connections_PCA(model, prune_ratio):
        # Extracting weights
        weights = torch.cat([param.view(-1) for param in model.parameters()]).cpu().detach().numpy()

        # Print weights for debugging
        print("Weights before PCA pruning:", weights)

        # Applying PCA
        pca = PCA(n_components=1)
        pca.fit(weights.reshape(-1, 1))
        explained_variance_ratio = np.sum(pca.explained_variance_ratio_)

        # Print explained variance for debugging
        print("Explained Variance Ratio (PCA):", explained_variance_ratio)

        weight_threshold = np.percentile(np.abs(weights), prune_ratio * 100)

        # Print node threshold for debugging
        print("Weights Threshold (PCA):", weight_threshold)

        # Prune connections
        for param in model.parameters():
            param.data[torch.abs(param.data) < weight_threshold] = 0

        return explained_variance_ratio

    explained_variance_ratio_connections_pca = prune_connections_PCA(model, prune_rate)
    accuracy_connections_pruned_pca = test_model(model, test_loader)
    print(f"PCA Pruning Rate: {prune_rate}, Accuracy after pruning connections: {accuracy_connections_pruned_pca:.2f}%")
    print(f"Explained Variance Ratio (Connections - PCA): {explained_variance_ratio_connections_pca:.2f}")