# Ingredient Recognition Training - Google Colab

This notebook trains ResNet-50 or SE-ResNet-50 models for ingredient recognition using HuggingFace datasets.

## Quick Start

1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better recommended)
2. **Clone Repository**: Run Cell 1 to clone from GitHub
3. **Run all cells** sequentially
4. Training will use streaming HuggingFace datasets (memory efficient)

## Features
- ‚úÖ GPU acceleration
- ‚úÖ Git-based setup (clone from GitHub)
- ‚úÖ HuggingFace dataset streaming (no download needed)
- ‚úÖ Wandb integration (automatic API key from config)
- ‚úÖ Modular trainer structure
- ‚úÖ Checkpoint saving


## 1. Clone Repository from GitHub

Update the repository URL below with your GitHub repository URL.


In [None]:
# ============================================
# UPDATE THIS: Your GitHub repository URL
# ============================================
REPO_URL = "https://github.com/yourusername/your-repo.git"  # UPDATE THIS!

import os

# Extract repo name from URL
repo_name = REPO_URL.split('/')[-1].replace('.git', '')

# Clone repository
if not os.path.exists(repo_name):
    print(f"üì¶ Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL}
    print("‚úì Repository cloned")
else:
    print(f"‚úì Repository '{repo_name}' already exists")

# Change to project directory
%cd {repo_name}
print(f"‚úì Changed to directory: {os.getcwd()}")

# Verify project structure
required_dirs = ['models', 'trainer', 'configs']
print("\nüìÅ Verifying project structure:")
for dir_name in required_dirs:
    if os.path.exists(dir_name):
        print(f"  ‚úì {dir_name}/ found")
    else:
        print(f"  ‚úó {dir_name}/ missing")
        print(f"     Make sure your repository contains the {dir_name}/ folder")

Looking in indexes: https://download.pytorch.org/whl/cu118
‚úì Dependencies installed


In [None]:
## 2. Install Dependencies


CUDA available: True
GPU: Tesla T4
CUDA version: 12.6


In [None]:
# Install required packages
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install datasets transformers pillow pyyaml wandb scikit-learn matplotlib seaborn tqdm psutil

print("‚úì Dependencies installed")

# Check what's in the current directory
import sys
import os

current_dir = os.getcwd()
print(f"Current working directory: {current_dir}")
print(f"\nContents of {current_dir}:")
print("="*50)

# List directories and files
items = os.listdir(current_dir)
dirs = [d for d in items if os.path.isdir(d)]
files = [f for f in items if os.path.isfile(f)]

print("üìÅ Directories:")
for d in sorted(dirs):
    print(f"  - {d}/")

print(f"\nüìÑ Files (showing first 10):")
for f in sorted(files)[:10]:
    print(f"  - {f}")
if len(files) > 10:
    print(f"  ... and {len(files) - 10} more files")

print("="*50)

# Check if trainer and models folders exist
trainer_exists = os.path.exists('trainer') and os.path.isdir('trainer')
models_exists = os.path.exists('models') and os.path.isdir('models')

print(f"\nüì¶ Required folders:")
print(f"  trainer/ {'‚úì EXISTS' if trainer_exists else '‚úó MISSING'}")
print(f"  models/  {'‚úì EXISTS' if models_exists else '‚úó MISSING'}")

# Setup Python path
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)
    print(f"\n‚úì Added {current_dir} to Python path")

if '/content' not in sys.path:
    sys.path.append('/content')
    print("‚úì Added /content to Python path")

# If folders don't exist, provide instructions
if not trainer_exists or not models_exists:
    print("\n" + "="*50)
    print("‚ö†Ô∏è  REQUIRED FOLDERS MISSING")
    print("="*50)
    print("\nPlease sync/upload your project files:")
    print("\n1. VS Code Colab Extension:")
    print("   - Make sure you opened this notebook from VS Code")
    print("   - Files should auto-sync from your local directory")
    print("   - Check that trainer/ and models/ exist locally")
    print("\n2. Manual Upload:")
    print("   - Use Colab's file browser (folder icon on left)")
    print("   - Upload trainer/ and models/ folders")
    print("\n3. Google Drive:")
    print("   - Upload project to Drive")
    print("   - Mount Drive in next cell")
    print("   - Copy files from Drive")
    print("\n4. Git:")
    print("   - Push project to GitHub")
    print("   - Clone in Colab: !git clone <your-repo-url>")
    print("\nAfter syncing files, re-run this cell to verify.")
else:
    print("\n‚úì Required folders found! Proceeding to import test...")


Current working directory: /content

Contents of /content:
üìÅ Directories:
  - .config/
  - sample_data/

