## Setup: Google Colab Environment

In [None]:
# Check GPU availability
import torch
import os

# Verify CUDA is available
if torch.cuda.is_available():
    device = 'cuda'
    print(f'GPU available: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
else:
    device = 'cpu'
    print('No GPU available, using CPU (training will be slow)')

print(f'\nPyTorch version: {torch.__version__}')
print(f'CUDA version: {torch.version.cuda}')

In [None]:
# Install required packages
!pip install -q ultralytics opencv-python-headless pillow pyyaml numpy scipy matplotlib pandas

print('All packages installed successfully')

In [None]:
# Mount Google Drive (if using dataset from Drive)
from google.colab import drive
drive.mount('/content/drive')

# Set paths - MODIFY THESE TO YOUR ACTUAL PATHS
DRIVE_ROOT = '/content/drive/MyDrive/csc173_dataset'
DATA_YAML_PATH = f'{DRIVE_ROOT}/dataset/data.yaml'
PRETRAINED_MODEL = f'{DRIVE_ROOT}/models/custom_ocr_last.pt'  # Your existing checkpoint

# Create local working directory
WORK_DIR = '/content/refined_training'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)

print(f'Drive mounted and working directory set')
print(f'Data config: {DATA_YAML_PATH}')
print(f'Pretrained model: {PRETRAINED_MODEL}')
print(f'Working directory: {WORK_DIR}')

## Core Components (Reused from Original Training)

In [None]:
# Character set and similarity matrix (from original)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

CHARS = [chr(i) for i in range(65, 91)] + [str(i) for i in range(10)]
NUM_CLASSES = len(CHARS)
CHAR_TO_IDX = {c: i for i, c in enumerate(CHARS)}
IDX_TO_CHAR = {i: c for i, c in enumerate(CHARS)}

print(f'Number of classes: {NUM_CLASSES}')
print(f'Characters: {"".join(CHARS)}')

SIMILAR_GROUPS = [
    ['O', '0'], # I want to refine Q and 0/O differentiation, so it was removed from the group
    ['I', '1'], # I want to refine L and 1/I differentiation, so it was removed from the group
    ['S', '5'],
    ['Z', '2'],
    ['B', '8'],
    ['D', '0'],
    ['G', 'C'],
    ['U', 'V'],
    ['P', 'R'],
]

def create_similarity_matrix(num_classes=NUM_CLASSES, groups=SIMILAR_GROUPS, base_sim=0.6):
    S = np.zeros((num_classes, num_classes), dtype=np.float32)
    np.fill_diagonal(S, 1.0)
    for group in groups:
        idxs = [CHAR_TO_IDX[c] for c in group if c in CHAR_TO_IDX]
        for i in idxs:
            for j in idxs:
                if i != j:
                    S[i, j] = base_sim
    return torch.tensor(S, dtype=torch.float32)

similarity_matrix = create_similarity_matrix()
print(f'Similarity matrix initialized: {similarity_matrix.shape}')

