# TCG Card Embedding Training v2

Train FastViT-T12 embedding model for card recognition.

**What this notebook does:**
1. Prepares your 673 card images for training
2. Trains an embedding model with ArcFace loss
3. Generates embeddings for all cards
4. Builds an Annoy vector search index
5. Saves everything to Google Drive

**Output files (saved to Drive):**
- `models/embedding/final_model.pt` - PyTorch model
- `models/indices/riftbound.ann` - Vector search index
- `ml/data/embeddings/riftbound_embeddings.npy` - All embeddings
- `ml/data/embeddings/riftbound_product_ids.json` - Product ID mapping

**Estimated Time:** 6-12 hours (depends on early stopping)

---
## 1. Setup Environment

In [None]:
# Check GPU - MUST be T4 or better
!nvidia-smi

In [None]:
# Install ALL dependencies upfront
!pip install -q torch torchvision --upgrade
!pip install -q timm==0.9.12
!pip install -q lightning==2.1.0
!pip install -q pytorch-metric-learning==2.3.0
!pip install -q albumentations==1.3.1
!pip install -q annoy==1.17.3
!pip install -q pyyaml tqdm pillow

print("\n‚úÖ All dependencies installed!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths
DRIVE_PROJECT = '/content/drive/MyDrive/tcg-scanner'
WORK_DIR = '/content/tcg-scanner'

import os
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print(f"Working directory: {os.getcwd()}")

In [None]:
# Verify data exists before proceeding
from pathlib import Path

print("=" * 50)
print("VERIFYING DATA FILES")
print("=" * 50)

drive_path = Path(DRIVE_PROJECT)
images_path = drive_path / 'ml/data/images/riftbound'
manifest_path = drive_path / 'ml/data/processed/riftbound/training_manifest.json'

# Check images
if images_path.exists():
    groups = [g for g in images_path.iterdir() if g.is_dir()]
    total_images = sum(len(list(g.glob('*.jpg'))) for g in groups)
    print(f"‚úÖ Images: {total_images} cards in {len(groups)} groups")
else:
    print(f"‚ùå Images not found at: {images_path}")
    raise FileNotFoundError("Upload your images to Google Drive first!")

# Check manifest
if manifest_path.exists():
    print(f"‚úÖ Manifest: {manifest_path.name}")
else:
    print(f"‚ùå Manifest not found at: {manifest_path}")
    raise FileNotFoundError("Upload training_manifest.json to Google Drive first!")

print("\n‚úÖ All data files verified! Ready to proceed.")

---
## 2. Prepare Dataset

In [None]:
import shutil
import json
import random
from pathlib import Path
from tqdm import tqdm

# Load manifest
with open(manifest_path) as f:
    manifest = json.load(f)
print(f"Loaded manifest with {len(manifest)} cards")

In [None]:
def prepare_dataset(manifest, output_dir, train_ratio=0.85):
    """
    Prepare dataset in class-folder structure.
    Each card gets its own folder with its image.
    """
    output_dir = Path(output_dir)
    train_dir = output_dir / 'train'
    val_dir = output_dir / 'val'
    
    # Clean up old data if exists
    if output_dir.exists():
        shutil.rmtree(output_dir)
    
    # Collect all card images
    cards = []
    missing = 0
    
    for entry in manifest:
        # FIX: Convert Windows paths to Linux paths
        img_rel_path = entry['image_path'].replace('\\', '/')
        img_path = Path(DRIVE_PROJECT) / 'ml/data' / img_rel_path
        
        if img_path.exists():
            cards.append({
                'path': img_path,
                'product_id': str(entry['product_id']),
                'name': entry['clean_name'],
            })
        else:
            missing += 1
    
    print(f"Found {len(cards)} card images ({missing} missing)")
    
    if len(cards) == 0:
        raise ValueError("No cards found!")
    
    # Shuffle and split
    random.seed(42)
    random.shuffle(cards)
    
    split_idx = int(len(cards) * train_ratio)
    train_cards = cards[:split_idx]
    val_cards = cards[split_idx:]
    
    print(f"Split: {len(train_cards)} train, {len(val_cards)} val")
    
    # Copy files to class folders
    for split_name, split_cards, split_dir in [
        ('train', train_cards, train_dir),
        ('val', val_cards, val_dir)
    ]:
        print(f"\nPreparing {split_name} set...")
        for card in tqdm(split_cards, desc=split_name):
            # Create class directory using product_id
            class_dir = split_dir / card['product_id']
            class_dir.mkdir(parents=True, exist_ok=True)
            
            # Sanitize filename
            safe_name = "".join(c for c in card['name'] if c.isalnum() or c in ' _-').strip()[:50]
            dst_path = class_dir / f"{safe_name}.jpg"
            
            if not dst_path.exists():
                shutil.copy(card['path'], dst_path)
    
    # Verify
    train_classes = len(list(train_dir.iterdir()))
    val_classes = len(list(val_dir.iterdir()))
    
    print(f"\n‚úÖ Dataset prepared!")
    print(f"   Train: {train_classes} classes")
    print(f"   Val: {val_classes} classes")
    
    return train_dir, val_dir

# Prepare the dataset
train_dir, val_dir = prepare_dataset(
    manifest=manifest,
    output_dir=Path(WORK_DIR) / 'data/embedding',
    train_ratio=0.85
)

---
## 3. Define Model & Training

In [None]:
import albumentations as A
import lightning as L
import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from PIL import Image
from pytorch_metric_learning import losses, miners
from torch.utils.data import DataLoader, Dataset

# Training configuration
CONFIG = {
    'model': {
        'backbone': 'fastvit_t12',
        'embedding_dim': 384,
        'dropout': 0.2,
    },
    'training': {
        'epochs': 100,
        'batch_size': 64,
        'learning_rate': 0.0003,
        'weight_decay': 0.05,
        'views_per_card': 4,  # Augmented views per card
    },
    'metric_learning': {
        'margin': 0.5,
        'scale': 64,
    },
}

print("Configuration loaded:")
print(f"  Backbone: {CONFIG['model']['backbone']}")
print(f"  Embedding dim: {CONFIG['model']['embedding_dim']}")
print(f"  Batch size: {CONFIG['training']['batch_size']}")
print(f"  Views per card: {CONFIG['training']['views_per_card']}")

In [None]:
def get_train_transforms():
    """Heavy augmentation to simulate real-world conditions."""
    return A.Compose([
        # Geometric transforms
        A.Perspective(scale=(0.05, 0.12), p=0.5),
        A.Affine(rotate=(-15, 15), shear=(-8, 8), scale=(0.85, 1.15), p=0.7),
        
        # Lighting variations
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
            A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=25, val_shift_limit=25, p=1.0),
            A.RandomGamma(gamma_limit=(70, 130), p=1.0),
        ], p=0.6),
        
        # Blur (camera out of focus)
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5), p=1.0),
            A.MotionBlur(blur_limit=5, p=1.0),
        ], p=0.3),
        
        # Noise (low light)
        A.GaussNoise(var_limit=(5, 50), p=0.3),
        
        # JPEG compression artifacts
        A.ImageCompression(quality_lower=70, quality_upper=100, p=0.3),
        
        # Resize and normalize
        A.Resize(256, 256),
        A.RandomCrop(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms():
    """Simple transforms for validation/inference."""
    return A.Compose([
        A.Resize(256, 256),
        A.CenterCrop(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
class CardDataset(Dataset):
    """Dataset that generates multiple augmented views per card."""
    
    def __init__(self, root_dir, transform, views_per_card=1):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.views_per_card = views_per_card
        
        # Collect all images and their class indices
        self.samples = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        
        for idx, class_dir in enumerate(sorted(self.root_dir.iterdir())):
            if class_dir.is_dir():
                class_name = class_dir.name
                self.class_to_idx[class_name] = idx
                self.idx_to_class[idx] = class_name
                
                for img_path in class_dir.glob('*.jpg'):
                    self.samples.append((img_path, idx))
        
        print(f"Loaded {len(self.samples)} images, {len(self.class_to_idx)} classes")
    
    def __len__(self):
        return len(self.samples) * self.views_per_card
    
    def __getitem__(self, idx):
        # Map idx to actual sample (allows multiple views)
        sample_idx = idx // self.views_per_card
        img_path, label = self.samples[sample_idx]
        
        # Load and transform image
        image = np.array(Image.open(img_path).convert('RGB'))
        augmented = self.transform(image=image)
        
        return augmented['image'], label
    
    @property
    def num_classes(self):
        return len(self.class_to_idx)

In [None]:
class EmbeddingModel(nn.Module):
    """FastViT backbone with embedding head."""
    
    def __init__(self, backbone='fastvit_t12', embedding_dim=384, dropout=0.2):
        super().__init__()
        
        # Load pretrained backbone
        self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)
        
        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            feature_dim = features.shape[-1]
        
        # Embedding head
        self.head = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, embedding_dim),
        )
        
        print(f"Model created: {backbone} -> {feature_dim} -> {embedding_dim}")
    
    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.head(features)
        # L2 normalize embeddings
        return F.normalize(embeddings, p=2, dim=1)

In [None]:
class TrainingModule(L.LightningModule):
    """PyTorch Lightning training module."""
    
    def __init__(self, config, num_classes):
        super().__init__()
        self.save_hyperparameters()
        self.config = config
        
        # Model
        self.model = EmbeddingModel(
            backbone=config['model']['backbone'],
            embedding_dim=config['model']['embedding_dim'],
            dropout=config['model']['dropout'],
        )
        
        # ArcFace loss for metric learning
        self.loss_fn = losses.ArcFaceLoss(
            num_classes=num_classes,
            embedding_size=config['model']['embedding_dim'],
            margin=config['metric_learning']['margin'],
            scale=config['metric_learning']['scale'],
        )
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self(images)
        loss = self.loss_fn(embeddings, labels)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self(images)
        loss = self.loss_fn(embeddings, labels)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.config['training']['learning_rate'],
            weight_decay=self.config['training']['weight_decay'],
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.config['training']['epochs'],
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}
        }

