# EfficientNet-B0 Fine-tuning for Comic Reverse Image Search

This notebook fine-tunes EfficientNet-B0 for comic reverse image search using triplet loss.

**Key advantages of EfficientNet-B0:**
- **Small size**: ~18MB ONNX (vs 331MB for ViT-Base)
- **High quality**: 0.136 similarity separation in our tests
- **Fast inference**: ~264ms per batch
- **Proven architecture**: EfficientNet is well-established

**Training approach:**
- Uses triplet loss for similarity learning
- Pre-trained on ImageNet for better initialization
- Enhanced data augmentation for robustness (blur, rotation, partial frames)
- Regularization techniques to prevent overfitting
- Mixed precision training for efficiency

**Improvements for robustness:**
- **Blurred images**: Random Gaussian blur augmentation
- **Rotated images**: Random affine transformations with up to 30¬∞ rotation
- **Partial frame images**: Random resized crops to simulate incomplete views
- **Regularization**: Dropout, weight decay, learning rate scheduling


In [None]:
%pip install torch torchvision pytorch-lightning timm matplotlib tqdm --quiet


In [None]:
import os
import pickle
from pathlib import Path
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TripletMarginLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image, ImageFilter
import pytorch_lightning as pl
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
pl.seed_everything(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================
# Custom augmentation transforms for robustness
# =============================================
class RandomGaussianBlur:
    """Apply random Gaussian blur to simulate slightly blurred images"""
    def __init__(self, p=0.5, radius_range=(0, 2.5)):
        self.p = p
        self.radius_range = radius_range
    
    def __call__(self, img):
        if random.random() < self.p:
            radius = random.uniform(*self.radius_range)
            if radius > 0:
                img = img.filter(ImageFilter.GaussianBlur(radius=radius))
        return img

class RandomPerspective:
    """Apply random perspective transformation"""
    def __init__(self, p=0.3, distortion_scale=0.2):
        self.p = p
        self.distortion_scale = distortion_scale
    
    def __call__(self, img):
        if random.random() < self.p:
            return TF.perspective(img, 
                                 startpoints=[(0, 0), (img.width, 0), (img.width, img.height), (0, img.height)],
                                 endpoints=self._get_random_endpoints(img),
                                 interpolation=TF.InterpolationMode.BILINEAR)
        return img
    
    def _get_random_endpoints(self, img):
        w, h = img.width, img.height
        d = int(min(w, h) * self.distortion_scale)
        return [
            (random.randint(-d, d), random.randint(-d, d)),
            (w + random.randint(-d, d), random.randint(-d, d)),
            (w + random.randint(-d, d), h + random.randint(-d, d)),
            (random.randint(-d, d), h + random.randint(-d, d))
        ]

class RandomAffine:
    """Apply random affine transformation (rotation, translation, scale, shear)"""
    def __init__(self, degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10, p=0.7):
        self.degrees = degrees
        self.translate = translate
        self.scale = scale
        self.shear = shear
        self.p = p
    
    def __call__(self, img):
        if random.random() < self.p:
            angle = random.uniform(-self.degrees, self.degrees)
            translate = (random.uniform(-self.translate[0], self.translate[0]) * img.width,
                        random.uniform(-self.translate[1], self.translate[1]) * img.height)
            scale = random.uniform(*self.scale)
            shear = random.uniform(-self.shear, self.shear)
            return TF.affine(img, angle=angle, translate=translate, scale=scale, shear=shear,
                           interpolation=TF.InterpolationMode.BILINEAR)
        return img


In [None]:
class EfficientNetEmbeddingModel(pl.LightningModule):
    def __init__(self, model_name="efficientnet_b0", lr=1e-4, embed_dim=512, dropout=0.3, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()
        
        # Load EfficientNet-B0 backbone
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)  # Remove classifier
        self.feature_dim = self.backbone.num_features  # 1280 for EfficientNet-B0
        
        # Two-layer projection head with dropout for regularization
        self.fc1 = nn.Linear(self.feature_dim, self.feature_dim)
        self.bn = nn.BatchNorm1d(self.feature_dim)  # Batch normalization for stability
        self.dropout = nn.Dropout(dropout)  # Dropout to prevent overfitting
        self.fc2 = nn.Linear(self.feature_dim, embed_dim)
        
        # Loss function
        self.loss_fn = TripletMarginLoss(margin=0.2)
        self.training_losses = []
        
        print(f"EfficientNet-B0: feature_dim={self.feature_dim}, embed_dim={embed_dim}, dropout={dropout}")

    def forward(self, x):
        features = self.backbone(x)
        # Two-layer projection head with dropout for regularization
        x = self.fc1(features)
        x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        embeddings = self.fc2(x)
        return F.normalize(embeddings, p=2, dim=1)

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        emb_a = self(anchor)
        emb_p = self(positive)
        emb_n = self(negative)
        loss = self.loss_fn(emb_a, emb_p, emb_n)
        self.training_losses.append(loss.item())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay  # L2 regularization
        )
        # Learning rate scheduler to reduce overfitting
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=1000,  # Adjust based on your training steps
            eta_min=1e-7
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            }
        }


