# CubeMaster: Color Classification Model Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mvipin/cubemaster/blob/main/notebooks/colab_training.ipynb)

This notebook enables training of **MLP** and **Shallow CNN** models for Rubik's Cube color classification using Google Colab's GPU resources.

## Contents
1. [Environment Setup](#1-environment-setup)
2. [Data Preparation](#2-data-preparation)
3. [Model Training](#3-model-training)
4. [Weights & Biases Integration](#4-weights--biases-integration)
   - 4.1 Authentication
   - 4.2 Training with Logging
   - 4.3 **Hyperparameter Sweeps** ‚ö°
   - 4.4 Resume/Join Sweeps
   - 4.5 Sweep Results Analysis
5. [Results Visualization](#5-results-visualization)
6. [Model Export](#6-model-export)

---
## 1. Environment Setup

### 1.1 Check GPU Availability

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU Available: {gpu_name}")
    print(f"   Memory: {gpu_memory:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected. Training will be slower.")
    print("   Go to Runtime > Change runtime type > Hardware accelerator > GPU")

### 1.2 Install Dependencies

In [None]:
# Install required packages
!pip install -q torch torchvision torchmetrics
!pip install -q albumentations opencv-python-headless Pillow
!pip install -q numpy pyyaml tqdm
!pip install -q matplotlib seaborn scikit-learn
!pip install -q onnx onnxruntime
!pip install -q wandb

print("‚úÖ All dependencies installed!")

### 1.3 Clone Repository

In [None]:
import os
from pathlib import Path

# Install Git LFS (required for dataset files)
!apt-get install -qq git-lfs
!git lfs install

# Clone the repository
REPO_URL = "https://github.com/mvipin/cubemaster.git"
REPO_DIR = Path("/content/cubemaster")

if not REPO_DIR.exists():
    !git clone {REPO_URL} {REPO_DIR}
    print(f"‚úÖ Repository cloned to {REPO_DIR}")
else:
    print(f"‚úÖ Repository already exists at {REPO_DIR}")
    # Pull latest changes
    !cd {REPO_DIR} && git pull

# Pull LFS files (dataset images)
print("\nüì• Pulling Git LFS files (dataset images)...")
!cd {REPO_DIR} && git lfs pull
print("‚úÖ LFS files downloaded")

# Change to project directory
os.chdir(REPO_DIR)
print(f"üìÅ Working directory: {os.getcwd()}")

### 1.4 Setup Python Path

In [None]:
import sys
from pathlib import Path

# Add src to Python path
SRC_PATH = Path("/content/cubemaster/src")
if str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

# Verify imports work
try:
    from cubemaster.models import MLPClassifier, ShallowCNNClassifier, MODEL_REGISTRY
    from cubemaster.utils.config import load_config, get_device, set_seed
    from cubemaster.training.dataset import CubeColorDataset
    from cubemaster.training.augmentations import get_train_transforms, get_val_transforms
    from cubemaster.training.trainer import Trainer, EarlyStopping
    from cubemaster import COLOR_CLASSES
    print("‚úÖ All imports successful!")
    print(f"   Available models: {list(MODEL_REGISTRY.keys())}")
    print(f"   Color classes: {COLOR_CLASSES}")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("   Make sure the repository is cloned correctly.")

---
## 2. Data Preparation

### 2.1 Dataset Selection

The CubeMaster repository includes a pre-processed dataset with 4,410 images ready for training.

**Dataset options:**
1. **Repository Dataset** (default) - Use the included `data/processed/` dataset
2. **Google Drive** - Upload your own dataset via Google Drive
3. **Direct Upload** - Upload a zip file directly

In [None]:
from pathlib import Path

# Dataset source selection
# Options: 'repository', 'google_drive', 'upload'
DATASET_SOURCE = 'repository'  # Default: use included dataset

if DATASET_SOURCE == 'repository':
    # Use the dataset included in the cloned repository
    DATASET_PATH = Path("/content/cubemaster/data/processed")
    
    if DATASET_PATH.exists():
        print(f"‚úÖ Using repository dataset: {DATASET_PATH}")
        print("   This dataset contains 4,410 images (train: 4180, val: 101, test: 129)")
    else:
        print(f"‚ùå Repository dataset not found at {DATASET_PATH}")
        print("   Make sure the repository was cloned correctly.")

elif DATASET_SOURCE == 'google_drive':
    # Mount Google Drive and use custom dataset
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set path to your dataset in Google Drive
    # Expected structure: dataset/train/{B,G,O,R,W,Y}/*.jpg
    DATASET_PATH = Path("/content/drive/MyDrive/CubeMaster/dataset")
    print(f"üìÅ Using Google Drive dataset: {DATASET_PATH}")

elif DATASET_SOURCE == 'upload':
    # Direct upload via Colab file picker
    from google.colab import files
    import zipfile
    
    print("Upload your dataset as a zip file:")
    print("Expected structure: train/{B,G,O,R,W,Y}/*.jpg")
    uploaded = files.upload()
    
    # Extract uploaded zip
    for filename in uploaded.keys():
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('/content/dataset')
    DATASET_PATH = Path("/content/dataset")
    print(f"‚úÖ Dataset extracted to {DATASET_PATH}")

else:
    raise ValueError(f"Unknown DATASET_SOURCE: {DATASET_SOURCE}")

### 2.2 Verify Dataset Structure

In [None]:
import os
from pathlib import Path
from collections import defaultdict

def verify_dataset(dataset_path: Path):
    """Verify dataset structure and count samples."""
    dataset_path = Path(dataset_path)
    
    if not dataset_path.exists():
        print(f"‚ùå Dataset path does not exist: {dataset_path}")
        return None
    
    splits = ['train', 'val', 'test']
    expected_classes = ['B', 'G', 'O', 'R', 'W', 'Y']
    stats = defaultdict(dict)
    
    print(f"üìä Dataset Statistics for: {dataset_path}\n")
    print(f"{'Split':<10} {'B':>6} {'G':>6} {'O':>6} {'R':>6} {'W':>6} {'Y':>6} {'Total':>8}")
    print("-" * 60)
    
    for split in splits:
        split_path = dataset_path / split
        if not split_path.exists():
            print(f"‚ö†Ô∏è  {split:<10} - Directory not found")
            continue
        
        total = 0
        row = f"{split:<10}"
        for cls in expected_classes:
            cls_path = split_path / cls
            if cls_path.exists():
                count = len(list(cls_path.glob('*.png'))) + len(list(cls_path.glob('*.jpg')))
                stats[split][cls] = count
                total += count
                row += f" {count:>6}"
            else:
                row += f" {'N/A':>6}"
        stats[split]['total'] = total
        row += f" {total:>8}"
        print(row)
    
    print("-" * 60)
    return stats

# Verify dataset
dataset_stats = verify_dataset(DATASET_PATH)

### 2.3 Setup Data Loaders

In [None]:
from torch.utils.data import DataLoader
from cubemaster.training.dataset import CubeColorDataset
from cubemaster.training.augmentations import get_train_transforms, get_val_transforms

def create_data_loaders(dataset_path, batch_size=32, num_workers=2, image_size=(50, 50)):
    """Create train, validation, and test data loaders."""
    dataset_path = Path(dataset_path)
    
    # Get transforms
    train_transform = get_train_transforms(image_size)
    val_transform = get_val_transforms(image_size)
    
    # Create datasets
    train_dataset = CubeColorDataset(
        root_dir=dataset_path / 'train',
        transform=train_transform
    )
    val_dataset = CubeColorDataset(
        root_dir=dataset_path / 'val',
        transform=val_transform
    )
    test_dataset = CubeColorDataset(
        root_dir=dataset_path / 'test',
        transform=val_transform
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    print(f"‚úÖ Data loaders created:")
    print(f"   Train: {len(train_dataset)} samples, {len(train_loader)} batches")
    print(f"   Val:   {len(val_dataset)} samples, {len(val_loader)} batches")
    print(f"   Test:  {len(test_dataset)} samples, {len(test_loader)} batches")
    
    return train_loader, val_loader, test_loader

# Create data loaders (will be called after config is set)
print("Data loader function defined. Will be created during training.")

---
## 3. Model Training

### 3.1 Training Configuration

In [None]:
import yaml
from cubemaster.utils.config import load_config, set_seed

# Training configuration
TRAINING_CONFIG = {
    # Model selection: 'mlp' or 'shallow_cnn'
    'model_type': 'mlp',  # Change to 'shallow_cnn' for CNN training
    
    # Training hyperparameters
    'batch_size': 32,
    'epochs': 50,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    
    # MLP-specific
    'mlp_hidden_dims': [256, 128],
    'mlp_dropout': 0.3,
    
    # Shallow CNN-specific
    'cnn_dropout': 0.5,
    
    # Early stopping
    'patience': 10,
    'min_delta': 0.001,
    
    # Reproducibility
    'seed': 42,
    
    # Wandb (optional)
    'use_wandb': False,  # Set to True to enable wandb logging
    'wandb_project': 'cubemaster',
}

# Set seed for reproducibility
set_seed(TRAINING_CONFIG['seed'])
print(f"‚úÖ Configuration set for {TRAINING_CONFIG['model_type'].upper()} training")
print(f"   Epochs: {TRAINING_CONFIG['epochs']}, Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"   Learning rate: {TRAINING_CONFIG['learning_rate']}")

### 3.2 Train MLP Model

In [None]:
import torch
import torch.nn as nn
from cubemaster.models import MLPClassifier
from cubemaster.training.trainer import Trainer, EarlyStopping
from cubemaster.utils.config import get_device

def train_mlp(config, train_loader, val_loader):
    """Train MLP model with given configuration."""
    device = get_device()
    print(f"\nüöÄ Training MLP on {device}")
    print("=" * 60)
    
    # Create model
    model = MLPClassifier(
        num_classes=6,
        input_size=(50, 50),
        hidden_dims=config['mlp_hidden_dims'],
        dropout_rate=config['mlp_dropout']
    ).to(device)
    
    params = model.count_parameters()
    print(f"Model parameters: {params['trainable']:,} trainable")
    
    # Setup training components
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    early_stopping = EarlyStopping(
        patience=config['patience'],
        min_delta=config['min_delta']
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        early_stopping=early_stopping,
        checkpoint_dir=Path('/content/cubemaster/models/mlp'),
        use_wandb=config['use_wandb']
    )
    
    # Train
    history = trainer.fit(train_loader, val_loader, epochs=config['epochs'])
    
    print("=" * 60)
    print(f"‚úÖ Training complete! Best val accuracy: {trainer.best_val_acc:.2f}%")
    
    return model, history, trainer

# Train MLP (set model_type='mlp' in config above)
if TRAINING_CONFIG['model_type'] == 'mlp':
    train_loader, val_loader, test_loader = create_data_loaders(
        DATASET_PATH, 
        batch_size=TRAINING_CONFIG['batch_size']
    )
    mlp_model, mlp_history, mlp_trainer = train_mlp(TRAINING_CONFIG, train_loader, val_loader)
else:
    print("‚ÑπÔ∏è Skipping MLP training. Set model_type='mlp' to train.")

### 3.3 Train Shallow CNN Model

In [None]:
from cubemaster.models import ShallowCNNClassifier

def train_shallow_cnn(config, train_loader, val_loader):
    """Train Shallow CNN model with given configuration."""
    device = get_device()
    print(f"\nüöÄ Training Shallow CNN on {device}")
    print("=" * 60)
    
    # Create model
    model = ShallowCNNClassifier(
        num_classes=6,
        input_size=(50, 50),
        dropout_rate=config['cnn_dropout']
    ).to(device)
    
    params = model.count_parameters()
    print(f"Model parameters: {params['trainable']:,} trainable")
    
    # Setup training components
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    early_stopping = EarlyStopping(
        patience=config['patience'],
        min_delta=config['min_delta']
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        early_stopping=early_stopping,
        checkpoint_dir=Path('/content/cubemaster/models/shallow_cnn'),
        use_wandb=config['use_wandb']
    )
    
    # Train
    history = trainer.fit(train_loader, val_loader, epochs=config['epochs'])
    
    print("=" * 60)
    print(f"‚úÖ Training complete! Best val accuracy: {trainer.best_val_acc:.2f}%")
    
    return model, history, trainer

# Train Shallow CNN (set model_type='shallow_cnn' in config above)
if TRAINING_CONFIG['model_type'] == 'shallow_cnn':
    train_loader, val_loader, test_loader = create_data_loaders(
        DATASET_PATH, 
        batch_size=TRAINING_CONFIG['batch_size']
    )
    cnn_model, cnn_history, cnn_trainer = train_shallow_cnn(TRAINING_CONFIG, train_loader, val_loader)
else:
    print("‚ÑπÔ∏è Skipping Shallow CNN training. Set model_type='shallow_cnn' to train.")

---
## 4. Weights & Biases Integration

### 4.1 Wandb Authentication

In [None]:
# Wandb setup for Colab
import wandb

# Login to wandb (will prompt for API key)
# Get your API key from: https://wandb.ai/authorize
wandb.login()

print("‚úÖ Logged into Weights & Biases!")
print("   Dashboard: https://wandb.ai/home")

### 4.2 Train with Wandb Logging

In [None]:
# Enable wandb and retrain
TRAINING_CONFIG['use_wandb'] = True

# Initialize wandb run
run = wandb.init(
    project=TRAINING_CONFIG['wandb_project'],
    name=f"{TRAINING_CONFIG['model_type']}_colab_{TRAINING_CONFIG['seed']}",
    config=TRAINING_CONFIG,
    tags=['colab', TRAINING_CONFIG['model_type']]
)

print(f"üìä Wandb run initialized: {run.url}")

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
    DATASET_PATH, 
    batch_size=TRAINING_CONFIG['batch_size']
)

# Train selected model with wandb
if TRAINING_CONFIG['model_type'] == 'mlp':
    model, history, trainer = train_mlp(TRAINING_CONFIG, train_loader, val_loader)
else:
    model, history, trainer = train_shallow_cnn(TRAINING_CONFIG, train_loader, val_loader)

# Finish wandb run
wandb.finish()
print(f"\n‚úÖ Training logged to wandb: {run.url}")

### 4.3 Hyperparameter Sweeps

Run automated hyperparameter optimization using wandb sweeps. This section allows you to:
- Create new sweeps with predefined configurations
- Run sweep agents to explore hyperparameter space
- Monitor sweep progress in real-time
- Resume or join existing sweeps

In [None]:
# ============================================================
# SWEEP CONFIGURATION
# ============================================================

# Choose which model to sweep
SWEEP_MODEL = 'shallow_cnn'  # Options: 'mlp', 'shallow_cnn'

# Sweep settings
SWEEP_CONFIG = {
    'wandb_project': 'cubemaster',
    'wandb_entity': None,  # Set to your wandb username/team if needed
    'max_runs': 20,  # Maximum runs per sweep (Colab-friendly limit)
    'epochs_per_run': 15,  # Epochs per training run
}

print(f"üîß Sweep Configuration:")
print(f"   Model: {SWEEP_MODEL}")
print(f"   Max runs: {SWEEP_CONFIG['max_runs']}")
print(f"   Epochs per run: {SWEEP_CONFIG['epochs_per_run']}")

In [None]:
# ============================================================
# SWEEP DEFINITIONS
# ============================================================

# MLP Sweep Configuration
MLP_SWEEP_CONFIG = {
    'method': 'bayes',  # bayes, random, or grid
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'lr': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.01
        },
        'batch_size': {
            'values': [16, 32, 64, 128]
        },
        'dropout_rate': {
            'distribution': 'uniform',
            'min': 0.1,
            'max': 0.5
        },
        'hidden_dims': {
            'values': [
                [128, 64],
                [256, 128],
                [512, 256],
                [256, 128, 64],
                [512, 256, 128]
            ]
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.01
        },
        'optimizer': {
            'values': ['adam', 'adamw']
        },
        'label_smoothing': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.2
        }
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 5,
        'eta': 3,
        's': 2
    }
}

# Shallow CNN Sweep Configuration
SHALLOW_CNN_SWEEP_CONFIG = {
    'method': 'bayes',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'lr': {
            'distribution': 'log_uniform_values',
            'min': 0.0001,
            'max': 0.01
        },
        'batch_size': {
            'values': [16, 32, 64, 128]
        },
        'dropout_rate': {
            'distribution': 'uniform',
            'min': 0.2,
            'max': 0.6
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.01
        },
        'optimizer': {
            'values': ['adam', 'adamw', 'sgd']
        },
        'label_smoothing': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 0.2
        },
        'rotation_limit': {
            'values': [10, 15, 20, 30]
        },
        'brightness_limit': {
            'distribution': 'uniform',
            'min': 0.1,
            'max': 0.3
        }
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 5,
        'eta': 3,
        's': 2
    }
}

# Select sweep config based on model choice
ACTIVE_SWEEP_CONFIG = MLP_SWEEP_CONFIG if SWEEP_MODEL == 'mlp' else SHALLOW_CNN_SWEEP_CONFIG
print(f"‚úÖ Loaded {SWEEP_MODEL.upper()} sweep configuration")
print(f"   Method: {ACTIVE_SWEEP_CONFIG['method']}")
print(f"   Parameters: {list(ACTIVE_SWEEP_CONFIG['parameters'].keys())}")

In [None]:
# ============================================================
# SWEEP TRAINING FUNCTION
# ============================================================

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pathlib import Path
import gc

from cubemaster.models import MLPClassifier, ShallowCNNClassifier
from cubemaster.training.trainer import Trainer, EarlyStopping
from cubemaster.training.dataset import CubeColorDataset
from cubemaster.training.augmentations import get_train_transforms, get_val_transforms
from cubemaster.utils.config import get_device

def sweep_train():
    """Training function called by wandb sweep agent."""
    # Initialize wandb run (sweep agent handles config)
    run = wandb.init()
    config = wandb.config
    
    device = get_device()
    print(f"\n{'='*60}")
    print(f"üöÄ Sweep Run: {run.name}")
    print(f"   Device: {device}")
    print(f"   Config: {dict(config)}")
    print(f"{'='*60}\n")
    
    try:
        # Get hyperparameters from sweep config
        lr = config.get('lr', 0.001)
        batch_size = config.get('batch_size', 32)
        dropout_rate = config.get('dropout_rate', 0.3)
        weight_decay = config.get('weight_decay', 0.0001)
        optimizer_name = config.get('optimizer', 'adam')
        label_smoothing = config.get('label_smoothing', 0.1)
        
        # Augmentation params
        rotation_limit = config.get('rotation_limit', 15)
        brightness_limit = config.get('brightness_limit', 0.2)
        
        # Image size
        image_size = (50, 50)
        
        # Create transforms with sweep augmentation params
        train_aug_config = {
            'rotation_limit': rotation_limit,
            'brightness_limit': brightness_limit,
            'horizontal_flip': True,
            'normalize': True
        }
        train_transform = get_train_transforms(image_size, train_aug_config)
        val_transform = get_val_transforms(image_size)
        
        # Create datasets
        train_dataset = CubeColorDataset(
            root_dir=DATASET_PATH / 'train',
            transform=train_transform
        )
        val_dataset = CubeColorDataset(
            root_dir=DATASET_PATH / 'val',
            transform=val_transform
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True,
            num_workers=2, pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False,
            num_workers=2, pin_memory=True
        )
        
        # Create model based on sweep type
        if SWEEP_MODEL == 'mlp':
            hidden_dims = config.get('hidden_dims', [256, 128])
            model = MLPClassifier(
                num_classes=6,
                input_size=image_size,
                hidden_dims=hidden_dims,
                dropout_rate=dropout_rate
            ).to(device)
        else:
            model = ShallowCNNClassifier(
                num_classes=6,
                input_size=image_size,
                dropout_rate=dropout_rate
            ).to(device)
        
        params = model.count_parameters()
        print(f"Model: {SWEEP_MODEL.upper()}, Parameters: {params['trainable']:,}")
        
        # Create optimizer
        if optimizer_name == 'adam':
            optimizer = torch.optim.Adam(
                model.parameters(), lr=lr, weight_decay=weight_decay
            )
        elif optimizer_name == 'adamw':
            optimizer = torch.optim.AdamW(
                model.parameters(), lr=lr, weight_decay=weight_decay
            )
        else:  # sgd
            optimizer = torch.optim.SGD(
                model.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9
            )
        
        # Scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=SWEEP_CONFIG['epochs_per_run'], eta_min=1e-6
        )
        
        # Loss with label smoothing
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
        
        # Early stopping
        early_stopping = EarlyStopping(patience=10, min_delta=0.001)
        
        # Trainer (wandb logging handled by run context)
        trainer = Trainer(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            early_stopping=early_stopping,
            checkpoint_dir=Path(f'/content/sweep_checkpoints/{run.name}'),
            use_wandb=True
        )
        
        # Train
        history = trainer.fit(
            train_loader, val_loader, 
            epochs=SWEEP_CONFIG['epochs_per_run']
        )
        
        print(f"\n‚úÖ Run complete! Best val accuracy: {trainer.best_val_acc:.2f}%")
        
    except Exception as e:
        print(f"‚ùå Run failed: {e}")
        wandb.log({'error': str(e)})
        raise
    
    finally:
        # Clean up GPU memory
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print("‚úÖ Sweep training function defined")

In [None]:
# ============================================================
# CREATE AND RUN SWEEP
# ============================================================

# Create a new sweep
sweep_id = wandb.sweep(
    sweep=ACTIVE_SWEEP_CONFIG,
    project=SWEEP_CONFIG['wandb_project'],
    entity=SWEEP_CONFIG['wandb_entity']
)

print(f"\nüéØ Sweep created!")
print(f"   Sweep ID: {sweep_id}")
print(f"   Dashboard: https://wandb.ai/{SWEEP_CONFIG['wandb_entity'] or 'your-entity'}/{SWEEP_CONFIG['wandb_project']}/sweeps/{sweep_id}")
print(f"\n   To run this sweep later from CLI:")
print(f"   wandb agent {SWEEP_CONFIG['wandb_project']}/{sweep_id}")

In [None]:
# ============================================================
# RUN SWEEP AGENT
# ============================================================

# Run the sweep agent
# This will train models with different hyperparameters
print(f"üèÉ Starting sweep agent...")
print(f"   Max runs: {SWEEP_CONFIG['max_runs']}")
print(f"   Model: {SWEEP_MODEL}")
print(f"\n   Press Ctrl+C or stop the cell to end the sweep early.")
print(f"   Progress will be saved and you can resume later.\n")

wandb.agent(
    sweep_id=sweep_id,
    function=sweep_train,
    count=SWEEP_CONFIG['max_runs'],
    project=SWEEP_CONFIG['wandb_project'],
    entity=SWEEP_CONFIG['wandb_entity']
)

### 4.4 Resume or Join Existing Sweep

Use this section to continue a previous sweep or join an existing one from another machine.

In [None]:
# ============================================================
# RESUME/JOIN EXISTING SWEEP
# ============================================================

# Set this to your existing sweep ID to resume
EXISTING_SWEEP_ID = None  # e.g., 'abc123xy' or 'entity/project/abc123xy'

if EXISTING_SWEEP_ID:
    print(f"üîÑ Resuming sweep: {EXISTING_SWEEP_ID}")
    print(f"   Running {SWEEP_CONFIG['max_runs']} more runs...\n")
    
    wandb.agent(
        sweep_id=EXISTING_SWEEP_ID,
        function=sweep_train,
        count=SWEEP_CONFIG['max_runs'],
        project=SWEEP_CONFIG['wandb_project'],
        entity=SWEEP_CONFIG['wandb_entity']
    )
else:
    print("‚ÑπÔ∏è Set EXISTING_SWEEP_ID to resume a previous sweep.")
    print("   Example: EXISTING_SWEEP_ID = 'abc123xy'")

### 4.5 Sweep Results Analysis

In [None]:
# ============================================================
# ANALYZE SWEEP RESULTS
# ============================================================

def get_sweep_results(sweep_id, project='cubemaster', entity=None):
    """Fetch and analyze sweep results from wandb."""
    api = wandb.Api()
    
    # Get sweep
    sweep_path = f"{entity}/{project}/{sweep_id}" if entity else f"{project}/{sweep_id}"
    sweep = api.sweep(sweep_path)
    
    print(f"üìä Sweep: {sweep.name}")
    print(f"   State: {sweep.state}")
    print(f"   Runs: {len(sweep.runs)}")
    
    # Collect run data
    runs_data = []
    for run in sweep.runs:
        if run.state == 'finished':
            summary = run.summary._json_dict
            config = {k: v for k, v in run.config.items() if not k.startswith('_')}
            runs_data.append({
                'name': run.name,
                'val_acc': summary.get('val_acc', 0),
                'train_acc': summary.get('train_acc', 0),
                'val_loss': summary.get('val_loss', float('inf')),
                **config
            })
    
    if runs_data:
        # Sort by validation accuracy
        runs_data.sort(key=lambda x: x['val_acc'], reverse=True)
        
        print(f"\nüèÜ Top 5 Runs:")
        for i, run in enumerate(runs_data[:5], 1):
            print(f"   {i}. {run['name']}: {run['val_acc']:.2f}% val_acc")
            print(f"      lr={run.get('lr', 'N/A'):.6f}, batch={run.get('batch_size', 'N/A')}, dropout={run.get('dropout_rate', 'N/A'):.2f}")
        
        # Best config
        best = runs_data[0]
        print(f"\nüéØ Best Configuration:")
        for key, value in best.items():
            if key not in ['name', 'val_acc', 'train_acc', 'val_loss']:
                print(f"   {key}: {value}")
        
        return runs_data
    else:
        print("   No finished runs yet.")
        return []

# Analyze the current sweep (uncomment after running sweep)
# results = get_sweep_results(sweep_id, project=SWEEP_CONFIG['wandb_project'])

In [None]:
# ============================================================
# TRAIN BEST CONFIG FROM SWEEP
# ============================================================

def train_best_from_sweep(sweep_id, project='cubemaster', entity=None, epochs=50):
    """Train a model with the best hyperparameters from a sweep."""
    api = wandb.Api()
    sweep_path = f"{entity}/{project}/{sweep_id}" if entity else f"{project}/{sweep_id}"
    sweep = api.sweep(sweep_path)
    
    # Find best run
    best_run = sweep.best_run()
    if not best_run:
        print("‚ùå No completed runs found in sweep.")
        return None
    
    best_config = {k: v for k, v in best_run.config.items() if not k.startswith('_')}
    
    print(f"üèÜ Training with best config from: {best_run.name}")
    print(f"   Original val_acc: {best_run.summary.get('val_acc', 0):.2f}%")
    print(f"   Config: {best_config}")
    print(f"   Training for {epochs} epochs...\n")
    
    # Initialize new run
    run = wandb.init(
        project=project,
        name=f"{SWEEP_MODEL}_best_config",
        config=best_config,
        tags=['best_from_sweep', SWEEP_MODEL]
    )
    
    # Override epochs for longer training
    original_epochs = SWEEP_CONFIG['epochs_per_run']
    SWEEP_CONFIG['epochs_per_run'] = epochs
    
    try:
        sweep_train()
    finally:
        SWEEP_CONFIG['epochs_per_run'] = original_epochs
        wandb.finish()
    
    return run

# Train with best config (uncomment after analyzing sweep)
# train_best_from_sweep(sweep_id, epochs=50)

### 4.6 Colab-Specific Tips

**Session Management:**
- Colab sessions timeout after ~90 minutes of inactivity
- Free tier has ~12 hour runtime limit
- Keep the browser tab active or use browser extensions to prevent timeout

**Sweep Strategy for Colab:**
1. Run sweeps with `max_runs=20-30` to stay within session limits
2. Save your `sweep_id` and resume later if needed
3. Use shorter `epochs_per_run` (10-20) for initial exploration
4. Train best config with more epochs after sweep completes

**Multi-Session Sweeps:**
```python
# Session 1: Create sweep
sweep_id = wandb.sweep(config, project='cubemaster')
print(f"Save this ID: {sweep_id}")  # Copy this!

# Session 2+: Resume sweep
wandb.agent(sweep_id='YOUR_SWEEP_ID', function=sweep_train, count=20)
```

In [None]:
# ============================================================
# QUICK SWEEP: Run a small sweep for testing
# ============================================================

# Minimal sweep config for quick testing
QUICK_SWEEP_CONFIG = {
    'method': 'random',
    'metric': {'name': 'val_acc', 'goal': 'maximize'},
    'parameters': {
        'lr': {'values': [0.001, 0.0005, 0.0001]},
        'batch_size': {'values': [32, 64]},
        'dropout_rate': {'values': [0.3, 0.5]}
    }
}

def run_quick_sweep(num_runs=5, epochs=5):
    """Run a quick sweep for testing."""
    print(f"‚ö° Running quick sweep: {num_runs} runs, {epochs} epochs each")
    
    # Save original epochs
    original_epochs = SWEEP_CONFIG['epochs_per_run']
    SWEEP_CONFIG['epochs_per_run'] = epochs
    
    # Create and run sweep
    quick_sweep_id = wandb.sweep(
        sweep=QUICK_SWEEP_CONFIG,
        project=SWEEP_CONFIG['wandb_project']
    )
    
    print(f"   Sweep ID: {quick_sweep_id}")
    
    try:
        wandb.agent(
            sweep_id=quick_sweep_id,
            function=sweep_train,
            count=num_runs
        )
    finally:
        SWEEP_CONFIG['epochs_per_run'] = original_epochs
    
    return quick_sweep_id

# Uncomment to run a quick test sweep
# quick_id = run_quick_sweep(num_runs=3, epochs=3)

---
## 5. Results Visualization

### 5.1 Training Curves

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_training_curves(history, title="Training Curves"):
    """Plot training and validation curves."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss Curves')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy plot
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Accuracy Curves')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Mark best epoch
    best_epoch = np.argmax(history['val_acc']) + 1
    best_acc = max(history['val_acc'])
    axes[1].axvline(x=best_epoch, color='g', linestyle='--', alpha=0.5, label=f'Best: {best_acc:.1f}%')
    axes[1].scatter([best_epoch], [best_acc], color='g', s=100, zorder=5)
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()
    
    return fig

# Plot training curves
if 'history' in dir():
    plot_training_curves(history, f"{TRAINING_CONFIG['model_type'].upper()} Training Curves")
else:
    print("‚ÑπÔ∏è No training history available. Run training first.")

### 5.2 Confusion Matrix

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from cubemaster import COLOR_CLASSES

def evaluate_and_plot_confusion_matrix(model, test_loader, device):
    """Evaluate model and plot confusion matrix."""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())
    
    # Calculate accuracy
    accuracy = 100 * sum(p == l for p, l in zip(all_preds, all_labels)) / len(all_labels)
    print(f"\nüìä Test Accuracy: {accuracy:.2f}%\n")
    
    # Classification report
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=COLOR_CLASSES))
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=COLOR_CLASSES, yticklabels=COLOR_CLASSES, ax=ax)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(f'Confusion Matrix (Accuracy: {accuracy:.2f}%)')
    plt.tight_layout()
    plt.show()
    
    return accuracy, cm

# Evaluate on test set
if 'model' in dir() and 'test_loader' in dir():
    device = get_device()
    test_accuracy, cm = evaluate_and_plot_confusion_matrix(model, test_loader, device)
else:
    print("‚ÑπÔ∏è No model or test data available. Run training first.")

### 5.3 Model Comparison (Optional)

Train both models and compare their performance.

In [None]:
def compare_models(results_dict):
    """Compare multiple models side by side."""
    if len(results_dict) < 2:
        print("‚ÑπÔ∏è Need at least 2 models to compare.")
        return
    
    # Create comparison table
    print("\n" + "=" * 60)
    print("üìä MODEL COMPARISON")
    print("=" * 60)
    print(f"{'Model':<20} {'Test Acc':>12} {'Parameters':>15}")
    print("-" * 60)
    
    for name, data in results_dict.items():
        acc = data.get('accuracy', 'N/A')
        params = data.get('params', 'N/A')
        print(f"{name:<20} {acc:>11.2f}% {params:>15,}")
    
    print("=" * 60)
    
    # Bar chart comparison
    fig, ax = plt.subplots(figsize=(10, 5))
    models = list(results_dict.keys())
    accuracies = [results_dict[m]['accuracy'] for m in models]
    
    bars = ax.bar(models, accuracies, color=['#2196F3', '#4CAF50'][:len(models)])
    ax.set_ylabel('Test Accuracy (%)')
    ax.set_title('Model Comparison')
    ax.set_ylim(0, 100)
    
    for bar, acc in zip(bars, accuracies):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                f'{acc:.1f}%', ha='center', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Example usage (uncomment after training both models)
# comparison_results = {
#     'MLP': {'accuracy': mlp_test_accuracy, 'params': mlp_model.count_parameters()['trainable']},
#     'Shallow CNN': {'accuracy': cnn_test_accuracy, 'params': cnn_model.count_parameters()['trainable']}
# }
# compare_models(comparison_results)
print("‚ÑπÔ∏è Train both MLP and Shallow CNN to enable comparison.")

---
## 6. Model Export

### 6.1 Save PyTorch Model

In [None]:
import os
from pathlib import Path

def save_pytorch_model(model, trainer, model_name, output_dir='/content/exported_models'):
    """Save trained model in PyTorch format."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save full checkpoint (includes optimizer state, epoch, etc.)
    checkpoint_path = output_dir / f"{model_name}_checkpoint.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': model.get_config(),
        'best_val_acc': trainer.best_val_acc,
        'epoch': trainer.current_epoch,
    }, checkpoint_path)
    print(f"‚úÖ Checkpoint saved: {checkpoint_path}")
    
    # Save model weights only (smaller file)
    weights_path = output_dir / f"{model_name}_weights.pth"
    torch.save(model.state_dict(), weights_path)
    print(f"‚úÖ Weights saved: {weights_path}")
    
    return checkpoint_path, weights_path

