# Setup 

This first cell set the working directory as the project directory:

In [5]:
import os

if os.getcwd().endswith('notebooks'):
    os.chdir('..')
print(os.getcwd())

/home/cmcouto-silva/Projects/github/pytorch-workflow-mastery


## Libraries

In [6]:
import torch
from torch import nn
from torch.optim import Adam
import torchmetrics
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights, resnet18

from tqdm import tqdm
from loguru import logger

import numpy as np
import matplotlib.pyplot as plt

import wandb
from dotenv import load_dotenv

## Device

In [7]:
# [Optional] Enable TF32 for better performance on modern NVIDIA GPUs
torch.set_float32_matmul_precision('high')

In [8]:
# Set available device (CPU or GPU - cuda)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Config Parameters

In [9]:
# Random seed
seed = 42

# Training parameters
num_epochs = 3
batch_size = 128  # Larger batches for faster training
learning_rate = 0.001

# Model parameters
num_classes = 10  # CIFAR10 has 10 classes
model_path = 'weights/cifar10_model.pt'  # Path to save/load model weights

# DataLoader settings
train_num_workers = 4  # Number of parallel processes for data loading
test_num_workers = 4   # Increase these if you have more CPU cores

In [10]:
## -- Set seeds -- ##

# CPU seed
torch.manual_seed(seed)  # Controls random number generation for PyTorch CPU operations

# NumPy seed (for data loading/processing)
np.random.seed(seed)     # Controls random number generation for NumPy operations

# If GPU is available
if torch.cuda.is_available():
    # GPU seed
    torch.cuda.manual_seed(seed)  # Controls random number generation for PyTorch GPU operations
    # Force CUDA to use deterministic algorithms
    torch.backends.cudnn.deterministic = False  # Makes GPU operations deterministic (might be slower)
    
# Set `deterministic = False` because we'll prioritize performance over reproducibility =S

# Weight & Biases Integration

In [None]:
# Load environment variables
load_dotenv()

# Verify API key is loaded (don't print in production!)
assert os.getenv("WANDB_API_KEY") is not None, "WANDB_API_KEY not found in environment variables"

In [None]:
# Rest of your wandb setup remains the same
wandb.init(
    project="pytorch-cifar10",
    config={
        "learning_rate": learning_rate,
        "epochs": num_epochs,
        "batch_size": batch_size,
        "model": "ResNet18",
        "optimizer": "Adam",
        "architecture": "Modified for CIFAR10"
    }
)

# Dataset

For real-life applications, the image folder is usually structured as:

```text
data/
├── train/
│   ├── airplane/        # This folder name becomes class 0
│   │   ├── img1.jpg
│   │   ├── img2.jpg
│   ├── automobile/      # This folder name becomes class 1
│   │   ├── img1.jpg
│   │   ├── img2.jpg
│   └── ...
├── val/
│   ├── airplane/
│   ├── automobile/
│   └── ...
└── test/
    ├── airplane/
    ├── automobile/
    └── ...
```

So we can use `torchvision.datasets.ImageFolder` to load the dataset, passing it to `DataLoader` as shown here.

In [None]:
# Data Transformers

train_transformer = transforms.Compose([   # Transformations like resizing and normalizaing alongside data augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_transformer = transforms.Compose([     # Usually Resizing and normalizaton without data augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

test_transformer = transforms.Compose([    # Usually same as validation transformation 
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
## -- Download & load data -- ##

# Datasets
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transformer)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transformer)

# DataLoaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=train_num_workers,
    pin_memory = True,

)

val_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=train_num_workers,
    pin_memory = True
)

# Test set -> not available in this example, but real-life applications require test data - mainly when tuning hyperparameters!!

# Model

In [None]:
# Load model
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# Modify the final layer for CIFAR-10 (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)

# Set it to target device
model = model.to(device)

# Compile model to make it faster (torch>=2.0)
model = torch.compile(model)

# Set up loss function
criterion = nn.CrossEntropyLoss()

