# EfficientNet-B0 Fine-tuning for Comic Reverse Image Search

This notebook fine-tunes EfficientNet-B0 for comic reverse image search using triplet loss.

**Key advantages of EfficientNet-B0:**
- **Small size**: ~18MB ONNX (vs 331MB for ViT-Base)
- **High quality**: 0.136 similarity separation in our tests
- **Fast inference**: ~264ms per batch
- **Proven architecture**: EfficientNet is well-established

**Training approach:**
- Uses triplet loss for similarity learning
- Pre-trained on ImageNet for better initialization
- Optimized transforms for EfficientNet
- Mixed precision training for efficiency


In [None]:
%pip install torch torchvision pytorch-lightning timm matplotlib tqdm --quiet


In [None]:
import os
import pickle
from pathlib import Path
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TripletMarginLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm

# Set random seeds for reproducibility
pl.seed_everything(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
class EfficientNetEmbeddingModel(pl.LightningModule):
    def __init__(self, model_name="efficientnet_b0", lr=1e-4, embed_dim=512):
        super().__init__()
        self.save_hyperparameters()
        
        # Load EfficientNet-B0 backbone
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)  # Remove classifier
        self.feature_dim = self.backbone.num_features  # 1280 for EfficientNet-B0
        
        # Projection head
        self.fc = nn.Linear(self.feature_dim, embed_dim)
        
        # Loss function
        self.loss_fn = TripletMarginLoss(margin=0.2)
        self.training_losses = []
        
        print(f"EfficientNet-B0: feature_dim={self.feature_dim}, embed_dim={embed_dim}")

    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.fc(features)
        return F.normalize(embeddings, p=2, dim=1)

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        emb_a = self(anchor)
        emb_p = self(positive)
        emb_n = self(negative)
        loss = self.loss_fn(emb_a, emb_p, emb_n)
        self.training_losses.append(loss.item())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)


In [None]:
# Dataset loading functions (same as original)
def get_image_files(root_dir, batch_cache=10000):
    """Fast, cached file indexing"""
    print(f"Scanning for image files in {root_dir} ...")
    exts = {'.jpg', '.jpeg', '.png'}
    all_files = []

    root_path = Path(root_dir)
    for cls_dir in root_path.iterdir():
        if not cls_dir.is_dir():
            continue
        for idx, f in enumerate(cls_dir.glob('**/*')):
            if f.suffix.lower() in exts:
                all_files.append(str(f))
            # Periodic batch caching
            if len(all_files) % batch_cache == 0:
                print(f"Indexed {len(all_files)} files so far...")

    print(f"Done. Found {len(all_files)} files.")
    return all_files

def build_class_index(files, root_dir, max_classes=None):
    """Group images by their immediate subfolder under root_dir"""
    root = Path(root_dir)
    class_to_images = {}
    for f in files:
        rel = Path(f).resolve().relative_to(root.resolve())
        parts = rel.parts
        if len(parts) < 2:
            continue  # skip images directly in root
        cls = parts[0]  # immediate subfolder name (e.g., "AR 101")
        class_to_images.setdefault(cls, []).append(f)
    
    # Keep only classes with >= 2 images (needed for triplets)
    class_to_images = {k: v for k, v in class_to_images.items() if len(v) >= 2}
    
    # Limit number of classes if specified
    if max_classes and len(class_to_images) > max_classes:
        # Sort by number of images (descending) and take top classes
        sorted_classes = sorted(class_to_images.items(), key=lambda x: len(x[1]), reverse=True)
        class_to_images = dict(sorted_classes[:max_classes])

    if len(class_to_images) == 0:
        print("⚠️ No valid classes found! Checking dataset structure...")
        for dirpath, dirnames, filenames in os.walk(root_dir):
            print(f"{dirpath}: {len(filenames)} files")
            break  # only show top-level
    else:
        print(f"✅ Found {len(class_to_images)} classes")
        for cls, imgs in list(class_to_images.items())[:5]:
            print(f"Class: {cls}, {len(imgs)} images, sample: {imgs[0]}")
    
    return class_to_images

class ComicTripletDataset(Dataset):
    def __init__(self, class_to_images, transform=None):
        self.transform = transform
        self.class_to_images = class_to_images
        self.classes = list(self.class_to_images.keys())

    def __len__(self):
        # Scale with number of classes for balanced sampling
        return max(len(self.classes) * 20, 1000)

    def __getitem__(self, idx):
        anchor_class = random.choice(self.classes)
        negative_class = random.choice([c for c in self.classes if c != anchor_class])

        anchor_path, positive_path = random.sample(self.class_to_images[anchor_class], 2)
        negative_path = random.choice(self.class_to_images[negative_class])

        anchor_img = Image.open(anchor_path).convert("RGB")
        positive_img = Image.open(positive_path).convert("RGB")
        negative_img = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img


