<a href="https://colab.research.google.com/github/cs-amy/project-codebase/blob/main/notebooks/colab_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MSc Project Model Training on Google Colab

This notebook sets up the environment for training the letter classification model on Google Colab.

## 1. Clone the GitHub Repository

First, clone your GitHub repository. Replace `YOUR_GITHUB_USERNAME` and `YOUR_REPO_NAME` with your actual GitHub username and repository name.

In [None]:
# Delete the project-codebase directory if it exists
!rm -rf project-codebase

!git clone https://github.com/cs-amy/project-codebase.git
%cd project-codebase

## 2. Mount Google Drive (for data files)

If the project's data files are stored in Google Drive, mount it here.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [25]:
# Create symbolic links to the data directory
!ln -s /content/drive/MyDrive/MScProject/data data

## 3. Install Dependencies

Install the required packages from the requirements.txt file.

In [None]:
!pip install -r requirements.txt

# Specific versions of PyTorch with CUDA support
!pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118

## 4. Set Up Python Path

Ensure that the project modules can be imported correctly.

In [None]:
import sys
sys.path.append('/content/project-codebase')

# Verify imports work
import os
import sys
import argparse
import torch
import logging
from pathlib import Path
import yaml
import json
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from rich.panel import Panel
from rich.table import Table
from rich.live import Live
from rich.layout import Layout

from src.data.data_loader import DataLoader
from src.models.letter_classifier import LetterClassifierCNN
from src.train.trainer import ModelTrainer

print("Imports successful!")

## 5. Load Configuration

Load the training configuration from the config file.

In [None]:
import yaml

with open('configs/train_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully!")

In [35]:
from src.utils.config import load_config, get_model_config, get_training_config, get_data_config
from src.models.letter_classifier import get_model
from src.train.trainer import ModelTrainer

In [36]:
from pathlib import Path
from typing import Dict, Tuple, Optional, List
import logging

# PyTorch imports
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def validate_data_directory(data_dir: Path) -> None:
    """
    Validate the data directory structure.

    Args:
        data_dir: Path to the data/characters directory
    """
    required_dirs = [
        "regular/train",
        "regular/test",
        "obfuscated/train",
        "obfuscated/test"
    ]

    missing_dirs = []
    for dir_path in required_dirs:
        full_path = data_dir / dir_path
        if not full_path.exists():
            missing_dirs.append(dir_path)

    if missing_dirs:
        raise FileNotFoundError(
            f"Missing required directories in {data_dir}:\n" +
            "\n".join(f"- {d}" for d in missing_dirs)
        )

    # Log directory structure
    logger.info("Data directory structure validated:")
    logger.info(f"Root: {data_dir}")
    for dir_path in required_dirs:
        full_path = data_dir / dir_path
        num_files = len(list(full_path.glob("**/*.png")))
        logger.info(f"- {dir_path}: {num_files} PNG files")


class CharacterDataset(Dataset):
    """Dataset for character images (regular or obfuscated)."""

    # Class-level character mapping
    CHAR_TO_IDX = {chr(97 + i): i for i in range(26)}  # a-z to 0-25
    IDX_TO_CHAR = {i: chr(97 + i) for i in range(26)}  # 0-25 to a-z

    def __init__(
        self,
        data_dir: str | Path,
        image_size: Tuple[int, int] = (28, 28),
        transform: Optional[transforms.Compose] = None,
        is_training: bool = True
    ):
        """
        Initialize the dataset.

        Args:
            data_dir: Directory containing character images (e.g., data/characters/regular/train)
            image_size: Target size for images (height, width)
            transform: Optional additional transformations
            is_training: Whether this is a training dataset
        """
        self.data_dir = Path(data_dir)
        if not self.data_dir.exists():
            raise FileNotFoundError(f"Data directory not found: {self.data_dir}")

        self.image_size = image_size
        self.is_training = is_training

        # Get all character directories (a-z)
        self.char_dirs = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])
        if not self.char_dirs:
            raise ValueError(f"No character directories found in {self.data_dir}")

        # Get all image paths and labels
        self.images, self.labels = self._load_dataset()

        # Set up transformations
        self.transform = transform if transform is not None else self._get_default_transforms()

        # Log dataset statistics
        logger.info(f"Loaded {len(self.images)} images from {self.data_dir}")
        char_counts = {char: sum(1 for l in self.labels if l == idx)
                       for char, idx in self.CHAR_TO_IDX.items()}
        logger.info("Character distribution:")
        for char, count in char_counts.items():
            logger.info(f"- {char}: {count} images")

    def _load_dataset(self) -> Tuple[List[Path], List[int]]:
        """Load all image paths and their corresponding labels."""
        images, labels = [], []

        for char_dir in self.char_dirs:
            char = char_dir.name.lower()
            if char not in self.CHAR_TO_IDX:
                logger.warning(f"Skipping unknown character directory: {char}")
                continue

            label = self.CHAR_TO_IDX[char]

            # Get all PNG images in this directory
            char_images = list(char_dir.glob("*.png"))
            images.extend(char_images)
            labels.extend([label] * len(char_images))

        return images, labels

    def _get_default_transforms(self) -> transforms.Compose:
        """Get default transformation pipeline."""
        transform_list = [
            transforms.Resize(self.image_size),
            transforms.Grayscale(1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]

        if self.is_training:
            transform_list.insert(1, transforms.RandomRotation(10))
            transform_list.insert(2, transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1)
            ))

        return transforms.Compose(transform_list)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """Get a single sample."""
        image_path = self.images[idx]
        label = self.labels[idx]

        # Load and convert image
        try:
            image = Image.open(image_path).convert('L')  # Convert to grayscale
            image = self.transform(image)
            return image, label
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {e}")
            # Return a blank image and the label if there's an error
            return torch.zeros((1, *self.image_size)), label

    @classmethod
    def get_char_mapping(cls) -> Tuple[Dict[str, int], Dict[int, str]]:
        """
        Get the character to index and index to character mappings.

        Returns:
            Tuple of (char_to_idx, idx_to_char) dictionaries
        """
        return cls.CHAR_TO_IDX, cls.IDX_TO_CHAR


