In [None]:
!pip install torch torchvision pytorch-lightning pytorch-metric-learning transformers --quiet


# Comic Reverse Image Search — Large Dataset Fine-Tuning

This notebook fine-tunes a ViT-based embedding model with triplet loss for reverse image search on comic images.

**Optimized for large datasets (~500k files):**
- Single-pass, **cached file indexing** (avoids repeated, slow recursive scans)
- **Recursive subdirectory** support
- **Single GPU by default**; optional **DDP** toggle
- Mixed precision (FP16) for speed
- Live loss in progress bar + **loss plot** after training
- **ONNX export** for TypeScript/Node.js use

👉 Set `DATASET_PATH` to your dataset root (direct parent of the group folders).


In [None]:
import os, pickle
from pathlib import Path
import glob, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TripletMarginLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from transformers import AutoImageProcessor, AutoModel, ViTImageProcessor, ViTModel
import matplotlib.pyplot as plt
import glob
from tqdm import tqdm  # for a progress bar

pl.seed_everything(42)

# =============================================
# 1) Fast, cached file indexing (class = immediate subfolder)
# =============================================
def get_image_files(root_dir, batch_cache=10000):
    print(f"Scanning for image files in {root_dir} ...")
    exts = {'.jpg', '.jpeg', '.png'}
    all_files = []

    root_path = Path(root_dir)
    for cls_dir in root_path.iterdir():
        if not cls_dir.is_dir():
            continue
        for idx, f in enumerate(cls_dir.glob('**/*')):
            if f.suffix.lower() in exts:
                all_files.append(str(f))
            # Periodic batch caching
            if len(all_files) % batch_cache == 0:
                print(f"Indexed {len(all_files)} files so far...")

    print(f"Done. Found {len(all_files)} files.")
    return all_files

def build_class_index(files, root_dir):
    """Group images by their immediate subfolder under root_dir"""
    root = Path(root_dir)
    class_to_images = {}
    for f in files:
        rel = Path(f).resolve().relative_to(root.resolve())
        parts = rel.parts
        if len(parts) < 2:
            continue  # skip images directly in root
        cls = parts[0]  # immediate subfolder name (e.g., "AR 101")
        class_to_images.setdefault(cls, []).append(f)
    # Keep only classes with >= 2 images (needed for triplets)
    class_to_images = {k: v for k, v in class_to_images.items() if len(v) >= 2}

    if len(class_to_images) == 0:
        print("⚠️ No valid classes found! Checking dataset structure...")
        for dirpath, dirnames, filenames in os.walk(DATASET_PATH):
            print(f"{dirpath}: {len(filenames)} files")
            break  # only show top-level
    else:
        print(f"✅ Found {len(class_to_images)} classes")
        for cls, imgs in list(class_to_images.items())[:5]:
            print(f"Class: {cls}, {len(imgs)} images, sample: {imgs[0]}")
    
    return class_to_images

# =============================================
# 2) Dataset using pre-indexed class mapping
# =============================================
class ComicTripletDataset(Dataset):
    def __init__(self, class_to_images, transform=None):
        self.transform = transform
        self.class_to_images = class_to_images
        self.classes = list(self.class_to_images.keys())

    def __len__(self):
        # You can scale this; here we do 10 * num_classes for balanced sampling
        return max(len(self.classes) * 10, 1)

    def __getitem__(self, idx):
        anchor_class = random.choice(self.classes)
        negative_class = random.choice([c for c in self.classes if c != anchor_class])

        anchor_path, positive_path = random.sample(self.class_to_images[anchor_class], 2)
        negative_path = random.choice(self.class_to_images[negative_class])

        anchor_img = Image.open(anchor_path).convert("RGB")
        positive_img = Image.open(positive_path).convert("RGB")
        negative_img = Image.open(negative_path).convert("RGB")

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img

# =============================================
# 3) Model (ViT backbone + projection head) with loss tracking
# =============================================
class ComicEmbeddingModel(pl.LightningModule):
    def __init__(self, model_name="google/vit-base-patch16-224", lr=2e-5, embed_dim=512):
        super().__init__()
        self.save_hyperparameters()
        self.processor = ViTImageProcessor.from_pretrained(
            model_name, 
            use_fast=True   # ✅ enable fast preprocessor
        )
        self.backbone = ViTModel.from_pretrained(model_name, add_pooling_layer=True)
        self.fc = nn.Linear(self.backbone.config.hidden_size, embed_dim)
        self.loss_fn = TripletMarginLoss(margin=0.2)
        self.training_losses = []

    def forward(self, x):
        outputs = self.backbone(x)
        pooled = outputs.pooler_output  # now this is a Tensor
        embeddings = self.fc(pooled)
        return F.normalize(embeddings, p=2, dim=1)

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        emb_a = self(anchor)
        emb_p = self(positive)
        emb_n = self(negative)
        loss = self.loss_fn(emb_a, emb_p, emb_n)
        self.training_losses.append(loss.item())
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

