In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os

BASE_DIR = "/content/drive/MyDrive/AML Challenge"
os.chdir(BASE_DIR)
print("Current working directory:", os.getcwd())


Current working directory: /content/drive/MyDrive/AML Challenge


# PreProcessing

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import numpy as np, pandas as pd
# ------------------------------------------------------
# 1. Device and data loading
# ------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

train_data = np.load(BASE_DIR+"/train.npz")
test_data  = np.load(BASE_DIR+"/test.clean.npz")

tx_train = train_data["captions/embeddings"]   # (125000, 1024)
im_train = train_data["images/embeddings"]     # (25000, 1536)
tx_test  = test_data["captions/embeddings"]    # (1500, 1024)

print("Train shapes:", tx_train.shape, im_train.shape)
print("Test shape:", tx_test.shape)

# ------------------------------------------------------
# 2. Match each caption to its corresponding image
# ------------------------------------------------------
repeat_factor = len(tx_train) // len(im_train)   # 5 captions per image
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)

Using device: cuda
Train shapes: (125000, 1024) (25000, 1536)
Test shape: (1500, 1024)


# Variational Autoencoder (VAE)
## Experiment 58: The "Giga-VAE."
# got score as Score: 0.86241

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time

# --- EXPERIMENT 58: "The Giga-VAE" ---
#
# This is a new probabilistic architecture inspired by the DALL-E 2 /
# CLIP Latents "prior" models.
# It learns a probabilistic mapping (A -> Distribution -> B) to
# handle the 1-to-5 text-to-image noise.
#
# 1. Architecture: NEW "VAETranslator"
# 2. Hidden Dim:   *** 8192 *** (Champion Horsepower)
# 3. Regularization: *** "GIGA-BOSS" ***
#    - Dropout:       0.4
#    - Weight Decay:  4e-4
#    - Mixup:         0.6
# 4. Time:         60 Epochs (Champion Sprint)
# 5. Loss:         CHAMPION Asymmetric Loss + *** KL Divergence Loss ***
# -----------------------------------------------------------------

# -----------------------------------------------------------------
# 1. Device and data loading (Same as before)
# -----------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

train_data = np.load("train.npz")
test_data  = np.load("test.clean.npz")

tx_train = train_data["captions/embeddings"]
im_train = train_data["images/embeddings"]
tx_test  = test_data["captions/embeddings"]

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------------------------------------------
# 2. Data preprocessing (Same as before)
# -----------------------------------------------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------------------------------------------
# 3. *** NEW: The VAE-Translator Model ***
# -----------------------------------------------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # --- ENCODER (Text -> Distribution) ---
        # 1024 -> 8192 -> 8192 -> (1536 * 2)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2) # Output mu and logvar
        )

        # --- DECODER (Distribution Sample -> Image) ---
        # 1536 -> 8192 -> 8192 -> 1536
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # 1. Encode text to mu and logvar
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        # 2. Get a sample 'z' from the distribution
        z = self.reparameterize(mu, logvar)

        # 3. Decode the sample to the image embedding space
        y_pred = self.decoder(z)

        # During training, return all parts for the loss function
        if self.training:
            return F.normalize(y_pred, p=2, dim=1), mu, logvar

        # During eval, just return the deterministic prediction
        # We use mu directly for a stable prediction
        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------------------------------------------
# 4. Loss Functions (CHAMPION Loss + KL Divergence)
# -----------------------------------------------------------------
TAU = 0.01
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.6 # Giga-Boss Mixup
KL_WEIGHT = 1e-5 # *** NEW: VAE Regularization Strength ***

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight={KL_WEIGHT}) ***")

def calculate_kl_loss(mu, logvar):
    # KL divergence, VAE formula
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Return the *mean* loss per item in the batch
    return kl_loss / mu.size(0)

# -----------------------------------------------------------------
# 5. Validation split (Same as before)
# -----------------------------------------------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu") # <-- New random seed for this run
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------------------------------------------
# 6. Memory-Safe Recall@K Utility (Same as before)
# -----------------------------------------------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1, 5, 10, 50), k_csls=10):
    model.eval() # <-- This is CRITICAL. It tells the VAE to return deterministic preds
    device = next(model.parameters()).device
    all_correct_in_top_k = {k: [] for k in ks}

    Q_embed_list = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(device)
        Q_embed_list.append(model(chunk)) # model(chunk) will now return y_pred_eval
    Q_embed = torch.cat(Q_embed_list, dim=0)

    D_embed = im_database.to(device)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))

    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(device)

    for i in tqdm(range(0, len(Q_embed), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk_size] @ D_embed.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls_sim_chunk = 2 * sim_chunk - mean_knn_q - mean_knn_d

        top_indices_chunk = torch.argsort(csls_sim_chunk, dim=1, descending=True)

        for k in ks:
            top_k_preds = top_indices_chunk[:, :k]
            correct_in_top_k = (top_k_preds == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct_in_top_k[k].append(correct_in_top_k)

    recalls = {}
    for k in ks:
        recalls[f"R@{k}"] = torch.cat(all_correct_in_top_k[k]).float().mean().item()
    return recalls

# -----------------------------------------------------------------
# 7. *** NEW: Training Loop ***
# -----------------------------------------------------------------
EPOCHS = 60 # Champion "Sprint"
START_LR = 3e-4
WEIGHT_DECAY = 4e-4 # Giga-Boss Weight Decay
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique

# --- Initialize the new Giga-VAE model ---
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=START_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting New Architecture: Giga-VAE (RUN {RUN_ID}) ---")
print(f"Training (Epochs: {EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, Mixup={MIXUP_ALPHA})...\n")

best_val_r10 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train() # <-- This is CRITICAL. It tells the VAE to return all 3 outputs
    total_recon_loss = 0.0
    total_kl_loss = 0.0

    for x_batch, y_batch, img_indices in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # --- GIGA-BOSS MIXUP IMPLEMENTATION ---
        idx_shuffle = torch.randperm(x_batch.size(0))
        lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
        x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
        y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            # --- VAE FORWARD PASS ---
            y_pred, mu, logvar = model(x_mix) # Model is in .train() mode

            # --- 1. Reconstruction Loss (Our Champion Loss) ---
            sims_with_tau = y_pred @ y_mix.T / TAU
            labels = torch.arange(y_pred.size(0), device=device)
            loss_con = F.cross_entropy(sims_with_tau, labels)

            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # --- 2. KL Divergence Loss ---
            kl_loss = calculate_kl_loss(mu, logvar)

            # --- 3. Total Loss ---
            loss = recon_loss + (KL_WEIGHT * kl_loss)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_kl_loss = total_kl_loss / len(train_loader)

    # ----- Validation (quick subset) -----
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: ReconLoss={avg_recon_loss:.4f} KL_Loss={avg_kl_loss:.2f} (LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} {save_marker}")

print(f"\n🎯 Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------------------------------------------
# 8. Inference + Submission
# -----------------------------------------------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH))
model.eval() # <-- This is CRITICAL. Ensures model(chunk) gives deterministic output
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk)) # model is in .eval() mode
    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

submission_filename = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\n✅ {submission_filename} saved successfully.")

# --- Final Validation (on FULL val set, now OOM-safe) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")

Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight=1e-05) ***
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting New Architecture: Giga-VAE (RUN giga_vae_model_1762268434) ---
Training (Epochs: 60, LR=0.0003, WD=0.0004, Mixup=0.6)...



