# ViT Canonical Training + Phase 3 Export (Colab)

**Model:** Vision Transformer (ViT-Tiny via timm)

**Objective:** Replicate ResNet50 canonical pipeline with ViT:
- Phase 1: Canonical splits (load from Drive)
- Phase 2: Canonical classes (27 classes, fp=cdfa70b13f7390e6)
- Phase 3: Export contract (.npz + _meta.json) with strict validation

**Expected outputs:**
- `STORE/artifacts/exports/vit_rerun_canonical_smoke/val.npz`
- `STORE/artifacts/exports/vit_rerun_canonical_smoke/val_meta.json`

**Validation:**
- split_signature must match ResNet50: `cf53f8eb169b3531`
- classes_fp must equal canonical: `cdfa70b13f7390e6`
- idx order must align with ResNet50 for fusion compatibility

In [None]:
from pathlib import Path
import os

from google.colab import drive
drive.mount("/content/drive")

# --- EDIT THESE PATHS ONCE ---
DRIVE_CODE_SNAPSHOT = Path("/content/drive/MyDrive/DS_rakuten_colab")
DRIVE_STORE = Path("/content/drive/MyDrive/DS_rakuten_store")
DRIVE_SPLITS_SRC = DRIVE_STORE / "splits"   # expects train_idx.txt / val_idx.txt / test_idx.txt
# ----------------------------

assert DRIVE_CODE_SNAPSHOT.exists(), f"Missing code snapshot: {DRIVE_CODE_SNAPSHOT}"
DRIVE_STORE.mkdir(parents=True, exist_ok=True)

os.environ["DS_RAKUTEN_STORE"] = str(DRIVE_STORE)

print("✓ DRIVE_CODE_SNAPSHOT:", DRIVE_CODE_SNAPSHOT)
print("✓ DRIVE_STORE:", DRIVE_STORE)
print("✓ DRIVE_SPLITS_SRC:", DRIVE_SPLITS_SRC)

In [None]:
import shutil
import sys
from pathlib import Path

RUNTIME_ROOT = Path("/content/DS_rakuten")

# Clean and copy for deterministic imports
if RUNTIME_ROOT.exists():
    shutil.rmtree(RUNTIME_ROOT)

shutil.copytree(DRIVE_CODE_SNAPSHOT, RUNTIME_ROOT)

sys.path.insert(0, str(RUNTIME_ROOT))

print("✓ Runtime code ready:", RUNTIME_ROOT)
print("✓ sys.path[0]:", sys.path[0])

In [None]:
from pathlib import Path
import shutil

runtime_splits_dir = Path("/content/DS_rakuten/data/splits")
runtime_splits_dir.mkdir(parents=True, exist_ok=True)

# Copy txt files from Drive persistent store into /content runtime repo
src_files = ["train_idx.txt", "val_idx.txt", "test_idx.txt"]
for fn in src_files:
    src = DRIVE_SPLITS_SRC / fn
    dst = runtime_splits_dir / fn
    assert src.exists(), f"Missing split file in Drive: {src}"
    shutil.copy2(src, dst)

print("✓ Splits synced to:", runtime_splits_dir)
print("✓ Contents:", list(runtime_splits_dir.glob("*.txt"))[:10])

In [None]:
# Install timm for Vision Transformer models
!pip -q install timm wandb
import wandb
wandb.login()

# Uncomment if your session is missing other packages:
# !pip -q install gdown
# !pip -q install scikit-learn

In [None]:
from pathlib import Path

IMAGE_FILE_ID = "15ZkS0iTQ7j3mHpxil4mABlXwP-jAN_zi"

BASE_DIR = Path("/content/images")
TMP_DIR = Path("/content/tmp")
ZIP_PATH = TMP_DIR / "images.zip"

BASE_DIR.mkdir(parents=True, exist_ok=True)
TMP_DIR.mkdir(parents=True, exist_ok=True)

if not ZIP_PATH.exists():
    print("Downloading images zip...")
    !gdown --id $IMAGE_FILE_ID -O {str(ZIP_PATH)}
else:
    print("Zip already present:", ZIP_PATH)

print("Unzipping images...")
!unzip -q -o {str(ZIP_PATH)} -d {str(BASE_DIR)}

def count_jpgs(p: Path, limit: int = 2000) -> int:
    if not p.exists():
        return 0
    n = 0
    for _ in p.rglob("*.jpg"):
        n += 1
        if n >= limit:
            break
    return n

# Common candidates
candidates = [
    BASE_DIR / "images" / "image_train",
    BASE_DIR / "image_train",
    BASE_DIR / "images" / "images" / "image_train",
]

