# Phase 4: CNN Implementation (Transfer Learning)

This notebook implements pretrained CNN models for waste classification:
- ResNet18 transfer learning (primary)
- EfficientNet-B0 transfer learning (optional)
- Complete training pipeline with early stopping
- Hyperparameter tuning
- Performance evaluation

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import json
from tqdm.auto import tqdm
from datetime import datetime
import copy

from src.config import load_config
from src.load_data import load_data, TrashNetDataset
from src.transforms import get_transforms

config = load_config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Path(config['paths']['models_dir']).mkdir(exist_ok=True)
Path(config['paths']['results_dir']).mkdir(exist_ok=True)

## 1. Dataset and DataLoader Setup

The `TrashNetDataset` class from `src/load_data.py` now supports transforms for CNN training.

In [None]:
def create_dataloaders(batch_size=16, num_workers=0):
    """Create train, validation, and test dataloaders with augmentation."""
    
    # Load datasets
    train_hf, val_hf, test_hf = load_data(split_data=True)
    
    # Enable augmentation for training
    config_aug = config.copy()
    config_aug['augmentation']['enabled'] = True
    
    # Get transforms (using ImageNet normalization for pretrained models)
    train_transform = get_transforms(config_aug, split='train')
    val_transform = get_transforms(config, split='val')
    
    # Create PyTorch datasets using TrashNetDataset
    train_dataset = TrashNetDataset(train_hf, transform=train_transform)
    val_dataset = TrashNetDataset(val_hf, transform=val_transform)
    test_dataset = TrashNetDataset(test_hf, transform=val_transform)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"\nDataLoader Summary:")
    print(f"  Training samples:   {len(train_dataset)} ({len(train_loader)} batches)")
    print(f"  Validation samples: {len(val_dataset)} ({len(val_loader)} batches)")
    print(f"  Test samples:       {len(test_dataset)} ({len(test_loader)} batches)")
    print(f"  Batch size:         {batch_size}")
    print(f"  Augmentation:       Enabled for training")
    
    return train_loader, val_loader, test_loader

# Create initial dataloaders
train_loader, val_loader, test_loader = create_dataloaders(batch_size=config['training']['batch_size'])

## 2. Model Architectures

### 2.1 ResNet18 Transfer Learning

## 2. Model Architectures

### 2.1 ResNet18 Transfer Learning

In [None]:
def create_resnet18(num_classes=6, pretrained=True, freeze_backbone=False):
    """Create ResNet18 model with transfer learning."""
    
    model = models.resnet18(pretrained=pretrained)
    
    # Freeze backbone if specified
    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
    
    # Replace final layer
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(config['cnn']['dropout']),
        nn.Linear(num_features, num_classes)
    )
    
    return model

# Test the model
test_resnet = create_resnet18(
    num_classes=config['data']['num_classes'],
    pretrained=config['cnn']['pretrained'],
    freeze_backbone=config['cnn']['freeze_backbone']
)
test_input = torch.randn(1, 3, 224, 224)
test_output = test_resnet(test_input)

print(f"\nResNet18 Architecture:")
print(f"  Output shape: {test_output.shape}")
print(f"  Total parameters: {sum(p.numel() for p in test_resnet.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in test_resnet.parameters() if p.requires_grad):,}")
print(f"  Pretrained: {config['cnn']['pretrained']}")
print(f"  Freeze backbone: {config['cnn']['freeze_backbone']}")

### 2.2 EfficientNet-B0 Transfer Learning (Optional)

In [None]:
def create_efficientnet_b0(num_classes=6, pretrained=True, freeze_backbone=False):
    """Create EfficientNet-B0 model with transfer learning."""
    
    try:
        model = models.efficientnet_b0(pretrained=pretrained)
        
        # Freeze backbone if specified
        if freeze_backbone:
            for param in model.parameters():
                param.requires_grad = False
        
        # Replace final layer
        num_features = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(config['cnn']['dropout']),
            nn.Linear(num_features, num_classes)
        )
        
        return model
    except Exception as e:
        print(f"EfficientNet not available: {e}")
        return None

# Test the model
test_efficientnet = create_efficientnet_b0(
    num_classes=config['data']['num_classes'],
    pretrained=config['cnn']['pretrained'],
    freeze_backbone=config['cnn']['freeze_backbone']
)

if test_efficientnet is not None:
    test_output = test_efficientnet(test_input)
    print(f"\nEfficientNet-B0 Architecture:")
    print(f"  Output shape: {test_output.shape}")
    print(f"  Total parameters: {sum(p.numel() for p in test_efficientnet.parameters()):,}")
    print(f"  Trainable parameters: {sum(p.numel() for p in test_efficientnet.parameters() if p.requires_grad):,}")