Epoch 001/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: ReconLoss=2.6223 KL_Loss=1208.27 (LR=3.0E-04) | R@10(Val)=0.3122 <- Best R@10! Saved.


Epoch 002/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 002: ReconLoss=1.4283 KL_Loss=1642.54 (LR=3.0E-04) | R@10(Val)=0.4142 <- Best R@10! Saved.


Epoch 003/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 003: ReconLoss=1.1266 KL_Loss=1881.42 (LR=3.0E-04) | R@10(Val)=0.4556 <- Best R@10! Saved.


Epoch 004/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 004: ReconLoss=0.9382 KL_Loss=2032.08 (LR=3.0E-04) | R@10(Val)=0.4988 <- Best R@10! Saved.


Epoch 005/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 005: ReconLoss=0.8200 KL_Loss=2137.36 (LR=3.0E-04) | R@10(Val)=0.5198 <- Best R@10! Saved.


Epoch 006/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 006: ReconLoss=0.7457 KL_Loss=2222.82 (LR=2.9E-04) | R@10(Val)=0.5390 <- Best R@10! Saved.


Epoch 007/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 007: ReconLoss=0.6578 KL_Loss=2310.74 (LR=2.9E-04) | R@10(Val)=0.5492 <- Best R@10! Saved.


Epoch 008/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 008: ReconLoss=0.6161 KL_Loss=2368.72 (LR=2.9E-04) | R@10(Val)=0.5670 <- Best R@10! Saved.


Epoch 009/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 009: ReconLoss=0.5702 KL_Loss=2438.00 (LR=2.9E-04) | R@10(Val)=0.5744 <- Best R@10! Saved.


Epoch 010/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 010: ReconLoss=0.4875 KL_Loss=2491.43 (LR=2.8E-04) | R@10(Val)=0.5946 <- Best R@10! Saved.


Epoch 011/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 011: ReconLoss=0.4511 KL_Loss=2517.03 (LR=2.8E-04) | R@10(Val)=0.6038 <- Best R@10! Saved.


Epoch 012/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 012: ReconLoss=0.4233 KL_Loss=2546.21 (LR=2.8E-04) | R@10(Val)=0.6036 


Epoch 013/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 013: ReconLoss=0.4064 KL_Loss=2547.67 (LR=2.7E-04) | R@10(Val)=0.6128 <- Best R@10! Saved.


Epoch 014/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 014: ReconLoss=0.3870 KL_Loss=2572.00 (LR=2.7E-04) | R@10(Val)=0.6150 <- Best R@10! Saved.


Epoch 015/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 015: ReconLoss=0.3643 KL_Loss=2582.23 (LR=2.6E-04) | R@10(Val)=0.6146 


Epoch 016/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 016: ReconLoss=0.3550 KL_Loss=2576.32 (LR=2.6E-04) | R@10(Val)=0.6232 <- Best R@10! Saved.


Epoch 017/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 017: ReconLoss=0.3362 KL_Loss=2607.47 (LR=2.5E-04) | R@10(Val)=0.6232 


Epoch 018/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 018: ReconLoss=0.3287 KL_Loss=2612.39 (LR=2.4E-04) | R@10(Val)=0.6318 <- Best R@10! Saved.


Epoch 019/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 019: ReconLoss=0.3061 KL_Loss=2632.91 (LR=2.4E-04) | R@10(Val)=0.6320 <- Best R@10! Saved.


Epoch 020/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 020: ReconLoss=0.2896 KL_Loss=2633.89 (LR=2.3E-04) | R@10(Val)=0.6340 <- Best R@10! Saved.


Epoch 021/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 021: ReconLoss=0.2785 KL_Loss=2642.31 (LR=2.2E-04) | R@10(Val)=0.6378 <- Best R@10! Saved.


Epoch 022/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 022: ReconLoss=0.2733 KL_Loss=2636.09 (LR=2.2E-04) | R@10(Val)=0.6432 <- Best R@10! Saved.


Epoch 023/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 023: ReconLoss=0.2673 KL_Loss=2648.99 (LR=2.1E-04) | R@10(Val)=0.6422 


Epoch 024/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 024: ReconLoss=0.2637 KL_Loss=2644.63 (LR=2.0E-04) | R@10(Val)=0.6460 <- Best R@10! Saved.


Epoch 025/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 025: ReconLoss=0.2582 KL_Loss=2649.24 (LR=2.0E-04) | R@10(Val)=0.6474 <- Best R@10! Saved.


Epoch 026/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 026: ReconLoss=0.2528 KL_Loss=2660.48 (LR=1.9E-04) | R@10(Val)=0.6456 


Epoch 027/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 027: ReconLoss=0.2492 KL_Loss=2657.22 (LR=1.8E-04) | R@10(Val)=0.6496 <- Best R@10! Saved.


Epoch 028/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 028: ReconLoss=0.2433 KL_Loss=2662.89 (LR=1.7E-04) | R@10(Val)=0.6508 <- Best R@10! Saved.


Epoch 029/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 029: ReconLoss=0.2383 KL_Loss=2669.38 (LR=1.7E-04) | R@10(Val)=0.6492 


Epoch 030/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 030: ReconLoss=0.2322 KL_Loss=2661.87 (LR=1.6E-04) | R@10(Val)=0.6496 


Epoch 031/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 031: ReconLoss=0.2303 KL_Loss=2660.01 (LR=1.5E-04) | R@10(Val)=0.6496 


Epoch 032/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 032: ReconLoss=0.2271 KL_Loss=2660.53 (LR=1.4E-04) | R@10(Val)=0.6534 <- Best R@10! Saved.


Epoch 033/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 033: ReconLoss=0.2205 KL_Loss=2662.74 (LR=1.3E-04) | R@10(Val)=0.6538 <- Best R@10! Saved.


Epoch 034/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 034: ReconLoss=0.2174 KL_Loss=2669.13 (LR=1.3E-04) | R@10(Val)=0.6544 <- Best R@10! Saved.


Epoch 035/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 035: ReconLoss=0.2162 KL_Loss=2662.95 (LR=1.2E-04) | R@10(Val)=0.6534 


Epoch 036/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 036: ReconLoss=0.2145 KL_Loss=2682.84 (LR=1.1E-04) | R@10(Val)=0.6582 <- Best R@10! Saved.


Epoch 037/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 037: ReconLoss=0.2138 KL_Loss=2674.68 (LR=1.0E-04) | R@10(Val)=0.6590 <- Best R@10! Saved.


Epoch 038/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 038: ReconLoss=0.2082 KL_Loss=2689.49 (LR=9.6E-05) | R@10(Val)=0.6584 


Epoch 039/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 039: ReconLoss=0.2062 KL_Loss=2679.25 (LR=8.9E-05) | R@10(Val)=0.6580 