üìÑ Files (showing first 10):

üì¶ Required folders:
  trainer/ ‚úó MISSING
  models/  ‚úó MISSING

‚ö†Ô∏è  REQUIRED FOLDERS MISSING

Please sync/upload your project files:

1. VS Code Colab Extension:
   - Make sure you opened this notebook from VS Code
   - Files should auto-sync from your local directory
   - Check that trainer/ and models/ exist locally

2. Manual Upload:
   - Use Colab's file browser (folder icon on left)
   - Upload trainer/ and models/ folders

3. Google Drive:
   - Upload project to Drive
   - Mount Drive in next cell
   - Copy files from Drive

4. Git:
   - Push project to GitHub
   - Clone in Colab: !git clone <your-repo-url>

After syncing files, re-run this cell to verify.


## 3. Check GPU Availability


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("‚ö†Ô∏è  GPU not available. Please enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU")


‚ö†Ô∏è  Please sync trainer/ and models/ folders first!
Run the previous cell to check folder status.


## 3. Sync Project Files

**IMPORTANT**: You need to sync `trainer/` and `models/` folders before proceeding.

**Choose one method:**


In [None]:
# Setup Python path and test imports
import sys
import os

# Add current directory to Python path
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)
    print(f"‚úì Added {current_dir} to Python path")

# Test imports
print("="*50)
print("Testing Imports")
print("="*50)

try:
    from trainer.hf_dataset import HuggingFaceStreamDataset, get_hf_data_loaders
    print("‚úì trainer.hf_dataset imported successfully")
except ImportError as e:
    print(f"‚úó Failed to import trainer.hf_dataset: {e}")

try:
    from trainer.config import load_config
    print("‚úì trainer.config imported successfully")
except ImportError as e:
    print(f"‚úó Failed to import trainer.config: {e}")

try:
    from trainer.metrics import calculate_metrics
    print("‚úì trainer.metrics imported successfully")
except ImportError as e:
    print(f"‚úó Failed to import trainer.metrics: {e}")

try:
    from trainer.validation import validate
    print("‚úì trainer.validation imported successfully")
except ImportError as e:
    print(f"‚úó Failed to import trainer.validation: {e}")

try:
    from models import create_resnet50, create_se_resnet50
    print("‚úì models imported successfully")
except ImportError as e:
    print(f"‚úó Failed to import models: {e}")

print("="*50)

# Create checkpoints directory
os.makedirs('checkpoints', exist_ok=True)
print("\n‚úì Checkpoints directory created")


## 5. Configuration Setup


In [None]:
# Set your config file path here
CONFIG_PATH = 'configs/resnet50_config.yaml'  # Change this to your config file

# Verify config exists
import os
if os.path.exists(CONFIG_PATH):
    print(f"‚úì Config file found: {CONFIG_PATH}")
    # Display config preview
    with open(CONFIG_PATH, 'r') as f:
        lines = f.readlines()[:15]  # Show first 15 lines
        print("\nConfig file preview:")
        print("="*50)
        print(''.join(lines))
        if len(lines) == 15:
            print("... (truncated)")
else:
    print(f"‚ö†Ô∏è  Config file not found: {CONFIG_PATH}")
    print("Please check the path or update CONFIG_PATH")


## 6. Load Configuration


In [None]:
# Import the trainer module
import sys
import os

# Add current directory to path (works for both Colab and git clone)
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)

# Now we can import from trainer
from trainer.config import load_config

# Load configuration
print(f"Loading configuration from: {CONFIG_PATH}")
print(f"Current working directory: {os.getcwd()}")
cfg = load_config(CONFIG_PATH)

# Print configuration summary
print("\n" + "="*50)
print("Configuration Summary")
print("="*50)
print(f"Model: {cfg['model']}")
print(f"Dataset: {cfg.get('dataset_name', 'N/A')}")
print(f"Epochs: {cfg['epochs']}")
print(f"Batch size: {cfg['batch_size']}")
print(f"Learning rate: {cfg['lr']}")
print(f"Optimizer: {cfg['optimizer']}")
print(f"Scheduler: {cfg['scheduler'].get('type', 'StepLR')}")
print(f"Wandb: {'Enabled' if cfg['use_wandb'] else 'Disabled'}")
print("="*50)


Device: cuda
GPU: Tesla T4
CUDA Version: 12.6
GPU Memory: 14.74 GB


## 7. Start Training

The training will run using GPU acceleration. You can monitor progress in wandb if enabled.


In [None]:
# Import the trainer module
import sys
import os