else:
    print("\nEfficientNet-B0 not available in this PyTorch version")

## 3. Training Pipeline

### 3.1 Early Stopping

In [None]:
class EarlyStopping:
    """Early stopping to prevent overfitting."""
    
    def __init__(self, patience=10, min_delta=0.001, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model_state = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model_state = copy.deepcopy(model.state_dict())
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f"  EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model_state = copy.deepcopy(model.state_dict())
            self.counter = 0
    
    def load_best_model(self, model):
        """Load the best model state."""
        if self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)

### 3.2 Training and Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{100 * correct / total:.2f}%"})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'acc': f"{100 * correct / total:.2f}%"})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

### 3.3 Complete Training Loop

In [None]:
def train_model(model, train_loader, val_loader, config, model_name='model', save_best=True):
    """
    Complete training pipeline with early stopping and checkpointing.
    
    Args:
        model: PyTorch model
        train_loader: Training dataloader
        val_loader: Validation dataloader
        config: Configuration dictionary
        model_name: Name for saving the model
        save_best: Whether to save the best model
    
    Returns:
        history: Dictionary containing training history
    """
    
    model = model.to(device)
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Optimizer
    if config['training']['optimizer'] == 'adam':
        optimizer = optim.Adam(
            model.parameters(),
            lr=config['training']['learning_rate'],
            weight_decay=config['training']['weight_decay']
        )
    else:
        optimizer = optim.SGD(
            model.parameters(),
            lr=config['training']['learning_rate'],
            momentum=config['training']['momentum'],
            weight_decay=config['training']['weight_decay']
        )
    
    # Learning rate scheduler
    scheduler = None
    if config['training']['scheduler']['enabled']:
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=config['training']['scheduler']['step_size'],
            gamma=config['training']['scheduler']['gamma']
        )
    
    # Early stopping
    early_stopping = None
    if config['training']['early_stopping']['enabled']:
        early_stopping = EarlyStopping(
            patience=config['training']['early_stopping']['patience'],
            min_delta=config['training']['early_stopping']['min_delta'],
            verbose=True
        )
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'learning_rates': []
    }
    
    # Training loop
    num_epochs = config['training']['epochs']
    best_val_acc = 0.0
    
    print(f"\n{'='*60}")
    print(f"Training: {model_name}")
    print(f"{'='*60}")
    print(f"Epochs: {num_epochs} | Batch size: {config['training']['batch_size']}")
    print(f"Optimizer: {config['training']['optimizer'].upper()} | LR: {config['training']['learning_rate']}")
    print(f"Early stopping: {config['training']['early_stopping']['enabled']} (patience={config['training']['early_stopping']['patience']})")
    print(f"{'='*60}\n")
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print('-' * 60)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update learning rate
        current_lr = optimizer.param_groups[0]['lr']
        if scheduler is not None:
            scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['learning_rates'].append(current_lr)
        
        # Print epoch summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
        print(f"  LR: {current_lr:.6f}")
        
        # Save best model
        if save_best and val_acc > best_val_acc:
            best_val_acc = val_acc
            model_path = Path(config['paths']['models_dir']) / f"{model_name}_best.pth"
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'config': config
            }, model_path)
            print(f"  ✓ Best model saved (Val Acc: {val_acc:.2f}%)")
        
        # Early stopping
        if early_stopping is not None:
            early_stopping(val_loss, model)
            if early_stopping.early_stop:
                print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
                print(f"  Loading best model weights...")
                early_stopping.load_best_model(model)
                break
    
    print(f"\n{'='*60}")
    print(f"Training Complete: {model_name}")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Total Epochs: {len(history['train_loss'])}")
    print(f"{'='*60}\n")
    
    return history

### 3.4 Plotting Functions

