# CIFAR-10 CNN Classifier with PyTorch

This notebook trains a Convolutional Neural Network to classify images from the CIFAR-10 dataset.

**Dataset:** 60,000 32x32 color images in 10 classes (plane, car, bird, cat, deer, dog, frog, horse, ship, truck)

**Runtime:** Make sure to enable GPU for faster training!
- Go to Runtime → Change runtime type → Hardware accelerator → GPU

## 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Define CNN Architecture

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # Conv block 1
        x = self.pool(torch.relu(self.conv1(x)))  # 32x32 -> 16x16
        # Conv block 2
        x = self.pool(torch.relu(self.conv2(x)))  # 16x16 -> 8x8
        # Conv block 3
        x = self.pool(torch.relu(self.conv3(x)))  # 8x8 -> 4x4
        
        # Flatten
        x = x.view(-1, 128 * 4 * 4)
        
        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Initialize model
model = SimpleCNN().to(device)
print(model)

## 3. Load and Prepare CIFAR-10 Dataset

In [None]:
# Data augmentation for training
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# No augmentation for test
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download and load datasets
print("Downloading CIFAR-10 dataset...")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=0)

# Class names
classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"Training samples: {len(trainset)}")
print(f"Test samples: {len(testset)}")

## 4. Visualize Sample Images

In [None]:
# Get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Function to show images
def imshow(img):
    img = img / 2 + 0.5  # Denormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Show 8 images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()

for i in range(8):
    axes[i].imshow(np.transpose(images[i].numpy(), (1, 2, 0)) * 0.5 + 0.5)
    axes[i].set_title(classes[labels[i]])
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## 5. Setup Training

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Setup complete!")
print(f"Optimizer: Adam with lr=0.001")
print(f"Loss function: CrossEntropyLoss")

## 6. Train the Model

In [None]:
def train_model(epochs=10):
    train_losses = []
    train_accuracies = []
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if i % 100 == 99:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(trainloader)}], '
                      f'Loss: {running_loss/100:.4f}')
                running_loss = 0.0
        
        # Epoch statistics
        epoch_acc = 100 * correct / total
        train_accuracies.append(epoch_acc)
        print(f'\nEpoch [{epoch+1}/{epochs}] - Training Accuracy: {epoch_acc:.2f}%\n')
    
    return train_losses, train_accuracies

# Train the model
print("Starting training...\n")
train_losses, train_accuracies = train_model(epochs=10)
print("\nTraining complete!")

## 7. Evaluate the Model

In [None]:
def evaluate_model():
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
    
    print(f'Overall Test Accuracy: {100 * correct / total:.2f}%\n')
    
    # Print per-class accuracy
    print('Per-class Accuracy:')
    for i in range(10):
        if class_total[i] > 0:
            acc = 100 * class_correct[i] / class_total[i]
            print(f'  {classes[i]:>10s}: {acc:.2f}%')

evaluate_model()

## 8. Visualize Predictions

In [None]:
def visualize_predictions(num_images=8):
    model.eval()
    dataiter = iter(testloader)
    images, labels = next(dataiter)
    images, labels = images.to(device), labels.to(device)
    
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)
    
    # Move to CPU for plotting
    images = images.cpu()
    
    # Plot
    fig, axes = plt.subplots(2, 4, figsize=(14, 7))
    axes = axes.ravel()
    
    for i in range(num_images):
        img = images[i].numpy().transpose((1, 2, 0))
        img = img * 0.5 + 0.5  # Denormalize
        
        axes[i].imshow(img)
        axes[i].set_title(f'True: {classes[labels[i]]}\nPred: {classes[predicted[i]]}',
                         color='green' if labels[i] == predicted[i] else 'red',
                         fontsize=12, weight='bold')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

visualize_predictions()

## 9. Save the Model

In [None]:
# Save the trained model
torch.save(model.state_dict(), 'cifar_cnn_model.pth')
print("Model saved as 'cifar_cnn_model.pth'")

# To download the model in Colab:
# from google.colab import files
# files.download('cifar_cnn_model.pth')

## 10. (Optional) Load Model for Inference

In [None]:
# To load the model later:
# model = SimpleCNN().to(device)
# model.load_state_dict(torch.load('cifar_cnn_model.pth'))
# model.eval()
# print("Model loaded successfully!")