# Setup 

This first cell set the working directory as the project directory:

In [None]:
import os

if os.getcwd().endswith('notebooks'):
    os.chdir('..')
print(os.getcwd())

## Libraries

In [None]:
import torch
from torch import nn
from torch.optim import Adam, SGD
import torchmetrics
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights, resnet18

from tqdm import tqdm
from loguru import logger

import numpy as np
import matplotlib.pyplot as plt

import optuna
import wandb
from dotenv import load_dotenv

## Device

In [None]:
# [Optional] Enable TF32 for better performance on modern NVIDIA GPUs
torch.set_float32_matmul_precision('high')

In [None]:
# Set available device (CPU or GPU - cuda)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Config Parameters

In [None]:
# Random seed
seed = 42

# Training parameters
# num_epochs = 3
# batch_size = 128  # Larger batches for faster training
# learning_rate = 0.001

# Model parameters
num_classes = 10  # CIFAR10 has 10 classes
model_path = 'weights/cifar10_model.pt'  # Path to save/load model weights

# DataLoader settings
train_num_workers = 4  # Number of parallel processes for data loading
test_num_workers = 4   # Increase these if you have more CPU cores

In [None]:
## -- Set seeds -- ##

# CPU seed
torch.manual_seed(seed)  # Controls random number generation for PyTorch CPU operations

# NumPy seed (for data loading/processing)
np.random.seed(seed)     # Controls random number generation for NumPy operations

# If GPU is available
if torch.cuda.is_available():
    # GPU seed
    torch.cuda.manual_seed(seed)  # Controls random number generation for PyTorch GPU operations
    # Force CUDA to use deterministic algorithms
    torch.backends.cudnn.deterministic = False  # Makes GPU operations deterministic (might be slower)
    
# Set `deterministic = False` because we'll prioritize performance over reproducibility =S

# Optuna Workflow

In [None]:
# Data Transformers (same as before)
train_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [None]:
# Define Optuna objective
def objective(trial):
    
    # Sample hyperparameters
    hparams = {
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
        "batch_size": trial.suggest_categorical("batch_size", [64, 128, 256]),
        "optimizer": trial.suggest_categorical("optimizer", ["Adam", "SGD"]),
        "dropout_rate": trial.suggest_float("dropout_rate", 0.1, 0.5),
        "num_epochs": 3,  # Fixed for quick trials
        "model": "ResNet18",
        "trial_number": trial.number
    }
    
    # Initialize W&B
    run = wandb.init(
        project="pytorch-cifar10-optuna",
        config=hparams,
        name=f"trial_{trial.number}",
        reinit=True
    )
    
    # Load datasets
    train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=train_transformer
    )
    val_dataset = datasets.CIFAR10(
        root='./data', train=False, download=True, transform=val_transformer
    )
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams["batch_size"],
        shuffle=True,
        num_workers=train_num_workers,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=hparams["batch_size"],
        shuffle=False,
        num_workers=train_num_workers,
        pin_memory=True
    )
    
    # Initialize model
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    model.fc = nn.Sequential(
        nn.Dropout(hparams["dropout_rate"]),
        nn.Linear(model.fc.in_features, num_classes)
    )
    model = model.to(device)
    model = torch.compile(model)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = (Adam if hparams["optimizer"] == "Adam" else SGD)(
        model.parameters(), lr=hparams["learning_rate"]
    )
    
    # Metrics
    train_loss = torchmetrics.MeanMetric().to(device)
    val_loss = torchmetrics.MeanMetric().to(device)
    train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
    val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
    
    # Training loop
    best_accuracy = 0.0
    
    for epoch in range(hparams["num_epochs"]):
        # Train
        model.train()
        train_loss.reset()
        train_accuracy.reset()
        
        train_progress = tqdm(train_loader, desc=f'• Epoch {epoch + 1}/{hparams["num_epochs"]} [Train]', leave=False)
        
        for images, labels in train_progress:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss.update(loss)
            train_accuracy.update(outputs, labels)
            
            train_progress.set_postfix({
                'loss': f'{train_loss.compute():.3f}',
                'acc': f'{train_accuracy.compute():.1%}'
            })
        
        # Validate
        model.eval()
        val_loss.reset()
        val_accuracy.reset()
        
        with torch.inference_mode():
            val_progress = tqdm(
                val_loader, desc=f'• Epoch {epoch + 1}/{hparams["num_epochs"]} [Valid]', leave=False
            )
            
            for images, labels in val_progress:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss.update(loss)
                val_accuracy.update(outputs, labels)
                
                val_progress.set_postfix({
                    'loss': f'{val_loss.compute():.3f}',
                    'acc': f'{val_accuracy.compute():.1%}'
                })
        
        # Log metrics
        metrics = {
            "epoch": epoch,
            "train_loss": train_loss.compute(),
            "train_accuracy": train_accuracy.compute(),
            "val_loss": val_loss.compute(),
            "val_accuracy": val_accuracy.compute()
        }
        
        wandb.log(metrics)
        logger.debug(
            f"Epoch {epoch+1}/{hparams['num_epochs']}: "
            f"Train Loss: {metrics['train_loss']:.3f} | "
            f"Train Acc: {metrics['train_accuracy']:.1%} | "
            f"Val Loss: {metrics['val_loss']:.3f} | "
            f"Val Acc: {metrics['val_accuracy']:.1%}"
        )
        
        # Update best accuracy
        best_accuracy = max(best_accuracy, metrics['val_accuracy'])
        
        # Report to Optuna
        trial.report(metrics['val_accuracy'], epoch)
        
        if trial.should_prune():
            run.finish()
            raise optuna.TrialPruned()
    
    run.finish()
    return best_accuracy

In [None]:
# Load W&B credentials
load_dotenv()
assert os.getenv("WANDB_API_KEY") is not None, "WANDB_API_KEY not found in environment variables"

In [None]:
# Create and run Optuna study
study = optuna.create_study(
    direction="maximize",
    pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=1)
)

In [None]:
logger.info("Starting hyperparameter optimization...")
study.optimize(objective, n_trials=10)

In [None]:
# Print results
logger.info("Best trial:")
trial = study.best_trial
logger.info(f"  Value: {trial.value:.3f}")
logger.info("  Params: ")
for key, value in trial.params.items():
    logger.info(f"    {key}: {value}")

In [None]:
trial.number