best = None
best_count = 0
for c in candidates:
    n = count_jpgs(c)
    if n > best_count:
        best, best_count = c, n

# Fallback: search any folder named image_train
if best_count == 0:
    for c in BASE_DIR.rglob("image_train"):
        if c.is_dir():
            n = count_jpgs(c)
            if n > best_count:
                best, best_count = c, n

assert best is not None and best_count > 0, (
    "Could not find an image_train directory with jpg files under /content/images. "
    "Check zip content and unzip path."
)

IMG_ROOT = best
sample_jpg = next(IMG_ROOT.rglob("*.jpg"))

print("✓ IMG_ROOT detected:", IMG_ROOT)
print("✓ sample jpg:", sample_jpg)

In [None]:
from src.data.image_dataset import RakutenImageDataset
from src.train.image_vit import ViTConfig, run_vit_canonical

print("✓ RakutenImageDataset:", RakutenImageDataset)
print("✓ ViTConfig:", ViTConfig)
print("✓ run_vit_canonical:", run_vit_canonical)

In [None]:
from src.data.split_manager import load_splits, split_signature

splits = load_splits(verbose=True)
sig = split_signature(splits)

print("✓ signature:", sig)
print({k: len(v) for k, v in splits.items()})

In [None]:
# ============================================================
# EXPORT TEST LOGITS (without retraining)
# Load pretrained ViT checkpoint and export test set logits
# ============================================================

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from pathlib import Path
from tqdm import tqdm

import timm

from src.data.data_colab import load_data_colab
from src.data.split_manager import load_splits, split_signature
from src.data.label_mapping import (
    CANONICAL_CLASSES,
    CANONICAL_CLASSES_FP,
    encode_labels,
)
from src.export.model_exporter import export_predictions, load_predictions
from src.data.image_dataset import RakutenImageDataset

# ---- Configuration ----
STORE = Path(os.environ["DS_RAKUTEN_STORE"])
CKPT_PATH = STORE / "checkpoints" / "image_vit" / "best_model.pth"
OUT_DIR = STORE / "artifacts" / "exports"
RAW_DIR = str(STORE / "data_raw")

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_CLASSES = 27
DROPOUT_RATE = 0.1
VIT_MODEL_NAME = "vit_tiny_patch16_224"  # embed_dim=192, confirmed from checkpoint
MODEL_NAME = "vit_canonical"  # Same directory as val.npz

print("=" * 60)
print("EXPORT TEST LOGITS - ViT")
print("=" * 60)
print(f"Checkpoint: {CKPT_PATH}")
print(f"Output dir: {OUT_DIR / MODEL_NAME}")

# ---- Verify checkpoint exists ----
assert CKPT_PATH.exists(), f"Checkpoint not found: {CKPT_PATH}"
print(f"✓ Checkpoint found: {CKPT_PATH}")

# ---- Build model architecture ----
def build_vit(num_classes: int, dropout_rate: float, model_name: str) -> nn.Module:
    model = timm.create_model(
        model_name,
        pretrained=False,  # We'll load our own weights
        num_classes=int(num_classes),
        drop_rate=float(dropout_rate),
    )
    return model

# ---- Load checkpoint ----
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

model = build_vit(NUM_CLASSES, DROPOUT_RATE, VIT_MODEL_NAME).to(device)
ckpt = torch.load(CKPT_PATH, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

print(f"✓ Model loaded from epoch {ckpt.get('epoch', '?')}")
print(f"✓ Best val F1: {ckpt.get('best_val_f1', '?'):.4f}")
print(f"✓ Split signature: {ckpt.get('split_signature', '?')}")

# ---- Load data ----
pack = load_data_colab(
    raw_dir=RAW_DIR,
    img_root=IMG_ROOT,
    splitted=False,
    verbose=True,
)
X, y = pack["X"], pack["y"]

# ---- Splits and labels ----
splits = load_splits(verbose=True)
sig = split_signature(splits)
y_encoded = encode_labels(y, CANONICAL_CLASSES).astype(int)

print(f"✓ Split signature: {sig}")
print(f"✓ Test set size: {len(splits['test_idx'])}")

# ---- Prepare test dataset ----
val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

full_df = X.copy()
full_df["encoded_label"] = y_encoded

# IndexedDataset to track real indices
class IndexedDataset(Dataset):
    def __init__(self, base_dataset: Dataset, indices: np.ndarray):
        self.base = base_dataset
        self.indices = np.asarray(indices).astype(int)

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, i: int):
        real_idx = int(self.indices[i])
        img, label = self.base[real_idx]
        return img, label, real_idx