---
## 4. Create Datasets & Train

In [None]:
# Create datasets
print("Creating datasets...\n")

train_dataset = CardDataset(
    train_dir,
    transform=get_train_transforms(),
    views_per_card=CONFIG['training']['views_per_card'],
)

val_dataset = CardDataset(
    val_dir,
    transform=get_val_transforms(),
    views_per_card=1,
)

print(f"\nTraining samples: {len(train_dataset)} ({train_dataset.num_classes} cards √ó {CONFIG['training']['views_per_card']} views)")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG['training']['batch_size'],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG['training']['batch_size'],
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Create model and trainer
model = TrainingModule(CONFIG, num_classes=train_dataset.num_classes)

# Callbacks
callbacks = [
    L.pytorch.callbacks.ModelCheckpoint(
        dirpath=Path(WORK_DIR) / 'checkpoints',
        filename='best-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        mode='min',
        save_top_k=3,
        save_last=True,
    ),
    L.pytorch.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=15,
        mode='min',
        verbose=True,
    ),
    L.pytorch.callbacks.LearningRateMonitor(logging_interval='epoch'),
]

# Trainer
trainer = L.Trainer(
    max_epochs=CONFIG['training']['epochs'],
    accelerator='gpu',
    devices=1,
    precision='16-mixed',
    callbacks=callbacks,
    default_root_dir=Path(WORK_DIR) / 'logs',
    log_every_n_steps=10,
)

