In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import numpy as np
import math
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
import os
import datetime
import pandas as pd
import seaborn as sns

# Create plots directory if it doesn't exist
os.makedirs("plots", exist_ok=True)

print("Loading data")

# Load the FashionMNIST data
def get_fashionmnist_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    return dataset

# Define the Simple CNN Model
class SimpleCNN(nn.Module):
    def __init__(self, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128), nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

# Noam Learning Rate Scheduler
def noam_lr_schedule(step, d_model=512, warmup=4000):
    return d_model ** -0.5 * min(step ** -0.5, step * warmup ** -1.5)

# Cosine Learning Rate Scheduler
def cosine_lr_schedule(t, eta_max, T, T0):
    if t <= T0:
        return 1e-4 + (eta_max - 1e-4) * (t / T0)
    return eta_max * np.cos((math.pi / 2) * ((t - T0) / (T - T0))) + 1e-6

# Train and Validate the Model
def train_and_validate(train_loader, val_loader, model, optimizer, scheduler_fn, epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    criterion = nn.CrossEntropyLoss()

    train_acc, val_acc = [], []
    train_loss, val_loss = [], []
    global_step = 0
    T = epochs * len(train_loader)
    T0 = T // 5

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0, 0, 0

        for images, labels in train_loader:
            global_step += 1
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # LR scheduling
            lr = scheduler_fn(global_step, T, T0)
            optimizer.param_groups[0]['lr'] = lr

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_acc.append(100 * correct / total)
        train_loss.append(running_loss / len(train_loader))

        # Validation
        model.eval()
        val_loss_epoch, correct, total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss_epoch += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        val_acc.append(100 * correct / total)
        val_loss.append(val_loss_epoch / len(val_loader))

    return train_acc, train_loss, val_acc, val_loss

# Cross-validation for hyperparameter tuning
def cross_validate_model(dataset, model_fn, params, k_folds=5, epochs=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    kfold = KFold(n_splits=k_folds, shuffle=True)
    results = {}

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"Training fold {fold+1}/{k_folds}...")

        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        # Use more workers for faster data loading if available
        num_workers = 4 if torch.cuda.is_available() else 0
        train_loader = DataLoader(train_subset, batch_size=params['batch_size'], shuffle=True, 
                                 num_workers=num_workers, pin_memory=True if torch.cuda.is_available() else False)
        val_loader = DataLoader(val_subset, batch_size=params['batch_size'], shuffle=False,
                               num_workers=num_workers, pin_memory=True if torch.cuda.is_available() else False)

        model = model_fn(dropout=params['dropout']).to(device)
        optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=params['momentum'], weight_decay=params['weight_decay'])

        if params['lr_scheduler'] == 'noam':
            scheduler_fn = lambda step, *_: noam_lr_schedule(step)
        else:
            scheduler_fn = lambda step, T, T0: cosine_lr_schedule(step, eta_max=0.1, T=T, T0=T0)

        train_acc, train_loss, val_acc, val_loss = train_and_validate(
            train_loader, val_loader, model, optimizer, scheduler_fn, epochs
        )

        results[fold] = {
            'train_acc': train_acc,
            'train_loss': train_loss,
            'val_acc': val_acc,
            'val_loss': val_loss
        }

    return results

# Grid Search with Batch Size and Other Parameters
def grid_search(dataset, model_fn, param_grid, k_folds=5, epochs=5):
    best_params = None
    best_val_acc = 0

    # Create a list to store all results for visualization
    results_data = []

    for dropout in param_grid['dropout']:
        for momentum in param_grid['momentum']:
            for batch_size in param_grid['batch_size']:
                for weight_decay in param_grid['weight_decay']:
                    for lr_scheduler in param_grid['lr_scheduler']:
                        params = {
                            'dropout': dropout,
                            'momentum': momentum,
                            'batch_size': batch_size,
                            'weight_decay': weight_decay,
                            'lr_scheduler': lr_scheduler
                        }

                        print(f"Evaluating: {params}")

                        results = cross_validate_model(dataset, model_fn, params, k_folds, epochs)

                        avg_val_acc = np.mean([results[fold]['val_acc'][-1] for fold in range(k_folds)])
                        avg_train_acc = np.mean([results[fold]['train_acc'][-1] for fold in range(k_folds)])
                        avg_train_loss = np.mean([results[fold]['train_loss'][-1] for fold in range(k_folds)])
                        avg_val_loss = np.mean([results[fold]['val_loss'][-1] for fold in range(k_folds)])
                        
                        # Calculate overfitting (train_acc - val_acc)
                        overfitting = avg_train_acc - avg_val_acc

                        print(f"Avg Validation Accuracy: {avg_val_acc:.2f}%")

                        # Store the results for visualization
                        results_data.append({
                            'dropout': dropout,
                            'momentum': momentum,
                            'batch_size': batch_size,
                            'weight_decay': weight_decay,
                            'lr_scheduler': lr_scheduler,
                            'val_acc': avg_val_acc,
                            'train_acc': avg_train_acc,
                            'train_loss': avg_train_loss,
                            'val_loss': avg_val_loss,
                            'overfitting': overfitting
                        })

                        if avg_val_acc > best_val_acc:
                            best_val_acc = avg_val_acc
                            best_params = params

    print(f"Best Params: {best_params}")
    
    # Convert results to DataFrame for easier visualization
    results_df = pd.DataFrame(results_data)
    
    return best_params, results_df