Epoch 040/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 040: ReconLoss=0.2070 KL_Loss=2670.69 (LR=8.2E-05) | R@10(Val)=0.6590 


Epoch 041/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 041: ReconLoss=0.2037 KL_Loss=2675.57 (LR=7.5E-05) | R@10(Val)=0.6592 <- Best R@10! Saved.


Epoch 042/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 042: ReconLoss=0.2007 KL_Loss=2673.34 (LR=6.8E-05) | R@10(Val)=0.6572 


Epoch 043/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 043: ReconLoss=0.1980 KL_Loss=2681.29 (LR=6.2E-05) | R@10(Val)=0.6584 


Epoch 044/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 044: ReconLoss=0.1962 KL_Loss=2689.12 (LR=5.6E-05) | R@10(Val)=0.6602 <- Best R@10! Saved.


Epoch 045/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 045: ReconLoss=0.1950 KL_Loss=2673.60 (LR=5.0E-05) | R@10(Val)=0.6616 <- Best R@10! Saved.


Epoch 046/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 046: ReconLoss=0.1933 KL_Loss=2685.08 (LR=4.4E-05) | R@10(Val)=0.6600 


Epoch 047/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 047: ReconLoss=0.1922 KL_Loss=2673.53 (LR=3.9E-05) | R@10(Val)=0.6606 


Epoch 048/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 048: ReconLoss=0.1900 KL_Loss=2689.72 (LR=3.3E-05) | R@10(Val)=0.6606 


Epoch 049/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 049: ReconLoss=0.1918 KL_Loss=2662.09 (LR=2.9E-05) | R@10(Val)=0.6610 


Epoch 050/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 050: ReconLoss=0.1901 KL_Loss=2671.69 (LR=2.4E-05) | R@10(Val)=0.6618 <- Best R@10! Saved.


Epoch 051/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 051: ReconLoss=0.1899 KL_Loss=2685.12 (LR=2.0E-05) | R@10(Val)=0.6616 


Epoch 052/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 052: ReconLoss=0.1902 KL_Loss=2674.20 (LR=1.6E-05) | R@10(Val)=0.6616 


Epoch 053/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 053: ReconLoss=0.1892 KL_Loss=2688.49 (LR=1.3E-05) | R@10(Val)=0.6620 <- Best R@10! Saved.


Epoch 054/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 054: ReconLoss=0.1886 KL_Loss=2689.95 (LR=1.0E-05) | R@10(Val)=0.6612 


Epoch 055/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 055: ReconLoss=0.1882 KL_Loss=2674.70 (LR=7.3E-06) | R@10(Val)=0.6616 


Epoch 056/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 056: ReconLoss=0.1867 KL_Loss=2680.64 (LR=5.1E-06) | R@10(Val)=0.6614 


Epoch 057/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 057: ReconLoss=0.1873 KL_Loss=2682.01 (LR=3.3E-06) | R@10(Val)=0.6618 


Epoch 058/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 058: ReconLoss=0.1901 KL_Loss=2677.00 (LR=1.8E-06) | R@10(Val)=0.6620 


Epoch 059/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 059: ReconLoss=0.1865 KL_Loss=2689.06 (LR=8.2E-07) | R@10(Val)=0.6620 


Epoch 060/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 060: ReconLoss=0.1872 KL_Loss=2675.33 (LR=2.1E-07) | R@10(Val)=0.6620 

🎯 Training complete. Best model (R@10=0.6620) saved as giga_vae_model_giga_vae_model_1762268434_best.pth
Loading BEST model from giga_vae_model_giga_vae_model_1762268434_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.


Generating Sub:   0%|          | 0/2 [00:00<?, ?it/s]


✅ submission_giga_vae_run_giga_vae_model_1762268434.csv saved successfully.
Calculating final validation scores on full val set...


CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/13 [00:00<?, ?it/s]


--- Full Validation Results ---
Validation Recall@1:  0.3070  
Validation Recall@5:  0.5630  
Validation Recall@10: 0.6671  
Validation Recall@50: 0.8613  


#

# Fine tune on previous VAE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time

# --- EXPERIMENT 59: "The Giga-VAE" (Looser Brakes) ---
#
# 1. Architecture: CHAMPION "Giga-VAE"
# 2. All "Giga-Boss" settings: hidden_dim=8192, DO=0.4, WD=4e-4, Mixup=0.6
# 3. Time:         60 Epochs
# 4. Loss Change:  *** KL_WEIGHT = 1e-6 *** (Was 1e-5)
# 5. RUN_ID:       *** VAE_exp2 ***
# -----------------------------------------------------------------

# -----------------------------------------------------------------
# 1. Device and data loading
# -----------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if not os.path.exists("train.npz") or not os.path.exists("test.clean.npz"):
    print("ERROR: 'train.npz' or 'test.clean.npz' not found.")
    print("Please make sure these files are in the same directory as your script.")
    # You might want to raise an error here to stop the script
    # raise FileNotFoundError("Required data files not found.")
else:
    print("Found 'train.npz' and 'test.clean.npz'. Loading data...")

train_data = np.load("train.npz")
test_data  = np.load("test.clean.npz")

tx_train = train_data["captions/embeddings"]
im_train = train_data["images/embeddings"]
tx_test  = test_data["captions/embeddings"]

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------------------------------------------
# 2. Data preprocessing
# -----------------------------------------------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------------------------------------------
# 3. The Giga-VAE Model
# -----------------------------------------------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # --- ENCODER (Text -> Distribution) ---
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, latent_dim * 2) # Output mu and logvar
        )

        # --- DECODER (Distribution Sample -> Image) ---
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        if self.training:
            z = self.reparameterize(mu, logvar)
            y_pred = self.decoder(z)
            return F.normalize(y_pred, p=2, dim=1), mu, logvar

        # During eval, use mu directly for a stable prediction
        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------------------------------------------
# 4. Loss Functions
# -----------------------------------------------------------------
TAU = 0.01
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.6 # Giga-Boss Mixup
KL_WEIGHT = 1e-6 # *** NEW: 10x WEAKER BRAKE ***

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + VAE KL Loss (weight={KL_WEIGHT})")

def calculate_kl_loss(mu, logvar):
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl_loss / mu.size(0)

