<a href="https://colab.research.google.com/github/giosanchez0208/CSC173-DeepCV-Sanchez/blob/main/refine_segmentation_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup: Google Colab Environment

In [None]:
# Check GPU availability
import torch
import os
import sys

# Verify CUDA is available
if torch.cuda.is_available():
    device = 'cuda'
    print(f'GPU available: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    device = 'cpu'
    print('No GPU available, using CPU (training will be slow)')

print(f'\nPyTorch version: {torch.__version__}')
print(f'CUDA version: {torch.version.cuda}')

In [None]:
# Install required packages
!pip install -q ultralytics opencv-python-headless pillow pyyaml numpy scipy matplotlib pandas gdown

print('All packages installed successfully')

## Dataset and Model Download (Reproducible Setup)

In [None]:
import os
import gdown
import zipfile
import yaml
from pathlib import Path
import shutil

# Public Google Drive folder for the dataset
PUBLIC_DRIVE_FOLDER = 'https://drive.google.com/drive/folders/1rYbZXSwnd0DQ49VP_04C8OYkhU6b4p-t'

# File IDs (you'll need to extract these from the public folder)
# You can get these by right-clicking on the file in Google Drive and getting the shareable link
# The file ID is the long string between /d/ and /view
DATASET_ZIP_ID = 'YOUR_DATASET_ZIP_FILE_ID'  # Replace with actual file ID
PRETRAINED_MODEL_ID = 'YOUR_PRETRAINED_MODEL_FILE_ID'  # Replace with actual file ID

# Local paths
LOCAL_ROOT = '/content/ocr_project'
DATASET_ZIP_PATH = f'{LOCAL_ROOT}/dataset.zip'
LOCAL_DATASET_PATH = f'{LOCAL_ROOT}/dataset'
LOCAL_MODELS_PATH = f'{LOCAL_ROOT}/models'
LOCAL_CHECKPOINTS_PATH = f'{LOCAL_ROOT}/checkpoints'

# Create directories
for path in [LOCAL_ROOT, LOCAL_MODELS_PATH, LOCAL_CHECKPOINTS_PATH]:
    Path(path).mkdir(parents=True, exist_ok=True)

print('Directory structure created:')
print(f'  Root: {LOCAL_ROOT}')
print(f'  Dataset: {LOCAL_DATASET_PATH}')
print(f'  Models: {LOCAL_MODELS_PATH}')
print(f'  Checkpoints: {LOCAL_CHECKPOINTS_PATH}')

In [None]:
def download_file_from_drive(file_id, output_path, retries=3):
    """Download a file from Google Drive using gdown."""
    url = f'https://drive.google.com/uc?id={file_id}'
    
    for attempt in range(retries):
        try:
            print(f'Download attempt {attempt + 1}/{retries} for {output_path}')
            gdown.download(url, output_path, quiet=False)
            
            if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
                print(f'Successfully downloaded: {output_path}')
                print(f'File size: {os.path.getsize(output_path) / (1024**2):.2f} MB')
                return True
            else:
                print(f'Download failed: file is empty or does not exist')
                
        except Exception as e:
            print(f'Download attempt {attempt + 1} failed: {e}')
            
    return False

# Download dataset if not already present
if not os.path.exists(DATASET_ZIP_PATH):
    print('Downloading dataset...')
    # IMPORTANT: Replace with actual file ID from your public folder
    # You need to get the actual file ID for dataset.zip
    success = download_file_from_drive(DATASET_ZIP_ID, DATASET_ZIP_PATH)
    if not success:
        print('ERROR: Failed to download dataset. Please check the file ID.')
        print(f'Public folder: {PUBLIC_DRIVE_FOLDER}')
        print('Please update DATASET_ZIP_ID with the actual file ID from the dataset.zip file')
        raise FileNotFoundError('Dataset download failed')
else:
    print(f'Dataset already exists at: {DATASET_ZIP_PATH}')

# Download pretrained model if not already present
PRETRAINED_MODEL_PATH = f'{LOCAL_MODELS_PATH}/custom_ocr_last.pt'
if not os.path.exists(PRETRAINED_MODEL_PATH):
    print('\nDownloading pretrained model...')
    # IMPORTANT: Replace with actual file ID from your public folder
    # You need to get the actual file ID for custom_ocr_last.pt
    success = download_file_from_drive(PRETRAINED_MODEL_ID, PRETRAINED_MODEL_PATH)
    if not success:
        print('WARNING: Failed to download pretrained model. Training will start from scratch.')
        PRETRAINED_MODEL_PATH = None
else:
    print(f'Pretrained model already exists at: {PRETRAINED_MODEL_PATH}')

In [None]:
# Extract dataset if not already extracted
if not os.path.exists(LOCAL_DATASET_PATH) or not os.path.exists(f'{LOCAL_DATASET_PATH}/data.yaml'):
    print('\nExtracting dataset...')
    
    if not os.path.exists(DATASET_ZIP_PATH):
        print(f'ERROR: Dataset zip not found at {DATASET_ZIP_PATH}')
        raise FileNotFoundError('Dataset zip file not found')
    
    try:
        with zipfile.ZipFile(DATASET_ZIP_PATH, 'r') as zip_ref:
            zip_ref.extractall(LOCAL_ROOT)
        
        # Check if extraction created a nested 'dataset' folder
        extracted_paths = list(Path(LOCAL_ROOT).glob('**/data.yaml'))
        
        if extracted_paths:
            # Found data.yaml somewhere
            data_yaml_path = str(extracted_paths[0])
            dataset_parent = str(extracted_paths[0].parent)
            
            # If it's not in the expected location, move it
            if dataset_parent != LOCAL_DATASET_PATH:
                print(f'Moving dataset from {dataset_parent} to {LOCAL_DATASET_PATH}')
                if os.path.exists(LOCAL_DATASET_PATH):
                    shutil.rmtree(LOCAL_DATASET_PATH)
                shutil.move(dataset_parent, LOCAL_DATASET_PATH)
        
        print(f'Dataset extracted to: {LOCAL_DATASET_PATH}')
        
    except Exception as e:
        print(f'ERROR: Failed to extract dataset: {e}')
        raise
else:
    print(f'Dataset already extracted at: {LOCAL_DATASET_PATH}')

# Verify data.yaml exists
DATA_YAML_PATH = f'{LOCAL_DATASET_PATH}/data.yaml'
if not os.path.exists(DATA_YAML_PATH):
    print(f'ERROR: data.yaml not found at {DATA_YAML_PATH}')
    print('Looking for data.yaml in extracted files...')
    
    # Search for data.yaml
    for root, dirs, files in os.walk(LOCAL_ROOT):
        if 'data.yaml' in files:
            DATA_YAML_PATH = os.path.join(root, 'data.yaml')
            print(f'Found data.yaml at: {DATA_YAML_PATH}')
            break
    
    if not os.path.exists(DATA_YAML_PATH):
        raise FileNotFoundError(f'data.yaml not found in {LOCAL_ROOT}')

print(f'\nData configuration: {DATA_YAML_PATH}')
print(f'Pretrained model: {PRETRAINED_MODEL_PATH if PRETRAINED_MODEL_PATH else "None (starting from scratch)"}')

In [None]:
# Sanity check: Verify dataset structure
print('\n=== DATASET SANITY CHECK ===')

required_folders = ['train/images', 'train/labels', 
                    'val/images', 'val/labels',
                    'test/images', 'test/labels']

all_good = True
for folder in required_folders:
    folder_path = f'{LOCAL_DATASET_PATH}/{folder}'
    if os.path.exists(folder_path):
        file_count = len(os.listdir(folder_path))
        print(f'✓ {folder}: {file_count} files')
    else:
        print(f'✗ {folder}: NOT FOUND')
        all_good = False

# Check data.yaml content
try:
    with open(DATA_YAML_PATH, 'r') as f:
        data_config = yaml.safe_load(f)
    
    print(f'\ndata.yaml content:')
    print(f'  Classes: {len(data_config.get("names", []))}')
    print(f'  Path: {data_config.get("path", "Not specified")}')
    print(f'  Train: {data_config.get("train", "Not specified")}')
    print(f'  Val: {data_config.get("val", "Not specified")}')
    
    # Fix paths in data.yaml if they're wrong
    if data_config.get('path') != LOCAL_DATASET_PATH:
        print(f'\nFixing data.yaml paths...')
        data_config['path'] = LOCAL_DATASET_PATH
        data_config['train'] = 'train/images'
        data_config['val'] = 'val/images'
        data_config['test'] = 'test/images'
        
        with open(DATA_YAML_PATH, 'w') as f:
            yaml.dump(data_config, f, default_flow_style=False)
        print(f'  Updated paths in data.yaml')
        
except Exception as e:
    print(f'ERROR reading data.yaml: {e}')
    all_good = False

if all_good:
    print('\n✓ All dataset checks passed')
else:
    print('\n✗ Dataset has issues. Please check the structure.')
    raise ValueError('Dataset structure is incorrect')

## Enhanced Checkpoint Management System

In [None]:
import pandas as pd
import json
import time
import threading
from datetime import datetime
from pathlib import Path

class CheckpointManager:
    """
    Robust checkpoint management system that:
    1. Saves CSV progress every epoch
    2. Saves model checkpoints to Google Drive every epoch
    3. Maintains detailed training history
    4. Allows resuming from any point
    """
    
    def __init__(self, experiment_name, local_checkpoint_dir, drive_checkpoint_dir=None):
        """
        Initialize checkpoint manager.
        
        Args:
            experiment_name: Name for this training run
            local_checkpoint_dir: Local directory for fast checkpoint access
            drive_checkpoint_dir: Google Drive directory for permanent storage (optional)
        """
        self.experiment_name = experiment_name
        self.local_dir = Path(local_checkpoint_dir) / experiment_name
        self.drive_dir = Path(drive_checkpoint_dir) / experiment_name if drive_checkpoint_dir else None
        
        # Create directories
        self.local_dir.mkdir(parents=True, exist_ok=True)
        if self.drive_dir:
            self.drive_dir.mkdir(parents=True, exist_ok=True)
        
        # File paths
        self.csv_path = self.local_dir / 'training_progress.csv'
        self.config_path = self.local_dir / 'training_config.json'
        self.best_model_local = self.local_dir / 'best_model.pt'
        self.last_model_local = self.local_dir / 'last_model.pt'
        
        if self.drive_dir:
            self.best_model_drive = self.drive_dir / 'best_model.pt'
            self.last_model_drive = self.drive_dir / 'last_model.pt'
            self.csv_drive = self.drive_dir / 'training_progress.csv'
        
        # Initialize CSV if it doesn't exist
        self._initialize_csv()
        
        print(f'Checkpoint Manager initialized:')
        print(f'  Experiment: {experiment_name}')
        print(f'  Local directory: {self.local_dir}')
        if self.drive_dir:
            print(f'  Drive directory: {self.drive_dir}')
    
    def _initialize_csv(self):
        """Initialize CSV with required columns."""
        if not self.csv_path.exists():
            columns = [
                'epoch', 'timestamp',
                'train/cls_loss', 'val/cls_loss',
                'train/seg_loss', 'val/seg_loss',
                'train/box_loss', 'val/box_loss',
                'metrics/precision(M)', 'metrics/recall(M)',
                'metrics/mAP50(M)', 'metrics/mAP50-95(M)',
                'learning_rate', 'phase',
                'ocr_char_accuracy', 'ocr_top2_accuracy', 'ocr_top3_accuracy'
            ]
            pd.DataFrame(columns=columns).to_csv(self.csv_path, index=False)
    
    def save_config(self, config):
        """Save training configuration."""
        with open(self.config_path, 'w') as f:
            json.dump(config, f, indent=2)
        
        # Also save to drive if available
        if self.drive_dir:
            drive_config_path = self.drive_dir / 'training_config.json'
            with open(drive_config_path, 'w') as f:
                json.dump(config, f, indent=2)
    
    def save_progress(self, epoch, metrics, model_path=None, is_best=False):
        """
        Save training progress for an epoch.
        
        Args:
            epoch: Current epoch number
            metrics: Dictionary of metrics
            model_path: Path to model file to save
            is_best: Whether this is the best model so far
        """
        # Add epoch and timestamp to metrics
        metrics['epoch'] = epoch
        metrics['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        # Load existing CSV and append new row
        df = pd.read_csv(self.csv_path) if self.csv_path.exists() else pd.DataFrame()
        new_row = pd.DataFrame([metrics])
        df = pd.concat([df, new_row], ignore_index=True)
        
        # Save to local CSV
        df.to_csv(self.csv_path, index=False)
        
        # Save model if provided
        if model_path and Path(model_path).exists():
            # Save last model
            shutil.copy2(model_path, self.last_model_local)
            
            # Save best model if applicable
            if is_best:
                shutil.copy2(model_path, self.best_model_local)
                print(f'  [Checkpoint] New best model saved (epoch {epoch})')
            
            # Async save to Google Drive
            if self.drive_dir:
                self._async_save_to_drive(model_path, is_best)
        
        # Async save CSV to Google Drive
        if self.drive_dir:
            self._async_save_csv_to_drive()
        
        print(f'[Checkpoint] Progress saved for epoch {epoch}')
    
    def _async_save_to_drive(self, model_path, is_best):
        """Asynchronously save model to Google Drive."""
        def save_task():
            try:
                # Save last model
                shutil.copy2(model_path, self.last_model_drive)
                
                # Save best model if applicable
                if is_best:
                    shutil.copy2(model_path, self.best_model_drive)
                    
                print(f'  [Checkpoint] Model backed up to Drive')
            except Exception as e:
                print(f'  [Checkpoint] Warning: Failed to save model to Drive: {e}')
        
        # Start async save
        thread = threading.Thread(target=save_task, daemon=True)
        thread.start()
    
    def _async_save_csv_to_drive(self):
        """Asynchronously save CSV to Google Drive."""
        def save_task():
            try:
                shutil.copy2(self.csv_path, self.csv_drive)
            except Exception as e:
                print(f'  [Checkpoint] Warning: Failed to save CSV to Drive: {e}')
        
        thread = threading.Thread(target=save_task, daemon=True)
        thread.start()
    
    def get_last_checkpoint(self):
        """Get information about the last checkpoint."""
        info = {
            'exists': False,
            'last_epoch': 0,
            'best_model_path': None,
            'last_model_path': None
        }
        
        # Check local first
        if self.csv_path.exists():
            df = pd.read_csv(self.csv_path)
            if len(df) > 0:
                info['exists'] = True
                info['last_epoch'] = int(df['epoch'].iloc[-1])
                
                # Check for model files
                if self.last_model_local.exists():
                    info['last_model_path'] = str(self.last_model_local)
                if self.best_model_local.exists():
                    info['best_model_path'] = str(self.best_model_local)
        
        return info
    
    def load_progress(self):
        """Load training progress from CSV."""
        if self.csv_path.exists():
            return pd.read_csv(self.csv_path)
        return pd.DataFrame()

# Initialize checkpoint manager
experiment_name = f'refine_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
checkpoint_manager = CheckpointManager(
    experiment_name=experiment_name,
    local_checkpoint_dir=LOCAL_CHECKPOINTS_PATH,
    drive_checkpoint_dir=None  # Will be set after mounting Drive
)

print('\nCheckpoint system ready. CSV will be saved every epoch.')

## Mount Google Drive for Permanent Storage (Optional)

In [None]:
# Mount Google Drive if available
DRIVE_CHECKPOINT_DIR = None
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set up Drive checkpoint directory
    DRIVE_CHECKPOINT_DIR = f'/content/drive/MyDrive/ocr_checkpoints'
    os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)
    
    # Update checkpoint manager with Drive directory
    checkpoint_manager.drive_dir = Path(DRIVE_CHECKPOINT_DIR) / experiment_name
    checkpoint_manager.drive_dir.mkdir(parents=True, exist_ok=True)
    
    checkpoint_manager.best_model_drive = checkpoint_manager.drive_dir / 'best_model.pt'
    checkpoint_manager.last_model_drive = checkpoint_manager.drive_dir / 'last_model.pt'
    checkpoint_manager.csv_drive = checkpoint_manager.drive_dir / 'training_progress.csv'
    
    print(f'Google Drive mounted successfully')
    print(f'Checkpoints will be saved to: {DRIVE_CHECKPOINT_DIR}')
    
except Exception as e:
    print(f'Note: Google Drive not mounted. Checkpoints will only be saved locally.')
    print(f'Error: {e}')

# Save training configuration
training_config = {
    'experiment_name': experiment_name,
    'dataset_path': LOCAL_DATASET_PATH,
    'pretrained_model': PRETRAINED_MODEL_PATH,
    'data_yaml': DATA_YAML_PATH,
    'device': device,
    'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}

checkpoint_manager.save_config(training_config)
print(f'Training configuration saved')

## Core Components (Reused from Original Training)

In [None]:
# Character set and similarity matrix (from original)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

CHARS = [chr(i) for i in range(65, 91)] + [str(i) for i in range(10)]
NUM_CLASSES = len(CHARS)
CHAR_TO_IDX = {c: i for i, c in enumerate(CHARS)}
IDX_TO_CHAR = {i: c for i, c in enumerate(CHARS)}

print(f'Number of classes: {NUM_CLASSES}')
print(f'Characters: {"".join(CHARS)}')

SIMILAR_GROUPS = [
    ['O', '0'],
    ['I', '1'],
    ['S', '5'],
    ['Z', '2'],
    ['B', '8'],
    ['D', '0'],
    ['G', 'C'],
    ['U', 'V'],
    ['P', 'R'],
]

def create_similarity_matrix(num_classes=NUM_CLASSES, groups=SIMILAR_GROUPS, base_sim=0.6):
    S = np.zeros((num_classes, num_classes), dtype=np.float32)
    np.fill_diagonal(S, 1.0)
    for group in groups:
        idxs = [CHAR_TO_IDX[c] for c in group if c in CHAR_TO_IDX]
        for i in idxs:
            for j in idxs:
                if i != j:
                    S[i, j] = base_sim
    return torch.tensor(S, dtype=torch.float32)

similarity_matrix = create_similarity_matrix()
print(f'Similarity matrix initialized: {similarity_matrix.shape}')

In [None]:
# Enhanced Similarity-Aware Loss with Adaptive Weighting
class RefinedSimilarityAwareTopKLoss(nn.Module):
    """
    Enhanced loss for fine-tuning with:
    - Higher penalty for similar character confusion
    - Adaptive temperature based on training phase
    - Confidence-based weighting
    """
    def __init__(self, num_classes=NUM_CLASSES, similarity_matrix=None,
                 k=3, initial_temperature=0.5, base_weight=0.5, topk_weight=0.5,
                 epochs=40):
        super().__init__()
        self.num_classes = num_classes
        self.k = k
        self.initial_temperature = initial_temperature
        self.base_weight = base_weight
        self.topk_weight = topk_weight
        self.epochs = epochs
        self.current_epoch = 0

        if similarity_matrix is not None:
            self.register_buffer('similarity_matrix', similarity_matrix)
        else:
            self.register_buffer('similarity_matrix', create_similarity_matrix())

    def update_epoch(self, epoch):
        """Update current epoch for temperature annealing."""
        self.current_epoch = epoch

    def get_temperature(self):
        """Anneal temperature more aggressively for fine-tuning."""
        progress = self.current_epoch / max(self.epochs, 1)
        # Start at 0.5, go to 0.3 (sharper predictions)
        return max(0.3, self.initial_temperature - progress * 0.2)

    def forward(self, logits, targets):
        B = logits.size(0)
        device = logits.device

        temperature = self.get_temperature()

        # Standard cross-entropy
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

        # Softmax with temperature
        probs = F.softmax(logits / temperature, dim=1)
        topk_probs, topk_indices = torch.topk(probs, min(self.k, self.num_classes), dim=1)

        # Similarity-aware penalty
        sim_loss = torch.zeros(B, device=device)
        confidence_scores = []

        for i in range(B):
            t = targets[i].item()
            if t < 0 or t >= self.num_classes:
                continue

            sims = self.similarity_matrix[t][topk_indices[i]]

            # Higher penalty for similar character confusion
            penalties = (1.0 - sims) * 1.5  # Amplify penalty
            weighted_penalties = topk_probs[i] * penalties
            sim_loss[i] = weighted_penalties.sum()

            confidence_scores.append(topk_probs[i][0].item())

        if len(confidence_scores) == 0:
            return ce_loss.mean()

        # Adaptive weighting based on confidence
        confidence = torch.tensor(confidence_scores, device=device)

        # When confident: rely more on CE (trust the model)
        # When uncertain: rely more on similarity (guide the model)
        adaptive_base = self.base_weight + (1 - confidence) * 0.2
        adaptive_topk = self.topk_weight + confidence * 0.2

        # Normalize
        total_weight = adaptive_base + adaptive_topk
        adaptive_base = adaptive_base / total_weight
        adaptive_topk = adaptive_topk / total_weight

        total_loss = adaptive_base * ce_loss + adaptive_topk * sim_loss
        return total_loss.mean()

print('Refined similarity-aware loss defined')

In [None]:
# OCR Metrics (reused from original)
class OCRMetrics:
    """Compute OCR-specific validation metrics."""
    def __init__(self, similarity_matrix=None):
        self.similarity_matrix = similarity_matrix if similarity_matrix is not None else create_similarity_matrix()
        self.reset()

    def reset(self):
        self.total_chars = 0
        self.correct_chars = 0
        self.top2_correct = 0
        self.top3_correct = 0
        self.similarity_score = 0.0

    def update(self, predictions, targets, top_k_preds=None):
        predictions = predictions.cpu().numpy()
        targets = targets.cpu().numpy()

        self.total_chars += len(targets)
        self.correct_chars += (predictions == targets).sum()

        # Similarity-aware accuracy
        for pred, target in zip(predictions, targets):
            if 0 <= target < len(self.similarity_matrix) and 0 <= pred < len(self.similarity_matrix):
                sim = self.similarity_matrix[target][pred].item()
                self.similarity_score += sim

        # Top-k accuracy
        if top_k_preds is not None:
            top_k_preds = top_k_preds.cpu().numpy()
            for i, target in enumerate(targets):
                if top_k_preds.shape[1] >= 2 and target in top_k_preds[i, :2]:
                    self.top2_correct += 1
                if top_k_preds.shape[1] >= 3 and target in top_k_preds[i, :3]:
                    self.top3_correct += 1

    def compute(self):
        if self.total_chars == 0:
            return {}

        return {
            'ocr_char_accuracy': self.correct_chars / self.total_chars,
            'ocr_top2_accuracy': self.top2_correct / self.total_chars,
            'ocr_top3_accuracy': self.top3_correct / self.total_chars,
            'ocr_similarity_aware_accuracy': self.similarity_score / self.total_chars,
        }

print('OCR metrics module loaded')

## Enhanced Trainer with Checkpoint Integration

In [None]:
# Custom Trainer for Refined Training
from ultralytics.models.yolo.segment import SegmentationTrainer
from ultralytics import YOLO

class RefinedSegmentationTrainer(SegmentationTrainer):
    """
    Refined trainer with:
    - Progressive layer unfreezing
    - Enhanced loss function
    - OCR-specific metrics tracking
    - Integrated checkpoint management
    """
    
    def __init__(self, cfg=None, overrides=None, _callbacks=None, checkpoint_manager=None):
        super().__init__(cfg, overrides, _callbacks)
        
        self.checkpoint_manager = checkpoint_manager
        
        # Get total epochs from config
        total_epochs = self.args.epochs if hasattr(self.args, 'epochs') else 40

        # Initialize refined loss
        self.character_loss_fn = RefinedSimilarityAwareTopKLoss(
            num_classes=NUM_CLASSES,
            similarity_matrix=similarity_matrix,
            k=3,
            initial_temperature=0.5,
            base_weight=0.5,
            topk_weight=0.5,
            epochs=total_epochs
        ).to(device)

        # OCR metrics
        self.ocr_metrics = OCRMetrics(similarity_matrix=similarity_matrix)

        # Training phase tracking
        self.phase = 1
        self.freeze_applied = False
        
        # Track best loss for checkpoint saving
        self.best_cls_loss = float('inf')
        
        print(f'Trainer initialized with {total_epochs} total epochs')
        
    def _setup_train(self, world_size):
        """Override to apply layer freezing for Phase 1."""
        super()._setup_train(world_size)

        if not self.freeze_applied and self.epoch < 12:
            print(f'\nPHASE 1: Classifier Head Fine-Tuning (Epochs 1-12)')
            print('Freezing backbone and segmentation layers...')

            # Freeze all layers except classification head
            for name, param in self.model.named_parameters():
                # Keep classification layers trainable
                if 'cls' in name.lower() or 'cv3' in name.lower():
                    param.requires_grad = True
                else:
                    param.requires_grad = False

            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')
            self.freeze_applied = True

    def on_train_epoch_start(self):
        """Handle phase transitions and progressive unfreezing."""
        super().on_train_epoch_start()

        # Update temperature in loss
        self.character_loss_fn.update_epoch(self.epoch)

        # Phase 2: Progressive unfreezing (epochs 12-24)
        if self.epoch == 12:
            self.phase = 2
            print(f'\nPHASE 2: Progressive Unfreezing (Epochs 13-24)')
            print('Unfreezing segmentation head...')

            for name, param in self.model.named_parameters():
                if 'seg' in name.lower() or 'mask' in name.lower():
                    param.requires_grad = True

            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')

        # Phase 3: Full fine-tuning (epochs 24+)
        elif self.epoch == 24:
            self.phase = 3
            print(f'\nPHASE 3: Full Fine-Tuning (Epochs 25-40)')
            print('Unfreezing all layers...')

            for param in self.model.parameters():
                param.requires_grad = True

            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')

    def on_train_epoch_end(self):
        """Save checkpoint after each epoch."""
        super().on_train_epoch_end()
        
        if self.checkpoint_manager:
            # Collect metrics
            metrics = self._collect_metrics()
            
            # Determine if this is the best model
            current_cls_loss = metrics.get('val/cls_loss', float('inf'))
            is_best = current_cls_loss < self.best_cls_loss
            
            if is_best:
                self.best_cls_loss = current_cls_loss
            
            # Save progress
            model_path = Path(self.save_dir) / 'weights' / 'last.pt'
            self.checkpoint_manager.save_progress(
                epoch=self.epoch,
                metrics=metrics,
                model_path=str(model_path) if model_path.exists() else None,
                is_best=is_best
            )

    def _collect_metrics(self):
        """Collect metrics from training."""
        metrics = {
            'phase': self.phase,
            'learning_rate': self.optimizer.param_groups[0]['lr'] if hasattr(self, 'optimizer') else 0.0,
        }
        
        # Get loss values
        if hasattr(self, 'loss_items') and self.loss_items is not None:
            if len(self.loss_items) > 0:
                metrics['train/cls_loss'] = float(self.loss_items[0])
            if len(self.loss_items) > 1:
                metrics['train/seg_loss'] = float(self.loss_items[1])
            if len(self.loss_items) > 2:
                metrics['train/box_loss'] = float(self.loss_items[2])
        
        # Get validation metrics from validator
        if hasattr(self, 'validator') and hasattr(self.validator, 'metrics'):
            val_metrics = self.validator.metrics
            if hasattr(val_metrics, 'results_dict'):
                val_dict = val_metrics.results_dict
                metrics.update({
                    'val/cls_loss': val_dict.get('val/cls_loss', 0.0),
                    'val/seg_loss': val_dict.get('val/seg_loss', 0.0),
                    'val/box_loss': val_dict.get('val/box_loss', 0.0),
                    'metrics/precision(M)': val_dict.get('metrics/precision(M)', 0.0),
                    'metrics/recall(M)': val_dict.get('metrics/recall(M)', 0.0),
                    'metrics/mAP50(M)': val_dict.get('metrics/mAP50(M)', 0.0),
                    'metrics/mAP50-95(M)': val_dict.get('metrics/mAP50-95(M)', 0.0),
                })
        
        # Add OCR metrics
        ocr_results = self.ocr_metrics.compute()
        metrics.update(ocr_results)
        
        return metrics

    def on_val_start(self):
        super().on_val_start()
        self.ocr_metrics.reset()

    def on_val_end(self):
        super().on_val_end()

        # Log OCR metrics
        ocr_results = self.ocr_metrics.compute()
        if ocr_results:
            print(f'\n[Epoch {self.epoch}] OCR Metrics:')
            for key, value in ocr_results.items():
                print(f'  {key}: {value:.4f}')

    def compute_loss(self, preds, batch):
        """Compute loss with refined similarity-aware classification."""
        # Get base YOLO losses
        base_loss = super().compute_loss(preds, batch)

        # Add custom similarity-aware character classification loss
        if len(preds) > 3:
            cls_logits = preds[3]
            cls_targets = batch['cls'].long()

            if cls_logits is not None and cls_targets is not None:
                cls_logits_flat = cls_logits.view(-1, NUM_CLASSES)
                cls_targets_flat = cls_targets.view(-1)

                valid_mask = cls_targets_flat >= 0
                if valid_mask.sum() > 0:
                    # Compute refined similarity-aware loss
                    char_loss = self.character_loss_fn(
                        cls_logits_flat[valid_mask],
                        cls_targets_flat[valid_mask]
                    )

                    # Update OCR metrics
                    with torch.no_grad():
                        preds_cls = cls_logits_flat[valid_mask].argmax(dim=1)
                        top_k_preds = torch.topk(cls_logits_flat[valid_mask], k=3, dim=1)[1]
                        self.ocr_metrics.update(
                            preds_cls,
                            cls_targets_flat[valid_mask],
                            top_k_preds
                        )

                    # Phase-dependent weighting
                    if self.phase == 1:
                        # Phase 1: Heavy emphasis on classification
                        cls_weight = 0.7
                    elif self.phase == 2:
                        # Phase 2: Balanced
                        cls_weight = 0.5
                    else:
                        # Phase 3: Standard weighting
                        cls_weight = 0.3

                    total_loss = (1 - cls_weight) * base_loss + cls_weight * char_loss
                    return total_loss

        return base_loss

print('Refined segmentation trainer defined with checkpoint integration')

## Load Pretrained Model and Configure Training

In [None]:
# Load model with resume capability
from ultralytics import YOLO

# Check if we should resume from checkpoint
checkpoint_info = checkpoint_manager.get_last_checkpoint()

if checkpoint_info['exists'] and checkpoint_info['last_model_path']:
    print(f'Found previous checkpoint at epoch {checkpoint_info["last_epoch"]}')
    print(f'Last model: {checkpoint_info["last_model_path"]}')
    
    # Ask user if they want to resume
    resume_choice = input('Do you want to resume training from checkpoint? (y/n): ').lower().strip()
    
    if resume_choice == 'y':
        model_path = checkpoint_info['last_model_path']
        print(f'Resuming from checkpoint: {model_path}')
        RESUME_TRAINING = True
    else:
        model_path = PRETRAINED_MODEL_PATH
        print(f'Starting fresh training from: {model_path}')
        RESUME_TRAINING = False
else:
    model_path = PRETRAINED_MODEL_PATH
    RESUME_TRAINING = False
    print(f'No checkpoint found. Starting fresh training from: {model_path}')

# Load the model
if model_path and os.path.exists(model_path):
    try:
        model = YOLO(model_path)
        print(f'Model loaded successfully')
    except Exception as e:
        print(f'ERROR loading model {model_path}: {e}')
        print('Loading default YOLO model...')
        model = YOLO('yolo11n-seg.pt')
else:
    print(f'Model path not found: {model_path}')
    print('Loading default YOLO model...')
    model = YOLO('yolo11n-seg.pt')

# Set custom trainer
model.trainer = RefinedSegmentationTrainer

print('\nModel ready for training')

## Training Configuration

In [None]:
# Refined training hyperparameters
REFINE_EPOCHS = 40
BATCH_SIZE = 16
IMG_SIZE = 224

# Learning rate schedule
LR0 = 0.005
LRF = 0.0001

# Optimizer settings
MOMENTUM = 0.937
WEIGHT_DECAY = 5e-4
WARMUP_EPOCHS = 3.0

# Augmentations
AUG_HSV_H = 0.02
AUG_HSV_S = 0.8
AUG_HSV_V = 0.5
AUG_ERASING = 0.5
AUG_DEGREES = 5.0
AUG_SHEAR = 2.0

# Disabled augmentations (not useful for OCR)
AUG_FLIPLR = 0.0
AUG_MOSAIC = 0.0
AUG_MIXUP = 0.0

print('Refined Training Configuration:')
print(f'  Epochs: {REFINE_EPOCHS}')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Learning rate: {LR0} -> {LRF}')
print(f'  Augmentations: Enhanced HSV + Erasing + Geometric')
print(f'\nTraining Strategy:')
print(f'  Phase 1 (1-12): Classifier head only')
print(f'  Phase 2 (13-24): + Segmentation head')
print(f'  Phase 3 (25-40): All layers')

## Execute Refined Training

In [None]:
import datetime

# Training parameters
train_params = dict(
    data=DATA_YAML_PATH,
    epochs=REFINE_EPOCHS,
    batch=BATCH_SIZE,
    imgsz=IMG_SIZE,

    # Optimizer
    optimizer='SGD',
    lr0=LR0,
    lrf=LRF,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,

    # Warmup
    warmup_epochs=WARMUP_EPOCHS,
    warmup_momentum=0.8,
    warmup_bias_lr=0.1,

    # Augmentations
    hsv_h=AUG_HSV_H,
    hsv_s=AUG_HSV_S,
    hsv_v=AUG_HSV_V,
    erasing=AUG_ERASING,
    degrees=AUG_DEGREES,
    shear=AUG_SHEAR,
    fliplr=AUG_FLIPLR,
    mosaic=AUG_MOSAIC,
    mixup=AUG_MIXUP,

    # Output settings
    project='refined_training',
    name=experiment_name,
    exist_ok=True,

    # Validation and saving
    val=True,
    save=True,
    save_period=5,  # Save local copy every 5 epochs

    # System
    device=device,
    amp=True,
    seed=42,
    deterministic=True,

    # Resume handling
    resume=RESUME_TRAINING,
)

# Create a custom trainer instance with checkpoint manager
trainer = RefinedSegmentationTrainer(
    cfg=train_params,
    checkpoint_manager=checkpoint_manager
)

print(f'\n{'='*80}')
if RESUME_TRAINING:
    print(f'RESUMING TRAINING FROM CHECKPOINT')
else:
    print(f'STARTING FRESH TRAINING')
print(f'{'='*80}\n')
print(f'Start time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}')
print(f'Device: {device}')
print(f'Experiment: {experiment_name}')
print(f'Checkpoint CSV: {checkpoint_manager.csv_path}')
if checkpoint_manager.drive_dir:
    print(f'Drive backup: {checkpoint_manager.drive_dir}')
print()

# Execute training
try:
    results = trainer.train(model)
    
    print(f'\n{'='*80}')
    print(f'TRAINING COMPLETED SUCCESSFULLY')
    print(f'{'='*80}\n')
    print(f'End time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}')
    
except KeyboardInterrupt:
    print(f'\nTraining interrupted by user')
    print(f'Progress saved to: {checkpoint_manager.csv_path}')
    
except Exception as e:
    print(f'\nTraining failed with error: {e}')
    print(f'Progress saved to: {checkpoint_manager.csv_path}')
    raise

## Save Final Models and Export Results

In [None]:
# Final export of models
print('\n=== FINAL MODEL EXPORT ===')

# Get paths from trainer
if 'trainer' in locals() and hasattr(trainer, 'save_dir'):
    results_dir = trainer.save_dir
    
    # Find best and last models
    best_model_local = Path(results_dir) / 'weights' / 'best.pt'
    last_model_local = Path(results_dir) / 'weights' / 'last.pt'
    
    # Check if models exist
    models_found = []
    
    if best_model_local.exists():
        models_found.append(('best', best_model_local))
        
    if last_model_local.exists():
        models_found.append(('last', last_model_local))
    
    # Also check checkpoint manager
    if checkpoint_manager.best_model_local.exists():
        models_found.append(('checkpoint_best', checkpoint_manager.best_model_local))
    
    if checkpoint_manager.last_model_local.exists():
        models_found.append(('checkpoint_last', checkpoint_manager.last_model_local))
    
    # Copy models to final export directory
    export_dir = Path(LOCAL_ROOT) / 'final_models' / experiment_name
    export_dir.mkdir(parents=True, exist_ok=True)
    
    for model_name, model_path in models_found:
        export_path = export_dir / f'{model_name}.pt'
        shutil.copy2(model_path, export_path)
        print(f'Exported {model_name} model to: {export_path}')
        
        # Also copy to Drive if available
        if checkpoint_manager.drive_dir:
            drive_export_path = checkpoint_manager.drive_dir / 'final_models' / f'{model_name}.pt'
            drive_export_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(model_path, drive_export_path)
            print(f'  Backup to Drive: {drive_export_path}')
    
    # Export results CSV
    results_csv = Path(results_dir) / 'results.csv'
    if results_csv.exists():
        export_csv = export_dir / 'training_results.csv'
        shutil.copy2(results_csv, export_csv)
        print(f'\nExported results CSV to: {export_csv}')
    
    # Also export our checkpoint CSV
    if checkpoint_manager.csv_path.exists():
        checkpoint_csv = export_dir / 'checkpoint_progress.csv'
        shutil.copy2(checkpoint_manager.csv_path, checkpoint_csv)
        print(f'Exported checkpoint CSV to: {checkpoint_csv}')
    
    print(f'\nAll models and results exported to: {export_dir}')

else:
    print('Trainer not found. Using checkpoint manager models.')
    
    # Export from checkpoint manager
    export_dir = Path(LOCAL_ROOT) / 'final_models' / experiment_name
    export_dir.mkdir(parents=True, exist_ok=True)
    
    if checkpoint_manager.best_model_local.exists():
        export_path = export_dir / 'best_model.pt'
        shutil.copy2(checkpoint_manager.best_model_local, export_path)
        print(f'Exported best model to: {export_path}')
    
    if checkpoint_manager.last_model_local.exists():
        export_path = export_dir / 'last_model.pt'
        shutil.copy2(checkpoint_manager.last_model_local, export_path)
        print(f'Exported last model to: {export_path}')
    
    if checkpoint_manager.csv_path.exists():
        export_csv = export_dir / 'training_progress.csv'
        shutil.copy2(checkpoint_manager.csv_path, export_csv)
        print(f'\nExported progress CSV to: {export_csv}')

print('\n=== EXPORT COMPLETE ===')

## Performance Analysis

In [None]:
# Load and analyze training progress
import pandas as pd
import matplotlib.pyplot as plt

# Load progress from checkpoint manager
progress_df = checkpoint_manager.load_progress()

if not progress_df.empty:
    print('\n=== TRAINING PERFORMANCE ANALYSIS ===')
    print(f'Total epochs completed: {len(progress_df)}')
    
    # Get final metrics
    final_row = progress_df.iloc[-1]
    
    print('\nFinal Epoch Metrics:')
    print('-' * 80)
    
    # Classification metrics
    if 'val/cls_loss' in progress_df.columns:
        final_cls_loss = final_row['val/cls_loss']
        best_cls_loss = progress_df['val/cls_loss'].min()
        best_cls_epoch = progress_df['val/cls_loss'].idxmin() + 1
        
        print(f'Classification Loss:')
        print(f'  Final: {final_cls_loss:.4f}')
        print(f'  Best:  {best_cls_loss:.4f} (epoch {best_cls_epoch})')
    
    # OCR metrics
    if 'ocr_char_accuracy' in progress_df.columns:
        final_ocr_acc = final_row['ocr_char_accuracy']
        best_ocr_acc = progress_df['ocr_char_accuracy'].max()
        best_ocr_epoch = progress_df['ocr_char_accuracy'].idxmax() + 1
        
        print(f'\nOCR Character Accuracy:')
        print(f'  Final: {final_ocr_acc:.4f} ({final_ocr_acc*100:.1f}%)')
        print(f'  Best:  {best_ocr_acc:.4f} ({best_ocr_acc*100:.1f}%) (epoch {best_ocr_epoch})')
    
    # Segmentation metrics
    if 'metrics/mAP50-95(M)' in progress_df.columns:
        final_map = final_row['metrics/mAP50-95(M)']
        best_map = progress_df['metrics/mAP50-95(M)'].max()
        best_map_epoch = progress_df['metrics/mAP50-95(M)'].idxmax() + 1
        
        print(f'\nSegmentation mAP@50-95:')
        print(f'  Final: {final_map:.4f}')
        print(f'  Best:  {best_map:.4f} (epoch {best_map_epoch})')
    
    # Plot training curves
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'Training Performance - {experiment_name}', fontsize=16, fontweight='bold')
    
    # Plot 1: Classification Loss
    if 'val/cls_loss' in progress_df.columns:
        ax = axes[0, 0]
        epochs = progress_df['epoch']
        
        if 'train/cls_loss' in progress_df.columns:
            ax.plot(epochs, progress_df['train/cls_loss'], label='Train', linewidth=2, alpha=0.7)
        
        ax.plot(epochs, progress_df['val/cls_loss'], label='Validation', linewidth=2)
        
        # Mark phase transitions
        ax.axvline(x=12, color='gray', linestyle=':', alpha=0.5, label='Phase 1→2')
        ax.axvline(x=24, color='gray', linestyle=':', alpha=0.5, label='Phase 2→3')
        
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel('Loss', fontweight='bold')
        ax.set_title('Classification Loss', fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.legend()
    
    # Plot 2: OCR Accuracy
    if 'ocr_char_accuracy' in progress_df.columns:
        ax = axes[0, 1]
        ax.plot(epochs, progress_df['ocr_char_accuracy'] * 100, label='Char Accuracy', linewidth=2)
        
        if 'ocr_top2_accuracy' in progress_df.columns:
            ax.plot(epochs, progress_df['ocr_top2_accuracy'] * 100, label='Top-2 Accuracy', linewidth=2, alpha=0.7)
        
        if 'ocr_top3_accuracy' in progress_df.columns:
            ax.plot(epochs, progress_df['ocr_top3_accuracy'] * 100, label='Top-3 Accuracy', linewidth=2, alpha=0.5)
        
        ax.axvline(x=12, color='gray', linestyle=':', alpha=0.5)
        ax.axvline(x=24, color='gray', linestyle=':', alpha=0.5)
        
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel('Accuracy (%)', fontweight='bold')
        ax.set_title('OCR Performance', fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.legend()
        ax.set_ylim([0, 105])
    
    # Plot 3: Segmentation mAP
    if 'metrics/mAP50-95(M)' in progress_df.columns:
        ax = axes[1, 0]
        ax.plot(epochs, progress_df['metrics/mAP50-95(M)'], linewidth=2, color='green')
        
        ax.axvline(x=12, color='gray', linestyle=':', alpha=0.5)
        ax.axvline(x=24, color='gray', linestyle=':', alpha=0.5)
        
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel('mAP@50-95', fontweight='bold')
        ax.set_title('Segmentation Quality', fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_ylim([0, 1])
    
    # Plot 4: Learning Rate
    if 'learning_rate' in progress_df.columns:
        ax = axes[1, 1]
        ax.plot(epochs, progress_df['learning_rate'], linewidth=2, color='purple')
        
        ax.axvline(x=12, color='gray', linestyle=':', alpha=0.5)
        ax.axvline(x=24, color='gray', linestyle=':', alpha=0.5)
        
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel('Learning Rate', fontweight='bold')
        ax.set_title('Learning Rate Schedule', fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
    
    plt.tight_layout()
    
    # Save plot
    plot_path = export_dir / 'training_analysis.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f'\nTraining analysis plot saved to: {plot_path}')
    
    plt.show()
    
else:
    print('No training progress data found.')

### Key Improvements:

1. **Reproducibility**:
   - Dataset downloaded from public Google Drive folder
   - No hardcoded paths - works from any account
   - Automatic dataset verification and fixing
   
2. **Robust Checkpoint System**:
   - CSV saved EVERY epoch (no exceptions)
   - Models saved to Google Drive every epoch (if mounted)
   - Automatic resume capability
   - Threaded async saves to prevent training slowdown
   
3. **Clean Output**:
   - Clear, professional logging
   
4. **Sanity Checks**:
   - Dataset structure verification
   - File existence checks
   - Model loading error handling
   - Automatic path fixing in data.yaml