In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2,2), stride =2),
            nn.Conv2d(64,128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2,2), stride = 2),
            nn.Conv2d(128, 64, kernel_size=3),  # Third Conv Layer
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2)  # MaxPooling
        )
        self.classification_head = nn.Sequential(
            nn.Linear(64,20,bias = True),
            nn.ReLU(),
            nn.Linear(20,10, bias = True)
        )

    def forward(self,x):
        features = self.net(x)
        features = features.view(features.size(0),-1)
        return self.classification_head(features)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download = True, transform = transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = CNNClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def train(model, train_loader, criterion,optimizer, epochs = 5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,labels)
            loss.backward()
            optimizer.step()
            running_loss +=loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader)}")


def evaluate(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []

    while torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds)
            all_labels.append(labels)

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    cm = confusion_matrix(all_labels.numpy(), all_preds.numpy())

    return cm


# Plotting the confusion matrix
def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.show()


# Training the model
train(model, train_loader, criterion, optimizer, epochs=5)

# Evaluate the model and get confusion matrix
conf_matrix = evaluate(model, test_loader)

# Plot confusion matrix
plot_confusion_matrix(conf_matrix, classes=[str(i) for i in range(10)])

# Number of learnable parameters in the model
num_params = count_parameters(model)
print(f"\nNumber of learnable parameters in the model: {num_params}")

Epoch [1/5], Loss: 0.2712869399538967
Epoch [2/5], Loss: 0.07871793104615658
Epoch [3/5], Loss: 0.05589824144928051
Epoch [4/5], Loss: 0.04317764701565722
Epoch [5/5], Loss: 0.035717154020049637