# -----------------------------------------------------------------
# 5. Validation split
# -----------------------------------------------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu") # <-- New random seed for this run
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------------------------------------------
# 6. Memory-Safe Recall@K Utility
# -----------------------------------------------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1, 5, 10, 50), k_csls=10):
    model.eval()
    device = next(model.parameters()).device
    all_correct_in_top_k = {k: [] for k in ks}

    Q_embed_list = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(device)
        Q_embed_list.append(model(chunk))
    Q_embed = torch.cat(Q_embed_list, dim=0)

    D_embed = im_database.to(device)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))

    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(device)

    for i in tqdm(range(0, len(Q_embed), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk_size] @ D_embed.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls_sim_chunk = 2 * sim_chunk - mean_knn_q - mean_knn_d

        top_indices_chunk = torch.argsort(csls_sim_chunk, dim=1, descending=True)

        for k in ks:
            top_k_preds = top_indices_chunk[:, :k]
            correct_in_top_k = (top_k_preds == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct_in_top_k[k].append(correct_in_top_k)

    recalls = {}
    for k in ks:
        recalls[f"R@{k}"] = torch.cat(all_correct_in_top_k[k]).float().mean().item()
    return recalls

# -----------------------------------------------------------------
# 7. Training Loop
# -----------------------------------------------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4

# --- MODIFICATION: Your specific RUN_ID ---
RUN_ID = "VAE_exp2"
FINAL_SAVE_PATH = f"{RUN_ID}_best.pth"
# --- END MODIFICATION ---

val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique

model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=START_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Giga-VAE (Looser Brakes, KL_WEIGHT=1e-6) (RUN {RUN_ID}) ---")
print(f"Training (Epochs: {EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, Mixup={MIXUP_ALPHA})...\n")

best_val_r10 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_recon_loss = 0.0
    total_kl_loss = 0.0

    for x_batch, y_batch, img_indices in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        idx_shuffle = torch.randperm(x_batch.size(0))
        lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
        x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
        y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            y_pred, mu, logvar = model(x_mix)

            # --- 1. Reconstruction Loss ---
            sims_with_tau = y_pred @ y_mix.T / TAU
            labels = torch.arange(y_pred.size(0), device=device)
            loss_con = F.cross_entropy(sims_with_tau, labels)

            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # --- 2. KL Divergence Loss ---
            kl_loss = calculate_kl_loss(mu, logvar)

            # --- 3. Total Loss ---
            loss = recon_loss + (KL_WEIGHT * kl_loss) # KL_WEIGHT is now 1e-6

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_kl_loss = total_kl_loss / len(train_loader)

    # ----- Validation (quick subset) -----
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: ReconLoss={avg_recon_loss:.4f} KL_Loss={avg_kl_loss:.2f} (LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} {save_marker}")

print(f"\n🎯 Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------------------------------------------
# 8. Inference + Submission
# -----------------------------------------------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH))
model.eval()
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk))

    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

# --- MODIFICATION: Your specific submission filename ---
submission_filename = f"submission_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\n✅ {submission_filename} saved successfully.")

# --- Final Validation (on FULL val set, now OOM-safe) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")

Using device: cuda
Found 'train.npz' and 'test.clean.npz'. Loading data...
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + VAE KL Loss (weight=1e-06)
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting Giga-VAE (Looser Brakes, KL_WEIGHT=1e-6) (RUN VAE_exp2) ---
Training (Epochs: 60, LR=0.0003, WD=0.0004, Mixup=0.6)...



Epoch 001/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: ReconLoss=2.6321 KL_Loss=1213.50 (LR=3.0E-04) | R@10(Val)=0.3082 <- Best R@10! Saved.


Epoch 002/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 002: ReconLoss=1.4348 KL_Loss=1694.01 (LR=3.0E-04) | R@10(Val)=0.4156 <- Best R@10! Saved.


Epoch 003/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 003: ReconLoss=1.1230 KL_Loss=1978.55 (LR=3.0E-04) | R@10(Val)=0.4636 <- Best R@10! Saved.


Epoch 004/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 004: ReconLoss=0.9361 KL_Loss=2162.57 (LR=3.0E-04) | R@10(Val)=0.4924 <- Best R@10! Saved.


Epoch 005/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 005: ReconLoss=0.8330 KL_Loss=2327.04 (LR=3.0E-04) | R@10(Val)=0.5242 <- Best R@10! Saved.


Epoch 006/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 006: ReconLoss=0.7267 KL_Loss=2446.02 (LR=2.9E-04) | R@10(Val)=0.5444 <- Best R@10! Saved.


Epoch 007/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 007: ReconLoss=0.6673 KL_Loss=2571.19 (LR=2.9E-04) | R@10(Val)=0.5564 <- Best R@10! Saved.


Epoch 008/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 008: ReconLoss=0.6087 KL_Loss=2687.05 (LR=2.9E-04) | R@10(Val)=0.5700 <- Best R@10! Saved.


Epoch 009/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 009: ReconLoss=0.5579 KL_Loss=2782.83 (LR=2.9E-04) | R@10(Val)=0.5904 <- Best R@10! Saved.


Epoch 010/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 010: ReconLoss=0.4849 KL_Loss=2887.66 (LR=2.8E-04) | R@10(Val)=0.6070 <- Best R@10! Saved.


Epoch 011/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 011: ReconLoss=0.4412 KL_Loss=2958.62 (LR=2.8E-04) | R@10(Val)=0.6080 <- Best R@10! Saved.


Epoch 012/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 012: ReconLoss=0.4140 KL_Loss=3003.93 (LR=2.8E-04) | R@10(Val)=0.6180 <- Best R@10! Saved.


Epoch 013/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 013: ReconLoss=0.3962 KL_Loss=3048.94 (LR=2.7E-04) | R@10(Val)=0.6226 <- Best R@10! Saved.


Epoch 014/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 014: ReconLoss=0.3735 KL_Loss=3078.92 (LR=2.7E-04) | R@10(Val)=0.6274 <- Best R@10! Saved.


Epoch 015/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 015: ReconLoss=0.3606 KL_Loss=3094.50 (LR=2.6E-04) | R@10(Val)=0.6344 <- Best R@10! Saved.


Epoch 016/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 016: ReconLoss=0.3473 KL_Loss=3130.90 (LR=2.6E-04) | R@10(Val)=0.6342 


Epoch 017/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 017: ReconLoss=0.3304 KL_Loss=3175.56 (LR=2.5E-04) | R@10(Val)=0.6418 <- Best R@10! Saved.


Epoch 018/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 018: ReconLoss=0.3177 KL_Loss=3214.86 (LR=2.4E-04) | R@10(Val)=0.6444 <- Best R@10! Saved.


Epoch 019/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 019: ReconLoss=0.2991 KL_Loss=3252.18 (LR=2.4E-04) | R@10(Val)=0.6442 


Epoch 020/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 020: ReconLoss=0.2796 KL_Loss=3291.52 (LR=2.3E-04) | R@10(Val)=0.6502 <- Best R@10! Saved.


Epoch 021/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 021: ReconLoss=0.2737 KL_Loss=3303.29 (LR=2.2E-04) | R@10(Val)=0.6528 <- Best R@10! Saved.


Epoch 022/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 022: ReconLoss=0.2640 KL_Loss=3336.04 (LR=2.2E-04) | R@10(Val)=0.6552 <- Best R@10! Saved.


Epoch 023/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 023: ReconLoss=0.2577 KL_Loss=3339.64 (LR=2.1E-04) | R@10(Val)=0.6548 


Epoch 024/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 024: ReconLoss=0.2517 KL_Loss=3372.88 (LR=2.0E-04) | R@10(Val)=0.6524 


Epoch 025/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 025: ReconLoss=0.2457 KL_Loss=3393.62 (LR=2.0E-04) | R@10(Val)=0.6576 <- Best R@10! Saved.


Epoch 026/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 026: ReconLoss=0.2420 KL_Loss=3384.71 (LR=1.9E-04) | R@10(Val)=0.6598 <- Best R@10! Saved.