print("\n‚úÖ Ready to train!")

In [None]:
# START TRAINING
print("=" * 50)
print("STARTING TRAINING")
print("=" * 50)
print(f"Training {train_dataset.num_classes} card classes")
print(f"Max epochs: {CONFIG['training']['epochs']}")
print(f"Early stopping patience: 15 epochs")
print("")
print("Monitor 'val_loss' - lower is better")
print("Training will stop automatically when val_loss stops improving")
print("=" * 50)

trainer.fit(model, train_loader, val_loader)

In [None]:
# Save the trained model
print("\nSaving model...")

output_dir = Path(WORK_DIR) / 'output'
output_dir.mkdir(exist_ok=True)

# Save model weights
model_path = output_dir / 'embedding_model.pt'
torch.save(model.model.state_dict(), model_path)
print(f"‚úÖ Model saved: {model_path}")

---
## 5. Generate Embeddings for All Cards

In [None]:
# Generate embeddings for ALL cards (train + val)
print("Generating embeddings for all cards...\n")

model.eval()
device = torch.device('cuda')
model = model.to(device)

# Simple transform for embedding generation
embed_transform = get_val_transforms()

all_embeddings = []
all_product_ids = []

data_dir = Path(WORK_DIR) / 'data/embedding'

# Process all cards (train + val)
for split in ['train', 'val']:
    split_dir = data_dir / split
    print(f"Processing {split}...")
    
    for class_dir in tqdm(sorted(split_dir.iterdir())):
        if not class_dir.is_dir():
            continue
            
        product_id = class_dir.name
        
        for img_path in class_dir.glob('*.jpg'):
            # Load and transform
            image = np.array(Image.open(img_path).convert('RGB'))
            transformed = embed_transform(image=image)
            img_tensor = transformed['image'].unsqueeze(0).to(device)
            
            # Generate embedding
            with torch.no_grad():
                embedding = model(img_tensor).cpu().numpy()[0]
            
            all_embeddings.append(embedding)
            all_product_ids.append(product_id)