def get_data_loaders(
    data_dir: str | Path,
    batch_size: int = 32,
    image_size: Tuple[int, int] = (28, 28),
    num_workers: int = 4,
    augment: bool = True
) -> Dict[str, DataLoader]:
    """
    Create data loaders for training, validation, and testing.

    Args:
        data_dir: Path to the data/characters directory (e.g., data/characters)
        batch_size: Batch size for training
        image_size: Target size for images
        num_workers: Number of worker processes for data loading
        augment: Whether to use data augmentation (applied in training mode)

    Returns:
        Dictionary containing train, val, and test data loaders.
    """
    data_dir = Path(data_dir)

    # Validate directory structure
    validate_data_directory(data_dir)

    # Load training datasets for both regular and obfuscated characters
    train_regular = CharacterDataset(
        data_dir / "regular" / "train",
        image_size=image_size,
        is_training=True
    )
    train_obfuscated = CharacterDataset(
        data_dir / "obfuscated" / "train",
        image_size=image_size,
        is_training=True
    )

    # Load test datasets for both regular and obfuscated characters
    test_regular = CharacterDataset(
        data_dir / "regular" / "test",
        image_size=image_size,
        is_training=False
    )
    test_obfuscated = CharacterDataset(
        data_dir / "obfuscated" / "test",
        image_size=image_size,
        is_training=False
    )

    # Combine datasets for training and testing
    train_dataset = ConcatDataset([train_regular, train_obfuscated])
    test_dataset = ConcatDataset([test_regular, test_obfuscated])

    # Create train/validation split (80/20) on the training dataset
    train_length = int(0.8 * len(train_dataset))
    val_length = len(train_dataset) - train_length
    train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_length, val_length])

    # Log dataset statistics
    logger.info("Combined dataset statistics:")
    logger.info(f"- Training set (after split): {len(train_subset)} images")
    logger.info(f"- Validation set: {len(val_subset)} images")
    logger.info(f"- Test set: {len(test_dataset)} images")

    # Create data loaders for each split
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return {
        "train": train_loader,
        "val": val_loader,
        "test": test_loader
    }

## 6. Train Model

Now you can run the training code.

In [38]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Initialize rich console
console = Console()

def setup_device():
    """
    Set up the training device (GPU/CPU) based on availability.
    Uses MPS for Apple Silicon, CUDA for NVIDIA GPUs, or falls back to CPU.
    """
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        console.print("[green]GPU available: Using Metal Performance Shaders (MPS)[/green]")
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        console.print("[green]GPU available: Using CUDA[/green]")
    else:
        device = torch.device("cpu")
        console.print("[yellow]No GPU detected: Using CPU[/yellow]")
    return device