# =============================================
# 4) User-config: dataset path & caching
# =============================================
DATASET_PATH = "/kaggle/input/inducks-entry-images/covers-by-storycode"
CACHE_DIR = "/kaggle/working" if Path('/kaggle/working').exists() else "."
FILE_CACHE = os.path.join(CACHE_DIR, "file_list.pkl")
CLASS_CACHE = os.path.join(CACHE_DIR, "class_to_images.pkl")

# !rm -f /kaggle/working/file_list.pkl /kaggle/working/class_to_images.pkl

if Path(CLASS_CACHE).exists():
    with open(CLASS_CACHE, 'rb') as f:
        class_to_images = pickle.load(f)
    print(f"Loaded class mapping from cache: {len(class_to_images)} classes")
else:
    if Path(FILE_CACHE).exists():
        with open(FILE_CACHE, 'rb') as f:
            all_files = pickle.load(f)
        print(f"Loaded {len(all_files)} file paths from cache")
    else:
        print("Indexing files (first run may take a while)...")
        all_files = get_image_files(DATASET_PATH)
        with open(FILE_CACHE, 'wb') as f:
            pickle.dump(all_files, f)
        print(f"Indexed {len(all_files)} image files and cached to {FILE_CACHE}")

    class_to_images = build_class_index(all_files, DATASET_PATH)
    with open(CLASS_CACHE, 'wb') as f:
        pickle.dump(class_to_images, f)
    print(f"Built class mapping: {len(class_to_images)} classes cached to {CLASS_CACHE}")

num_images = sum(len(v) for v in class_to_images.values())
print(f"Ready: {len(class_to_images)} classes, {num_images} images with >=2 per class.")

# =============================================
# 5) Transforms, Dataset, DataLoader (I/O-friendly settings)
# =============================================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

max_classes = 5000
sampled_classes = dict(random.sample(list(class_to_images.items()), max_classes))
dataset = ComicTripletDataset(sampled_classes, transform=transform)

# Try modest workers to avoid oversubscribing I/O on Kaggle
BATCH_SIZE = 16
NUM_WORKERS = 2
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    persistent_workers=True
)

# =============================================
# 6) Trainer: single GPU by default, DDP optional
# =============================================
num_gpus = torch.cuda.device_count()
USE_DDP = False  # Set to True AFTER you confirm fast startup with single GPU

model = ComicEmbeddingModel(lr=2e-5)

if num_gpus > 0:
    accelerator = "gpu"
    devices = 1   # ✅ stick to 1 GPU for Kaggle stability
    strategy = "auto"   # "ddp" can break in notebooks, keep "auto" here
    precision = 16
else:
    accelerator = "cpu"
    devices = 1
    strategy = "auto"
    precision = 32

trainer = pl.Trainer(
    max_epochs=1,              # just 1 epoch per run
    accelerator=accelerator,
    devices=devices,
    strategy=strategy,
    precision=precision, 
    limit_train_batches=0.1,   # train on only 10% of batches
    log_every_n_steps=5
)
trainer.fit(model, loader)

OUT_DIR = "/kaggle/working" if Path('/kaggle/working').exists() else "."
pt_path = os.path.join(OUT_DIR, "comic_embedding_model.pt")
torch.save(model.state_dict(), pt_path)
print(f"Model saved to {pt_path}")

trainer.save_checkpoint(f"{OUT_DIR}/comic_model_epoch1.ckpt")

# =============================================
# 7) Plot training loss
# =============================================
plt.figure(figsize=(8,5))
plt.plot(model.training_losses, label="Training Loss")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# =============================================
# 8) Export to ONNX for TypeScript/Node.js
# =============================================
import torch
from pathlib import Path

OUT_DIR = "/kaggle/working" if Path('/kaggle/working').exists() else "."
pt_path = os.path.join(OUT_DIR, "comic_embedding_model.pt")
onnx_path = os.path.join(OUT_DIR, "comic_embedding_model.onnx")

export_model = ComicEmbeddingModel()
export_model.load_state_dict(torch.load(pt_path, map_location="cpu"))
export_model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    export_model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["embedding"],
    dynamic_axes={"input": {0: "batch"}, "embedding": {0: "batch"}},
    opset_version=17,
)
print(f"ONNX model saved to {onnx_path}")


## Tips for very large datasets
- Keep `USE_DDP=False` for the first successful run; turn it on only if startup is fast and stable.
- Consider packing images into **WebDataset (.tar)** shards for faster sequential I/O.
- Increase `BATCH_SIZE` if you have headroom (watch GPU memory).
- You can reduce startup time further by saving `class_to_images.pkl` and reusing it across sessions.