In [None]:
# Dataset loading functions (same as original)
def get_image_files(root_dir, batch_cache=10000):
    """Fast, cached file indexing"""
    print(f"Scanning for image files in {root_dir} ...")
    exts = {'.jpg', '.jpeg', '.png'}
    all_files = []

    root_path = Path(root_dir)
    for cls_dir in root_path.iterdir():
        if not cls_dir.is_dir():
            continue
        for idx, f in enumerate(cls_dir.glob('**/*')):
            if f.suffix.lower() in exts:
                all_files.append(str(f))
            # Periodic batch caching
            if len(all_files) % batch_cache == 0:
                print(f"Indexed {len(all_files)} files so far...")

    print(f"Done. Found {len(all_files)} files.")
    return all_files

def build_class_index(files, root_dir, max_classes=None):
    """Group images by their immediate subfolder under root_dir"""
    root = Path(root_dir)
    class_to_images = {}
    for f in files:
        rel = Path(f).resolve().relative_to(root.resolve())
        parts = rel.parts
        if len(parts) < 2:
            continue  # skip images directly in root
        cls = parts[0]  # immediate subfolder name (e.g., "AR 101")
        class_to_images.setdefault(cls, []).append(f)
    
    # Keep only classes with >= 2 images (needed for triplets)
    class_to_images = {k: v for k, v in class_to_images.items() if len(v) >= 2}
    
    # Limit number of classes if specified
    if max_classes and len(class_to_images) > max_classes:
        # Sort by number of images (descending) and take top classes
        sorted_classes = sorted(class_to_images.items(), key=lambda x: len(x[1]), reverse=True)
        class_to_images = dict(sorted_classes[:max_classes])

    if len(class_to_images) == 0:
        print("‚ö†Ô∏è No valid classes found! Checking dataset structure...")
        for dirpath, dirnames, filenames in os.walk(root_dir):
            print(f"{dirpath}: {len(filenames)} files")
            break  # only show top-level
    else:
        print(f"‚úÖ Found {len(class_to_images)} classes")
        for cls, imgs in list(class_to_images.items())[:5]:
            print(f"Class: {cls}, {len(imgs)} images, sample: {imgs[0]}")
    
    return class_to_images

class ComicTripletDataset(Dataset):
    def __init__(self, class_to_images, transform=None, samples_per_class=20):
        self.transform = transform
        self.class_to_images = class_to_images
        self.classes = list(self.class_to_images.keys())
        self.samples_per_class = samples_per_class

    def __len__(self):
        # Increase samples per class for better coverage and diversity
        # More samples = more diverse augmentations seen during training
        return max(len(self.classes) * self.samples_per_class, 1000)

    def __getitem__(self, idx):
        anchor_class = random.choice(self.classes)
        negative_class = random.choice([c for c in self.classes if c != anchor_class])

        anchor_path, positive_path = random.sample(self.class_to_images[anchor_class], 2)
        negative_path = random.choice(self.class_to_images[negative_class])

        anchor_img = Image.open(anchor_path).convert("RGB")
        positive_img = Image.open(positive_path).convert("RGB")
        negative_img = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img


In [None]:
# Configuration
# Try to auto-detect the dataset path in Kaggle (using absolute paths only)
POSSIBLE_PATHS = [
    "/kaggle/input/inducks-entry-images/covers-by-storycode",
    "/kaggle/input/inducks-entry-images/Inducks entry images/covers-by-storycode",  # With spaces
]

DATASET_PATH = None
for path in POSSIBLE_PATHS:
    abs_path = Path(path).resolve()
    if abs_path.exists():
        DATASET_PATH = str(abs_path)
        print(f"‚úÖ Found dataset at: {DATASET_PATH}")
        break