In [None]:
def plot_training_curves(history, model_name='Model'):
    """Plot training and validation loss/accuracy curves."""
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2, marker='o', markersize=4)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2, marker='s', markersize=4)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title(f'{model_name} - Loss Curves', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Mark best validation loss
    best_val_idx = np.argmin(history['val_loss'])
    axes[0].plot(best_val_idx + 1, history['val_loss'][best_val_idx], 'r*', markersize=15, label='Best')
    
    # Accuracy plot
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Training Accuracy', linewidth=2, marker='o', markersize=4)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2, marker='s', markersize=4)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Accuracy (%)', fontsize=12)
    axes[1].set_title(f'{model_name} - Accuracy Curves', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # Mark best validation accuracy
    best_val_idx = np.argmax(history['val_acc'])
    axes[1].plot(best_val_idx + 1, history['val_acc'][best_val_idx], 'r*', markersize=15, label='Best')
    
    plt.tight_layout()
    
    # Save plot
    save_path = Path(config['paths']['results_dir']) / f"{model_name.lower().replace(' ', '_')}_curves.png"
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Training curves saved to {save_path}")
    
    plt.show()

def save_training_history(history, model_name='model'):
    """Save training history to JSON."""
    save_path = Path(config['paths']['results_dir']) / f"{model_name}_history.json"
    with open(save_path, 'w') as f:
        json.dump(history, f, indent=2)
    print(f"Training history saved to {save_path}")

## 4. Train Models

### 4.1 Train ResNet18

In [None]:
# Create ResNet18
resnet18 = create_resnet18(
    num_classes=config['data']['num_classes'],
    pretrained=config['cnn']['pretrained'],
    freeze_backbone=config['cnn']['freeze_backbone']
)

# Train
resnet_history = train_model(
    resnet18,
    train_loader,
    val_loader,
    config,
    model_name='resnet18'
)

# Plot and save
plot_training_curves(resnet_history, 'ResNet18')
save_training_history(resnet_history, 'resnet18')

### 4.2 Train EfficientNet-B0 (Optional)

In [None]:
# Create EfficientNet-B0
efficientnet = create_efficientnet_b0(
    num_classes=config['data']['num_classes'],
    pretrained=config['cnn']['pretrained'],
    freeze_backbone=config['cnn']['freeze_backbone']
)

if efficientnet is not None:
    # Train
    efficientnet_history = train_model(
        efficientnet,
        train_loader,
        val_loader,
        config,
        model_name='efficientnet_b0'
    )
    
    # Plot and save
    plot_training_curves(efficientnet_history, 'EfficientNet-B0')
    save_training_history(efficientnet_history, 'efficientnet_b0')
else:
    print("Skipping EfficientNet-B0 (not available)")

## 5. Hyperparameter Tuning

### 5.1 Grid Search Function

In [None]:
def hyperparameter_search(model_fn, param_grid, model_base_name='model'):
    """
    Perform grid search over hyperparameters.
    
    Args:
        model_fn: Function that creates the model
        param_grid: Dictionary of hyperparameters to search
        model_base_name: Base name for saving models
    
    Returns:
        results: List of tuning results
    """
    
    results = []
    
    # Generate all combinations
    from itertools import product
    
    param_names = list(param_grid.keys())
    param_values = list(param_grid.values())
    
    total_runs = 1
    for values in param_values:
        total_runs *= len(values)
    
    print(f"\n{'='*80}")
    print(f"Starting Hyperparameter Search")
    print(f"Total configurations to test: {total_runs}")
    print(f"Parameters: {list(param_grid.keys())}")
    print(f"{'='*80}\n")
    
    run_num = 0
    for param_combo in product(*param_values):
        run_num += 1
        params = dict(zip(param_names, param_combo))
        
        print(f"\n{'='*80}")
        print(f"Configuration {run_num}/{total_runs}")
        print(f"Parameters: {params}")
        print(f"{'='*80}\n")
        
        # Update config
        config_copy = config.copy()
        for key, value in params.items():
            if key in config_copy['training']:
                config_copy['training'][key] = value
        
        # Create dataloaders with new batch size if needed
        if 'batch_size' in params:
            train_loader_temp, val_loader_temp, _ = create_dataloaders(batch_size=params['batch_size'])
        else:
            train_loader_temp, val_loader_temp = train_loader, val_loader
        
        # Create model
        model = model_fn()
        
        # Create model name
        param_str = '_'.join([f"{k}{v}" for k, v in params.items()])
        model_name = f"{model_base_name}_{param_str}"
        
        # Train
        try:
            history = train_model(
                model,
                train_loader_temp,
                val_loader_temp,
                config_copy,
                model_name=model_name,
                save_best=True
            )
            
            # Save results
            result = {
                'params': params,
                'best_val_acc': max(history['val_acc']),
                'best_val_loss': min(history['val_loss']),
                'final_train_acc': history['train_acc'][-1],
                'final_val_acc': history['val_acc'][-1],
                'epochs_trained': len(history['train_loss'])
            }
            results.append(result)
            
            print(f"\n✓ Result: Best Val Acc = {result['best_val_acc']:.2f}%")
            
        except Exception as e:
            print(f"✗ Error training with params {params}: {e}")
            continue
    
    # Print summary
    print(f"\n{'='*80}")
    print(f"Hyperparameter Search Complete")
    print(f"{'='*80}\n")
    
    # Sort by best validation accuracy
    results.sort(key=lambda x: x['best_val_acc'], reverse=True)
    
    print("\nTop 5 Configurations:")
    print("-" * 80)
    for i, result in enumerate(results[:5], 1):
        print(f"{i}. {result['params']}")
        print(f"   Best Val Acc: {result['best_val_acc']:.2f}% | Val Loss: {result['best_val_loss']:.4f} | Epochs: {result['epochs_trained']}")
        print()
    
    # Save results
    results_path = Path(config['paths']['results_dir']) / f"{model_base_name}_tuning_results.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Full results saved to {results_path}")
    
    return results