# Convert to numpy array
embeddings = np.array(all_embeddings)
print(f"\n‚úÖ Generated {len(embeddings)} embeddings")
print(f"   Shape: {embeddings.shape}")

---
## 6. Build Vector Search Index

In [None]:
from annoy import AnnoyIndex

print("Building Annoy index...")

# Create index
embedding_dim = embeddings.shape[1]
index = AnnoyIndex(embedding_dim, 'angular')  # angular = cosine similarity

# Add all embeddings
for i, emb in enumerate(embeddings):
    index.add_item(i, emb)

# Build with 10 trees (good balance of speed/accuracy)
index.build(10)

# Save locally
index_path = output_dir / 'riftbound.ann'
index.save(str(index_path))
print(f"‚úÖ Index saved: {index_path}")
print(f"   Vectors: {len(embeddings)}")
print(f"   Dimensions: {embedding_dim}")

In [None]:
# Test the index with a few queries
print("\n" + "=" * 50)
print("TESTING VECTOR INDEX")
print("=" * 50)

# Test 3 random cards
test_indices = [0, len(embeddings) // 2, len(embeddings) - 1]

for test_idx in test_indices:
    neighbors, distances = index.get_nns_by_item(test_idx, 5, include_distances=True)
    
    query_id = all_product_ids[test_idx]
    print(f"\nQuery: Product ID {query_id}")
    print("Top 5 matches:")
    
    for rank, (n, d) in enumerate(zip(neighbors, distances), 1):
        match_id = all_product_ids[n]
        is_self = "(self)" if n == test_idx else ""
        print(f"  {rank}. ID: {match_id}, Distance: {d:.4f} {is_self}")

---
## 7. Save Everything to Google Drive

In [None]:
import json

print("Saving all outputs to Google Drive...\n")

# Create directories
drive_models = Path(DRIVE_PROJECT) / 'models/embedding'
drive_indices = Path(DRIVE_PROJECT) / 'models/indices'
drive_embeddings = Path(DRIVE_PROJECT) / 'ml/data/embeddings'

drive_models.mkdir(parents=True, exist_ok=True)
drive_indices.mkdir(parents=True, exist_ok=True)
drive_embeddings.mkdir(parents=True, exist_ok=True)

# 1. Copy model
shutil.copy(output_dir / 'embedding_model.pt', drive_models / 'embedding_model.pt')
print(f"‚úÖ Model: {drive_models / 'embedding_model.pt'}")

# 2. Copy index
shutil.copy(output_dir / 'riftbound.ann', drive_indices / 'riftbound.ann')
print(f"‚úÖ Index: {drive_indices / 'riftbound.ann'}")

# 3. Save embeddings
np.save(drive_embeddings / 'riftbound_embeddings.npy', embeddings)
print(f"‚úÖ Embeddings: {drive_embeddings / 'riftbound_embeddings.npy'}")

# 4. Save product ID mapping
with open(drive_embeddings / 'riftbound_product_ids.json', 'w') as f:
    json.dump(all_product_ids, f)
print(f"‚úÖ Product IDs: {drive_embeddings / 'riftbound_product_ids.json'}")

# 5. Save config for reference
with open(drive_models / 'config.json', 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"‚úÖ Config: {drive_models / 'config.json'}")

In [None]:
# Final summary
print("\n" + "=" * 60)
print("                    TRAINING COMPLETE!")
print("=" * 60)
print("\nFiles saved to Google Drive:")
print(f"\nüìÅ {DRIVE_PROJECT}/")
print("   ‚îú‚îÄ‚îÄ models/")
print("   ‚îÇ   ‚îú‚îÄ‚îÄ embedding/")
print("   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ embedding_model.pt      <- PyTorch model")
print("   ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ config.json             <- Training config")
print("   ‚îÇ   ‚îî‚îÄ‚îÄ indices/")
print("   ‚îÇ       ‚îî‚îÄ‚îÄ riftbound.ann           <- Vector search index")
print("   ‚îî‚îÄ‚îÄ ml/data/embeddings/")
print("       ‚îú‚îÄ‚îÄ riftbound_embeddings.npy    <- All card embeddings")
print("       ‚îî‚îÄ‚îÄ riftbound_product_ids.json  <- Product ID mapping")
print("\n" + "=" * 60)
print("\nNEXT STEPS:")
print("1. Download these files from Google Drive")
print("2. Convert to TFLite locally (I'll help with this)")
print("3. Add to Flutter app assets")
print("4. Test the app!")
print("\n" + "=" * 60)