# TCG Card Embedding Training

Train FastViT-T12 embedding model with ArcFace loss for card recognition.

**Features:**
- Heavy augmentation for single-sample-per-class learning
- Multi-view training (4 augmented views per card)
- Memory bank for cross-batch hard negative mining
- Recall@K validation metrics

**Prerequisites:**
- Card images organized by class in Google Drive

**Estimated Time:** ~10-12 hours for full training

## 1. Setup Environment

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q torch torchvision --upgrade
!pip install -q timm lightning pytorch-metric-learning
!pip install -q albumentations pyyaml tqdm pillow
!pip install -q wandb

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths
DRIVE_PROJECT = '/content/drive/MyDrive/tcg-scanner'
WORK_DIR = '/content/tcg-scanner'

import os
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print(f"Working directory: {os.getcwd()}")

In [None]:
# Optional: Initialize Weights & Biases
import wandb

# Uncomment and run to enable wandb logging
# wandb.login()
# wandb.init(project='tcg-scanner', name='embedding-fastvit-t12')

## 2. Prepare Data

Organize card images into class directories (one folder per card).

In [None]:
import shutil
from pathlib import Path
import json
import random

# Source data
cards_src = Path(DRIVE_PROJECT) / 'ml/data/images/riftbound'
manifest_path = Path(DRIVE_PROJECT) / 'ml/data/processed/riftbound/training_manifest.json'

# Load manifest
if manifest_path.exists():
    with open(manifest_path) as f:
        manifest = json.load(f)
    print(f"Loaded manifest with {len(manifest)} cards")
else:
    print("Manifest not found - using directory structure")
    manifest = None

In [None]:
def prepare_embedding_dataset(src_dir, output_dir, manifest=None, train_ratio=0.85):
    """Prepare dataset in class-folder structure for embedding training."""
    src_dir = Path(src_dir)
    output_dir = Path(output_dir)
    
    train_dir = output_dir / 'train'
    val_dir = output_dir / 'val'
    
    # Collect all images
    if manifest:
        # Use manifest for proper card-to-image mapping
        cards = []
        for entry in manifest:
            img_path = Path(DRIVE_PROJECT) / 'ml/data' / entry['image_path']
            if img_path.exists():
                cards.append({
                    'path': img_path,
                    'id': str(entry['product_id']),
                    'name': entry['clean_name']
                })
    else:
        # Fall back to directory structure
        cards = []
        for img_path in src_dir.rglob('*.jpg'):
            card_id = img_path.stem.split('_')[0]
            cards.append({
                'path': img_path,
                'id': card_id,
                'name': img_path.stem
            })
    
    print(f"Found {len(cards)} card images")
    
    # Shuffle and split
    random.seed(42)
    random.shuffle(cards)
    
    split_idx = int(len(cards) * train_ratio)
    train_cards = cards[:split_idx]
    val_cards = cards[split_idx:]
    
    print(f"Train: {len(train_cards)} cards, Val: {len(val_cards)} cards")
    
    # Create directories and copy files
    for split_name, split_cards, split_dir in [
        ('train', train_cards, train_dir),
        ('val', val_cards, val_dir)
    ]:
        for card in split_cards:
            class_dir = split_dir / card['id']
            class_dir.mkdir(parents=True, exist_ok=True)
            
            dst_path = class_dir / f"{card['name']}.jpg"
            if not dst_path.exists():
                shutil.copy(card['path'], dst_path)
    
    # Count classes
    train_classes = len(list(train_dir.iterdir()))
    val_classes = len(list(val_dir.iterdir()))
    
    print(f"\nDataset prepared:")
    print(f"  Train: {train_classes} classes in {train_dir}")
    print(f"  Val: {val_classes} classes in {val_dir}")
    
    return train_dir, val_dir

In [None]:
# Prepare dataset
train_dir, val_dir = prepare_embedding_dataset(
    src_dir=cards_src,
    output_dir=Path(WORK_DIR) / 'data/embedding',
    manifest=manifest,
    train_ratio=0.85
)

## 3. Define Model and Training Components

In [None]:
import albumentations as A
import lightning as L
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from PIL import Image
from pytorch_metric_learning import losses, miners
from torch.utils.data import DataLoader, Dataset, Sampler
from tqdm import tqdm

# Configuration
CONFIG = {
    'model': {
        'backbone': 'fastvit_t12',
        'embedding_dim': 384,
        'dropout': 0.2,
    },
    'training': {
        'epochs': 100,
        'batch_size': 64,
        'learning_rate': 0.0003,
        'weight_decay': 0.05,
        'views_per_card': 4,
    },
    'metric_learning': {
        'margin': 0.5,
        'scale': 64,
        'mining_margin': 0.3,
    },
}