In [None]:
# Enhanced Similarity-Aware Loss with Adaptive Weighting
class RefinedSimilarityAwareTopKLoss(nn.Module):
    """
    Enhanced loss for fine-tuning with:
    - Higher penalty for similar character confusion
    - Adaptive temperature based on training phase
    - Confidence-based weighting
    """
    def __init__(self, num_classes=NUM_CLASSES, similarity_matrix=None,
                 k=3, initial_temperature=0.5, base_weight=0.5, topk_weight=0.5,
                 epochs=40):
        super().__init__()
        self.num_classes = num_classes
        self.k = k
        self.initial_temperature = initial_temperature
        self.base_weight = base_weight
        self.topk_weight = topk_weight
        self.epochs = epochs
        self.current_epoch = 0
        
        if similarity_matrix is not None:
            self.register_buffer('similarity_matrix', similarity_matrix)
        else:
            self.register_buffer('similarity_matrix', create_similarity_matrix())

    def update_epoch(self, epoch):
        """Update current epoch for temperature annealing."""
        self.current_epoch = epoch
    
    def get_temperature(self):
        """Anneal temperature more aggressively for fine-tuning."""
        progress = self.current_epoch / max(self.epochs, 1)
        # Start at 0.5, go to 0.3 (sharper predictions)
        return max(0.3, self.initial_temperature - progress * 0.2)
    
    def forward(self, logits, targets):
        B = logits.size(0)
        device = logits.device
        
        temperature = self.get_temperature()
        
        # Standard cross-entropy
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        
        # Softmax with temperature
        probs = F.softmax(logits / temperature, dim=1)
        topk_probs, topk_indices = torch.topk(probs, min(self.k, self.num_classes), dim=1)
        
        # Similarity-aware penalty
        sim_loss = torch.zeros(B, device=device)
        confidence_scores = []
        
        for i in range(B):
            t = targets[i].item()
            if t < 0 or t >= self.num_classes:
                continue
                
            sims = self.similarity_matrix[t][topk_indices[i]]
            
            # Higher penalty for similar character confusion
            # If model confuses O with 0 (high similarity), penalty is lower
            # If model confuses O with X (low similarity), penalty is higher
            penalties = (1.0 - sims) * 1.5  # Amplify penalty
            weighted_penalties = topk_probs[i] * penalties
            sim_loss[i] = weighted_penalties.sum()
            
            confidence_scores.append(topk_probs[i][0].item())
        
        if len(confidence_scores) == 0:
            return ce_loss.mean()
        
        # Adaptive weighting based on confidence
        confidence = torch.tensor(confidence_scores, device=device)
        
        # When confident: rely more on CE (trust the model)
        # When uncertain: rely more on similarity (guide the model)
        adaptive_base = self.base_weight + (1 - confidence) * 0.2
        adaptive_topk = self.topk_weight + confidence * 0.2
        
        # Normalize
        total_weight = adaptive_base + adaptive_topk
        adaptive_base = adaptive_base / total_weight
        adaptive_topk = adaptive_topk / total_weight
        
        total_loss = adaptive_base * ce_loss + adaptive_topk * sim_loss
        return total_loss.mean()

print('Refined similarity-aware loss defined')

In [None]:
# OCR Metrics (reused from original)
class OCRMetrics:
    """Compute OCR-specific validation metrics."""
    def __init__(self, similarity_matrix=None):
        self.similarity_matrix = similarity_matrix if similarity_matrix is not None else create_similarity_matrix()
        self.reset()
    
    def reset(self):
        self.total_chars = 0
        self.correct_chars = 0
        self.top2_correct = 0
        self.top3_correct = 0
        self.similarity_score = 0.0
    
    def update(self, predictions, targets, top_k_preds=None):
        predictions = predictions.cpu().numpy()
        targets = targets.cpu().numpy()
        
        self.total_chars += len(targets)
        self.correct_chars += (predictions == targets).sum()
        
        # Similarity-aware accuracy
        for pred, target in zip(predictions, targets):
            if 0 <= target < len(self.similarity_matrix) and 0 <= pred < len(self.similarity_matrix):
                sim = self.similarity_matrix[target][pred].item()
                self.similarity_score += sim
        
        # Top-k accuracy
        if top_k_preds is not None:
            top_k_preds = top_k_preds.cpu().numpy()
            for i, target in enumerate(targets):
                if top_k_preds.shape[1] >= 2 and target in top_k_preds[i, :2]:
                    self.top2_correct += 1
                if top_k_preds.shape[1] >= 3 and target in top_k_preds[i, :3]:
                    self.top3_correct += 1
    
    def compute(self):
        if self.total_chars == 0:
            return {}
        
        return {
            'CER': 1.0 - (self.correct_chars / self.total_chars),
            'char_accuracy': self.correct_chars / self.total_chars,
            'top2_accuracy': self.top2_correct / self.total_chars,
            'top3_accuracy': self.top3_correct / self.total_chars,
            'similarity_aware_accuracy': self.similarity_score / self.total_chars,
        }

print('OCR metrics module loaded')

## Refined Training Strategy