if DATASET_PATH is None:
    print("‚ö†Ô∏è Dataset path not found. Checking available directories...")
    # Check if /kaggle/input exists and list its contents
    if Path("/kaggle/input").exists():
        print("\nAvailable directories in /kaggle/input:")
        for item in Path("/kaggle/input").iterdir():
            print(f"  - {item}")
            if item.is_dir():
                # Check subdirectories
                try:
                    for subitem in item.iterdir():
                        print(f"    - {subitem}")
                except:
                    pass
    raise FileNotFoundError(
        f"Could not find dataset. Please set DATASET_PATH manually to an absolute path.\n"
        f"Expected structure: /kaggle/input/<dataset-name>/covers-by-storycode/<storycode-folders>/<images>"
    )

# Ensure DATASET_PATH is absolute
DATASET_PATH = str(Path(DATASET_PATH).resolve())
print(f"üìÅ Using absolute dataset path: {DATASET_PATH}")

# OUTPUT_DIR should also be absolute
OUTPUT_DIR = "/kaggle/working" if Path("/kaggle/working").exists() else str(Path(".").resolve())
OUTPUT_DIR = str(Path(OUTPUT_DIR).resolve())
print(f"üìÅ Using absolute output directory: {OUTPUT_DIR}")

# Training configuration
EPOCHS = 3
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EMBED_DIM = 512
SAMPLES_PER_CLASS = 20  # Samples per class per epoch

# Limit number of classes for training (to keep training time reasonable)
# Strategy: Use top classes by number of images (these have most examples = better learning)
# The model learns a GENERAL similarity function, so it can detect ANY storycode at inference,
# but training on diverse classes helps generalization
MAX_CLASSES_FOR_TRAINING = 5000  # Adjust based on available compute time
# With 5000 classes √ó 20 samples √ó 3 epochs = 300k samples total (~18k batches per epoch)
# This should take ~2-3 hours per epoch instead of 20 hours
#
# Note: This model uses triplet loss for similarity learning, NOT classification.
# At inference, it compares query embeddings against ALL storycodes in the database,
# so it CAN detect storycodes not seen during training. Quality may vary.

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Dataset path and caching (using absolute paths)
CACHE_DIR = OUTPUT_DIR
FILE_CACHE = str(Path(CACHE_DIR) / "file_list.pkl")
CLASS_CACHE = str(Path(CACHE_DIR) / "class_to_images.pkl")

# Load or build dataset index
if Path(CLASS_CACHE).exists():
    with open(CLASS_CACHE, 'rb') as f:
        all_class_to_images = pickle.load(f)
    print(f"Loaded class mapping from cache: {len(all_class_to_images)} total classes")
    
    # Apply training limit if needed
    if len(all_class_to_images) > MAX_CLASSES_FOR_TRAINING:
        print(f"\n‚ö†Ô∏è Found {len(all_class_to_images)} classes, limiting to top {MAX_CLASSES_FOR_TRAINING} for training")
        sorted_classes = sorted(all_class_to_images.items(), key=lambda x: len(x[1]), reverse=True)
        class_to_images = dict(sorted_classes[:MAX_CLASSES_FOR_TRAINING])
        print(f"‚úÖ Using top {len(class_to_images)} classes (by number of images)")
        print(f"   Note: Model learns general similarity, so it can detect ALL storycodes at inference")
    else:
        class_to_images = all_class_to_images
        print(f"‚úÖ Using all {len(class_to_images)} classes")
else:
    if Path(FILE_CACHE).exists():
        with open(FILE_CACHE, 'rb') as f:
            all_files = pickle.load(f)
        print(f"Loaded {len(all_files)} file paths from cache")
    else:
        print("Indexing files (first run may take a while)...")
        all_files = get_image_files(DATASET_PATH)
        with open(FILE_CACHE, 'wb') as f:
            pickle.dump(all_files, f)
        print(f"Indexed {len(all_files)} image files and cached to {FILE_CACHE}")

    # Build class index - first get all classes to see what we have
    all_class_to_images = build_class_index(all_files, DATASET_PATH, max_classes=None)
    with open(CLASS_CACHE, 'wb') as f:
        pickle.dump(all_class_to_images, f)
    print(f"Built class mapping: {len(all_class_to_images)} total classes found")
    
    # Now limit to top classes for training
    if len(all_class_to_images) > MAX_CLASSES_FOR_TRAINING:
        print(f"\n‚ö†Ô∏è Found {len(all_class_to_images)} classes, limiting to top {MAX_CLASSES_FOR_TRAINING} for training")
        # Sort by number of images (descending) and take top classes
        # This ensures we train on classes with most examples (better for learning)
        sorted_classes = sorted(all_class_to_images.items(), key=lambda x: len(x[1]), reverse=True)
        class_to_images = dict(sorted_classes[:MAX_CLASSES_FOR_TRAINING])
        print(f"‚úÖ Using top {len(class_to_images)} classes (by number of images)")
        print(f"   Note: Model learns general similarity, so it can detect ALL storycodes at inference")
        print(f"   Training on diverse classes helps generalization to unseen storycodes")
    else:
        class_to_images = all_class_to_images
        print(f"‚úÖ Using all {len(class_to_images)} classes")