Epoch 027/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 027: ReconLoss=0.2376 KL_Loss=3407.90 (LR=1.8E-04) | R@10(Val)=0.6596 


Epoch 028/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 028: ReconLoss=0.2312 KL_Loss=3441.66 (LR=1.7E-04) | R@10(Val)=0.6654 <- Best R@10! Saved.


Epoch 029/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 029: ReconLoss=0.2247 KL_Loss=3448.89 (LR=1.7E-04) | R@10(Val)=0.6632 


Epoch 030/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 030: ReconLoss=0.2274 KL_Loss=3428.00 (LR=1.6E-04) | R@10(Val)=0.6676 <- Best R@10! Saved.


Epoch 031/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 031: ReconLoss=0.2198 KL_Loss=3475.81 (LR=1.5E-04) | R@10(Val)=0.6658 


Epoch 032/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 032: ReconLoss=0.2169 KL_Loss=3480.25 (LR=1.4E-04) | R@10(Val)=0.6668 


Epoch 033/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 033: ReconLoss=0.2090 KL_Loss=3498.90 (LR=1.3E-04) | R@10(Val)=0.6696 <- Best R@10! Saved.


Epoch 034/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 034: ReconLoss=0.2105 KL_Loss=3483.60 (LR=1.3E-04) | R@10(Val)=0.6694 


Epoch 035/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 035: ReconLoss=0.2075 KL_Loss=3496.35 (LR=1.2E-04) | R@10(Val)=0.6682 


Epoch 036/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 036: ReconLoss=0.2046 KL_Loss=3513.15 (LR=1.1E-04) | R@10(Val)=0.6730 <- Best R@10! Saved.


Epoch 037/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 037: ReconLoss=0.2028 KL_Loss=3528.61 (LR=1.0E-04) | R@10(Val)=0.6734 <- Best R@10! Saved.


Epoch 038/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 038: ReconLoss=0.1988 KL_Loss=3527.93 (LR=9.6E-05) | R@10(Val)=0.6726 


Epoch 039/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 039: ReconLoss=0.1949 KL_Loss=3534.26 (LR=8.9E-05) | R@10(Val)=0.6752 <- Best R@10! Saved.


Epoch 040/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 040: ReconLoss=0.1956 KL_Loss=3529.61 (LR=8.2E-05) | R@10(Val)=0.6742 


Epoch 041/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 041: ReconLoss=0.1900 KL_Loss=3582.98 (LR=7.5E-05) | R@10(Val)=0.6772 <- Best R@10! Saved.


Epoch 042/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 042: ReconLoss=0.1915 KL_Loss=3555.14 (LR=6.8E-05) | R@10(Val)=0.6800 <- Best R@10! Saved.


Epoch 043/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 043: ReconLoss=0.1901 KL_Loss=3543.63 (LR=6.2E-05) | R@10(Val)=0.6766 


Epoch 044/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 044: ReconLoss=0.1906 KL_Loss=3576.53 (LR=5.6E-05) | R@10(Val)=0.6784 


Epoch 045/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 045: ReconLoss=0.1861 KL_Loss=3593.00 (LR=5.0E-05) | R@10(Val)=0.6778 


Epoch 046/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 046: ReconLoss=0.1838 KL_Loss=3571.09 (LR=4.4E-05) | R@10(Val)=0.6756 


Epoch 047/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 047: ReconLoss=0.1832 KL_Loss=3603.56 (LR=3.9E-05) | R@10(Val)=0.6762 


Epoch 048/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 048: ReconLoss=0.1842 KL_Loss=3578.60 (LR=3.3E-05) | R@10(Val)=0.6784 


Epoch 049/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 049: ReconLoss=0.1825 KL_Loss=3573.86 (LR=2.9E-05) | R@10(Val)=0.6788 


Epoch 050/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 050: ReconLoss=0.1811 KL_Loss=3590.11 (LR=2.4E-05) | R@10(Val)=0.6792 


Epoch 051/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 051: ReconLoss=0.1826 KL_Loss=3615.26 (LR=2.0E-05) | R@10(Val)=0.6790 


Epoch 052/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 052: ReconLoss=0.1795 KL_Loss=3601.33 (LR=1.6E-05) | R@10(Val)=0.6778 


Epoch 053/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 053: ReconLoss=0.1798 KL_Loss=3578.66 (LR=1.3E-05) | R@10(Val)=0.6780 


Epoch 054/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 054: ReconLoss=0.1777 KL_Loss=3602.27 (LR=1.0E-05) | R@10(Val)=0.6770 


Epoch 055/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 055: ReconLoss=0.1766 KL_Loss=3605.42 (LR=7.3E-06) | R@10(Val)=0.6770 


Epoch 056/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 056: ReconLoss=0.1794 KL_Loss=3597.35 (LR=5.1E-06) | R@10(Val)=0.6768 


Epoch 057/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 057: ReconLoss=0.1773 KL_Loss=3602.44 (LR=3.3E-06) | R@10(Val)=0.6766 


Epoch 058/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 058: ReconLoss=0.1757 KL_Loss=3600.17 (LR=1.8E-06) | R@10(Val)=0.6768 


Epoch 059/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 059: ReconLoss=0.1764 KL_Loss=3608.50 (LR=8.2E-07) | R@10(Val)=0.6768 


Epoch 060/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 060: ReconLoss=0.1780 KL_Loss=3614.43 (LR=2.1E-07) | R@10(Val)=0.6768 

🎯 Training complete. Best model (R@10=0.6800) saved as VAE_exp2_best.pth
Loading BEST model from VAE_exp2_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.


Generating Sub:   0%|          | 0/2 [00:00<?, ?it/s]


✅ submission_VAE_exp2.csv saved successfully.
Calculating final validation scores on full val set...


CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/13 [00:00<?, ?it/s]


--- Full Validation Results ---
Validation Recall@1:  0.3002  
Validation Recall@5:  0.5590  
Validation Recall@10: 0.6649  
Validation Recall@50: 0.8632  


# EXP3 - VAE with hidden_dim = 2048 and high Braker KL_WEIGHT=5e-5

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time

# --- EXPERIMENT 60: "The Giga-VAE" (Tighter Brakes) ---
#
# VAE_exp1 (KL=1e-5) scored 0.86241
# VAE_exp2 (KL=1e-6) scored 0.85748
# This means MORE regularization (a stronger brake) is BETTER.
#
# 1. Architecture: CHAMPION "Giga-VAE"
# 2. All "Giga-Boss" settings: hidden_dim=8192, DO=0.4, WD=4e-4, Mixup=0.6
# 3. Time:         60 Epochs
# 4. Loss Change:  *** KL_WEIGHT = 5e-5 *** (Was 1e-5. We are tightening the brake)
# 5. RUN_ID:       *** VAE_exp3 ***
# -----------------------------------------------------------------

# -----------------------------------------------------------------
# 1. Device and data loading
# -----------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if not os.path.exists("train.npz") or not os.path.exists("test.clean.npz"):
    print("ERROR: 'train.npz' or 'test.clean.npz' not found.")
    raise FileNotFoundError("Required data files not found.")