In [None]:
# Configuration
DATASET_PATH = "kaggle/input/inducks-entry-images/covers-by-storycode"
OUTPUT_DIR = "."
MAX_CLASSES = 1000  # Adjust based on your dataset size
EPOCHS = 3
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EMBED_DIM = 512

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Dataset path and caching
CACHE_DIR = OUTPUT_DIR
FILE_CACHE = os.path.join(CACHE_DIR, "file_list.pkl")
CLASS_CACHE = os.path.join(CACHE_DIR, "class_to_images.pkl")

# Load or build dataset index
if Path(CLASS_CACHE).exists():
    with open(CLASS_CACHE, 'rb') as f:
        class_to_images = pickle.load(f)
    print(f"Loaded class mapping from cache: {len(class_to_images)} classes")
else:
    if Path(FILE_CACHE).exists():
        with open(FILE_CACHE, 'rb') as f:
            all_files = pickle.load(f)
        print(f"Loaded {len(all_files)} file paths from cache")
    else:
        print("Indexing files (first run may take a while)...")
        all_files = get_image_files(DATASET_PATH)
        with open(FILE_CACHE, 'wb') as f:
            pickle.dump(all_files, f)
        print(f"Indexed {len(all_files)} image files and cached to {FILE_CACHE}")

    class_to_images = build_class_index(all_files, DATASET_PATH, MAX_CLASSES)
    with open(CLASS_CACHE, 'wb') as f:
        pickle.dump(class_to_images, f)
    print(f"Built class mapping: {len(class_to_images)} classes cached to {CLASS_CACHE}")

num_images = sum(len(v) for v in class_to_images.values())
print(f"Ready: {len(class_to_images)} classes, {num_images} images with >=2 per class.")


In [None]:
# Transforms optimized for EfficientNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Create dataset and dataloader
dataset = ComicTripletDataset(class_to_images, transform=transform)

NUM_WORKERS = 4 if device.type == 'cuda' else 2
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if device.type == 'cuda' else False,
    persistent_workers=True if NUM_WORKERS > 0 else False
)

print(f"Dataset created: {len(dataset)} samples per epoch")
print(f"Batch size: {BATCH_SIZE}, Workers: {NUM_WORKERS}")


In [None]:
# Create model
model = EfficientNetEmbeddingModel(
    model_name="efficientnet_b0",
    lr=LEARNING_RATE,
    embed_dim=EMBED_DIM
)

# Setup trainer
num_gpus = torch.cuda.device_count()

if num_gpus > 0:
    accelerator = "gpu"
    devices = min(num_gpus, 2)  # Use up to 2 GPUs
    strategy = "auto"
    precision = 16  # Mixed precision for efficiency
else:
    accelerator = "cpu"
    devices = 1
    strategy = "auto"
    precision = 32

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=accelerator,
    devices=devices,
    strategy=strategy,
    precision=precision,
    log_every_n_steps=10,
    val_check_interval=0.5,  # Validate every 50% of epoch
    enable_progress_bar=True,
    enable_model_summary=True
)

print(f"Trainer configured: {accelerator}, {devices} devices, precision={precision}")


In [None]:
# Train the model
print(f"\n🚀 Starting training for {EPOCHS} epochs...")
trainer.fit(model, loader)

# Save the model
model_path = os.path.join(OUTPUT_DIR, "efficientnet_b0_comic_embedding.pt")
torch.save(model.state_dict(), model_path)
print(f"✅ Model saved to {model_path}")

# Save checkpoint
checkpoint_path = os.path.join(OUTPUT_DIR, f"efficientnet_b0_epoch{EPOCHS}.ckpt")
trainer.save_checkpoint(checkpoint_path)
print(f"✅ Checkpoint saved to {checkpoint_path}")


In [None]:
# Plot training loss
if model.training_losses:
    plt.figure(figsize=(10, 6))
    plt.plot(model.training_losses, label="Training Loss", alpha=0.7)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    plt.title("EfficientNet-B0 Training Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plot_path = os.path.join(OUTPUT_DIR, "training_loss.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✅ Training loss plot saved to {plot_path}")
    
    print(f"\n📊 Final training loss: {model.training_losses[-1]:.4f}")
    print(f"📈 Loss improvement: {model.training_losses[0]:.4f} → {model.training_losses[-1]:.4f}")


In [None]:
# Export to ONNX
try:
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224)
    
    onnx_path = os.path.join(OUTPUT_DIR, "efficientnet_b0_comic_embedding.onnx")
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        input_names=["input"],
        output_names=["embedding"],
        dynamic_axes={"input": {0: "batch"}, "embedding": {0: "batch"}},
        opset_version=17,
        export_params=True
    )
    
    onnx_size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
    print(f"✅ ONNX model exported to {onnx_path} ({onnx_size_mb:.1f} MB)")
    
    # Compare with original ViT-Base size
    vit_size_mb = 331.1
    size_reduction = (1 - onnx_size_mb / vit_size_mb) * 100
    print(f"📊 Size reduction vs ViT-Base: {size_reduction:.1f}% ({vit_size_mb:.1f}MB → {onnx_size_mb:.1f}MB)")
    
except Exception as e:
    print(f"⚠️ ONNX export failed: {e}")

print("\n🎉 Training completed successfully!")
print(f"📁 All outputs saved to: {OUTPUT_DIR}")