### Phase 1: Classifier Head Fine-Tuning (Epochs 1-12)
- Freeze backbone and segmentation head
- Focus exclusively on improving character classification
- Use cyclic learning rate to escape plateau

### Phase 2: Progressive Unfreezing (Epochs 13-24)
- Gradually unfreeze deeper layers
- Lower learning rate for stable refinement
- Continue with similarity-aware loss

### Phase 3: Full Fine-Tuning (Epochs 25-40)
- All layers unfrozen
- Very low learning rate for final polish
- Focus on reducing classification loss below 0.35

In [None]:
# Custom Trainer for Refined Training
from ultralytics.models.yolo.segment import SegmentationTrainer
from ultralytics import YOLO

class RefinedSegmentationTrainer(SegmentationTrainer):
    """
    Refined trainer with:
    - Progressive layer unfreezing
    - Enhanced loss function
    - OCR-specific metrics tracking
    - Cyclic learning rate support
    """
    def __init__(self, cfg=None, overrides=None, _callbacks=None):
        super().__init__(cfg, overrides, _callbacks)
        
        # Get total epochs from config
        total_epochs = self.args.epochs if hasattr(self.args, 'epochs') else 40
        
        # Initialize refined loss
        self.character_loss_fn = RefinedSimilarityAwareTopKLoss(
            num_classes=NUM_CLASSES,
            similarity_matrix=similarity_matrix,
            k=3,
            initial_temperature=0.5,
            base_weight=0.5,
            topk_weight=0.5,
            epochs=total_epochs
        ).to(device)
        
        # OCR metrics
        self.ocr_metrics = OCRMetrics(similarity_matrix=similarity_matrix)
        
        # Training phase tracking
        self.phase = 1
        self.freeze_applied = False
    
    def _setup_train(self, world_size):
        """Override to apply layer freezing for Phase 1."""
        super()._setup_train(world_size)
        
        if not self.freeze_applied and self.epoch < 12:
            print(f'\n=== PHASE 1: Classifier Head Fine-Tuning (Epochs 1-12) ===')
            print('Freezing backbone and segmentation layers...')
            
            # Freeze all layers except classification head
            for name, param in self.model.named_parameters():
                # Keep classification layers trainable
                if 'cls' in name.lower() or 'cv3' in name.lower():
                    param.requires_grad = True
                else:
                    param.requires_grad = False
            
            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')
            self.freeze_applied = True
    
    def on_train_epoch_start(self):
        """Handle phase transitions and progressive unfreezing."""
        super().on_train_epoch_start()
        
        # Update temperature in loss
        self.character_loss_fn.update_epoch(self.epoch)
        
        # Phase 2: Progressive unfreezing (epochs 12-24)
        if self.epoch == 12:
            self.phase = 2
            print(f'\n=== PHASE 2: Progressive Unfreezing (Epochs 13-24) ===')
            print('Unfreezing segmentation head...')
            
            for name, param in self.model.named_parameters():
                if 'seg' in name.lower() or 'mask' in name.lower():
                    param.requires_grad = True
            
            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')
        
        # Phase 3: Full fine-tuning (epochs 24+)
        elif self.epoch == 24:
            self.phase = 3
            print(f'\n=== PHASE 3: Full Fine-Tuning (Epochs 25-40) ===')
            print('Unfreezing all layers...')
            
            for param in self.model.parameters():
                param.requires_grad = True
            
            trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
            total = sum(p.numel() for p in self.model.parameters())
            print(f'Trainable parameters: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)')
    
    def on_val_start(self):
        super().on_val_start()
        self.ocr_metrics.reset()
    
    def on_val_end(self):
        super().on_val_end()
        
        # Log OCR metrics
        ocr_results = self.ocr_metrics.compute()
        if ocr_results:
            print(f'\n[Epoch {self.epoch}] OCR Metrics:')
            for key, value in ocr_results.items():
                print(f'  {key}: {value:.4f}')
    
    def compute_loss(self, preds, batch):
        """Compute loss with refined similarity-aware classification."""
        # Get base YOLO losses
        base_loss = super().compute_loss(preds, batch)
        
        # Add custom similarity-aware character classification loss
        if len(preds) > 3:
            cls_logits = preds[3]
            cls_targets = batch['cls'].long()
            
            if cls_logits is not None and cls_targets is not None:
                cls_logits_flat = cls_logits.view(-1, NUM_CLASSES)
                cls_targets_flat = cls_targets.view(-1)
                
                valid_mask = cls_targets_flat >= 0
                if valid_mask.sum() > 0:
                    # Compute refined similarity-aware loss
                    char_loss = self.character_loss_fn(
                        cls_logits_flat[valid_mask],
                        cls_targets_flat[valid_mask]
                    )
                    
                    # Update OCR metrics
                    with torch.no_grad():
                        preds_cls = cls_logits_flat[valid_mask].argmax(dim=1)
                        top_k_preds = torch.topk(cls_logits_flat[valid_mask], k=3, dim=1)[1]
                        self.ocr_metrics.update(
                            preds_cls,
                            cls_targets_flat[valid_mask],
                            top_k_preds
                        )
                    
                    # Phase-dependent weighting
                    if self.phase == 1:
                        # Phase 1: Heavy emphasis on classification
                        cls_weight = 0.7
                    elif self.phase == 2:
                        # Phase 2: Balanced
                        cls_weight = 0.5
                    else:
                        # Phase 3: Standard weighting
                        cls_weight = 0.3
                    
                    total_loss = (1 - cls_weight) * base_loss + cls_weight * char_loss
                    return total_loss
        
        return base_loss