else:
    print("Found 'train.npz' and 'test.clean.npz'. Loading data...")

train_data = np.load("train.npz")
test_data  = np.load("test.clean.npz")

tx_train = train_data["captions/embeddings"]
im_train = train_data["images/embeddings"]
tx_test  = test_data["captions/embeddings"]

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------------------------------------------
# 2. Data preprocessing
# -----------------------------------------------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------------------------------------------
# 3. The Giga-VAE Model
# -----------------------------------------------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # --- ENCODER (Text -> Distribution) ---
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, latent_dim * 2) # Output mu and logvar
        )

        # --- DECODER (Distribution Sample -> Image) ---
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        if self.training:
            z = self.reparameterize(mu, logvar)
            y_pred = self.decoder(z)
            return F.normalize(y_pred, p=2, dim=1), mu, logvar

        # During eval, use mu directly for a stable prediction
        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------------------------------------------
# 4. Loss Functions
# -----------------------------------------------------------------
TAU = 0.01
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.6 # Giga-Boss Mixup
KL_WEIGHT = 5e-5 # *** NEW: 5x STRONGER BRAKE ***

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + VAE KL Loss (weight={KL_WEIGHT})")

def calculate_kl_loss(mu, logvar):
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl_loss / mu.size(0)

# -----------------------------------------------------------------
# 5. Validation split
# -----------------------------------------------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu") # <-- New random seed for this run
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------------------------------------------
# 6. Memory-Safe Recall@K Utility
# -----------------------------------------------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1, 5, 10, 50), k_csls=10):
    model.eval()
    device = next(model.parameters()).device
    all_correct_in_top_k = {k: [] for k in ks}

    Q_embed_list = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(device)
        Q_embed_list.append(model(chunk))
    Q_embed = torch.cat(Q_embed_list, dim=0)

    D_embed = im_database.to(device)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))

    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(device)

    for i in tqdm(range(0, len(Q_embed), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk_size] @ D_embed.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls_sim_chunk = 2 * sim_chunk - mean_knn_q - mean_knn_d

        top_indices_chunk = torch.argsort(csls_sim_chunk, dim=1, descending=True)

        for k in ks:
            top_k_preds = top_indices_chunk[:, :k]
            correct_in_top_k = (top_k_preds == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct_in_top_k[k].append(correct_in_top_k)

    recalls = {}
    for k in ks:
        recalls[f"R@{k}"] = torch.cat(all_correct_in_top_k[k]).float().mean().item()
    return recalls

# -----------------------------------------------------------------
# 7. Training Loop
# -----------------------------------------------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4

# --- MODIFICATION: Your specific RUN_ID ---
RUN_ID = "VAE_exp3"
FINAL_SAVE_PATH = f"{RUN_ID}_best.pth"
# --- END MODIFICATION ---

val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique

model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=START_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Giga-VAE (Tighter Brakes, KL_WEIGHT=5e-5) (RUN {RUN_ID}) ---")
print(f"Training (Epochs: {EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, Mixup={MIXUP_ALPHA})...\n")

best_val_r10 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_recon_loss = 0.0
    total_kl_loss = 0.0

    for x_batch, y_batch, img_indices in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        idx_shuffle = torch.randperm(x_batch.size(0))
        lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
        x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
        y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            y_pred, mu, logvar = model(x_mix)

            # --- 1. Reconstruction Loss ---
            sims_with_tau = y_pred @ y_mix.T / TAU
            labels = torch.arange(y_pred.size(0), device=device)
            loss_con = F.cross_entropy(sims_with_tau, labels)

            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # --- 2. KL Divergence Loss ---
            kl_loss = calculate_kl_loss(mu, logvar)

            # --- 3. Total Loss ---
            loss = recon_loss + (KL_WEIGHT * kl_loss) # KL_WEIGHT is now 5e-5

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_kl_loss = total_kl_loss / len(train_loader)

    # ----- Validation (quick subset) -----
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: ReconLoss={avg_recon_loss:.4f} KL_Loss={avg_kl_loss:.2f} (LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} {save_marker}")

print(f"\n🎯 Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------------------------------------------
# 8. Inference + Submission
# -----------------------------------------------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH))
model.eval()
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk))

    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

# --- MODIFICATION: Your specific submission filename ---
submission_filename = f"submission_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\n✅ {submission_filename} saved successfully.")

# --- Final Validation (on FULL val set, now OOM-safe) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")

Using device: cuda
Found 'train.npz' and 'test.clean.npz'. Loading data...
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + VAE KL Loss (weight=5e-05)
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting Giga-VAE (Tighter Brakes, KL_WEIGHT=5e-5) (RUN VAE_exp3) ---
Training (Epochs: 60, LR=0.0003, WD=0.0004, Mixup=0.6)...



Epoch 001/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: ReconLoss=2.6204 KL_Loss=1155.09 (LR=3.0E-04) | R@10(Val)=0.3298 <- Best R@10! Saved.


Epoch 002/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 002: ReconLoss=1.4413 KL_Loss=1465.44 (LR=3.0E-04) | R@10(Val)=0.4096 <- Best R@10! Saved.


Epoch 003/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 003: ReconLoss=1.1353 KL_Loss=1594.25 (LR=3.0E-04) | R@10(Val)=0.4654 <- Best R@10! Saved.


Epoch 004/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 004: ReconLoss=0.9582 KL_Loss=1643.55 (LR=3.0E-04) | R@10(Val)=0.4984 <- Best R@10! Saved.


Epoch 005/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 005: ReconLoss=0.8485 KL_Loss=1677.13 (LR=3.0E-04) | R@10(Val)=0.5228 <- Best R@10! Saved.


Epoch 006/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 006: ReconLoss=0.7750 KL_Loss=1704.61 (LR=2.9E-04) | R@10(Val)=0.5476 <- Best R@10! Saved.


Epoch 007/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 007: ReconLoss=0.6942 KL_Loss=1721.48 (LR=2.9E-04) | R@10(Val)=0.5582 <- Best R@10! Saved.


Epoch 008/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 008: ReconLoss=0.6439 KL_Loss=1737.78 (LR=2.9E-04) | R@10(Val)=0.5702 <- Best R@10! Saved.


Epoch 009/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 009: ReconLoss=0.5886 KL_Loss=1748.16 (LR=2.9E-04) | R@10(Val)=0.5782 <- Best R@10! Saved.


Epoch 010/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 010: ReconLoss=0.5174 KL_Loss=1744.39 (LR=2.8E-04) | R@10(Val)=0.5964 <- Best R@10! Saved.


Epoch 011/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 011: ReconLoss=0.4759 KL_Loss=1737.05 (LR=2.8E-04) | R@10(Val)=0.5992 <- Best R@10! Saved.


Epoch 012/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 012: ReconLoss=0.4492 KL_Loss=1722.90 (LR=2.8E-04) | R@10(Val)=0.6106 <- Best R@10! Saved.


Epoch 013/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 013: ReconLoss=0.4363 KL_Loss=1714.65 (LR=2.7E-04) | R@10(Val)=0.6128 <- Best R@10! Saved.