# Save trained model
if 'model' in dir() and 'trainer' in dir():
    model_name = TRAINING_CONFIG['model_type']
    checkpoint_path, weights_path = save_pytorch_model(model, trainer, model_name)
else:
    print("‚ÑπÔ∏è No model available. Run training first.")

### 6.2 Export to ONNX Format

In [None]:
import onnx
import onnxruntime as ort

def export_to_onnx(model, model_name, output_dir='/content/exported_models', input_size=(50, 50)):
    """Export model to ONNX format."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    onnx_path = output_dir / f"{model_name}.onnx"
    
    # Set model to eval mode
    model.eval()
    model.cpu()
    
    # Create dummy input
    dummy_input = torch.randn(1, 3, input_size[0], input_size[1])
    
    # Export to ONNX
    print(f"\nüì¶ Exporting to ONNX...")
    torch.onnx.export(
        model,
        dummy_input,
        str(onnx_path),
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        },
        dynamo=False  # Use legacy exporter for compatibility
    )
    
    # Verify ONNX model
    onnx_model = onnx.load(str(onnx_path))
    onnx.checker.check_model(onnx_model)
    print(f"‚úÖ ONNX model exported and verified: {onnx_path}")
    
    # Test with ONNX Runtime
    ort_session = ort.InferenceSession(str(onnx_path))
    ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
    ort_outputs = ort_session.run(None, ort_inputs)
    print(f"   Output shape: {ort_outputs[0].shape}")
    
    # Get file size
    size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
    print(f"   File size: {size_mb:.2f} MB")
    
    return onnx_path

# Export to ONNX
if 'model' in dir():
    model_name = TRAINING_CONFIG['model_type']
    onnx_path = export_to_onnx(model, model_name)
else:
    print("‚ÑπÔ∏è No model available. Run training first.")

### 6.3 Download Trained Models

In [None]:
from google.colab import files
import os

def download_models(output_dir='/content/exported_models'):
    """Download all exported models."""
    output_dir = Path(output_dir)
    
    if not output_dir.exists():
        print("‚ùå No exported models found. Export models first.")
        return
    
    print("\nüì• Available models for download:")
    for f in output_dir.iterdir():
        size_mb = os.path.getsize(f) / (1024 * 1024)
        print(f"   - {f.name} ({size_mb:.2f} MB)")
    
    # Create zip archive
    import shutil
    zip_path = '/content/cubemaster_models.zip'
    shutil.make_archive('/content/cubemaster_models', 'zip', output_dir)
    print(f"\nüì¶ Created archive: {zip_path}")
    
    # Download
    print("\n‚¨áÔ∏è Starting download...")
    files.download(zip_path)

# Download models
download_models()

### 6.4 Copy to Google Drive (Optional)

In [None]:
import shutil

def copy_to_drive(source_dir='/content/exported_models', drive_dest='/content/drive/MyDrive/CubeMaster/trained_models'):
    """Copy exported models to Google Drive."""
    source_dir = Path(source_dir)
    drive_dest = Path(drive_dest)
    
    if not source_dir.exists():
        print("‚ùå No exported models found.")
        return
    
    # Create destination directory
    drive_dest.mkdir(parents=True, exist_ok=True)
    
    # Copy files
    print(f"\nüìÅ Copying models to Google Drive...")
    for f in source_dir.iterdir():
        dest_file = drive_dest / f.name
        shutil.copy2(f, dest_file)
        print(f"   ‚úÖ {f.name} -> {dest_file}")
    
    print(f"\n‚úÖ All models copied to: {drive_dest}")

# Copy to Drive (uncomment to run)
# copy_to_drive()

---
## Troubleshooting

### Common Issues

**1. GPU Not Available**
- Go to Runtime > Change runtime type > Hardware accelerator > GPU
- If GPU quota exceeded, try later or use CPU (slower)

**2. Session Timeout**
- Colab sessions timeout after ~90 minutes of inactivity
- Enable wandb to save training progress
- Save checkpoints to Google Drive periodically

**3. Out of Memory**
- Reduce batch size
- Use a simpler model
- Clear memory: `torch.cuda.empty_cache()`

**4. Dataset Upload Issues**
- Use Google Drive for large datasets (>100MB)
- Compress dataset as ZIP before uploading
- Check file paths after extraction

In [None]:
# Utility: Clear GPU memory
def clear_gpu_memory():
    """Clear GPU memory cache."""
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("‚úÖ GPU memory cleared")
        print(f"   Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"   Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Uncomment to clear memory
# clear_gpu_memory()