def get_optimal_batch_size(image_size, available_memory_gb=None):
    """
    Calculate optimal batch size based on available memory and image size.

    Args:
        image_size (tuple): Image dimensions (height, width)
        available_memory_gb (float): Available GPU memory in GB. If None, estimates based on system.

    Returns:
        int: Optimal batch size
    """
    # Estimate memory if not provided
    if available_memory_gb is None:
        if torch.backends.mps.is_available() or torch.cuda.is_available():
            available_memory_gb = 16  # Conservative estimate for GPU memory
        else:
            available_memory_gb = 8   # Conservative estimate for CPU memory

    # Calculate memory requirements per sample
    bytes_per_pixel = 4  # float32
    sample_memory = image_size[0] * image_size[1] * bytes_per_pixel

    # Reserve 20% of memory for the model and other operations
    usable_memory = available_memory_gb * 1e9 * 0.2

    # Calculate batch size
    optimal_batch_size = min(128, int(usable_memory / sample_memory))

    # Ensure batch size is at least 16
    return max(16, optimal_batch_size)

def resume_training(trainer, checkpoint_path):
    """
    Resume training from a checkpoint if available.

    Args:
        trainer (Trainer): Training instance
        checkpoint_path (Path): Path to checkpoint file
    """
    if checkpoint_path.exists():
        trainer.load_checkpoint(checkpoint_path)
        console.print(f"[green]Resumed training from {checkpoint_path}[/green]")
        return True
    return False

def monitor_memory():
    """
    Monitor and log GPU/CPU memory usage during training.
    Returns a formatted string with memory information.
    """
    memory_info = []

    if torch.backends.mps.is_available():
        try:
            used_memory = torch.mps.current_allocated_memory() / 1e9
            memory_info.append(f"GPU Memory Used: {used_memory:.2f} GB")
        except:
            memory_info.append("GPU Memory: Not available")
    elif torch.cuda.is_available():
        try:
            used_memory = torch.cuda.memory_allocated() / 1e9
            total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
            memory_info.append(f"GPU Memory: {used_memory:.2f}GB / {total_memory:.2f}GB")
        except:
            memory_info.append("GPU Memory: Not available")
    else:
        import psutil
        process = psutil.Process()
        used_memory = process.memory_info().rss / 1e9
        total_memory = psutil.virtual_memory().total / 1e9
        memory_info.append(f"CPU Memory: {used_memory:.2f}GB / {total_memory:.2f}GB")

    return " | ".join(memory_info)

def create_layout():
    """Create the layout for the training display."""
    layout = Layout()
    layout.split(
        Layout(name="header", size=3),
        Layout(name="body"),
        Layout(name="footer", size=3)
    )
    return layout

def create_header():
    """Create the header panel with training information."""
    return Panel(
        "[bold blue]Letter Classification Model Training[/bold blue]",
        style="white on blue"
    )

def create_footer(epoch, total_epochs, train_loss, train_acc, val_loss, val_acc, lr):
    """Create the footer panel with current training metrics."""
    return Panel(
        f"Epoch: {epoch}/{total_epochs} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Train Acc: {train_acc:.2f}% | "
        f"Val Loss: {val_loss:.4f} | "
        f"Val Acc: {val_acc:.2f}% | "
        f"LR: {lr:.6f}",
        style="white on blue"
    )

def create_metrics_table(train_loss, train_acc, val_loss, val_acc):
    """Create a table with training metrics."""
    table = Table(title="Training Metrics")
    table.add_column("Metric", style="cyan")
    table.add_column("Value", style="green")

    table.add_row("Training Loss", f"{train_loss:.4f}")
    table.add_row("Training Accuracy", f"{train_acc:.2f}%")
    table.add_row("Validation Loss", f"{val_loss:.4f}")
    table.add_row("Validation Accuracy", f"{val_acc:.2f}%")

    return table