num_images = sum(len(v) for v in class_to_images.values())
avg_images_per_class = num_images / len(class_to_images) if len(class_to_images) > 0 else 0
print(f"\nüìä Training Dataset Summary:")
print(f"  - Classes: {len(class_to_images)}")
print(f"  - Total images: {num_images}")
print(f"  - Avg images per class: {avg_images_per_class:.1f}")
print(f"  - Samples per epoch: {len(class_to_images) * SAMPLES_PER_CLASS}")
print(f"  - Estimated batches per epoch: {(len(class_to_images) * SAMPLES_PER_CLASS) // BATCH_SIZE}")


In [None]:
# Enhanced transforms for robustness (blur, rotation, partial frames)
# Aggressive data augmentation to handle:
# - Blurred images (RandomGaussianBlur)
# - Rotated images (RandomAffine with up to 30¬∞ rotation)
# - Partial frame images (RandomResizedCrop with scale variation)
# - Perspective distortions (RandomPerspective)
transform = transforms.Compose([
    # First resize to larger size to allow for cropping
    transforms.Resize((256, 256)),
    
    # Random crop to simulate partial frame images (70-100% of image)
    # This helps model learn to recognize images even when not fully visible
    transforms.RandomResizedCrop(
        size=224,
        scale=(0.7, 1.0),  # Can crop up to 30% of image
        ratio=(0.8, 1.25)  # Allow some aspect ratio variation
    ),
    
    # Color augmentations (more aggressive)
    transforms.ColorJitter(
        brightness=0.3,  # Increased from 0.2
        contrast=0.3,    # Increased from 0.2
        saturation=0.3,  # Increased from 0.2
        hue=0.1          # Added hue variation
    ),
    
    # Random blur to handle slightly blurred images
    RandomGaussianBlur(p=0.5, radius_range=(0, 2.5)),
    
    # Affine transformations: rotation (up to 30¬∞), translation, scale, shear
    RandomAffine(degrees=30, translate=(0.15, 0.15), scale=(0.75, 1.25), shear=15, p=0.7),
    
    # Perspective transformation for geometric robustness
    RandomPerspective(p=0.3, distortion_scale=0.2),
    
    # Horizontal flip
    transforms.RandomHorizontalFlip(p=0.5),
    
    # Convert to tensor and normalize (ImageNet normalization for EfficientNet)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    
    # Random erasing (cutout) for additional regularization
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1), ratio=(0.3, 3.3)),
])

# Create dataset and dataloader
dataset = ComicTripletDataset(class_to_images, transform=transform, samples_per_class=SAMPLES_PER_CLASS)

NUM_WORKERS = 4 if device.type == 'cuda' else 2
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False,
    persistent_workers=True if NUM_WORKERS > 0 else False
)

print(f"\n‚úÖ Dataset created: {len(dataset)} samples per epoch")
print(f"   Batch size: {BATCH_SIZE}, Workers: {NUM_WORKERS}")
print(f"   Actual batches per epoch: {len(loader)}")


In [None]:
# Create model with regularization parameters
# dropout=0.3 and weight_decay=0.01 help prevent overfitting
model = EfficientNetEmbeddingModel(
    model_name="efficientnet_b0",
    lr=LEARNING_RATE,
    embed_dim=EMBED_DIM,
    dropout=0.3,        # Dropout for regularization
    weight_decay=0.01  # L2 weight decay
)

# Setup trainer
num_gpus = torch.cuda.device_count()