full_dataset = RakutenImageDataset(
    dataframe=full_df.reset_index(drop=True),
    image_dir=str(IMG_ROOT),
    transform=val_transform,
    label_col="encoded_label",
)

test_idx = splits["test_idx"]
test_dataset = IndexedDataset(full_dataset, test_idx)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Colab stability
    pin_memory=device.startswith("cuda"),
)

print(f"✓ Test DataLoader ready: {len(test_loader)} batches")

# ---- Inference: Extract LOGITS (not softmax probs) ----
@torch.no_grad()
def predict_logits(model, loader, device):
    model.eval()
    logits_list = []
    idx_list = []
    
    for images, _, real_idx in tqdm(loader, desc="Test Inference", ncols=100):
        images = images.to(device, non_blocking=True)
        logits = model(images)  # Raw logits, NO softmax
        logits_list.append(logits.detach().cpu().numpy())
        idx_list.append(real_idx.detach().cpu().numpy())
    
    logits = np.concatenate(logits_list, axis=0)
    idx = np.concatenate(idx_list, axis=0)
    return logits, idx

print("Running inference on test set...")
test_logits, seen_idx = predict_logits(model, test_loader, device)

# ---- Verify alignment ----
if not np.array_equal(seen_idx, test_idx):
    raise AssertionError("Index order mismatch during test inference!")

print(f"✓ Inference complete: logits shape = {test_logits.shape}")

# ---- Get y_true for test set ----
y_true_test = y_encoded[test_idx].astype(int)

# ---- Export test logits ----
export_result = export_predictions(
    out_dir=OUT_DIR,
    model_name=MODEL_NAME,
    split_name="test",
    idx=seen_idx,
    split_signature=sig,
    logits=test_logits,  # Export LOGITS, not probs
    classes=CANONICAL_CLASSES,
    y_true=y_true_test,
    extra_meta={
        "source": "image_04_vit_v2_logits.ipynb",
        "model_architecture": f"timm.{VIT_MODEL_NAME}",
        "img_dir": str(IMG_ROOT),
        "img_size": IMG_SIZE,
        "batch_size": BATCH_SIZE,
        "dropout_rate": DROPOUT_RATE,
        "output_type": "logits",
        "checkpoint_path": str(CKPT_PATH),
        "classes_fp": CANONICAL_CLASSES_FP,
        "split_signature": sig,
    },
)

print()
print("=" * 60)
print("EXPORT COMPLETE")
print("=" * 60)
print(f"NPZ path: {export_result['npz_path']}")
print(f"Meta JSON: {export_result['meta_json_path']}")
print(f"Samples: {export_result['num_samples']}")
print(f"Split signature: {export_result['split_signature']}")

# ---- Verify export ----
loaded = load_predictions(
    npz_path=export_result["npz_path"],
    verify_split_signature=sig,
    verify_classes_fp=CANONICAL_CLASSES_FP,
    require_y_true=True,
)

print()
print("✓ Export verification passed!")
print(f"  - model: {loaded['metadata']['model_name']}")
print(f"  - split: {loaded['metadata']['split_name']}")
print(f"  - output_type: {loaded['metadata'].get('output_type', 'probs')}")
print(f"  - logits shape: {loaded['logits'].shape}")
print(f"  - has_y_true: {loaded['metadata']['has_y_true']}")

In [None]:
# ============================================================
# TRAINING CODE (SKIP - already trained)
# ============================================================
# The model has already been trained and checkpoint saved.
# Run cell-8 above to export test logits without retraining.
# 
# Original training code is commented out below for reference:
# ------------------------------------------------------------
# cfg = ViTConfig(
#     raw_dir=str(STORE / "data_raw"),
#     img_dir=str(IMG_ROOT),  
#     out_dir=str(STORE / "artifacts" / "exports"),
#     ckpt_dir=str(STORE / "checkpoints" / "image_vit"),
#     img_size=224,
#     batch_size=32, 
#     num_workers=8,
#     num_epochs=5, 
#     lr=1e-4,
#     use_amp=True,
#     label_smoothing=0.1,
#     dropout_rate=0.1,
#     vit_model_name="vit_base_patch16_224",
#     vit_pretrained=True,
#     force_colab_loader=True,
#     model_name="vit_base_canonical_smoke",
#     export_split="val",
#     use_wandb=True,
#     wandb_project="rakuten-vit-colab-smoke"
# )
# result = run_vit_canonical(cfg)
print("Training code skipped - model already trained.")