print('Refined segmentation trainer defined')

## Load Pretrained Model and Configure Training

In [None]:
# Load your existing trained model
print(f'Loading pretrained model from: {PRETRAINED_MODEL}\n')

model = YOLO(PRETRAINED_MODEL)
model.trainer = RefinedSegmentationTrainer

print('Model loaded with refined trainer attached')
print(f'  Starting from epoch 68 checkpoint')
print(f'  Will train for 100 additional epochs with progressive refinement')

## Training Configuration

In [None]:
# Refined training hyperparameters
REFINE_EPOCHS = 40
BATCH_SIZE = 16
IMG_SIZE = 224

# Cyclic learning rate for Phase 1 (escaping plateau)
# Start higher to shake the model out of local minimum
LR0 = 0.005  # Higher than previous 0.001
LRF = 0.0001  # End lower for fine control

# Optimizer settings
MOMENTUM = 0.937
WEIGHT_DECAY = 5e-4
WARMUP_EPOCHS = 3.0

# Augmentations - more aggressive for character robustness
AUG_HSV_H = 0.02  # Increased hue variation
AUG_HSV_S = 0.8   # Increased saturation variation
AUG_HSV_V = 0.5   # Increased brightness variation
AUG_ERASING = 0.5  # Increased random erasing
AUG_DEGREES = 5.0  # Small rotation for character variation
AUG_SHEAR = 2.0    # Perspective variation

# Disabled augmentations (not useful for OCR)
AUG_FLIPLR = 0.0
AUG_MOSAIC = 0.0
AUG_MIXUP = 0.0

print('Refined Training Configuration:')
print(f'  Epochs: {REFINE_EPOCHS}')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Learning rate: {LR0} → {LRF}')
print(f'  Augmentations: Enhanced HSV + Erasing + Geometric')
print(f'\nTraining Strategy:')
print(f'  Phase 1 (1-12): Classifier head only')
print(f'  Phase 2 (13-24): + Segmentation head')
print(f'  Phase 3 (25-40): All layers')

## Execute Refined Training

In [None]:
import datetime