# Use single GPU for stability (multi-GPU can cause hangs in notebooks)
if num_gpus > 0:
    accelerator = "gpu"
    devices = 1  # Use single GPU to avoid DDP issues in notebooks
    strategy = "auto"
    precision = "16-mixed"  # Fixed: use "16-mixed" instead of 16 for mixed precision
else:
    accelerator = "cpu"
    devices = 1
    strategy = "auto"
    precision = "32-true"

# Custom callback to print progress
class ProgressCallback(pl.Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % 100 == 0:  # Print every 100 batches
            current_epoch = trainer.current_epoch + 1
            total_batches = len(trainer.train_dataloader)
            progress = (batch_idx + 1) / total_batches * 100
            loss = outputs['loss'].item() if isinstance(outputs, dict) else outputs.item()
            print(f"Epoch {current_epoch}/{EPOCHS}, Batch {batch_idx+1}/{total_batches} ({progress:.1f}%), Loss: {loss:.4f}")

# Print dataset info before training
print(f"\nüìä Dataset Info:")
print(f"  - Total samples per epoch: {len(dataset)}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Batches per epoch: {len(loader)}")
estimated_time_per_epoch = len(loader) * 0.3  # More realistic estimate: ~0.3s per batch
print(f"  - Estimated time per epoch: ~{estimated_time_per_epoch/60:.1f} minutes ({estimated_time_per_epoch:.0f} seconds)")

# Test data loading first
print(f"\nüß™ Testing data loading...")
try:
    test_batch = next(iter(loader))
    print(f"‚úÖ Data loading works! Batch shape: {[x.shape for x in test_batch]}")
except Exception as e:
    print(f"‚ùå Data loading failed: {e}")
    raise

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=accelerator,
    devices=devices,
    strategy=strategy,
    precision=precision,
    log_every_n_steps=100,  # Log every 100 steps
    enable_progress_bar=True,
    enable_model_summary=True,
    # Add gradient clipping to prevent exploding gradients
    gradient_clip_val=1.0,
    # Add callback for progress monitoring
    callbacks=[ProgressCallback()],
    # Remove val_check_interval since we don't have a validation set
    # This was causing issues - validation check with no validation set can hang
)

print(f"\n‚úÖ Trainer configured: {accelerator}, {devices} device(s), precision={precision}")
print(f"   Will train for {EPOCHS} epochs (~{len(loader) * EPOCHS} total batches)")
print(f"   Estimated total time: ~{estimated_time_per_epoch * EPOCHS / 60:.1f} minutes")


In [None]:
# Train the model
print(f"\nüöÄ Starting training for {EPOCHS} epochs...")
print(f"   Dataset: {len(dataset)} samples, {len(loader)} batches per epoch")
print(f"   Progress will be printed every 100 batches\n")

import time
start_time = time.time()

try:
    trainer.fit(model, loader)
    elapsed = time.time() - start_time
    print(f"\n‚úÖ Training completed!")
    print(f"   Total time: {elapsed/60:.1f} minutes ({elapsed:.1f} seconds)")
    print(f"   Average time per epoch: {elapsed/EPOCHS/60:.1f} minutes")
except KeyboardInterrupt:
    elapsed = time.time() - start_time
    print(f"\n‚ö†Ô∏è Training interrupted by user after {elapsed/60:.1f} minutes")
    raise