Epoch 014/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 014: ReconLoss=0.4180 KL_Loss=1716.27 (LR=2.7E-04) | R@10(Val)=0.6140 <- Best R@10! Saved.


Epoch 015/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 015: ReconLoss=0.3985 KL_Loss=1705.12 (LR=2.6E-04) | R@10(Val)=0.6184 <- Best R@10! Saved.


Epoch 016/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 016: ReconLoss=0.3855 KL_Loss=1696.46 (LR=2.6E-04) | R@10(Val)=0.6238 <- Best R@10! Saved.


Epoch 017/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 017: ReconLoss=0.3674 KL_Loss=1697.05 (LR=2.5E-04) | R@10(Val)=0.6278 <- Best R@10! Saved.


Epoch 018/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 018: ReconLoss=0.3563 KL_Loss=1685.36 (LR=2.4E-04) | R@10(Val)=0.6336 <- Best R@10! Saved.


Epoch 019/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 019: ReconLoss=0.3320 KL_Loss=1685.99 (LR=2.4E-04) | R@10(Val)=0.6392 <- Best R@10! Saved.


Epoch 020/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 020: ReconLoss=0.3220 KL_Loss=1674.81 (LR=2.3E-04) | R@10(Val)=0.6356 


Epoch 021/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 021: ReconLoss=0.3212 KL_Loss=1668.73 (LR=2.2E-04) | R@10(Val)=0.6410 <- Best R@10! Saved.


Epoch 022/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 022: ReconLoss=0.3115 KL_Loss=1666.97 (LR=2.2E-04) | R@10(Val)=0.6404 


Epoch 023/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 023: ReconLoss=0.3004 KL_Loss=1662.29 (LR=2.1E-04) | R@10(Val)=0.6426 <- Best R@10! Saved.


Epoch 024/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 024: ReconLoss=0.2962 KL_Loss=1652.44 (LR=2.0E-04) | R@10(Val)=0.6450 <- Best R@10! Saved.


Epoch 025/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 025: ReconLoss=0.2918 KL_Loss=1655.43 (LR=2.0E-04) | R@10(Val)=0.6428 


Epoch 026/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 026: ReconLoss=0.2881 KL_Loss=1644.55 (LR=1.9E-04) | R@10(Val)=0.6450 


Epoch 027/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 027: ReconLoss=0.2779 KL_Loss=1642.84 (LR=1.8E-04) | R@10(Val)=0.6432 


Epoch 028/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 028: ReconLoss=0.2771 KL_Loss=1643.42 (LR=1.7E-04) | R@10(Val)=0.6476 <- Best R@10! Saved.


Epoch 029/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 029: ReconLoss=0.2700 KL_Loss=1637.07 (LR=1.7E-04) | R@10(Val)=0.6504 <- Best R@10! Saved.


Epoch 030/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 030: ReconLoss=0.2713 KL_Loss=1628.19 (LR=1.6E-04) | R@10(Val)=0.6510 <- Best R@10! Saved.


Epoch 031/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 031: ReconLoss=0.2659 KL_Loss=1633.50 (LR=1.5E-04) | R@10(Val)=0.6532 <- Best R@10! Saved.


Epoch 032/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 032: ReconLoss=0.2576 KL_Loss=1635.62 (LR=1.4E-04) | R@10(Val)=0.6544 <- Best R@10! Saved.


Epoch 033/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 033: ReconLoss=0.2573 KL_Loss=1617.85 (LR=1.3E-04) | R@10(Val)=0.6552 <- Best R@10! Saved.


Epoch 034/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 034: ReconLoss=0.2508 KL_Loss=1622.60 (LR=1.3E-04) | R@10(Val)=0.6554 <- Best R@10! Saved.


Epoch 035/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 035: ReconLoss=0.2516 KL_Loss=1619.49 (LR=1.2E-04) | R@10(Val)=0.6572 <- Best R@10! Saved.


Epoch 036/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 036: ReconLoss=0.2484 KL_Loss=1616.13 (LR=1.1E-04) | R@10(Val)=0.6570 


Epoch 037/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 037: ReconLoss=0.2457 KL_Loss=1611.20 (LR=1.0E-04) | R@10(Val)=0.6566 


Epoch 038/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 038: ReconLoss=0.2403 KL_Loss=1617.05 (LR=9.6E-05) | R@10(Val)=0.6582 <- Best R@10! Saved.


Epoch 039/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 039: ReconLoss=0.2415 KL_Loss=1609.86 (LR=8.9E-05) | R@10(Val)=0.6606 <- Best R@10! Saved.


Epoch 040/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 040: ReconLoss=0.2354 KL_Loss=1604.92 (LR=8.2E-05) | R@10(Val)=0.6588 


Epoch 041/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 041: ReconLoss=0.2375 KL_Loss=1596.34 (LR=7.5E-05) | R@10(Val)=0.6586 


Epoch 042/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 042: ReconLoss=0.2322 KL_Loss=1604.68 (LR=6.8E-05) | R@10(Val)=0.6578 


Epoch 043/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 043: ReconLoss=0.2353 KL_Loss=1598.83 (LR=6.2E-05) | R@10(Val)=0.6578 


Epoch 044/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 044: ReconLoss=0.2313 KL_Loss=1602.85 (LR=5.6E-05) | R@10(Val)=0.6590 


Epoch 045/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 045: ReconLoss=0.2297 KL_Loss=1599.17 (LR=5.0E-05) | R@10(Val)=0.6602 


Epoch 046/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 046: ReconLoss=0.2296 KL_Loss=1595.31 (LR=4.4E-05) | R@10(Val)=0.6608 <- Best R@10! Saved.


Epoch 047/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 047: ReconLoss=0.2263 KL_Loss=1598.52 (LR=3.9E-05) | R@10(Val)=0.6606 


Epoch 048/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 048: ReconLoss=0.2274 KL_Loss=1594.54 (LR=3.3E-05) | R@10(Val)=0.6618 <- Best R@10! Saved.


Epoch 049/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 049: ReconLoss=0.2232 KL_Loss=1597.40 (LR=2.9E-05) | R@10(Val)=0.6622 <- Best R@10! Saved.


Epoch 050/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 050: ReconLoss=0.2248 KL_Loss=1591.71 (LR=2.4E-05) | R@10(Val)=0.6624 <- Best R@10! Saved.


Epoch 051/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 051: ReconLoss=0.2240 KL_Loss=1595.69 (LR=2.0E-05) | R@10(Val)=0.6618 


Epoch 052/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 052: ReconLoss=0.2251 KL_Loss=1590.20 (LR=1.6E-05) | R@10(Val)=0.6626 <- Best R@10! Saved.


Epoch 053/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 053: ReconLoss=0.2193 KL_Loss=1592.94 (LR=1.3E-05) | R@10(Val)=0.6626 


Epoch 054/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 054: ReconLoss=0.2197 KL_Loss=1593.82 (LR=1.0E-05) | R@10(Val)=0.6618 


Epoch 055/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 055: ReconLoss=0.2220 KL_Loss=1588.87 (LR=7.3E-06) | R@10(Val)=0.6620 