# Set up optimizer
optimizer = Adam(model.parameters(), lr=learning_rate)

In [None]:
# Metrics
train_loss = torchmetrics.MeanMetric().to(device)
val_loss = torchmetrics.MeanMetric().to(device)
train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)

## Train

In [None]:
logger.info("Starting training...")

for epoch in range(num_epochs):

    ## -- Train step -- ##
    
    model.train()
    train_loss.reset()
    train_accuracy.reset()

    train_progress = tqdm(train_loader, desc=f'• Epoch {epoch + 1}/{num_epochs} [Train]', leave=False)

    for images, labels in train_progress:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # Update metrics
        train_loss.update(loss)
        train_accuracy.update(outputs, labels)

        train_progress.set_postfix({
            'loss': f'{train_loss.compute():.3f}',
            'acc': f'{train_accuracy.compute():.1%}'
        })

    ## -- Validation step -- ##

    model.eval()
    val_loss.reset()
    val_accuracy.reset()

    with torch.inference_mode():
        val_progress = tqdm(
            val_loader, desc=f'• Epoch {epoch + 1}/{num_epochs} [Valid]', leave=False
        )

        for images, labels in val_progress:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Update metrics
            val_loss.update(loss)
            val_accuracy.update(outputs, labels)
            
            val_progress.set_postfix({
                'loss': f'{val_loss.compute():.3f}',
                'acc': f'{val_accuracy.compute():.1%}'
            })

    # Print epoch summary
    logger.debug(
        f"Epoch {epoch+1}/{num_epochs}: "
        f"Train Loss: {train_loss.compute():.3f} | "
        f"Train Acc: {train_accuracy.compute():.1%} | "
        f"Val Loss: {val_loss.compute():.3f} | "
        f"Val Acc: {val_accuracy.compute():.1%}"
    )

    # Log to wandb
    wandb.log({
        'epoch': epoch,
        'train_loss': train_loss.compute(),
        'train_accuracy': train_accuracy.compute(),
        'val_loss': val_loss.compute(),
        'val_accuracy': val_accuracy.compute()
    })

logger.success('Model trained!')

In [None]:
# Save fine-tuned model and training state
torch.save({
   'model_state_dict': model.state_dict(),          # Model weights
   'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
   'epoch': num_epochs,                             # Number of trained epochs
}, model_path)

logger.info(f"Model saved to {model_path}")

## Inference

In [None]:
# Load checkpoint
checkpoint = torch.load(model_path, weights_only=True)

# Initialize model architecture
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Compile model as training
model = model.to(device) 
model = torch.compile(model)

# Load model weights
model.load_state_dict(checkpoint['model_state_dict']);

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Number of samples to visualize
n_samples = 10

# Get random indices
val_indices = np.random.choice(len(val_dataset), size=n_samples, replace=False)

# Set up the plot
plt.figure(figsize=(15, 3))

# Get class names from CIFAR10
classes = val_dataset.classes

model.eval()
with torch.inference_mode():
    for idx, sample_idx in enumerate(val_indices):

        # Get the image and label
        image, true_label = val_dataset[sample_idx]
        
        # Add batch dimension and move to device
        image = image.unsqueeze(0).to(device)
        
        # Get prediction
        output = model(image)
        predicted_label = output.argmax(1).item()
        
        # Convert image for display
        img = image.cpu().squeeze()
        img = img.permute(1, 2, 0)  # Change from CxHxW to HxWxC
        
        # Denormalize the image
        mean = torch.tensor([0.485, 0.456, 0.406])
        std = torch.tensor([0.229, 0.224, 0.225])
        img = img * std + mean
        
        # Plot
        plt.subplot(1, n_samples, idx + 1)
        plt.imshow(img)
        plt.axis('off')
        color = 'green' if predicted_label == true_label else 'red'
        plt.title(
            f'Pred: {classes[predicted_label]}\n'
            f'True: {classes[true_label]}',
            color=color
        )

plt.tight_layout()
plt.show()