except Exception as e:
    elapsed = time.time() - start_time
    print(f"\n‚ùå Training failed after {elapsed/60:.1f} minutes")
    print(f"   Error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Save the model (using absolute paths)
model_path = str(Path(OUTPUT_DIR) / "efficientnet_b0_comic_embedding.pt")
torch.save(model.state_dict(), model_path)
print(f"‚úÖ Model saved to {model_path}")

# Save checkpoint
checkpoint_path = str(Path(OUTPUT_DIR) / f"efficientnet_b0_epoch{EPOCHS}.ckpt")
trainer.save_checkpoint(checkpoint_path)
print(f"‚úÖ Checkpoint saved to {checkpoint_path}")


In [None]:
# Plot training loss with analysis
if model.training_losses:
    import numpy as np
    
    losses = np.array(model.training_losses)
    
    # Calculate moving average for smoother visualization
    window_size = 100
    if len(losses) > window_size:
        moving_avg = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
        moving_avg_x = np.arange(window_size//2, len(losses) - window_size//2)
    else:
        moving_avg = losses
        moving_avg_x = np.arange(len(losses))
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
    
    # Plot 1: Full loss with moving average
    ax1.plot(losses, label="Training Loss", alpha=0.3, linewidth=0.5)
    if len(losses) > window_size:
        ax1.plot(moving_avg_x, moving_avg, label=f"Moving Average (window={window_size})", 
                linewidth=2, color='red')
    ax1.set_xlabel("Batch")
    ax1.set_ylabel("Loss")
    ax1.set_title("EfficientNet-B0 Training Loss (Full)")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Last 25% of training to see final convergence
    last_quarter_start = len(losses) // 4 * 3
    ax2.plot(range(last_quarter_start, len(losses)), losses[last_quarter_start:], 
            label="Training Loss (Last 25%)", alpha=0.5, linewidth=0.5)
    if len(losses) > window_size:
        last_quarter_mask = moving_avg_x >= last_quarter_start
        ax2.plot(moving_avg_x[last_quarter_mask], moving_avg[last_quarter_mask], 
                label=f"Moving Average", linewidth=2, color='red')
    ax2.set_xlabel("Batch")
    ax2.set_ylabel("Loss")
    ax2.set_title("Training Loss - Final 25% (Convergence Check)")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plot_path = str(Path(OUTPUT_DIR) / "training_loss.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"‚úÖ Training loss plot saved to {plot_path}")
    
    # Calculate statistics
    initial_loss = np.mean(losses[:100])
    final_loss = np.mean(losses[-100:])
    min_loss = np.min(losses)
    max_loss = np.max(losses)
    std_loss = np.std(losses)
    final_std = np.std(losses[-1000:]) if len(losses) > 1000 else std_loss
    
    print(f"\nüìä Training Loss Statistics:")
    print(f"  - Initial loss (first 100 batches): {initial_loss:.4f}")
    print(f"  - Final loss (last 100 batches): {final_loss:.4f}")
    print(f"  - Overall improvement: {initial_loss:.4f} ‚Üí {final_loss:.4f} ({((initial_loss-final_loss)/initial_loss*100):.1f}% reduction)")
    print(f"  - Minimum loss: {min_loss:.4f}")
    print(f"  - Maximum loss: {max_loss:.4f}")
    print(f"  - Overall std dev: {std_loss:.4f}")
    print(f"  - Final std dev (last 1000 batches): {final_std:.4f}")
    
    # Assessment
    print(f"\nüìà Assessment:")
    if final_loss < 0.1:
        print(f"  ‚úÖ Final loss is low (< 0.1), indicating good learning")
    elif final_loss < 0.15:
        print(f"  ‚ö†Ô∏è Final loss is moderate (0.1-0.15), model may benefit from more training")
    else:
        print(f"  ‚ö†Ô∏è Final loss is high (> 0.15), consider adjusting hyperparameters")
    
    if final_std < 0.05:
        print(f"  ‚úÖ Low variance in final loss, stable training")
    elif final_std < 0.1:
        print(f"  ‚ö†Ô∏è Moderate variance - acceptable for triplet loss with random sampling")
    else:
        print(f"  ‚ö†Ô∏è High variance - consider reducing learning rate or using smoother sampling")
    
    if (initial_loss - final_loss) / initial_loss > 0.5:
        print(f"  ‚úÖ Significant improvement (>50% reduction), model learned effectively")
    else:
        print(f"  ‚ö†Ô∏è Limited improvement, may need more training or hyperparameter tuning")


In [None]:
# Export to ONNX
try:
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224)
    
    onnx_path = str(Path(OUTPUT_DIR) / "efficientnet_b0_comic_embedding.onnx")
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["embedding"],
        dynamic_axes={"input": {0: "batch"}, "embedding": {0: "batch"}},
        opset_version=17,
        export_params=True
    )
    
    onnx_size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
    print(f"‚úÖ ONNX model exported to {onnx_path} ({onnx_size_mb:.1f} MB)")
    
    # Compare with original ViT-Base size
    vit_size_mb = 331.1
    size_reduction = (1 - onnx_size_mb / vit_size_mb) * 100
    print(f"üìä Size reduction vs ViT-Base: {size_reduction:.1f}% ({vit_size_mb:.1f}MB ‚Üí {onnx_size_mb:.1f}MB)")
    
except Exception as e:
    print(f"‚ö†Ô∏è ONNX export failed: {e}")

print("\nüéâ Training completed successfully!")
print(f"üìÅ All outputs saved to: {OUTPUT_DIR}")