def train():
    """Main function for training the model."""
    # Load config
    config_path = Path("configs/train_config.yaml")
    if not config_path.exists():
        console.print(f"[red]Config file not found: {config_path}[/red]")
        return

    config = load_config(config_path)
    model_config = get_model_config(config)
    training_config = get_training_config(config)
    data_config = get_data_config(config)

    # Create output directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path("outputs/letter_classifier") / timestamp
    os.makedirs(output_dir, exist_ok=True)

    # Save config
    with open(output_dir / "config.yaml", "w") as f:
        yaml.dump(config, f, default_flow_style=False)

    # Calculate optimal batch size
    optimal_batch_size = get_optimal_batch_size(model_config["input_shape"][:2])
    if optimal_batch_size != training_config["batch_size"]:
        console.print(f"[yellow]Adjusting batch size from {training_config['batch_size']} to {optimal_batch_size} based on available memory[/yellow]")
        training_config["batch_size"] = optimal_batch_size

    # Create data loaders
    console.print("\n[bold cyan]Loading datasets...[/bold cyan]")
    data_loaders = get_data_loaders(
        data_dir="/content/drive/MyDrive/MScProject/data/characters",
        batch_size=training_config["batch_size"],
        image_size=model_config["input_shape"][:2],
        num_workers=4,
        augment=data_config.get("augmentation", {}).get("use", True)
    )

    # Print dataset statistics
    train_size = len(data_loaders["train"].dataset)
    val_size = len(data_loaders["val"].dataset)
    console.print(f"\n[green]Dataset Statistics:[/green]")
    console.print(f"- Training set: {train_size:,} images")
    console.print(f"- Validation set: {val_size:,} images")
    console.print(f"- Batch size: {training_config['batch_size']}")

    # Create model
    console.print("\n[bold cyan]Initializing model...[/bold cyan]")
    model = get_model(model_config["architecture"], model_config)
    console.print(f"- Input shape: {model_config['input_shape']}")
    console.print(f"- Number of classes: {model_config['num_classes']}")
    console.print(f"- Model architecture: {model_config['architecture']}")

    # Set up device
    device = setup_device()

    # Create trainer
    trainer = ModelTrainer(
        model=model,
        train_loader=data_loaders["train"],
        val_loader=data_loaders["val"],
        config=training_config,
        output_dir=output_dir,
        device=device
    )

    # Try to resume from checkpoint
    checkpoint_path = output_dir / "latest_checkpoint.pth"
    if resume_training(trainer, checkpoint_path):
        console.print("[green]Successfully resumed training from checkpoint[/green]")

    # Create layout for training display
    layout = create_layout()

    # Start training with rich display
    console.print("\n[bold cyan]Starting training...[/bold cyan]")
    with Live(layout, refresh_per_second=4) as live:
        for epoch in range(1, training_config["epochs"] + 1):
            # Update header
            layout["header"].update(create_header())

            # Monitor and display memory usage
            memory_status = monitor_memory()
            console.print(f"\n[cyan]Memory Status: {memory_status}[/cyan]")

            # Train for one epoch
            train_loss, train_acc = trainer.train_epoch()

            # Validate
            val_loss, val_acc, predictions, targets = trainer.validate()

            # Update footer with current metrics
            layout["footer"].update(create_footer(
                epoch, training_config["epochs"],
                train_loss, train_acc,
                val_loss, val_acc,
                trainer.optimizer.param_groups[0]['lr']
            ))

            # Update metrics table
            layout["body"].update(create_metrics_table(
                train_loss, train_acc,
                val_loss, val_acc
            ))

            # Save checkpoint and visualizations
            if epoch % 5 == 0:
                trainer.save_checkpoint(epoch)
                # Monitor memory after checkpoint save
                memory_status = monitor_memory()
                console.print(f"[cyan]Memory Status after checkpoint: {memory_status}[/cyan]")

            if epoch % 10 == 0:
                trainer.plot_confusion_matrix(predictions, targets, epoch)

    # Save final model and plots
    trainer.save_model("final")
    trainer.plot_history()

    # Final memory status
    memory_status = monitor_memory()
    console.print(f"\n[cyan]Final Memory Status: {memory_status}[/cyan]")

    console.print("\n[bold green]Training completed![/bold green]")
    console.print(f"Results saved to: {output_dir}")

In [None]:
# Train model
train()

## 7. Save Results to Google Drive

Save the trained model and results to Google Drive for persistence.

In [None]:
# Create directory for results if it doesn't exist
!mkdir -p /content/drive/MyDrive/MScProject/results

# Copy results to Google Drive
!cp -r results/* /content/drive/MyDrive/MScProject/results