Epoch 056/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 056: ReconLoss=0.2222 KL_Loss=1588.96 (LR=5.1E-06) | R@10(Val)=0.6614 


Epoch 057/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 057: ReconLoss=0.2213 KL_Loss=1594.40 (LR=3.3E-06) | R@10(Val)=0.6610 


Epoch 058/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 058: ReconLoss=0.2272 KL_Loss=1584.31 (LR=1.8E-06) | R@10(Val)=0.6610 


Epoch 059/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 059: ReconLoss=0.2218 KL_Loss=1587.54 (LR=8.2E-07) | R@10(Val)=0.6608 


Epoch 060/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 060: ReconLoss=0.2193 KL_Loss=1592.41 (LR=2.1E-07) | R@10(Val)=0.6608 

🎯 Training complete. Best model (R@10=0.6626) saved as VAE_exp3_best.pth
Loading BEST model from VAE_exp3_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.


Generating Sub:   0%|          | 0/2 [00:00<?, ?it/s]


✅ submission_VAE_exp3.csv saved successfully.
Calculating final validation scores on full val set...


CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/13 [00:00<?, ?it/s]


--- Full Validation Results ---
Validation Recall@1:  0.3066  
Validation Recall@5:  0.5583  
Validation Recall@10: 0.6611  
Validation Recall@50: 0.8569  


## Fine tune to VAE, changing the latent_space from 1546 --> 2048

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time

# --- EXPERIMENT 61: "The Giga-VAE" (Bottleneck) ---
#
# VAE_exp1 (KL=1e-5, latent=1536) is our champion: 0.86241
# We've proven 1e-5 is the best KL_WEIGHT.
# Now we tune the next knob: latent_dim.
#
# 1. Architecture: CHAMPION "Giga-VAE"
# 2. All "Giga-Boss" settings: hidden_dim=8192, DO=0.4, WD=4e-4, Mixup=0.6
# 3. Champion "Brake": KL_WEIGHT = 1e-5
# 4. Loss Change:  *** latent_dim = 1024 *** (Was 1536. Creating a true bottleneck)
# 5. RUN_ID:       *** VAE_exp4 ***
# -----------------------------------------------------------------

# -----------------------------------------------------------------
# 1. Device and data loading
# -----------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if not os.path.exists("train.npz") or not os.path.exists("test.clean.npz"):
    print("ERROR: 'train.npz' or 'test.clean.npz' not found.")
    raise FileNotFoundError("Required data files not found.")
else:
    print("Found 'train.npz' and 'test.clean.npz'. Loading data...")

train_data = np.load("train.npz")
test_data  = np.load("test.clean.npz")

tx_train = train_data["captions/embeddings"]
im_train = train_data["images/embeddings"]
tx_test  = test_data["captions/embeddings"]

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------------------------------------------
# 2. Data preprocessing
# -----------------------------------------------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------------------------------------------
# 3. The Giga-VAE Model (*** MODIFIED latent_dim ***)
# -----------------------------------------------------------------
class VAETranslator(nn.Module):
    # --- MODIFIED: latent_dim is now an argument ---
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1024, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # --- ENCODER (Text -> Distribution) ---
        # 1024 -> 8192 -> 8192 -> (1024 * 2)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, latent_dim * 2) # Output mu and logvar
        )

        # --- DECODER (Distribution Sample -> Image) ---
        # 1024 -> 8192 -> 8192 -> 1536
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        if self.training:
            z = self.reparameterize(mu, logvar)
            y_pred = self.decoder(z)
            return F.normalize(y_pred, p=2, dim=1), mu, logvar

        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------------------------------------------
# 4. Loss Functions
# -----------------------------------------------------------------
TAU = 0.01
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.6
KL_WEIGHT = 1e-5 # *** CHAMPION "BRAKE" SETTING ***

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + VAE KL Loss (weight={KL_WEIGHT})")

def calculate_kl_loss(mu, logvar):
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl_loss / mu.size(0)

# -----------------------------------------------------------------
# 5. Validation split
# -----------------------------------------------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu")
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------------------------------------------
# 6. Memory-Safe Recall@K Utility
# -----------------------------------------------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1, 5, 10, 50), k_csls=10):
    model.eval()
    device = next(model.parameters()).device
    all_correct_in_top_k = {k: [] for k in ks}

    Q_embed_list = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(device)
        Q_embed_list.append(model(chunk))
    Q_embed = torch.cat(Q_embed_list, dim=0)

    D_embed = im_database.to(device)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))

    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(device)

    for i in tqdm(range(0, len(Q_embed), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk_size] @ D_embed.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls_sim_chunk = 2 * sim_chunk - mean_knn_q - mean_knn_d

        top_indices_chunk = torch.argsort(csls_sim_chunk, dim=1, descending=True)

        for k in ks:
            top_k_preds = top_indices_chunk[:, :k]
            correct_in_top_k = (top_k_preds == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct_in_top_k[k].append(correct_in_top_k)

    recalls = {}
    for k in ks:
        recalls[f"R@{k}"] = torch.cat(all_correct_in_top_k[k]).float().mean().item()
    return recalls

# -----------------------------------------------------------------
# 7. Training Loop
# -----------------------------------------------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4

# --- MODIFICATION: Your specific RUN_ID ---
RUN_ID = "VAE_exp4"
FINAL_SAVE_PATH = f"{RUN_ID}_best.pth"
# --- END MODIFICATION ---

val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique

# --- MODIFICATION: Initialize model with latent_dim=1024 ---
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1024, output_dim=1536).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=START_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Giga-VAE (Bottleneck, latent_dim=1024) (RUN {RUN_ID}) ---")
print(f"Training (Epochs: {EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, Mixup={MIXUP_ALPHA})...\n")

best_val_r10 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_recon_loss = 0.0
    total_kl_loss = 0.0

    for x_batch, y_batch, img_indices in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        idx_shuffle = torch.randperm(x_batch.size(0))
        lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
        x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
        y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)

        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            y_pred, mu, logvar = model(x_mix)

            # --- 1. Reconstruction Loss ---
            sims_with_tau = y_pred @ y_mix.T / TAU
            labels = torch.arange(y_pred.size(0), device=device)
            loss_con = F.cross_entropy(sims_with_tau, labels)

            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # --- 2. KL Divergence Loss ---
            kl_loss = calculate_kl_loss(mu, logvar)

            # --- 3. Total Loss ---
            loss = recon_loss + (KL_WEIGHT * kl_loss) # KL_WEIGHT is 1e-5

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_kl_loss = total_kl_loss / len(train_loader)

    # ----- Validation (quick subset) -----
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: ReconLoss={avg_recon_loss:.4f} KL_Loss={avg_kl_loss:.2f} (LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} {save_marker}")

print(f"\n🎯 Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------------------------------------------
# 8. Inference + Submission
# -----------------------------------------------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
# --- MODIFICATION: Load model with latent_dim=1024 ---
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1024, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH))
model.eval()
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk))

    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

submission_filename = f"submission_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\n✅ {submission_filename} saved successfully.")

# --- Final Validation (on FULL val set, now OOM-safe) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")