In [None]:
def get_train_transforms():
    """Heavy augmentation pipeline for training."""
    return A.Compose([
        # Geometric
        A.Perspective(scale=(0.05, 0.15), p=0.5),
        A.Affine(rotate=(-20, 20), shear=(-10, 10), scale=(0.8, 1.2), p=0.8),
        
        # Lighting
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.4, contrast_limit=0.4, p=1.0),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=30, val_shift_limit=30, p=1.0),
            A.RandomGamma(gamma_limit=(60, 140), p=1.0),
        ], p=0.7),
        
        # Shadows
        A.RandomShadow(shadow_roi=(0, 0, 1, 1), p=0.3),
        
        # Blur
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 7), p=1.0),
            A.MotionBlur(blur_limit=7, p=1.0),
        ], p=0.4),
        
        # Noise
        A.OneOf([
            A.GaussNoise(var_limit=(10, 80), p=1.0),
            A.ISONoise(p=1.0),
        ], p=0.5),
        
        # Compression
        A.ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
        
        # Occlusion
        A.CoarseDropout(max_holes=3, max_height=40, max_width=40, p=0.3),
        
        # Resize and normalize
        A.Resize(256, 256),
        A.RandomCrop(224, 224),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms():
    """Light augmentation for validation."""
    return A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
class MultiViewCardDataset(Dataset):
    """Dataset generating multiple augmented views per card."""
    
    def __init__(self, root_dir, transform, views_per_card=4):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.views_per_card = views_per_card
        self.samples = []
        self.class_to_idx = {}
        
        for idx, class_dir in enumerate(sorted(self.root_dir.iterdir())):
            if class_dir.is_dir():
                self.class_to_idx[class_dir.name] = idx
                for img_path in class_dir.glob('*.[jp][pn][g]'):
                    self.samples.append((img_path, idx))
        
        print(f"Loaded {len(self.samples)} cards, {len(self.class_to_idx)} classes")
    
    def __len__(self):
        return len(self.samples) * self.views_per_card
    
    def __getitem__(self, idx):
        card_idx = idx // self.views_per_card
        img_path, label = self.samples[card_idx]
        
        image = np.array(Image.open(img_path).convert('RGB'))
        augmented = self.transform(image=image)
        
        return augmented['image'], label
    
    @property
    def num_classes(self):
        return len(self.class_to_idx)

In [None]:
class CardEmbeddingModel(nn.Module):
    """FastViT embedding model."""
    
    def __init__(self, backbone='fastvit_t12', embedding_dim=384, dropout=0.2):
        super().__init__()
        
        self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            feature_dim = features.shape[-1]
        
        self.embedding_head = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, embedding_dim),
        )
    
    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.embedding_head(features)
        return F.normalize(embeddings, p=2, dim=1)

In [None]:
class CardEmbeddingModule(L.LightningModule):
    """Lightning module for training."""
    
    def __init__(self, config, num_classes):
        super().__init__()
        self.save_hyperparameters()
        self.config = config
        
        self.model = CardEmbeddingModel(
            backbone=config['model']['backbone'],
            embedding_dim=config['model']['embedding_dim'],
            dropout=config['model']['dropout'],
        )
        
        self.loss_fn = losses.ArcFaceLoss(
            num_classes=num_classes,
            embedding_size=config['model']['embedding_dim'],
            margin=config['metric_learning']['margin'],
            scale=config['metric_learning']['scale'],
        )
        
        self.miner = miners.TripletMarginMiner(
            margin=config['metric_learning']['mining_margin'],
            type_of_triplets='hard',
        )
        
        self.val_embeddings = []
        self.val_labels = []
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self(images)
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_fn(embeddings, labels)
        
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self(images)
        loss = self.loss_fn(embeddings, labels)
        
        self.val_embeddings.append(embeddings.detach().cpu())
        self.val_labels.append(labels.detach().cpu())
        
        self.log('val_loss', loss, prog_bar=True)
    
    def on_validation_epoch_end(self):
        if not self.val_embeddings:
            return
        
        embeddings = torch.cat(self.val_embeddings, dim=0).numpy()
        labels = torch.cat(self.val_labels, dim=0).numpy()
        
        # Calculate Recall@1
        distances = np.linalg.norm(embeddings[:, None] - embeddings[None, :], axis=2)
        np.fill_diagonal(distances, np.inf)
        
        nearest = np.argmin(distances, axis=1)
        recall_1 = np.mean(labels[nearest] == labels)
        
        # Recall@5
        nearest_5 = np.argsort(distances, axis=1)[:, :5]
        recall_5 = np.mean([labels[i] in labels[nearest_5[i]] for i in range(len(labels))])
        
        self.log('val_recall_at_1', recall_1, prog_bar=True)
        self.log('val_recall_at_5', recall_5, prog_bar=True)
        
        self.val_embeddings = []
        self.val_labels = []
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.config['training']['learning_rate'],
            weight_decay=self.config['training']['weight_decay'],
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.config['training']['epochs'],
        )
        
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}}