# Add current directory to path (works for both Colab and VS Code sync)
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.append(current_dir)
if '/content' not in sys.path:
    sys.path.append('/content')

# Now we can import from trainer
from trainer.config import load_config

# Load configuration
print(f"Loading configuration from: {CONFIG_PATH}")
print(f"Current working directory: {os.getcwd()}")
cfg = load_config(CONFIG_PATH)

# Print configuration summary
print("\n" + "="*50)
print("Configuration Summary")
print("="*50)
print(f"Model: {cfg['model']}")
print(f"Dataset: {cfg.get('dataset_name', 'N/A')}")
print(f"Epochs: {cfg['epochs']}")
print(f"Batch size: {cfg['batch_size']}")
print(f"Learning rate: {cfg['lr']}")
print(f"Optimizer: {cfg['optimizer']}")
print(f"Scheduler: {cfg['scheduler'].get('type', 'StepLR')}")
print(f"Wandb: {'Enabled' if cfg['use_wandb'] else 'Disabled'}")
print("="*50)


## 7. Start Training

The training will run using GPU acceleration. You can monitor progress in wandb if enabled.


In [None]:
# Run training using the trainer module
# This uses the same code as your local training script
import sys
import os

# Ensure we're in the right directory and paths are set
current_dir = os.getcwd()
if current_dir not in sys.path:
    sys.path.insert(0, current_dir)
if '/content' not in sys.path:
    sys.path.append('/content')

# Set up sys.argv to simulate command line call
original_argv = sys.argv.copy()
sys.argv = ['train.py', CONFIG_PATH]

try:
    # Import and run main function
    from trainer.train import main
    
    print("="*50)
    print("Starting Training")
    print("="*50)
    print(f"Using GPU: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("="*50)
    
    # Run training
    main()
finally:
    # Restore original argv
    sys.argv = original_argv


## 8. Download Results (Optional)

After training completes, download checkpoints and results to your local machine.


In [None]:
# Download checkpoints and results
from google.colab import files
import os

# Download best model
if os.path.exists('checkpoints'):
    checkpoint_files = [f for f in os.listdir('checkpoints') if f.endswith('.pth')]
    results_files = [f for f in os.listdir('checkpoints') if f.endswith('.json')]
    
    print("Available files to download:")
    for f in checkpoint_files + results_files:
        print(f"  - checkpoints/{f}")
    
    # Download all checkpoints
    for f in checkpoint_files + results_files:
        files.download(f'checkpoints/{f}')
        print(f"‚úì Downloaded: {f}")
else:
    print("No checkpoints directory found")


In [None]:
# Create __init__.py for models package if it doesn't exist
if not os.path.exists('models/__init__.py'):
    with open('models/__init__.py', 'w') as f:
        f.write('''
from .resnet50 import ResNet50, create_resnet50
from .se_resnet50 import SEResNet50, create_se_resnet50

__all__ = [
    'ResNet50',
    'create_resnet50',
    'SEResNet50',
    'create_se_resnet50',
]
''')

# Import model classes
from models import create_resnet50, create_se_resnet50

# Import other dependencies
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import json
from datetime import datetime
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

print("‚úì All imports successful")


## 8. Initialize Wandb (Optional)


In [None]:
use_wandb = cfg.get('use_wandb', False)

if use_wandb:
    try:
        import wandb
        
        # Login to wandb (first time only - uncomment to login)
        # wandb.login()  # Run this once, then comment it out
        
        # Generate run name if not provided
        run_name = cfg.get('wandb_run_name')
        if run_name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            run_name = f"{cfg['model']}_{timestamp}"
        
        wandb.init(
            project=cfg.get('wandb_project', 'ingredient-recognition'),
            entity=cfg.get('wandb_entity'),
            name=run_name,
            tags=cfg.get('wandb_tags', []),
            config={
                'model': cfg['model'],
                'epochs': cfg['epochs'],
                'batch_size': cfg['batch_size'],
                'learning_rate': cfg['lr'],
                'weight_decay': cfg['weight_decay'],
                'image_size': cfg['image_size'],
                'num_workers': cfg['num_workers'],
                'se_reduction': cfg.get('se_reduction') if cfg['model'] == 'se_resnet50' else None,
                'device': str(device),
            }
        )
        print(f"‚úì Wandb initialized: {wandb.run.url}")
    except ImportError:
        print("‚ö†Ô∏è  wandb not installed. Continuing without wandb.")
        use_wandb = False
else:
    print("Wandb disabled in config")


## 9. Setup Data Loaders


In [None]:
def get_data_loaders(data_dir, batch_size=32, num_workers=2, image_size=224):
    """Create data loaders for training and validation"""
    # Data augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation transform (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load datasets
    train_dataset = ImageFolder(os.path.join(data_dir, 'train'), transform=train_transform)
    val_dataset = ImageFolder(os.path.join(data_dir, 'val'), transform=val_transform)
    
    num_classes = len(train_dataset.classes)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader, num_classes, train_dataset.classes

# Load data
print("\nLoading datasets...")
train_loader, val_loader, num_classes, class_names = get_data_loaders(
    cfg['data_dir'], cfg['batch_size'], cfg['num_workers'], cfg['image_size']
)
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"\nFirst 10 classes: {class_names[:10]}")


