In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt

# Define the CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc = nn.Linear(64 * 5 * 5, 10)  # Corrected input size

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

# Load MNIST dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('../../datasets', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))
                               ])),
    batch_size=64, shuffle=True)

# Initialize the model
model = SimpleCNN()

# Train the model (simplified for brevity)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

for epoch in range(5):  # 5 epochs for demonstration
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

# Visualize filters
def visualize_filters(layer, num_filters=32):
    filters = layer.weight.data.numpy()
    fig, axs = plt.subplots(num_filters//8, 8, figsize=(12, 6))
    for i in range(num_filters):
        ax = axs[i//8, i%8]
        ax.imshow(filters[i, 0], cmap='gray')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Visualize first layer filters
visualize_filters(model.conv1)

# Visualize second layer filters (more complex patterns)
visualize_filters(model.conv2, num_filters=64)