In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

class EnhancedCNN(nn.Module):
    def __init__(self):
        super(EnhancedCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 3 * 3, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 3 * 3)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Initialize the network and optimizer
net = EnhancedCNN()
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Function to visualize kernels
def visualize_kernels(epoch):
    layers = [net.conv1, net.conv2, net.conv3]
    fig, axs = plt.subplots(3, 10, figsize=(20, 8))
    fig.suptitle(f'Kernels after Epoch {epoch}')
    for i, layer in enumerate(layers):
        weights = layer.weight.data.cpu().numpy()
        for j in range(10):
            axs[i, j].imshow(weights[j, 0], cmap='gray')
            axs[i, j].axis('off')
    plt.tight_layout()
    plt.show()

# Function to visualize feature maps
def visualize_feature_maps(epoch):
    # Get a sample image
    dataiter = iter(trainloader)
    images, _ = next(dataiter)
    img = images[0].unsqueeze(0)  # Add batch dimension

    # Get feature maps
    feature_maps = []
    x = img
    for layer in [net.conv1, net.pool, net.conv2, net.pool, net.conv3, net.pool]:
        x = layer(x)
        if isinstance(layer, nn.Conv2d):
            feature_maps.append(x)

    # Visualize feature maps
    fig, axs = plt.subplots(3, 10, figsize=(20, 8))
    fig.suptitle(f'Feature Maps after Epoch {epoch}')
    for i, fmap in enumerate(feature_maps):
        fmap = fmap.squeeze().detach().cpu().numpy()
        for j in range(10):
            axs[i, j].imshow(fmap[j], cmap='viridis')
            axs[i, j].axis('off')
    plt.tight_layout()
    plt.show()

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0
    
    # Visualize kernels and feature maps after each epoch
    visualize_kernels(epoch + 1)
    visualize_feature_maps(epoch + 1)

print('Finished Training')

# Final visualizations
visualize_kernels('Final')
visualize_feature_maps('Final')