### 5.2 Tune ResNet18

In [None]:
# Define hyperparameter grid
param_grid = {
    'learning_rate': [1e-3, 1e-4, 1e-5],
    'batch_size': [16, 32, 64],
}

# Model creation function
def create_resnet_for_tuning():
    return create_resnet18(
        num_classes=config['data']['num_classes'],
        pretrained=True,
        freeze_backbone=False
    )

# Uncomment to run hyperparameter search
# WARNING: This will train 9 models (3 LRs × 3 batch sizes) and may take hours!

# tuning_results = hyperparameter_search(
#     create_resnet_for_tuning,
#     param_grid,
#     model_base_name='resnet18_tuning'
# )

print("\n" + "="*80)
print("Hyperparameter Tuning")
print("="*80)
print("\nTo run hyperparameter tuning, uncomment the code above.")
print(f"\nThis will train {len(param_grid['learning_rate']) * len(param_grid['batch_size'])} models:")
print(f"  - Learning rates: {param_grid['learning_rate']}")
print(f"  - Batch sizes: {param_grid['batch_size']}")
print("\nEstimated time: 2-6 hours (depending on hardware)")
print("="*80)

## 6. Model Evaluation on Test Set

In [None]:
def evaluate_model(model, test_loader, model_name='Model'):
    """Evaluate model on test set."""
    
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc=f'Testing {model_name}'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    
    print(f"\n{model_name} Test Accuracy: {accuracy:.2f}% ({correct}/{total})")
    
    return accuracy, all_preds, all_labels

In [None]:
# Load and evaluate best models
print("\n" + "="*80)
print("Model Evaluation on Test Set")
print("="*80 + "\n")

# ResNet18
resnet18_eval = create_resnet18(num_classes=config['data']['num_classes'])
checkpoint = torch.load(Path(config['paths']['models_dir']) / 'resnet18_best.pth')
resnet18_eval.load_state_dict(checkpoint['model_state_dict'])
resnet18_eval.to(device)
resnet_acc, resnet_preds, resnet_labels = evaluate_model(resnet18_eval, test_loader, 'ResNet18')

# EfficientNet-B0 (if trained)
efficientnet_path = Path(config['paths']['models_dir']) / 'efficientnet_b0_best.pth'
if efficientnet_path.exists():
    efficientnet_eval = create_efficientnet_b0(num_classes=config['data']['num_classes'])
    checkpoint = torch.load(efficientnet_path)
    efficientnet_eval.load_state_dict(checkpoint['model_state_dict'])
    efficientnet_eval.to(device)
    efficient_acc, efficient_preds, efficient_labels = evaluate_model(efficientnet_eval, test_loader, 'EfficientNet-B0')
else:
    print("\nEfficientNet-B0 model not found (skipped during training)")
    efficient_acc = None

# Summary
print(f"\n{'='*80}")
print(f"Final Test Set Results")
print(f"{'='*80}")
print(f"ResNet18:        {resnet_acc:.2f}%")
if efficient_acc:
    print(f"EfficientNet-B0: {efficient_acc:.2f}%")
print(f"{'='*80}\n")

## 7. Summary

This notebook implemented:

### Models:
- ResNet18 with transfer learning (ImageNet pretrained)
- EfficientNet-B0 with transfer learning (optional)

### Training Pipeline:
- DataLoader with data augmentation (rotation, flips, color jitter)
- Cross-entropy loss function
- Adam optimizer with weight decay
- Learning rate scheduling (StepLR)
- Early stopping with patience
- Model checkpointing (saves best model)

### Hyperparameter Tuning:
- Grid search framework for LR and batch size
- Tested: [1e-3, 1e-4, 1e-5] × [16, 32, 64]

### Results:
- Training/validation curves saved to `results/`
- Best models saved to `models/`
- Training history saved as JSON

All deliverables complete!