## 10. Create Model


In [None]:
# Create model
print(f"\nCreating {cfg['model']} model...")
if cfg['model'] == 'resnet50':
    model = create_resnet50(num_classes=num_classes, pretrained=cfg.get('pretrained', True))
else:
    model = create_se_resnet50(num_classes=num_classes, pretrained=cfg.get('pretrained', True), 
                               reduction=cfg.get('se_reduction', 16))

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Log model info to wandb
if use_wandb:
    wandb.config.update({
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'num_classes': num_classes,
        'train_samples': len(train_loader.dataset),
        'val_samples': len(val_loader.dataset),
    })
    wandb.watch(model, log='all', log_freq=100)


## 11. Setup Optimizer and Scheduler


In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()

# Create optimizer based on config
optimizer_type = cfg.get('optimizer', 'Adam').lower()
if optimizer_type == 'sgd':
    sgd_cfg = cfg.get('sgd', {})
    optimizer = optim.SGD(
        model.parameters(), 
        lr=cfg['lr'], 
        weight_decay=cfg['weight_decay'],
        momentum=sgd_cfg.get('momentum', 0.9),
        nesterov=sgd_cfg.get('nesterov', False)
    )
else:  # Default to Adam
    optimizer = optim.Adam(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])

# Create scheduler based on config
scheduler_cfg = cfg.get('scheduler', {})
scheduler_type = scheduler_cfg.get('type', 'StepLR').lower()
if scheduler_type == 'cosineannealinglr':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=cfg['epochs']
    )
elif scheduler_type == 'reducelronplateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=scheduler_cfg.get('gamma', 0.1), patience=5
    )
else:  # Default to StepLR
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=scheduler_cfg.get('step_size', 15), 
        gamma=scheduler_cfg.get('gamma', 0.1)
    )

print(f"Optimizer: {optimizer_type}")
print(f"Scheduler: {scheduler_type}")
print(f"Initial Learning Rate: {cfg['lr']}")


## 12. Training Functions


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device, epoch, use_wandb=False):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]')
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # Log batch metrics to wandb
        if use_wandb and batch_idx % 10 == 0:
            wandb.log({
                'train/batch_loss': loss.item(),
                'train/batch_acc': 100 * (predicted == labels).sum().item() / labels.size(0),
                'train/epoch': epoch + 1,
                'train/batch': batch_idx
            })
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{running_loss / (batch_idx + 1):.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device, use_wandb=False):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            pbar.set_postfix({
                'loss': f'{running_loss / (pbar.n + 1):.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc, all_preds, all_labels


def calculate_metrics(all_preds, all_labels, num_classes):
    """Calculate precision, recall, and F1-score"""
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted', zero_division=0
    )
    
    precision_per_class, recall_per_class, f1_per_class, _ = precision_recall_fscore_support(
        all_labels, all_preds, average=None, zero_division=0
    )
    
    cm = confusion_matrix(all_labels, all_preds)
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'precision_per_class': precision_per_class.tolist(),
        'recall_per_class': recall_per_class.tolist(),
        'f1_per_class': f1_per_class.tolist(),
        'confusion_matrix': cm.tolist()
    }

print("‚úì Training functions defined")


In [None]:
# Create save directory
os.makedirs(cfg['save_dir'], exist_ok=True)

# Resume from checkpoint if specified
start_epoch = 0
best_val_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

