In [1]:
import torch  # PyTorch library for tensor computations and deep learning
import torch.nn as nn  # Neural network modules
import torch.nn.functional as F  # Functional interface for neural network operations
import torch.optim as optim  # Optimization algorithms
import torchvision  # Computer vision datasets and models
import torchvision.transforms as transforms  # Image transformations
import matplotlib.pyplot as plt  # Plotting library
import numpy as np  # Numerical computations
from sklearn.metrics import confusion_matrix  # For confusion matrix
import seaborn as sns  # Visualization library for confusion matrix

# Set random seed for reproducibility across runs
torch.manual_seed(42)

# Define data transformations for preprocessing
# - RandomHorizontalFlip: Randomly flip images horizontally for data augmentation
# - RandomRotation: Randomly rotate images by up to 10 degrees for augmentation
# - ToTensor: Convert images to PyTorch tensors (HWC to CHW format)
# - Normalize: Normalize RGB channels with mean=0.5 and std=0.5
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 training dataset
# - root: Directory to store dataset
# - train: True for training set
# - download: Download dataset if not present
# - transform: Apply defined transformations
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
# Create DataLoader for training set
# - batch_size: Number of images per batch (64)
# - shuffle: Randomly shuffle data for better training
# - num_workers: Number of subprocesses for data loading
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

# Load CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
# Create DataLoader for test set
# - shuffle: False to maintain order for evaluation
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=2)

# Define class labels for CIFAR-10
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Define CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # First convolutional layer: 3 input channels (RGB), 32 output channels, 3x3 kernel
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        # Second convolutional layer: 32 input channels, 64 output channels, 3x3 kernel
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        # Max pooling layer: 2x2 kernel, stride 2
        self.pool = nn.MaxPool2d(2, 2)
        # Batch normalization for first conv layer
        self.bn1 = nn.BatchNorm2d(32)
        # Batch normalization for second conv layer
        self.bn2 = nn.BatchNorm2d(64)
        # First fully connected layer: Input size calculated from conv output (64*8*8), 512 units
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        # Second fully connected layer: 512 units to 10 output classes
        self.fc2 = nn.Linear(512, 10)
        # Dropout layer with 50% probability to prevent overfitting
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # Forward pass through the network
        # Conv1 -> BatchNorm -> ReLU -> MaxPool
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        # Conv2 -> BatchNorm -> ReLU -> MaxPool
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        # Flatten the output for fully connected layers
        x = x.view(-1, 64 * 8 * 8)
        # Fully connected layer 1 -> ReLU
        x = F.relu(self.fc1(x))
        # Apply dropout
        x = self.dropout(x)
        # Final fully connected layer for classification
        x = self.fc2(x)
        return x

# Initialize model and move to appropriate device (GPU if available, else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net().to(device)

# Define loss function (CrossEntropyLoss for multi-class classification)
criterion = nn.CrossEntropyLoss()
# Define optimizer (Adam with learning rate 0.001)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10  # Number of training epochs
train_losses = []  # Store loss per epoch
train_accuracies = []  # Store accuracy per epoch

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0  # Track total loss for epoch
    correct = 0  # Track correct predictions
    total = 0  # Track total samples
    for i, data in enumerate(trainloader, 0):
        # Get inputs and labels, move to device
        inputs, labels = data[0].to(device), data[1].to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        # Compute loss
        loss = criterion(outputs, labels)
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update running loss
        running_loss += loss.item()
        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    # Calculate and store epoch metrics
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100 * correct / total
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss:.3f}, Accuracy: {epoch_acc:.2f}%")

# Evaluate model on test set
model.eval()  # Set model to evaluation mode
correct = 0
total = 0
all_preds = []  # Store predictions for confusion matrix
all_labels = []  # Store true labels
with torch.no_grad():  # Disable gradient computation for evaluation
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate and print test accuracy
test_accuracy = 100 * correct / total
print(f"\nTest Accuracy: {test_accuracy:.2f}%")

# Plot training metrics (loss and accuracy)
plt.figure(figsize=(12, 4))

# Plot training loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot training accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.savefig('training_metrics.png')  # Save plot
plt.close()

# Plot confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.savefig('confusion_matrix.png')  # Save plot
plt.close()

# Function to unnormalize and display images
def imshow(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    return np.transpose(npimg, (1, 2, 0))  # Convert from CHW to HWC

# Show sample test images with predictions
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images[:8].to(device), labels[:8]
outputs = model(images)
_, predicted = torch.max(outputs, 1)

plt.figure(figsize=(12, 6))
for i in range(8):
    plt.subplot(2, 4, i + 1)
    plt.imshow(imshow(images[i].cpu()))
    plt.title(f'Pred: {classes[predicted[i]]}\nTrue: {classes[labels[i]]}')
    plt.axis('off')
plt.savefig('sample_predictions.png')  # Save plot
plt.close()

print("Training complete. Plots saved as 'training_metrics.png', 'confusion_matrix.png', and 'sample_predictions.png'")

100%|██████████| 170M/170M [00:07<00:00, 24.0MB/s] 


Epoch 1, Loss: 1.535, Accuracy: 44.58%
Epoch 2, Loss: 1.254, Accuracy: 54.91%
Epoch 3, Loss: 1.144, Accuracy: 59.47%
Epoch 4, Loss: 1.071, Accuracy: 62.25%
Epoch 5, Loss: 1.017, Accuracy: 64.26%
Epoch 6, Loss: 0.975, Accuracy: 65.75%
Epoch 7, Loss: 0.945, Accuracy: 67.02%
Epoch 8, Loss: 0.916, Accuracy: 68.16%
Epoch 9, Loss: 0.892, Accuracy: 68.86%
Epoch 10, Loss: 0.865, Accuracy: 70.02%

Test Accuracy: 73.62%
Training complete. Plots saved as 'training_metrics.png', 'confusion_matrix.png', and 'sample_predictions.png'