# Visualization functions
def create_visualizations(results_df, timestamp):
    """Create various visualizations to understand hyperparameter effects"""
    
    # 1. Heatmap of batch size vs dropout effect on validation accuracy
    plt.figure(figsize=(12, 8))
    pivot_table = results_df.pivot_table(
        values='val_acc', 
        index='dropout', 
        columns='batch_size', 
        aggfunc='mean'
    )
    sns.heatmap(pivot_table, annot=True, cmap='viridis', fmt='.2f')
    plt.title('Effect of Dropout and Batch Size on Validation Accuracy')
    plt.tight_layout()
    plt.savefig(f"plots/heatmap_dropout_batchsize_{timestamp}.png", dpi=300)
    
    # 2. Scatter plot of batch size vs validation accuracy colored by dropout
    plt.figure(figsize=(12, 8))
    for dropout in results_df['dropout'].unique():
        subset = results_df[results_df['dropout'] == dropout]
        plt.scatter(subset['batch_size'], subset['val_acc'], 
                   label=f'Dropout: {dropout}', alpha=0.7, s=80)
    plt.xlabel('Batch Size')
    plt.ylabel('Validation Accuracy')
    plt.title('Validation Accuracy by Batch Size and Dropout')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/scatter_batchsize_accuracy_{timestamp}.png", dpi=300)
    
    # 3. Bar plot for weight decay effect
    plt.figure(figsize=(12, 8))
    weight_decay_effect = results_df.groupby('weight_decay')['val_acc'].mean().reset_index()
    sns.barplot(x='weight_decay', y='val_acc', data=weight_decay_effect)
    plt.xlabel('Weight Decay')
    plt.ylabel('Average Validation Accuracy')
    plt.title('Effect of Weight Decay on Validation Accuracy')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/barplot_weightdecay_{timestamp}.png", dpi=300)
    
    # 4. Compare scheduler effects
    plt.figure(figsize=(12, 8))
    scheduler_effect = results_df.groupby('lr_scheduler')['val_acc'].mean().reset_index()
    sns.barplot(x='lr_scheduler', y='val_acc', data=scheduler_effect)
    plt.xlabel('Learning Rate Scheduler')
    plt.ylabel('Average Validation Accuracy')
    plt.title('Effect of Learning Rate Scheduler on Validation Accuracy')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/barplot_scheduler_{timestamp}.png", dpi=300)
    
    # 5. Overfitting analysis
    plt.figure(figsize=(12, 8))
    plt.scatter(results_df['train_acc'], results_df['val_acc'], alpha=0.7, s=80)
    plt.plot([75, 100], [75, 100], 'r--', alpha=0.5)  # Diagonal line
    plt.xlabel('Training Accuracy')
    plt.ylabel('Validation Accuracy')
    plt.title('Training vs Validation Accuracy (Diagonal Line Indicates No Overfitting)')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/overfitting_analysis_{timestamp}.png", dpi=300)
    
    # 6. Parallel coordinates plot for all parameters
    plt.figure(figsize=(15, 8))
    pd.plotting.parallel_coordinates(
        results_df, 'batch_size', 
        cols=['dropout', 'momentum', 'weight_decay', 'val_acc'],
        colormap='viridis'
    )
    plt.title('Parallel Coordinates Plot of Hyperparameters')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/parallel_coordinates_{timestamp}.png", dpi=300)

    # 7. Top 10 configurations
    top_10 = results_df.sort_values('val_acc', ascending=False).head(10)
    plt.figure(figsize=(14, 8))
    
    # Create labels for x-axis
    x_labels = []
    for _, row in top_10.iterrows():
        label = f"D:{row['dropout']},M:{row['momentum']}\nB:{row['batch_size']},W:{row['weight_decay']},S:{row['lr_scheduler']}"
        x_labels.append(label)
    
    # Create the bar plot
    bars = plt.bar(range(len(top_10)), top_10['val_acc'], alpha=0.7)
    plt.xticks(range(len(top_10)), x_labels, rotation=45, ha='right')
    plt.ylabel('Validation Accuracy')
    plt.title('Top 10 Hyperparameter Configurations')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(f"plots/top10_configurations_{timestamp}.png", dpi=300)

    print(f"Visualizations saved in plots/ directory with timestamp {timestamp}")

# Run the Experiment
def run_experiment():
    dataset = get_fashionmnist_data()

    # Use the original parameter grid
    param_grid = {
        'dropout': [0.3, 0.5, 0.7],
        'momentum': [0.8, 0.9, 0.95],
        'batch_size': [32, 64, 128],
        'weight_decay': [0.0, 0.0005, 0.001],
        'lr_scheduler': ['noam', 'cosine']
    }

    # Generate timestamp for file naming
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    print("Starting grid search")
    start_time = time.time()
    
    # Run grid search
    best_params, results_df = grid_search(dataset, SimpleCNN, param_grid, k_folds=5, epochs=5)
    
    end_time = time.time()
    execution_time = (end_time - start_time) / 3600  # hours
    
    print(f"Grid search completed in {execution_time:.2f} hours")
    print(f"Best Hyperparameters: {best_params}")
    
    # Create visualizations
    create_visualizations(results_df, timestamp)
    
    # Save the full results dataframe for later analysis
    results_df.to_csv(f"plots/hyperparameter_results_{timestamp}.csv", index=False)
    
    print(f"All results saved. Total execution time: {execution_time:.2f} hours")

if __name__ == '__main__':
    import time
    run_experiment()

Loading data
Starting grid search
Evaluating: {'dropout': 0.3, 'momentum': 0.8, 'batch_size': 32, 'weight_decay': 0.0, 'lr_scheduler': 'noam'}
Training fold 1/5...
Training fold 2/5...
Training fold 3/5...
Training fold 4/5...
Training fold 5/5...
Avg Validation Accuracy: 82.59%
Evaluating: {'dropout': 0.3, 'momentum': 0.8, 'batch_size': 32, 'weight_decay': 0.0, 'lr_scheduler': 'cosine'}
Training fold 1/5...
Training fold 2/5...
Training fold 3/5...


KeyboardInterrupt: 