# Image Classification with PyTorch

This notebook demonstrates how to build and train a Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset.

## What you'll learn:
- Loading and preprocessing image datasets
- Building CNN architectures
- Training and evaluating models
- Visualizing results

## 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import sys
sys.path.append('..')

from models.cnn import SimpleCNN
from utils.data_loader import get_cifar10_loaders
from utils.visualization import show_batch, plot_training_history
from utils.training import train_one_epoch, evaluate

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f'PyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')

## 2. Load and Explore the Dataset

CIFAR-10 consists of 60,000 32x32 color images in 10 classes:
- airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

In [None]:
# Load CIFAR-10 dataset
batch_size = 32
train_loader, test_loader = get_cifar10_loaders(
    batch_size=batch_size,
    num_workers=2,
    data_dir='../data'
)

print(f'Training samples: {len(train_loader.dataset)}')
print(f'Test samples: {len(test_loader.dataset)}')
print(f'Number of batches (train): {len(train_loader)}')
print(f'Number of batches (test): {len(test_loader)}')

In [None]:
# Visualize a batch of images
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

# Get a batch of training images
images, labels = next(iter(train_loader))
show_batch(images[:16], labels[:16], class_names, nrow=4)

## 3. Build the CNN Model

In [None]:
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN(num_classes=10).to(device)

print(model)
print(f'\nDevice: {device}')

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'\nTotal parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

## 4. Define Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2
)

## 5. Train the Model

In [None]:
num_epochs = 5  # Use more epochs for better results

train_losses = []
val_losses = []
train_accs = []
val_accs = []

for epoch in range(1, num_epochs + 1):
    print(f'\nEpoch {epoch}/{num_epochs}')
    print('=' * 50)
    
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Evaluate
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Record history
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

## 6. Visualize Training Results

In [None]:
plot_training_history(train_losses, val_losses, train_accs, val_accs)

## 7. Make Predictions

In [None]:
# Get a batch of test images
images, labels = next(iter(test_loader))
images_device = images.to(device)

# Make predictions
model.eval()
with torch.no_grad():
    outputs = model(images_device)
    _, predicted = outputs.max(1)

# Visualize predictions
from utils.visualization import visualize_predictions
visualize_predictions(images, labels, predicted.cpu(), class_names, num_images=10)

## 8. Exercise: Try Different Things!

Now that you've trained a basic model, try experimenting with:

1. **Model Architecture**: Modify the SimpleCNN or try SimpleResNet
2. **Hyperparameters**: Change learning rate, batch size, number of epochs
3. **Optimization**: Try different optimizers (SGD, RMSprop, etc.)
4. **Data Augmentation**: Add more transformations in the data loader
5. **Regularization**: Adjust dropout rates or add weight decay

Compare the results and see what works best!

In [None]:
# Your experiments here!