# Training parameters
train_params = dict(
    data=DATA_YAML_PATH,
    epochs=REFINE_EPOCHS,
    batch=BATCH_SIZE,
    imgsz=IMG_SIZE,
    
    # Optimizer
    optimizer='SGD',
    lr0=LR0,
    lrf=LRF,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
    
    # Warmup
    warmup_epochs=WARMUP_EPOCHS,
    warmup_momentum=0.8,
    warmup_bias_lr=0.1,
    
    # Augmentations
    hsv_h=AUG_HSV_H,
    hsv_s=AUG_HSV_S,
    hsv_v=AUG_HSV_V,
    erasing=AUG_ERASING,
    degrees=AUG_DEGREES,
    shear=AUG_SHEAR,
    fliplr=AUG_FLIPLR,
    mosaic=AUG_MOSAIC,
    mixup=AUG_MIXUP,
    
    # Output settings
    project='refined_training',
    name=f'refine_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}',
    exist_ok=True,
    
    # Validation and saving
    val=True,
    save=True,
    save_period=10,
    
    # System
    device=device,
    amp=True,  # Enable automatic mixed precision for faster training
    seed=42,
    deterministic=True,
    
    # Resume from pretrained
    resume=False,  # Don't resume - we're fine-tuning
)

print(f'\n{"="*80}')
print(f'STARTING REFINED TRAINING')
print(f'{"="*80}\n')
print(f'Start time: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
print(f'Device: {device}\n')

# Execute training
results = model.train(**train_params)

print(f'\n{"="*80}')
print(f'TRAINING COMPLETED')
print(f'{"="*80}\n')
print(f'End time: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
print(f'Results directory: {results.save_dir}')

## Export Best Model to Drive

In [None]:
import shutil
import os
from pathlib import Path

# Find best model from training run
best_model = Path(results.save_dir) / 'weights' / 'best.pt'
last_model = Path(results.save_dir) / 'weights' / 'last.pt'

# Export to Drive
export_dir = f'{DRIVE_ROOT}/refined_models'
os.makedirs(export_dir, exist_ok=True)

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

if best_model.exists():
    export_best = f'{export_dir}/refined_best_{timestamp}.pt'
    shutil.copy2(best_model, export_best)
    print(f'Best model exported to: {export_best}')

if last_model.exists():
    export_last = f'{export_dir}/refined_last_{timestamp}.pt'
    shutil.copy2(last_model, export_last)
    print(f'Last model exported to: {export_last}')

# Copy results CSV
results_csv = Path(results.save_dir) / 'results.csv'
if results_csv.exists():
    export_results = f'{export_dir}/refined_results_{timestamp}.csv'
    shutil.copy2(results_csv, export_results)
    print(f'Results CSV exported to: {export_results}')

print(f'\nAll files exported to Google Drive')
print(f'  Location: {export_dir}')

## Performance Analysis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load results
results_csv_path = Path(results.save_dir) / 'results.csv'

if results_csv_path.exists():
    df = pd.read_csv(results_csv_path)
    df.columns = df.columns.str.strip()
    
    print('Refined Training Results Summary')
    print('=' * 80)
    print(f'Total epochs: {len(df)}')
    print()
    
    # Final metrics
    print('Final Epoch Metrics:')
    print('-' * 80)
    print(f'Classification Loss (val): {df["val/cls_loss"].iloc[-1]:.4f}')
    print(f'Segmentation mAP@50-95:    {df["metrics/mAP50-95(M)"].iloc[-1]:.4f}')
    print(f'Segmentation Precision:     {df["metrics/precision(M)"].iloc[-1]:.4f}')
    print(f'Segmentation Recall:        {df["metrics/recall(M)"].iloc[-1]:.4f}')
    print()
    
    # Best metrics
    best_cls_loss_idx = df['val/cls_loss'].idxmin()
    best_map_idx = df['metrics/mAP50-95(M)'].idxmax()
    
    print('Best Performance:')
    print('-' * 80)
    print(f'Best Classification Loss:   {df["val/cls_loss"].iloc[best_cls_loss_idx]:.4f} (epoch {df["epoch"].iloc[best_cls_loss_idx]:.0f})')
    print(f'Best Segmentation mAP:      {df["metrics/mAP50-95(M)"].iloc[best_map_idx]:.4f} (epoch {df["epoch"].iloc[best_map_idx]:.0f})')
    print()
    
    # Improvement over baseline
    baseline_cls_loss = 0.4321  # From epoch 68 of original training
    baseline_map = 0.4799
    
    final_cls_loss = df['val/cls_loss'].iloc[-1]
    final_map = df['metrics/mAP50-95(M)'].iloc[-1]
    
    cls_improvement = ((baseline_cls_loss - final_cls_loss) / baseline_cls_loss) * 100
    map_improvement = ((final_map - baseline_map) / baseline_map) * 100
    
    print('Improvement Over Baseline (Epoch 68):')
    print('-' * 80)
    print(f'Classification Loss: {baseline_cls_loss:.4f} → {final_cls_loss:.4f} ({cls_improvement:+.2f}%)')
    print(f'Segmentation mAP:    {baseline_map:.4f} → {final_map:.4f} ({map_improvement:+.2f}%)')
    print()
    
    # Plot training curves
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Refined Training Performance', fontsize=16, fontweight='bold')
    
    # Plot 1: Classification Loss
    ax = axes[0, 0]
    ax.plot(df['epoch'], df['train/cls_loss'], label='Train', linewidth=2, alpha=0.7)
    ax.plot(df['epoch'], df['val/cls_loss'], label='Validation', linewidth=2)
    ax.axhline(y=baseline_cls_loss, color='red', linestyle='--', label=f'Baseline ({baseline_cls_loss:.4f})', alpha=0.5)
    ax.axvline(x=30, color='gray', linestyle=':', alpha=0.5, label='Phase 2')
    ax.axvline(x=60, color='gray', linestyle=':', alpha=0.5, label='Phase 3')
    ax.set_xlabel('Epoch', fontweight='bold')
    ax.set_ylabel('Loss', fontweight='bold')
    ax.set_title('Classification Loss (Lower is Better)', fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # Plot 2: Segmentation mAP
    ax = axes[0, 1]
    ax.plot(df['epoch'], df['metrics/mAP50-95(M)'], linewidth=2, color='green')
    ax.axhline(y=baseline_map, color='red', linestyle='--', label=f'Baseline ({baseline_map:.4f})', alpha=0.5)
    ax.axvline(x=30, color='gray', linestyle=':', alpha=0.5)
    ax.axvline(x=60, color='gray', linestyle=':', alpha=0.5)
    ax.set_xlabel('Epoch', fontweight='bold')
    ax.set_ylabel('mAP@50-95', fontweight='bold')
    ax.set_title('Segmentation Quality (Higher is Better)', fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_ylim([0.4, 1.0])
    
    # Plot 3: Precision & Recall
    ax = axes[1, 0]
    ax.plot(df['epoch'], df['metrics/precision(M)'], label='Precision', linewidth=2)
    ax.plot(df['epoch'], df['metrics/recall(M)'], label='Recall', linewidth=2)
    ax.axvline(x=30, color='gray', linestyle=':', alpha=0.5)
    ax.axvline(x=60, color='gray', linestyle=':', alpha=0.5)
    ax.set_xlabel('Epoch', fontweight='bold')
    ax.set_ylabel('Score', fontweight='bold')
    ax.set_title('Precision & Recall', fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_ylim([0.7, 1.0])
    
    # Plot 4: Learning Rate
    ax = axes[1, 1]
    ax.plot(df['epoch'], df['lr/pg0'], linewidth=2, color='purple')
    ax.axvline(x=30, color='gray', linestyle=':', alpha=0.5, label='Phase transitions')
    ax.axvline(x=60, color='gray', linestyle=':', alpha=0.5)
    ax.set_xlabel('Epoch', fontweight='bold')
    ax.set_ylabel('Learning Rate', fontweight='bold')
    ax.set_title('Learning Rate Schedule', fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_yscale('log')
    
    plt.tight_layout()
    
    # Save plot
    plot_path = Path(results.save_dir) / 'refined_training_analysis.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f'Training curves saved to: {plot_path}')
    
    # Also save to Drive
    drive_plot_path = f'{export_dir}/refined_training_analysis_{timestamp}.png'
    shutil.copy2(plot_path, drive_plot_path)
    print(f'Plots exported to Drive: {drive_plot_path}')
    
    plt.show()
    
else:
    print('Results CSV not found. Training may not have completed.')