if cfg.get('resume'):
    print(f"\nResuming from checkpoint: {cfg['resume']}")
    checkpoint = torch.load(cfg['resume'], map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_val_acc = checkpoint['best_val_acc']
    history = checkpoint['history']

# Training loop
print("\n" + "="*50)
print("Starting Training")
print("="*50)

for epoch in range(start_epoch, cfg['epochs']):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, 
                                       optimizer, device, epoch, use_wandb=use_wandb)
    
    # Validate
    val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, 
                                                       criterion, device, use_wandb=use_wandb)
    
    # Update learning rate
    if scheduler_type == 'reducelronplateau':
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
    else:
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Log epoch metrics to wandb
    if use_wandb:
        wandb.log({
            'epoch': epoch + 1,
            'train/epoch_loss': train_loss,
            'train/epoch_acc': train_acc,
            'val/epoch_loss': val_loss,
            'val/epoch_acc': val_acc,
            'learning_rate': current_lr,
        })
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{cfg['epochs']}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_acc': best_val_acc,
        'history': history,
        'num_classes': num_classes,
        'class_names': class_names,
        'model_type': cfg['model']
    }
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        checkpoint['best_val_acc'] = best_val_acc
        best_model_path = os.path.join(cfg['save_dir'], f"{cfg['model']}_best.pth")
        torch.save(checkpoint, best_model_path)
        print(f"‚úì Saved best model (Val Acc: {val_acc:.2f}%)")
        
        # Log best model to wandb
        if use_wandb:
            wandb.run.summary['best_val_acc'] = best_val_acc
            wandb.run.summary['best_epoch'] = epoch + 1
            wandb.save(best_model_path)
    
    # Save latest checkpoint
    latest_model_path = os.path.join(cfg['save_dir'], f"{cfg['model']}_latest.pth")
    torch.save(checkpoint, latest_model_path)
    
    # Log checkpoint to wandb
    if use_wandb:
        wandb.save(latest_model_path)

print("\n‚úì Training complete!")


In [None]:
# Final evaluation with detailed metrics
print("\n" + "="*50)
print("Final Evaluation")
print("="*50)

final_val_loss, final_val_acc, final_preds, final_labels = validate(
    model, val_loader, criterion, device
)

metrics = calculate_metrics(final_preds, final_labels, num_classes)

print(f"\nFinal Validation Results:")
print(f"Accuracy: {final_val_acc:.2f}%")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1-Score: {metrics['f1']:.4f}")

# Log final metrics to wandb
if use_wandb:
    wandb.run.summary.update({
        'final_accuracy': final_val_acc,
        'final_precision': metrics['precision'],
        'final_recall': metrics['recall'],
        'final_f1': metrics['f1'],
    })
    
    # Log confusion matrix
    try:
        import matplotlib.pyplot as plt
        import seaborn as sns
        import numpy as np
        
        cm = np.array(metrics['confusion_matrix'])
        # Plot top 20 classes for readability
        if len(class_names) > 20:
            class_counts = cm.sum(axis=1)
            top_indices = np.argsort(class_counts)[-20:]
            cm_plot = cm[np.ix_(top_indices, top_indices)]
            class_names_plot = [class_names[i] for i in top_indices]
        else:
            cm_plot = cm
            class_names_plot = class_names
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm_plot, annot=True, fmt='d', cmap='Blues',
                   xticklabels=class_names_plot, yticklabels=class_names_plot)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        
        wandb.log({'confusion_matrix': wandb.Image(plt)})
        plt.close()
    except Exception as e:
        print(f"Warning: Could not log confusion matrix to wandb: {e}")

# Save final metrics
results = {
    'model': cfg['model'],
    'num_classes': num_classes,
    'final_accuracy': final_val_acc,
    'final_precision': metrics['precision'],
    'final_recall': metrics['recall'],
    'final_f1': metrics['f1'],
    'history': history,
    'timestamp': datetime.now().isoformat(),
    'config': cfg
}

results_path = os.path.join(cfg['save_dir'], f"{cfg['model']}_results.json")
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

# Log results file to wandb
if use_wandb:
    wandb.save(results_path)
    wandb.finish()

print(f"\n‚úì Results saved to {cfg['save_dir']}")
if use_wandb:
    print(f"‚úì Wandb run: {wandb.run.url}")


## 15. Download Results

Download checkpoints and results to your local machine:


In [None]:
# Download checkpoints
from google.colab import files

# Download best model
best_model_path = os.path.join(cfg['save_dir'], f"{cfg['model']}_best.pth")
if os.path.exists(best_model_path):
    print(f"Downloading {best_model_path}...")
    files.download(best_model_path)

# Download results JSON
results_path = os.path.join(cfg['save_dir'], f"{cfg['model']}_results.json")
if os.path.exists(results_path):
    print(f"Downloading {results_path}...")
    files.download(results_path)

print("‚úì Files downloaded")


## 16. Save to Google Drive (Optional)

Save checkpoints to Google Drive for permanent storage:


In [None]:
# Copy checkpoints to Drive (uncomment to enable)
# !cp -r checkpoints /content/drive/MyDrive/ai_coursework/

print("To save to Drive, uncomment the line above and update the path")