## 4. Train Model

In [None]:
# Create datasets
train_dataset = MultiViewCardDataset(
    train_dir,
    transform=get_train_transforms(),
    views_per_card=CONFIG['training']['views_per_card'],
)

val_dataset = MultiViewCardDataset(
    val_dir,
    transform=get_val_transforms(),
    views_per_card=1,
)

print(f"\nTraining: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['training']['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['training']['batch_size'],
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

In [None]:
# Create model
model = CardEmbeddingModule(CONFIG, num_classes=train_dataset.num_classes)

# Callbacks
callbacks = [
    L.pytorch.callbacks.ModelCheckpoint(
        dirpath=Path(WORK_DIR) / 'runs/embedding/checkpoints',
        filename='best-{epoch}-{val_recall_at_1:.4f}',
        monitor='val_recall_at_1',
        mode='max',
        save_top_k=3,
        save_last=True,
    ),
    L.pytorch.callbacks.EarlyStopping(
        monitor='val_recall_at_1',
        patience=15,
        mode='max',
    ),
    L.pytorch.callbacks.LearningRateMonitor(logging_interval='epoch'),
]

# Trainer
trainer = L.Trainer(
    max_epochs=CONFIG['training']['epochs'],
    accelerator='auto',
    precision='16-mixed',
    callbacks=callbacks,
    default_root_dir=Path(WORK_DIR) / 'runs/embedding',
    log_every_n_steps=10,
)

In [None]:
# Train!
print("Starting training...")
trainer.fit(model, train_loader, val_loader)

In [None]:
# Save final model
output_dir = Path(WORK_DIR) / 'runs/embedding'
torch.save(model.model.state_dict(), output_dir / 'final_model.pt')
print(f"Model saved to {output_dir / 'final_model.pt'}")

## 5. Evaluate Model

In [None]:
# Generate embeddings for all cards
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

all_embeddings = []
all_labels = []
all_paths = []

val_simple = MultiViewCardDataset(val_dir, get_val_transforms(), views_per_card=1)

with torch.no_grad():
    for img, label in tqdm(val_simple):
        img = img.unsqueeze(0).to(device)
        emb = model(img).cpu().numpy()[0]
        all_embeddings.append(emb)
        all_labels.append(label)

embeddings = np.array(all_embeddings)
labels = np.array(all_labels)

print(f"Generated {len(embeddings)} embeddings")

In [None]:
# Calculate final metrics
distances = np.linalg.norm(embeddings[:, None] - embeddings[None, :], axis=2)
np.fill_diagonal(distances, np.inf)

for k in [1, 5, 10]:
    nearest_k = np.argsort(distances, axis=1)[:, :k]
    recall_k = np.mean([labels[i] in labels[nearest_k[i]] for i in range(len(labels))])
    print(f"Recall@{k}: {recall_k:.4f} ({recall_k*100:.2f}%)")

## 6. Save to Drive

In [None]:
# Copy model to Drive
drive_models = Path(DRIVE_PROJECT) / 'models/embedding'
drive_models.mkdir(parents=True, exist_ok=True)

shutil.copy(output_dir / 'final_model.pt', drive_models / 'final_model.pt')
print(f"Model saved to Drive: {drive_models / 'final_model.pt'}")

# Save embeddings for vector index
embeddings_dir = Path(DRIVE_PROJECT) / 'ml/data/embeddings'
embeddings_dir.mkdir(parents=True, exist_ok=True)
np.save(embeddings_dir / 'riftbound.npy', embeddings)
print(f"Embeddings saved: {embeddings_dir / 'riftbound.npy'}")

## 7. Build Vector Index

In [None]:
!pip install -q annoy

In [None]:
from annoy import AnnoyIndex

# Build index
embedding_dim = embeddings.shape[1]
index = AnnoyIndex(embedding_dim, 'angular')

for i, emb in enumerate(embeddings):
    index.add_item(i, emb)

index.build(10)  # 10 trees

# Save index
index_dir = Path(DRIVE_PROJECT) / 'models/indices'
index_dir.mkdir(parents=True, exist_ok=True)
index.save(str(index_dir / 'riftbound.ann'))
print(f"Index saved: {index_dir / 'riftbound.ann'}")

In [None]:
# Test index
test_idx = 0
neighbors, distances = index.get_nns_by_item(test_idx, 5, include_distances=True)

print(f"Query card label: {labels[test_idx]}")
print(f"Top 5 neighbors:")
for i, (n, d) in enumerate(zip(neighbors, distances)):
    print(f"  {i+1}. Label: {labels[n]}, Distance: {d:.4f}")