In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.linear_model import Ridge
from scipy.linalg import orthogonal_procrustes
import time

# === T4x2 CHANGE ===
torch.backends.cudnn.benchmark = True  # autotune kernels
print(f"CUDA devices: {torch.cuda.device_count()}")

CUDA devices: 2


In [2]:
train_data = np.load("/kaggle/input/aml-competition/train/train/train.npz")
test_data  = np.load("/kaggle/input/aml-competition/test/test/test.clean.npz")

Retraining the 'best model'
Issue: since the seed was not set while training, we couldn't reachieve the best results, other splits happened to be unlucky apparently

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time
from sklearn.linear_model import Ridge
from scipy.linalg import orthogonal_procrustes

# --- EXPERIMENT 58: "The Giga-VAE" ---
#
# This is a new probabilistic architecture inspired by the DALL-E 2 /
# CLIP Latents "prior" models.
# It learns a probabilistic mapping (A -> Distribution -> B) to
# handle the 1-to-5 text-to-image noise.
#
# 1. Architecture: NEW "VAETranslator"
# 2. Hidden Dim:   *** 8192 *** (Champion Horsepower)
# 3. Regularization: *** "GIGA-BOSS" ***
#    - Dropout:       0.4
#    - Weight Decay:  4e-4
#    - Mixup:         warmup then 0.3
# 4. Time:         60 Epochs (Champion Sprint)
# 5. Loss:         CHAMPION Asymmetric Loss + *** KL Divergence Loss ***
#    + Symmetric InfoNCE + Learnable Temperature
# -----------------------------------------------------------------

# -----------------------------------------------------------------
# 1. Device and data loading (Same as before)
# -----------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tx_train = train_data["captions/embeddings"]
im_train = train_data["images/embeddings"]
tx_test  = test_data["captions/embeddings"]

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------------------------------------------
# 2. Data preprocessing (Same as before)
# -----------------------------------------------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------------------------------------------
# 3. *** NEW: The VAE-Translator Model ***
# -----------------------------------------------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # --- ENCODER (Text -> Distribution) ---
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2) # Output mu and logvar
        )

        # --- DECODER (Distribution Sample -> Image) ---
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # 1. Encode text to mu and logvar
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        # 2. Get a sample 'z' from the distribution
        z = self.reparameterize(mu, logvar)

        # 3. Decode the sample to the image embedding space
        y_pred = self.decoder(z)

        # During training, return all parts for the loss function
        if self.training:
            return F.normalize(y_pred, p=2, dim=1), mu, logvar

        # During eval, just return the deterministic prediction
        # We use mu directly for a stable prediction
        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------------------------------------------
# 4. Loss Functions (CHAMPION Loss + KL Divergence) + NEW flags
# -----------------------------------------------------------------
TAU = 0.01   # kept for initialization only
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.3          # gentler mixup after warmup
KL_WEIGHT = 1e-5           # VAE Regularization Strength
SYMMETRIC_INFO_NCE = True  # new
MIXUP_WARMUP_EPOCHS = 3    # new: no mixup for first epochs

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight={KL_WEIGHT}) ***")

def calculate_kl_loss(mu, logvar):
    # KL divergence, VAE formula
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Return the *mean* loss per item in the batch
    return kl_loss / mu.size(0)

# -----------------------------------------------------------------
# 5. Validation split (Same as before)
# -----------------------------------------------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu") # <-- New random seed for this run
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------------------------------------------
# 6. Memory-Safe Recall@K Utility (Same as before)
# -----------------------------------------------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1, 5, 10, 50), k_csls=10):
    model.eval() # <-- CRITICAL. Return deterministic preds
    device_m = next(model.parameters()).device
    all_correct_in_top_k = {k: [] for k in ks}

    Q_embed_list = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(device_m)
        Q_embed_list.append(model(chunk)) # model(chunk) returns y_pred_eval
    Q_embed = torch.cat(Q_embed_list, dim=0)

    D_embed = im_database.to(device_m)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))

    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(device_m)

    for i in tqdm(range(0, len(Q_embed), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk_size] @ D_embed.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls_sim_chunk = 2 * sim_chunk - mean_knn_q - mean_knn_d

        top_indices_chunk = torch.argsort(csls_sim_chunk, dim=1, descending=True)

        for k in ks:
            top_k_preds = top_indices_chunk[:, :k]
            correct_in_top_k = (top_k_preds == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct_in_top_k[k].append(correct_in_top_k)

    recalls = {}
    for k in ks:
        recalls[f"R@{k}"] = torch.cat(all_correct_in_top_k[k]).float().mean().item()
    return recalls

# -----------------------------------------------------------------
# DIAGNOSTIC TOOLS (unchanged logic; kept simple for Kaggle)
# -----------------------------------------------------------------
class DiagnosticMetrics:
    """Comprehensive diagnostic metrics for embedding translation"""

    @staticmethod
    @torch.no_grad()
    def emb2emb_linear_transfer(src_emb, tgt_emb, model_pred_emb=None, alpha=1.0):
        # CPU-safe
        src_np = (src_emb.detach().cpu().numpy()
                  if isinstance(src_emb, torch.Tensor) else src_emb)
        tgt_np = (tgt_emb.detach().cpu().numpy()
                  if isinstance(tgt_emb, torch.Tensor) else tgt_emb)

        ridge = Ridge(alpha=alpha, fit_intercept=False)
        ridge.fit(src_np, tgt_np)

        pred = ridge.predict(src_np)
        pred_t = torch.from_numpy(pred).float()

        if not isinstance(tgt_emb, torch.Tensor):
            tgt_emb = torch.from_numpy(tgt_np).float()
        tgt_t = tgt_emb.float()

        pred_t = F.normalize(pred_t, p=2, dim=1)
        tgt_norm = F.normalize(tgt_t, p=2, dim=1)

        transfer_mse = F.mse_loss(pred_t, tgt_norm).item()
        transfer_cosine = F.cosine_similarity(pred_t, tgt_norm).mean().item()

        if model_pred_emb is not None:
            if not isinstance(model_pred_emb, torch.Tensor):
                model_pred_emb = torch.from_numpy(model_pred_emb).float()
            model_cos = F.cosine_similarity(
                F.normalize(model_pred_emb.float(), p=2, dim=1), tgt_norm
            ).mean().item()
            improvement = transfer_cosine - model_cos
            baseline_name = 'model_cosine'
        else:
            model_cos = 0.0
            improvement = transfer_cosine
            baseline_name = 'random_baseline'

        return {
            'transfer_mse': transfer_mse,
            'transfer_cosine': transfer_cosine,
            baseline_name: model_cos,
            'improvement': improvement
        }

    @staticmethod
    @torch.no_grad()
    def stitching_penalty_monitor(src_emb, tgt_emb):
        # Rectangular Procrustes via SVD (works when dims differ)
        src_np = src_emb.detach().cpu().numpy()
        tgt_np = tgt_emb.detach().cpu().numpy()

        C = src_np.T @ tgt_np  # (d_src x d_tgt)
        U, _, Vt = np.linalg.svd(C, full_matrices=False)
        R_rect = U @ Vt  # (d_src x d_tgt)
        pred_ortho = src_np @ R_rect

        ridge = Ridge(alpha=0.01, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred_full = ridge.predict(src_np)

        mse_ortho = np.mean((pred_ortho - tgt_np) ** 2)
        mse_full  = np.mean((pred_full  - tgt_np) ** 2)

        pred_ortho_t = F.normalize(torch.from_numpy(pred_ortho).float(), p=2, dim=1)
        pred_full_t  = F.normalize(torch.from_numpy(pred_full ).float(), p=2, dim=1)
        tgt_norm     = F.normalize(torch.from_numpy(tgt_np   ).float(), p=2, dim=1)

        cos_ortho = F.cosine_similarity(pred_ortho_t, tgt_norm).mean().item()
        cos_full  = F.cosine_similarity(pred_full_t,  tgt_norm).mean().item()

        gap_pct = 100 * (mse_ortho - mse_full) / (mse_full + 1e-8)

        return {
            'mse_orthogonal': float(mse_ortho),
            'mse_full': float(mse_full),
            'gap_percent': float(gap_pct),
            'cos_orthogonal': cos_ortho,
            'cos_full': cos_full
        }

    @staticmethod
    @torch.no_grad()
    def mutual_knn_alignment(emb1, emb2, k=10):
        sim1 = emb1 @ emb1.T
        sim2 = emb2 @ emb2.T
        _, knn1 = torch.topk(sim1, k=k+1, dim=1)
        _, knn2 = torch.topk(sim2, k=k+1, dim=1)
        knn1 = knn1[:, 1:]
        knn2 = knn2[:, 1:]

        mutual_count = 0
        for i in range(len(emb1)):
            set1 = set(knn1[i].cpu().tolist())
            set2 = set(knn2[i].cpu().tolist())
            mutual_count += len(set1 & set2)
        return mutual_count / (len(emb1) * k)

# -----------------------------------------------------------------
# 7. *** NEW: Training Loop ***
# -----------------------------------------------------------------
EPOCHS = 60 # Champion "Sprint"
START_LR = 3e-4
WEIGHT_DECAY = 4e-4 # Giga-Boss Weight Decay
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique

# --- Initialize the new Giga-VAE model ---
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

# Learnable temperature (CLIP-style), initialized from TAU
logit_scale = nn.Parameter(torch.log(torch.tensor(1.0 / TAU, device=device)))

optimizer = torch.optim.AdamW(
    list(model.parameters()) + [logit_scale],
    lr=START_LR, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting New Architecture: Giga-VAE (RUN {RUN_ID}) ---")
print(f"Training (Epochs: {EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, Mixup warmup {MIXUP_WARMUP_EPOCHS}â†’Î±={MIXUP_ALPHA})...\n")

best_val_r10 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train() # Return all 3 outputs
    total_recon_loss = 0.0
    total_kl_loss = 0.0

    for x_batch, y_batch, img_indices in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # --- Mixup warmup ---
        use_mixup = (epoch > MIXUP_WARMUP_EPOCHS)
        if use_mixup:
            idx_shuffle = torch.randperm(x_batch.size(0), device=x_batch.device)
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
            y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)
        else:
            x_mix = x_batch
            y_mix = y_batch

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            # --- VAE FORWARD PASS ---
            y_pred, mu, logvar = model(x_mix) # Model is in .train() mode

            # --- Symmetric InfoNCE with learnable temperature ---
            scale = logit_scale.exp().clamp(max=torch.tensor(100.0, device=logit_scale.device))
            logits = (y_pred @ y_mix.T) * scale

            labels = torch.arange(y_pred.size(0), device=device)
            loss_i2t = F.cross_entropy(logits, labels)
            if SYMMETRIC_INFO_NCE:
                loss_t2i = F.cross_entropy(logits.T, labels)
                loss_con = 0.5 * (loss_i2t + loss_t2i)
            else:
                loss_con = loss_i2t

            # --- Hard negative triplet (unchanged logic) ---
            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # --- KL Divergence Loss ---
            kl_loss = calculate_kl_loss(mu, logvar)

            # --- Total Loss ---
            loss = recon_loss + (KL_WEIGHT * kl_loss)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon_loss = total_recon_loss / len(train_loader)
    avg_kl_loss = total_kl_loss / len(train_loader)
    current_scale = float(logit_scale.exp().clamp(max=100).item())

    # ----- Validation (quick subset) -----
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: ReconLoss={avg_recon_loss:.4f} KL_Loss={avg_kl_loss:.2f} "
          f"(LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} | temp_scale={current_scale:.2f} {save_marker}")

    # ----- Diagnostics (every 5 epochs + first) -----
    if epoch == 1 or (epoch % 5 == 0):
        with torch.no_grad():
            model.eval()
            diag_n = min(2000, len(val_query_subset))
            tx_sample = val_query_subset[:diag_n].to(device)
            # Map each query to its GT image embedding for diagnostics
            tgt_sample = val_db_subset[val_indices_subset[:diag_n]].to(device)
            pred_sample = model(tx_sample)

            # Move to CPU for diagnostics (CPU-safe ops)
            tx_cpu = tx_sample.detach().cpu()
            tgt_cpu = tgt_sample.detach().cpu()
            pred_cpu = pred_sample.detach().cpu()

            emb2emb = DiagnosticMetrics.emb2emb_linear_transfer(tx_cpu, tgt_cpu, model_pred_emb=pred_cpu, alpha=1.0)
            stitch = DiagnosticMetrics.stitching_penalty_monitor(tx_cpu, tgt_cpu)
            mk_align = DiagnosticMetrics.mutual_knn_alignment(
                F.normalize(tx_cpu, p=2, dim=1),
                F.normalize(tgt_cpu, p=2, dim=1),
                k=10
            )

            print(f"[Diagnostics @ epoch {epoch:03d}]")
            print("  emb2emb_linear_transfer:", emb2emb)
            print("  stitching_penalty_monitor:", stitch)
            print(f"  mutual_knn_alignment_fraction: {mk_align:.4f}")

print(f"\nðŸŽ¯ Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------------------------------------------
# 8. Inference + Submission
# -----------------------------------------------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH, map_location=device))
model.eval() # Deterministic output
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk)) # model is in .eval() mode
    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

submission_filename = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\nâœ… {submission_filename} saved successfully.")

# --- Final Validation (on FULL val set, now OOM-safe) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")


Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight=1e-05) ***
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting New Architecture: Giga-VAE (RUN giga_vae_model_1762439103) ---
Training (Epochs: 60, LR=0.0003, WD=0.0004, Mixup warmup 3â†’Î±=0.3)...



Epoch 001/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: ReconLoss=2.5278 KL_Loss=1345.82 (LR=3.0E-04) | R@10(Val)=0.3442 | temp_scale=97.84 <- Best R@10! Saved.
[Diagnostics @ epoch 001]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.12204770743846893, 'improvement': 0.5002218037843704}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 002/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 002: ReconLoss=1.4018 KL_Loss=1847.92 (LR=3.0E-04) | R@10(Val)=0.4384 | temp_scale=96.80 <- Best R@10! Saved.


Epoch 003/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 003: ReconLoss=1.1237 KL_Loss=2059.55 (LR=3.0E-04) | R@10(Val)=0.4912 | temp_scale=95.98 <- Best R@10! Saved.


Epoch 004/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 004: ReconLoss=0.9174 KL_Loss=2178.60 (LR=3.0E-04) | R@10(Val)=0.5232 | temp_scale=95.43 <- Best R@10! Saved.


Epoch 005/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 005: ReconLoss=0.7995 KL_Loss=2257.80 (LR=3.0E-04) | R@10(Val)=0.5452 | temp_scale=95.35 <- Best R@10! Saved.
[Diagnostics @ epoch 005]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.1611255258321762, 'improvement': 0.46114398539066315}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 006/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 006: ReconLoss=0.7209 KL_Loss=2351.58 (LR=2.9E-04) | R@10(Val)=0.5630 | temp_scale=95.74 <- Best R@10! Saved.


Epoch 007/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 007: ReconLoss=0.6319 KL_Loss=2421.15 (LR=2.9E-04) | R@10(Val)=0.5712 | temp_scale=96.63 <- Best R@10! Saved.


Epoch 008/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 008: ReconLoss=0.5906 KL_Loss=2472.11 (LR=2.9E-04) | R@10(Val)=0.5882 | temp_scale=97.71 <- Best R@10! Saved.


Epoch 009/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 009: ReconLoss=0.5416 KL_Loss=2546.30 (LR=2.9E-04) | R@10(Val)=0.6028 | temp_scale=99.36 <- Best R@10! Saved.


Epoch 010/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 010: ReconLoss=0.4680 KL_Loss=2599.69 (LR=2.8E-04) | R@10(Val)=0.6196 | temp_scale=100.00 <- Best R@10! Saved.
[Diagnostics @ epoch 010]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.16583998501300812, 'improvement': 0.45642952620983124}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 011/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 011: ReconLoss=0.4288 KL_Loss=2631.61 (LR=2.8E-04) | R@10(Val)=0.6270 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 012/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 012: ReconLoss=0.4052 KL_Loss=2653.54 (LR=2.8E-04) | R@10(Val)=0.6342 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 013/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 013: ReconLoss=0.3832 KL_Loss=2671.68 (LR=2.7E-04) | R@10(Val)=0.6416 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 014/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 014: ReconLoss=0.3591 KL_Loss=2674.44 (LR=2.7E-04) | R@10(Val)=0.6378 | temp_scale=100.00 


Epoch 015/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 015: ReconLoss=0.3470 KL_Loss=2674.38 (LR=2.6E-04) | R@10(Val)=0.6434 | temp_scale=100.00 <- Best R@10! Saved.
[Diagnostics @ epoch 015]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.170462965965271, 'improvement': 0.45180654525756836}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 016/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 016: ReconLoss=0.3244 KL_Loss=2705.41 (LR=2.6E-04) | R@10(Val)=0.6518 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 017/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 017: ReconLoss=0.3140 KL_Loss=2713.78 (LR=2.5E-04) | R@10(Val)=0.6538 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 018/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 018: ReconLoss=0.3002 KL_Loss=2733.43 (LR=2.4E-04) | R@10(Val)=0.6566 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 019/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 019: ReconLoss=0.2816 KL_Loss=2744.52 (LR=2.4E-04) | R@10(Val)=0.6554 | temp_scale=100.00 


Epoch 020/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 020: ReconLoss=0.2686 KL_Loss=2755.99 (LR=2.3E-04) | R@10(Val)=0.6616 | temp_scale=100.00 <- Best R@10! Saved.
[Diagnostics @ epoch 020]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.17499786615371704, 'improvement': 0.4472716450691223}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 021/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 021: ReconLoss=0.2585 KL_Loss=2757.72 (LR=2.2E-04) | R@10(Val)=0.6612 | temp_scale=100.00 


Epoch 022/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 022: ReconLoss=0.2523 KL_Loss=2766.68 (LR=2.2E-04) | R@10(Val)=0.6626 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 023/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 023: ReconLoss=0.2462 KL_Loss=2774.00 (LR=2.1E-04) | R@10(Val)=0.6644 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 024/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 024: ReconLoss=0.2380 KL_Loss=2763.23 (LR=2.0E-04) | R@10(Val)=0.6674 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 025/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 025: ReconLoss=0.2334 KL_Loss=2769.29 (LR=2.0E-04) | R@10(Val)=0.6708 | temp_scale=100.00 <- Best R@10! Saved.
[Diagnostics @ epoch 025]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.17787212133407593, 'improvement': 0.4443973898887634}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 026/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 026: ReconLoss=0.2279 KL_Loss=2765.59 (LR=1.9E-04) | R@10(Val)=0.6690 | temp_scale=100.00 


Epoch 027/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 027: ReconLoss=0.2264 KL_Loss=2770.58 (LR=1.8E-04) | R@10(Val)=0.6696 | temp_scale=100.00 


Epoch 028/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 028: ReconLoss=0.2209 KL_Loss=2771.78 (LR=1.7E-04) | R@10(Val)=0.6700 | temp_scale=100.00 


Epoch 029/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 029: ReconLoss=0.2196 KL_Loss=2770.87 (LR=1.7E-04) | R@10(Val)=0.6740 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 030/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 030: ReconLoss=0.2143 KL_Loss=2779.81 (LR=1.6E-04) | R@10(Val)=0.6732 | temp_scale=100.00 
[Diagnostics @ epoch 030]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.1793111115694046, 'improvement': 0.44295839965343475}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 031/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 031: ReconLoss=0.2102 KL_Loss=2798.61 (LR=1.5E-04) | R@10(Val)=0.6746 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 032/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 032: ReconLoss=0.2071 KL_Loss=2772.90 (LR=1.4E-04) | R@10(Val)=0.6754 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 033/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 033: ReconLoss=0.2022 KL_Loss=2801.35 (LR=1.3E-04) | R@10(Val)=0.6756 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 034/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 034: ReconLoss=0.1988 KL_Loss=2803.99 (LR=1.3E-04) | R@10(Val)=0.6790 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 035/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 035: ReconLoss=0.1957 KL_Loss=2795.09 (LR=1.2E-04) | R@10(Val)=0.6800 | temp_scale=100.00 <- Best R@10! Saved.
[Diagnostics @ epoch 035]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.1804998219013214, 'improvement': 0.44176968932151794}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 036/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 036: ReconLoss=0.1925 KL_Loss=2793.88 (LR=1.1E-04) | R@10(Val)=0.6778 | temp_scale=100.00 


Epoch 037/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 037: ReconLoss=0.1924 KL_Loss=2794.30 (LR=1.0E-04) | R@10(Val)=0.6806 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 038/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 038: ReconLoss=0.1888 KL_Loss=2787.30 (LR=9.6E-05) | R@10(Val)=0.6800 | temp_scale=100.00 


Epoch 039/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 039: ReconLoss=0.1876 KL_Loss=2792.08 (LR=8.9E-05) | R@10(Val)=0.6786 | temp_scale=100.00 


Epoch 040/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 040: ReconLoss=0.1867 KL_Loss=2797.14 (LR=8.2E-05) | R@10(Val)=0.6784 | temp_scale=100.00 
[Diagnostics @ epoch 040]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.1813887655735016, 'improvement': 0.44088074564933777}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 041/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 041: ReconLoss=0.1836 KL_Loss=2790.07 (LR=7.5E-05) | R@10(Val)=0.6802 | temp_scale=100.00 


Epoch 042/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 042: ReconLoss=0.1821 KL_Loss=2780.54 (LR=6.8E-05) | R@10(Val)=0.6810 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 043/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 043: ReconLoss=0.1800 KL_Loss=2790.52 (LR=6.2E-05) | R@10(Val)=0.6806 | temp_scale=100.00 


Epoch 044/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 044: ReconLoss=0.1776 KL_Loss=2776.65 (LR=5.6E-05) | R@10(Val)=0.6820 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 045/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 045: ReconLoss=0.1756 KL_Loss=2794.65 (LR=5.0E-05) | R@10(Val)=0.6810 | temp_scale=100.00 
[Diagnostics @ epoch 045]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.18177343904972076, 'improvement': 0.4404960721731186}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 046/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 046: ReconLoss=0.1760 KL_Loss=2789.88 (LR=4.4E-05) | R@10(Val)=0.6812 | temp_scale=100.00 


Epoch 047/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 047: ReconLoss=0.1745 KL_Loss=2800.41 (LR=3.9E-05) | R@10(Val)=0.6822 | temp_scale=100.00 <- Best R@10! Saved.


Epoch 048/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 048: ReconLoss=0.1740 KL_Loss=2796.49 (LR=3.3E-05) | R@10(Val)=0.6808 | temp_scale=100.00 


Epoch 049/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 049: ReconLoss=0.1717 KL_Loss=2769.98 (LR=2.9E-05) | R@10(Val)=0.6800 | temp_scale=100.00 


Epoch 050/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 050: ReconLoss=0.1699 KL_Loss=2795.23 (LR=2.4E-05) | R@10(Val)=0.6808 | temp_scale=100.00 
[Diagnostics @ epoch 050]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.18252593278884888, 'improvement': 0.4397435784339905}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 051/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 051: ReconLoss=0.1708 KL_Loss=2794.85 (LR=2.0E-05) | R@10(Val)=0.6810 | temp_scale=100.00 


Epoch 052/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 052: ReconLoss=0.1720 KL_Loss=2799.70 (LR=1.6E-05) | R@10(Val)=0.6818 | temp_scale=100.00 


Epoch 053/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 053: ReconLoss=0.1723 KL_Loss=2777.84 (LR=1.3E-05) | R@10(Val)=0.6808 | temp_scale=100.00 


Epoch 054/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 054: ReconLoss=0.1719 KL_Loss=2791.04 (LR=1.0E-05) | R@10(Val)=0.6802 | temp_scale=100.00 


Epoch 055/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 055: ReconLoss=0.1722 KL_Loss=2787.55 (LR=7.3E-06) | R@10(Val)=0.6810 | temp_scale=100.00 
[Diagnostics @ epoch 055]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.18259023129940033, 'improvement': 0.439679279923439}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164


Epoch 056/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 056: ReconLoss=0.1695 KL_Loss=2798.43 (LR=5.1E-06) | R@10(Val)=0.6816 | temp_scale=100.00 


Epoch 057/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 057: ReconLoss=0.1701 KL_Loss=2791.25 (LR=3.3E-06) | R@10(Val)=0.6814 | temp_scale=100.00 


Epoch 058/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 058: ReconLoss=0.1693 KL_Loss=2775.61 (LR=1.8E-06) | R@10(Val)=0.6814 | temp_scale=100.00 


Epoch 059/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 059: ReconLoss=0.1712 KL_Loss=2791.86 (LR=8.2E-07) | R@10(Val)=0.6814 | temp_scale=100.00 


Epoch 060/60:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 060: ReconLoss=0.1721 KL_Loss=2787.73 (LR=2.1E-07) | R@10(Val)=0.6814 | temp_scale=100.00 
[Diagnostics @ epoch 060]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004918365739285946, 'transfer_cosine': 0.6222695112228394, 'model_cosine': 0.18261292576789856, 'improvement': 0.4396565854549408}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006508729420602322, 'mse_full': 0.0002451127511449158, 'gap_percent': 165.53346803595238, 'cos_orthogonal': 0.5001296401023865, 'cos_full': 0.7917026281356812}
  mutual_knn_alignment_fraction: 0.1164

ðŸŽ¯ Training complete. Best model (R@10=0.6822) saved as giga_vae_model_giga_vae_model_1762439103_best.pth
Loading BEST model from giga_vae_model_giga_vae_model_1762439103_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.


Generating Sub:   0%|          | 0/2 [00:00<?, ?it/s]


âœ… submission_giga_vae_run_giga_vae_model_1762439103.csv saved successfully.
Calculating final validation scores on full val set...


CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/13 [00:00<?, ?it/s]


--- Full Validation Results ---
Validation Recall@1:  0.3284  
Validation Recall@5:  0.5884  
Validation Recall@10: 0.6890  
Validation Recall@50: 0.8690  


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os
import time
from sklearn.linear_model import Ridge
import warnings

# --- EXPERIMENT 58+: "Giga-VAE (Retrieval Optimized)" ---
#
# Upgrades in this script:
# - Deterministic training (no reparam sampling; KL=0)
# - Symmetric InfoNCE (t<->i) with learned temperature (softplus param)
# - Label smoothing (no mixup)
# - Memory-bank hard negatives for i->t
# - Ridge distillation (aux loss to strong linear baseline)
# - Diagnostics included (rectangular Procrustes)
# -----------------------------------------------------------------

# -----------------------------
# 0. Repro + Device
# -----------------------------
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# 1. Load data
# -----------------------------

tx_train = train_data["captions/embeddings"]      # (N_tx, 1024)
im_train = train_data["images/embeddings"]        # (N_im, 1536)
tx_test  = test_data["captions/embeddings"]       # (N_test, 1024)

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)

print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------
# 2. Preprocess (normalize)
# -----------------------------
cpu_device = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu_device)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu_device)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu_device)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t = F.normalize(tx_test_t - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu_device)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)

print("Data preprocessed and normalized (on CPU).")

# -----------------------------
# 3. Model (deterministic translator)
# -----------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=6144, latent_dim=1024, output_dim=1536, dropout=0.4, use_reparam=False):
        super().__init__()
        self.latent_dim = latent_dim
        self.use_reparam = use_reparam
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim}, use_reparam={use_reparam})")

        # Encoder: text -> (mu, logvar)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2)
        )

        # Decoder: latent -> image embedding
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

    @staticmethod
    def _reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        if self.training and self.use_reparam:
            z = self._reparameterize(mu, logvar)
        else:
            z = mu  # deterministic

        y_pred = self.decoder(z)
        if self.training:
            return F.normalize(y_pred, p=2, dim=1), mu, logvar
        y_pred_eval = self.decoder(mu)
        return F.normalize(y_pred_eval, p=2, dim=1)

# -----------------------------
# 4. Training knobs
# -----------------------------
EPOCHS = 50
START_LR = 3e-4
WEIGHT_DECAY = 4e-4
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

# Retrieval/Loss
LABEL_SMOOTHING = 0.05
SYMMETRIC_INFO_NCE = True
TRIPLET_WEIGHT = 0.0             # triplet off (InfoNCE + hard negs dominate)

# Temperature (softplus param)
INIT_TAU = 0.01                  # initial temperature
LOGITSCALE_LR_MULT = 0.1         # smaller LR for temp
SMALL_EPS = 1e-6

# VAE regularization (off for retrieval)
KL_WEIGHT = 0.0
KL_WARMUP_EPOCHS = 0

# Ridge distillation
DISTILL_ALPHA = 1.0              # ridge alpha (fit)
DISTILL_LOSS_W = 0.05            # small auxiliary loss weight
DISTILL_SAMPLE_MAX = None        # None = use all; or int to subsample for speed/memory

# Memory bank (image side)
MEMBANK_USE = True
MEMBANK_TOPK = 64                # hard negatives per sample
MEMBANK_REFRESH_EVERY = 9999     # fixed bank (images don't change)
MEMBANK_CHUNK = 2048             # chunked sim for mining

print(f"Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight={KL_WEIGHT}) ***")

# -----------------------------
# 5. Validation split
# -----------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
idx_cpu = torch.randperm(N, device="cpu")
val_idx, train_idx = idx_cpu[:val_size], idx_cpu[val_size:]

img_indices_train = (train_idx // 5)
img_indices_val = (val_idx // 5)

tx_val_t, im_val_t = tx_train_t[val_idx], im_train_exp[val_idx]
tx_train_t_sub, im_train_exp_sub = tx_train_t[train_idx], im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_t_sub, im_train_exp_sub, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------
# 6. Recall@K (+ CSLS)
# -----------------------------
@torch.no_grad()
def recall_at_k(model, tx_queries, im_database, query_img_indices, chunk_size=1024, ks=(1,5,10,50), k_csls=10):
    model.eval()
    dev = next(model.parameters()).device
    all_correct = {k: [] for k in ks}

    Q = []
    for i in range(0, len(tx_queries), chunk_size):
        chunk = tx_queries[i:i+chunk_size].to(dev)
        Q.append(model(chunk))
    Q = torch.cat(Q, dim=0)

    D = im_database.to(dev)

    mean_knn_d_list = []
    for i in tqdm(range(0, len(D), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D[i:i+512] @ Q.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))
    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    gt = query_img_indices.to(dev)
    for i in tqdm(range(0, len(Q), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q[i:i+chunk_size] @ D.T
        gt_chunk = gt[i:i+chunk_size]

        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)

        csls = 2 * sim_chunk - mean_knn_q - mean_knn_d
        top_idx = torch.argsort(csls, dim=1, descending=True)
        for k in ks:
            top_k = top_idx[:, :k]
            correct = (top_k == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct[k].append(correct)

    return {f"R@{k}": torch.cat(all_correct[k]).float().mean().item() for k in ks}

# -----------------------------
# 7. Diagnostics
# -----------------------------
class DiagnosticMetrics:
    @staticmethod
    @torch.no_grad()
    def emb2emb_linear_transfer(src_emb, tgt_emb, model_pred_emb=None, alpha=1.0):
        src_np = (src_emb.detach().cpu().numpy()
                  if isinstance(src_emb, torch.Tensor) else src_emb)
        tgt_np = (tgt_emb.detach().cpu().numpy()
                  if isinstance(tgt_emb, torch.Tensor) else tgt_emb)

        ridge = Ridge(alpha=alpha, fit_intercept=False)
        ridge.fit(src_np, tgt_np)

        pred = ridge.predict(src_np)
        pred_t = torch.from_numpy(pred).float()

        if not isinstance(tgt_emb, torch.Tensor):
            tgt_emb = torch.from_numpy(tgt_np).float()
        tgt_t = tgt_emb.float()

        pred_t = F.normalize(pred_t, p=2, dim=1)
        tgt_norm = F.normalize(tgt_t, p=2, dim=1)

        transfer_mse = F.mse_loss(pred_t, tgt_norm).item()
        transfer_cosine = F.cosine_similarity(pred_t, tgt_norm).mean().item()

        if model_pred_emb is not None:
            if not isinstance(model_pred_emb, torch.Tensor):
                model_pred_emb = torch.from_numpy(model_pred_emb).float()
            model_cos = F.cosine_similarity(
                F.normalize(model_pred_emb.float(), p=2, dim=1), tgt_norm
            ).mean().item()
            improvement = transfer_cosine - model_cos
            baseline_name = 'model_cosine'
        else:
            model_cos = 0.0
            improvement = transfer_cosine
            baseline_name = 'random_baseline'

        return {
            'transfer_mse': transfer_mse,
            'transfer_cosine': transfer_cosine,
            baseline_name: model_cos,
            'improvement': improvement
        }

    @staticmethod
    @torch.no_grad()
    def stitching_penalty_monitor(src_emb, tgt_emb):
        # rectangular Procrustes via SVD (dims can differ)
        src_np = src_emb.detach().cpu().numpy()
        tgt_np = tgt_emb.detach().cpu().numpy()
        C = src_np.T @ tgt_np
        U, _, Vt = np.linalg.svd(C, full_matrices=False)
        R_rect = U @ Vt
        pred_ortho = src_np @ R_rect

        ridge = Ridge(alpha=0.01, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred_full = ridge.predict(src_np)

        mse_ortho = np.mean((pred_ortho - tgt_np) ** 2)
        mse_full  = np.mean((pred_full  - tgt_np) ** 2)

        pred_ortho_t = F.normalize(torch.from_numpy(pred_ortho).float(), p=2, dim=1)
        pred_full_t  = F.normalize(torch.from_numpy(pred_full ).float(), p=2, dim=1)
        tgt_norm     = F.normalize(torch.from_numpy(tgt_np   ).float(), p=2, dim=1)

        cos_ortho = F.cosine_similarity(pred_ortho_t, tgt_norm).mean().item()
        cos_full  = F.cosine_similarity(pred_full_t,  tgt_norm).mean().item()

        gap_pct = 100 * (mse_ortho - mse_full) / (mse_full + 1e-8)

        return {
            'mse_orthogonal': float(mse_ortho),
            'mse_full': float(mse_full),
            'gap_percent': float(gap_pct),
            'cos_orthogonal': cos_ortho,
            'cos_full': cos_full
        }

    @staticmethod
    @torch.no_grad()
    def mutual_knn_alignment(emb1, emb2, k=10):
        sim1 = emb1 @ emb1.T
        sim2 = emb2 @ emb2.T
        _, knn1 = torch.topk(sim1, k=k+1, dim=1)
        _, knn2 = torch.topk(sim2, k=k+1, dim=1)
        knn1 = knn1[:, 1:]
        knn2 = knn2[:, 1:]
        mutual_count = 0
        for i in range(len(emb1)):
            set1 = set(knn1[i].cpu().tolist())
            set2 = set(knn2[i].cpu().tolist())
            mutual_count += len(set1 & set2)
        return mutual_count / (len(emb1) * k)

# -----------------------------
# 8. Ridge baseline (fit once)
# -----------------------------
print("Fitting ridge baseline (text->image) for distillation...")
tx_np = tx_train_t.detach().cpu().numpy()
im_np = im_train_exp.detach().cpu().numpy()

if DISTILL_SAMPLE_MAX is not None and DISTILL_SAMPLE_MAX < len(tx_np):
    sel = np.random.choice(len(tx_np), size=DISTILL_SAMPLE_MAX, replace=False)
    tx_np_fit = tx_np[sel]
    im_np_fit = im_np[sel]
else:
    tx_np_fit = tx_np
    im_np_fit = im_np

ridge_map = Ridge(alpha=DISTILL_ALPHA, fit_intercept=False)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    ridge_map.fit(tx_np_fit, im_np_fit)

print("Ridge fit complete.")

# -----------------------------
# 9. Memory bank (images)
# -----------------------------
if MEMBANK_USE:
    # Fixed bank of unique image embeddings
    mem_bank = im_train_t_unique.clone().to(device)  # (N_im, 1536)
    mem_bank.requires_grad_(False)
else:
    mem_bank = None

# -----------------------------
# 10. Val subsets
# -----------------------------
val_query_subset = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset = im_train_t_unique  # unique images

# -----------------------------
# 11. Build model & optimizer
# -----------------------------
model = VAETranslator(
    input_dim=1024, hidden_dim=6144, latent_dim=1024, output_dim=1536,
    dropout=0.4, use_reparam=False  # deterministic training
).to(device)

# Learnable temperature via softplus param: temp = softplus(phi) + eps
# Initialize so that temp ~ INIT_TAU
init_phi = torch.log(torch.exp(torch.tensor(INIT_TAU)) - 1.0 + SMALL_EPS).to(device)
phi = nn.Parameter(init_phi.clone())

# Two parameter groups: model & temperature with smaller LR
optimizer = torch.optim.AdamW(
    [
        {"params": model.parameters(), "lr": START_LR, "weight_decay": WEIGHT_DECAY},
        {"params": [phi], "lr": START_LR * LOGITSCALE_LR_MULT, "weight_decay": 0.0},
    ]
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

# Label smoothing CE
def ce_with_smoothing(logits, target, smoothing=0.0):
    return F.cross_entropy(logits, target, label_smoothing=smoothing)

print(f"\n--- Starting Training (RUN {RUN_ID}) ---")
print(f"Epochs={EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, SymInfoNCE={SYMMETRIC_INFO_NCE}, MemBank={MEMBANK_USE}, Distill_w={DISTILL_LOSS_W}\n")

# -----------------------------
# 12. Training loop
# -----------------------------
best_val_r10 = 0.0
triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=0.06
)

from torch import amp  # <- make sure this import is present at the top!

import warnings

def mine_hard_negatives(y_pred, img_idx_batch, k=MEMBANK_TOPK):
    """Return (B, k, d) hard negatives from mem_bank in FP32 (safe for fp16 runs, no warnings)."""
    assert mem_bank is not None, "mem_bank not initialized"

    with torch.no_grad():
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            with torch.cuda.amp.autocast(enabled=False):  # keep old API, suppress warning
                y32 = y_pred.float()
                mb32 = mem_bank.float()
                sims = y32 @ mb32.T

                B = y32.size(0)
                ar = torch.arange(B, device=y_pred.device)
                sims[ar, img_idx_batch.long()] = -1e9

                k_eff = int(min(k, mb32.size(0) - 1))
                topk_idx = torch.topk(sims, k=k_eff, dim=1).indices
                negs = mb32[topk_idx]
    return negs

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_main, total_kl, total_distill = 0.0, 0.0, 0.0

    for x_batch, y_batch, img_idx in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)          # text emb
        y_batch = y_batch.to(device)          # GT image emb (expanded)
        img_idx  = img_idx.to(device)         # indices into unique image table

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            # Forward (deterministic training)
            y_pred, mu, logvar = model(x_batch)    # (B,d)

            # Learned temperature
            temp = torch.nn.functional.softplus(phi) + SMALL_EPS  # scalar
            scale = 1.0 / temp

            # ------- Symmetric InfoNCE on batch -------
            # i->t (rows=text, cols=images); positives are diagonal
            logits_bt = (y_pred @ y_batch.T) * scale  # (B,B)
            labels = torch.arange(y_pred.size(0), device=device)

            # Add hard negatives from memory bank (i->t side)
            if MEMBANK_USE and MEMBANK_TOPK > 0:
                negs = mine_hard_negatives(y_pred, img_idx, k=MEMBANK_TOPK)   # (B,k,d)
                # compute per-row sims to mined negatives in FP32, then cast to logits dtype
                with torch.cuda.amp.autocast(enabled=False):
                    sims_negs32 = torch.einsum('bd,bkd->bk', y_pred.float(), negs)  # (B,k) float32
                sims_negs = (sims_negs32 * (1.0 / (torch.nn.functional.softplus(phi).detach() + SMALL_EPS))).to(logits_bt.dtype)
                
                # concatenate batch logits (B,B) with mined negatives (B,k)
                logits_i2t = torch.cat([logits_bt, sims_negs], dim=1)
                # labels remain 0..B-1 (positives on the diagonal of first B cols)
            else:
                logits_i2t = logits_bt

            loss_i2t = ce_with_smoothing(logits_i2t, labels, smoothing=LABEL_SMOOTHING)

            # t->i (symmetric) â€” batch-only (no text memory bank)
            if SYMMETRIC_INFO_NCE:
                logits_t2i = (y_batch @ y_pred.T) * scale
                loss_t2i = ce_with_smoothing(logits_t2i, labels, smoothing=LABEL_SMOOTHING)
                loss_con = 0.5 * (loss_i2t + loss_t2i)
            else:
                loss_con = loss_i2t

            # Optional triplet (usually redundant with strong InfoNCE + hard negs)
            if TRIPLET_WEIGHT > 0.0:
                with torch.no_grad():
                    sims_no = y_pred @ y_batch.T
                    pos_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                    sims_no.masked_fill_(pos_mask, -float('inf'))
                    hard_neg_idx = sims_no.argmax(dim=1)
                y_hard_neg = y_batch[hard_neg_idx]
                loss_tri = triplet_loss_fn(y_pred, y_batch, y_hard_neg)
                main_loss = loss_con + TRIPLET_WEIGHT * loss_tri
            else:
                main_loss = loss_con

            # KL (off)
            kl_loss = torch.zeros((), device=device)
            if KL_WEIGHT > 0 and epoch <= KL_WARMUP_EPOCHS:
                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.size(0)
                warm = (epoch / max(1, KL_WARMUP_EPOCHS))
                kl_loss = KL_WEIGHT * warm * kl

            # Ridge distillation: pull decoder toward ridge(text)
            # Compute ridge(x) on CPU numpy for stability, then back to torch
            with torch.no_grad():
                x_np = x_batch.detach().cpu().numpy()
                ridge_pred = ridge_map.predict(x_np)                # (B,1536)
            ridge_pred_t = F.normalize(torch.from_numpy(ridge_pred).to(device).float(), p=2, dim=1)
            distill_loss = DISTILL_LOSS_W * (1.0 - F.cosine_similarity(y_pred, ridge_pred_t).mean())

            loss = main_loss + kl_loss + distill_loss

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_main += main_loss.item()
        total_kl += kl_loss.item() if torch.is_tensor(kl_loss) else 0.0
        total_distill += distill_loss.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()

    # --- Validation ---
    rec = recall_at_k(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    current_r10 = rec['R@10']
    temp_value = float((torch.nn.functional.softplus(phi) + SMALL_EPS).item())

    save_marker = ""
    if current_r10 > best_val_r10:
        best_val_r10 = current_r10
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best R@10! Saved."

    print(f"Epoch {epoch:03d}: MainLoss={total_main/len(train_loader):.4f} "
          f"KL={total_kl/len(train_loader):.4f} Distill={total_distill/len(train_loader):.4f} "
          f"(LR={current_lr:.1E}) | R@10(Val)={current_r10:.4f} | temp={temp_value:.4f} {save_marker}")

    # Diagnostics every 5 epochs + first
    if epoch == 1 or (epoch % 5 == 0):
        with torch.no_grad():
            model.eval()
            diag_n = min(2000, len(val_query_subset))
            tx_sample = val_query_subset[:diag_n].to(device)
            tgt_sample = val_db_subset[val_indices_subset[:diag_n]].to(device)
            pred_sample = model(tx_sample)

            tx_cpu = tx_sample.detach().cpu()
            tgt_cpu = tgt_sample.detach().cpu()
            pred_cpu = pred_sample.detach().cpu()

            emb2emb = DiagnosticMetrics.emb2emb_linear_transfer(tx_cpu, tgt_cpu, model_pred_emb=pred_cpu, alpha=1.0)
            stitch = DiagnosticMetrics.stitching_penalty_monitor(tx_cpu, tgt_cpu)
            mk_align = DiagnosticMetrics.mutual_knn_alignment(
                F.normalize(tx_cpu, p=2, dim=1),
                F.normalize(tgt_cpu, p=2, dim=1),
                k=10
            )
            print(f"[Diagnostics @ epoch {epoch:03d}]")
            print("  emb2emb_linear_transfer:", emb2emb)
            print("  stitching_penalty_monitor:", stitch)
            print(f"  mutual_knn_alignment_fraction: {mk_align:.4f}")

print(f"\nðŸŽ¯ Training complete. Best model (R@10={best_val_r10:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------
# 13. Inference + Submission
# -----------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(
    input_dim=1024, hidden_dim=6144, latent_dim=1024, output_dim=1536,
    dropout=0.4, use_reparam=False
).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH, map_location=device))
model.eval()
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        chunk = tx_test_t[i:i+1024].to(device)
        preds_list.append(model(chunk))
    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})

submission_filename = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(submission_filename, index=False)
print(f"\nâœ… {submission_filename} saved successfully.")

# --- Final Validation (full val set) ---
print("Calculating final validation scores on full val set...")
rec_full = recall_at_k(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
print("\n--- Full Validation Results ---")
print(f"Validation Recall@1:  {rec_full['R@1']:<8.4f}")
print(f"Validation Recall@5:  {rec_full['R@5']:<8.4f}")
print(f"Validation Recall@10: {rec_full['R@10']:<8.4f}")
print(f"Validation Recall@50: {rec_full['R@50']:<8.4f}")


Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight=0.0) ***
Train pairs: 112500, Validation pairs: 12500
Fitting ridge baseline (text->image) for distillation...
Ridge fit complete.
Initializing Giga-VAE (hidden_dim=6144, latent_dim=1024, use_reparam=False)

--- Starting Training (RUN giga_vae_model_1762442313) ---
Epochs=50, LR=0.0003, WD=0.0004, SymInfoNCE=True, MemBank=True, Distill_w=0.05



Epoch 001/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 001: MainLoss=4.8182 KL=0.0000 Distill=0.0445 (LR=3.0E-04) | R@10(Val)=0.3686 | temp=0.0099 <- Best R@10! Saved.
[Diagnostics @ epoch 001]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.07407386600971222, 'improvement': 0.5427722185850143}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 002/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 002: MainLoss=3.7613 KL=0.0000 Distill=0.0424 (LR=3.0E-04) | R@10(Val)=0.4594 | temp=0.0099 <- Best R@10! Saved.


Epoch 003/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 003: MainLoss=3.4168 KL=0.0000 Distill=0.0417 (LR=3.0E-04) | R@10(Val)=0.5160 | temp=0.0098 <- Best R@10! Saved.


Epoch 004/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 004: MainLoss=3.1784 KL=0.0000 Distill=0.0414 (LR=3.0E-04) | R@10(Val)=0.5530 | temp=0.0097 <- Best R@10! Saved.


Epoch 005/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 005: MainLoss=2.9826 KL=0.0000 Distill=0.0411 (LR=3.0E-04) | R@10(Val)=0.5806 | temp=0.0096 <- Best R@10! Saved.
[Diagnostics @ epoch 005]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.09526350349187851, 'improvement': 0.521582581102848}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 006/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 006: MainLoss=2.8122 KL=0.0000 Distill=0.0410 (LR=2.9E-04) | R@10(Val)=0.6080 | temp=0.0096 <- Best R@10! Saved.


Epoch 007/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 007: MainLoss=2.6638 KL=0.0000 Distill=0.0409 (LR=2.9E-04) | R@10(Val)=0.6216 | temp=0.0095 <- Best R@10! Saved.


Epoch 008/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 008: MainLoss=2.5346 KL=0.0000 Distill=0.0409 (LR=2.9E-04) | R@10(Val)=0.6414 | temp=0.0094 <- Best R@10! Saved.


Epoch 009/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 009: MainLoss=2.4140 KL=0.0000 Distill=0.0409 (LR=2.8E-04) | R@10(Val)=0.6464 | temp=0.0094 <- Best R@10! Saved.


Epoch 010/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 010: MainLoss=2.2567 KL=0.0000 Distill=0.0410 (LR=2.8E-04) | R@10(Val)=0.6568 | temp=0.0093 <- Best R@10! Saved.
[Diagnostics @ epoch 010]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10597935318946838, 'improvement': 0.5108667314052582}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 011/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 011: MainLoss=2.1613 KL=0.0000 Distill=0.0411 (LR=2.7E-04) | R@10(Val)=0.6622 | temp=0.0093 <- Best R@10! Saved.


Epoch 012/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 012: MainLoss=2.1031 KL=0.0000 Distill=0.0411 (LR=2.7E-04) | R@10(Val)=0.6690 | temp=0.0092 <- Best R@10! Saved.


Epoch 013/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 013: MainLoss=2.0441 KL=0.0000 Distill=0.0412 (LR=2.6E-04) | R@10(Val)=0.6724 | temp=0.0092 <- Best R@10! Saved.


Epoch 014/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 014: MainLoss=1.9902 KL=0.0000 Distill=0.0413 (LR=2.5E-04) | R@10(Val)=0.6758 | temp=0.0091 <- Best R@10! Saved.


Epoch 015/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 015: MainLoss=1.9384 KL=0.0000 Distill=0.0414 (LR=2.5E-04) | R@10(Val)=0.6780 | temp=0.0091 <- Best R@10! Saved.
[Diagnostics @ epoch 015]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10880851000547409, 'improvement': 0.5080375745892525}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 016/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 016: MainLoss=1.8903 KL=0.0000 Distill=0.0415 (LR=2.4E-04) | R@10(Val)=0.6788 | temp=0.0090 <- Best R@10! Saved.


Epoch 017/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 017: MainLoss=1.8442 KL=0.0000 Distill=0.0416 (LR=2.3E-04) | R@10(Val)=0.6876 | temp=0.0090 <- Best R@10! Saved.


Epoch 018/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 018: MainLoss=1.8011 KL=0.0000 Distill=0.0417 (LR=2.2E-04) | R@10(Val)=0.6868 | temp=0.0089 


Epoch 019/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 019: MainLoss=1.7435 KL=0.0000 Distill=0.0417 (LR=2.1E-04) | R@10(Val)=0.6890 | temp=0.0089 <- Best R@10! Saved.


Epoch 020/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 020: MainLoss=1.7020 KL=0.0000 Distill=0.0418 (LR=2.1E-04) | R@10(Val)=0.6896 | temp=0.0089 <- Best R@10! Saved.
[Diagnostics @ epoch 020]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.1088089570403099, 'improvement': 0.5080371275544167}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 021/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 021: MainLoss=1.6820 KL=0.0000 Distill=0.0419 (LR=2.0E-04) | R@10(Val)=0.6888 | temp=0.0088 


Epoch 022/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 022: MainLoss=1.6581 KL=0.0000 Distill=0.0420 (LR=1.9E-04) | R@10(Val)=0.6860 | temp=0.0088 


Epoch 023/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 023: MainLoss=1.6384 KL=0.0000 Distill=0.0420 (LR=1.8E-04) | R@10(Val)=0.6892 | temp=0.0088 


Epoch 024/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 024: MainLoss=1.6216 KL=0.0000 Distill=0.0421 (LR=1.7E-04) | R@10(Val)=0.6944 | temp=0.0087 <- Best R@10! Saved.


Epoch 025/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 025: MainLoss=1.6028 KL=0.0000 Distill=0.0421 (LR=1.6E-04) | R@10(Val)=0.6944 | temp=0.0087 
[Diagnostics @ epoch 025]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10787755995988846, 'improvement': 0.5089685246348381}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 026/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 026: MainLoss=1.5924 KL=0.0000 Distill=0.0422 (LR=1.5E-04) | R@10(Val)=0.6956 | temp=0.0087 <- Best R@10! Saved.


Epoch 027/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 027: MainLoss=1.5724 KL=0.0000 Distill=0.0422 (LR=1.4E-04) | R@10(Val)=0.6924 | temp=0.0087 


Epoch 028/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 028: MainLoss=1.5564 KL=0.0000 Distill=0.0423 (LR=1.3E-04) | R@10(Val)=0.6958 | temp=0.0086 <- Best R@10! Saved.


Epoch 029/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 029: MainLoss=1.5500 KL=0.0000 Distill=0.0423 (LR=1.2E-04) | R@10(Val)=0.6944 | temp=0.0086 


Epoch 030/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 030: MainLoss=1.5377 KL=0.0000 Distill=0.0423 (LR=1.1E-04) | R@10(Val)=0.6962 | temp=0.0086 <- Best R@10! Saved.
[Diagnostics @ epoch 030]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.1070181131362915, 'improvement': 0.5098279714584351}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 031/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 031: MainLoss=1.5275 KL=0.0000 Distill=0.0424 (LR=1.0E-04) | R@10(Val)=0.6944 | temp=0.0086 


Epoch 032/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 032: MainLoss=1.5185 KL=0.0000 Distill=0.0424 (LR=9.5E-05) | R@10(Val)=0.6960 | temp=0.0085 


Epoch 033/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 033: MainLoss=1.5098 KL=0.0000 Distill=0.0424 (LR=8.6E-05) | R@10(Val)=0.6956 | temp=0.0085 


Epoch 034/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 034: MainLoss=1.4993 KL=0.0000 Distill=0.0425 (LR=7.8E-05) | R@10(Val)=0.6960 | temp=0.0085 


Epoch 035/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 035: MainLoss=1.4929 KL=0.0000 Distill=0.0425 (LR=7.0E-05) | R@10(Val)=0.6976 | temp=0.0085 <- Best R@10! Saved.
[Diagnostics @ epoch 035]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10630013048648834, 'improvement': 0.5105459541082382}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 036/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 036: MainLoss=1.4861 KL=0.0000 Distill=0.0425 (LR=6.2E-05) | R@10(Val)=0.6976 | temp=0.0085 


Epoch 037/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 037: MainLoss=1.4768 KL=0.0000 Distill=0.0425 (LR=5.4E-05) | R@10(Val)=0.6974 | temp=0.0085 


Epoch 038/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 038: MainLoss=1.4702 KL=0.0000 Distill=0.0425 (LR=4.7E-05) | R@10(Val)=0.6960 | temp=0.0085 


Epoch 039/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 039: MainLoss=1.4714 KL=0.0000 Distill=0.0426 (LR=4.1E-05) | R@10(Val)=0.6962 | temp=0.0085 


Epoch 040/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 040: MainLoss=1.4638 KL=0.0000 Distill=0.0426 (LR=3.4E-05) | R@10(Val)=0.6960 | temp=0.0085 
[Diagnostics @ epoch 040]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10582587867975235, 'improvement': 0.5110202059149742}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 041/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 041: MainLoss=1.4595 KL=0.0000 Distill=0.0426 (LR=2.9E-05) | R@10(Val)=0.6962 | temp=0.0085 


Epoch 042/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 042: MainLoss=1.4587 KL=0.0000 Distill=0.0426 (LR=2.3E-05) | R@10(Val)=0.6966 | temp=0.0085 


Epoch 043/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 043: MainLoss=1.4556 KL=0.0000 Distill=0.0426 (LR=1.9E-05) | R@10(Val)=0.6976 | temp=0.0084 


Epoch 044/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 044: MainLoss=1.4540 KL=0.0000 Distill=0.0426 (LR=1.4E-05) | R@10(Val)=0.6974 | temp=0.0084 


Epoch 045/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 045: MainLoss=1.4528 KL=0.0000 Distill=0.0426 (LR=1.1E-05) | R@10(Val)=0.6974 | temp=0.0084 
[Diagnostics @ epoch 045]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10567040741443634, 'improvement': 0.5111756771802902}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152


Epoch 046/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 046: MainLoss=1.4507 KL=0.0000 Distill=0.0426 (LR=7.3E-06) | R@10(Val)=0.6974 | temp=0.0084 


Epoch 047/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 047: MainLoss=1.4489 KL=0.0000 Distill=0.0426 (LR=4.7E-06) | R@10(Val)=0.6972 | temp=0.0084 


Epoch 048/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 048: MainLoss=1.4480 KL=0.0000 Distill=0.0426 (LR=2.7E-06) | R@10(Val)=0.6974 | temp=0.0084 


Epoch 049/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 049: MainLoss=1.4473 KL=0.0000 Distill=0.0426 (LR=1.2E-06) | R@10(Val)=0.6974 | temp=0.0084 


Epoch 050/50:   0%|          | 0/220 [00:00<?, ?it/s]

CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 050: MainLoss=1.4512 KL=0.0000 Distill=0.0426 (LR=3.0E-07) | R@10(Val)=0.6974 | temp=0.0084 
[Diagnostics @ epoch 050]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004988984437659383, 'transfer_cosine': 0.6168460845947266, 'model_cosine': 0.10564102232456207, 'improvement': 0.5112050622701645}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006537790759466588, 'mse_full': 0.00024762985412962735, 'gap_percent': 164.00802013250754, 'cos_orthogonal': 0.4978977143764496, 'cos_full': 0.7889490127563477}
  mutual_knn_alignment_fraction: 0.1152

ðŸŽ¯ Training complete. Best model (R@10=0.6976) saved as giga_vae_model_giga_vae_model_1762442313_best.pth
Loading BEST model from giga_vae_model_giga_vae_model_1762442313_best.pth for inference...
Initializing Giga-VAE (hidden_dim=6144, latent_dim=1024, use_reparam=False)
Best model loaded.


Generating Sub:   0%|          | 0/2 [00:00<?, ?it/s]


âœ… submission_giga_vae_run_giga_vae_model_1762442313.csv saved successfully.
Calculating final validation scores on full val set...


CSLS (Pass 1/2):   0%|          | 0/49 [00:00<?, ?it/s]

CSLS (Pass 2/2):   0%|          | 0/13 [00:00<?, ?it/s]


--- Full Validation Results ---
Validation Recall@1:  0.3706  
Validation Recall@5:  0.6118  
Validation Recall@10: 0.7010  
Validation Recall@50: 0.8655  


In [14]:
import os, warnings
os.environ["PYTHONWARNINGS"] = "ignore"   # blanket: don't show warnings
warnings.filterwarnings("ignore")         # extra belt-and-suspenders
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*autocast.*deprecated.*")


In [3]:
# ===========================
# Giga-VAE â€” Retrieval-Focused (LB-parity)
# ===========================
import os, time, warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# (Optional) silence widget warnings in some notebooks
warnings.filterwarnings("ignore", message=".*Error displaying widget.*")

# -----------------------------
# 0) Device & basic setup
# -----------------------------
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# 1) Load data
# -----------------------------
tx_train = train_data["captions/embeddings"]      # (N_tx, 1024)
im_train = train_data["images/embeddings"]        # (N_im, 1536)
tx_test  = test_data["captions/embeddings"]       # (N_test, 1024)

# Expand images 5x to align with captions count
repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------
# 2) Preprocess (mean-center + L2 norm)
# -----------------------------
cpu = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t  = F.normalize(tx_test_t  - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------
# 3) Model (VAETranslator)
# -----------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # Encoder: text -> (mu, logvar)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2)
        )

        # Decoder: latent -> image embedding
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

    @staticmethod
    def _reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Encode
        mu_logvar = self.encoder(x)
        mu     = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        # Train: stochastic sample; Eval: deterministic (mu)
        if self.training:
            z = self._reparameterize(mu, logvar)
            y = self.decoder(z)
            return F.normalize(y, p=2, dim=1), mu, logvar
        else:
            y_eval = self.decoder(mu)
            return F.normalize(y_eval, p=2, dim=1)

# -----------------------------
# 4) Loss knobs & utilities
# -----------------------------
# Keep identical to your stronger run, but:
# - use Symmetric InfoNCE
# - learnable temperature with a reasonable cap (scale<=50)
TAU_INIT = 0.01        # for initializing temperature
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
MIXUP_ALPHA = 0.3
MIXUP_WARMUP_EPOCHS = 3
KL_WEIGHT = 1e-5       # light VAE regularization (same as your better run)
SYMMETRIC_INFO_NCE = True
MAX_LOGIT_SCALE = 50.0 # <-- important: avoid over-peaky logits (hurts LB)

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight={KL_WEIGHT}) ***")

def kl_loss_fn(mu, logvar):
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl / mu.size(0)

# -----------------------------
# 5) Split train/val
# -----------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
perm = torch.randperm(N, device=cpu)
val_idx, train_idx = perm[:val_size], perm[val_size:]

img_indices_train = (train_idx // 5)  # indices into unique images
img_indices_val   = (val_idx // 5)

tx_val_t  = tx_train_t[val_idx]
im_val_t  = im_train_exp[val_idx]
tx_train_sub = tx_train_t[train_idx]
im_train_sub = im_train_exp[train_idx]

train_ds = TensorDataset(tx_train_sub, im_train_sub, img_indices_train)
train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_ds)}, Validation pairs: {len(val_idx)}")

# -----------------------------
# 6) Recall@K evaluators
# -----------------------------
@torch.no_grad()
def recall_at_k_cosine(model, tx_queries, im_database, gt_img_idx, chunk_size=1024, ks=(1,5,10,50)):
    model.eval()
    dev = next(model.parameters()).device
    # Encode queries
    Q = []
    for i in range(0, len(tx_queries), chunk_size):
        Q.append(model(tx_queries[i:i+chunk_size].to(dev)))
    Q = torch.cat(Q, dim=0)  # (Q,d)
    D = im_database.to(dev)  # (I,d)

    all_correct = {k: [] for k in ks}
    # Cosine sim is dot product since both are L2-normalized
    sim = Q @ D.T  # (Q, I)
    top_idx = torch.argsort(sim, dim=1, descending=True)  # (Q, I)
    gt = gt_img_idx.to(dev)
    for k in ks:
        top_k = top_idx[:, :k]
        correct = (top_k == gt.unsqueeze(1)).any(dim=1)
        all_correct[k].append(correct)
    return {f"R@{k}": torch.cat(all_correct[k]).float().mean().item() for k in ks}

@torch.no_grad()
def recall_at_k_csls(model, tx_queries, im_database, gt_img_idx, chunk_size=1024, ks=(1,5,10,50), k_csls=10):
    model.eval()
    dev = next(model.parameters()).device
    # Encode queries
    Q = []
    for i in range(0, len(tx_queries), chunk_size):
        Q.append(model(tx_queries[i:i+chunk_size].to(dev)))
    Q = torch.cat(Q, dim=0)  # (Q,d)
    D = im_database.to(dev)  # (I,d)

    # CSLS components
    mean_knn_d_list = []
    for i in tqdm(range(0, len(D), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D[i:i+512] @ Q.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))
    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    all_correct = {k: [] for k in ks}
    gt = gt_img_idx.to(dev)
    for i in tqdm(range(0, len(Q), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q[i:i+chunk_size] @ D.T
        gt_chunk = gt[i:i+chunk_size]
        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)
        csls = 2 * sim_chunk - mean_knn_q - mean_knn_d
        top_idx = torch.argsort(csls, dim=1, descending=True)
        for k in ks:
            top_k = top_idx[:, :k]
            correct = (top_k == gt_chunk.unsqueeze(1)).any(dim=1)
            all_correct[k].append(correct)
    return {f"R@{k}": torch.cat(all_correct[k]).float().mean().item() for k in ks}

# -----------------------------
# 7) Diagnostics (unchanged)
# -----------------------------
class DiagnosticMetrics:
    @staticmethod
    @torch.no_grad()
    def emb2emb_linear_transfer(src_emb, tgt_emb, model_pred_emb=None, alpha=1.0):
        from sklearn.linear_model import Ridge
        src_np = (src_emb.detach().cpu().numpy()
                  if isinstance(src_emb, torch.Tensor) else src_emb)
        tgt_np = (tgt_emb.detach().cpu().numpy()
                  if isinstance(tgt_emb, torch.Tensor) else tgt_emb)
        ridge = Ridge(alpha=alpha, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred = ridge.predict(src_np)
        pred_t = torch.from_numpy(pred).float()

        if not isinstance(tgt_emb, torch.Tensor):
            tgt_emb = torch.from_numpy(tgt_np).float()
        tgt_t = tgt_emb.float()

        pred_t = F.normalize(pred_t, p=2, dim=1)
        tgt_norm = F.normalize(tgt_t, p=2, dim=1)
        transfer_mse = F.mse_loss(pred_t, tgt_norm).item()
        transfer_cosine = F.cosine_similarity(pred_t, tgt_norm).mean().item()

        if model_pred_emb is not None:
            if not isinstance(model_pred_emb, torch.Tensor):
                model_pred_emb = torch.from_numpy(model_pred_emb).float()
            model_cos = F.cosine_similarity(
                F.normalize(model_pred_emb.float(), p=2, dim=1), tgt_norm
            ).mean().item()
            improvement = transfer_cosine - model_cos
            baseline_name = 'model_cosine'
        else:
            model_cos = 0.0
            improvement = transfer_cosine
            baseline_name = 'random_baseline'
        return {'transfer_mse': transfer_mse,
                'transfer_cosine': transfer_cosine,
                baseline_name: model_cos,
                'improvement': improvement}

    @staticmethod
    @torch.no_grad()
    def stitching_penalty_monitor(src_emb, tgt_emb):
        # rectangular Procrustes via SVD (handles 1024 vs 1536 dims)
        src_np = src_emb.detach().cpu().numpy()
        tgt_np = tgt_emb.detach().cpu().numpy()
        C = src_np.T @ tgt_np
        U, _, Vt = np.linalg.svd(C, full_matrices=False)
        R_rect = U @ Vt
        pred_ortho = src_np @ R_rect

        from sklearn.linear_model import Ridge
        ridge = Ridge(alpha=0.01, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred_full = ridge.predict(src_np)

        mse_ortho = float(np.mean((pred_ortho - tgt_np) ** 2))
        mse_full  = float(np.mean((pred_full  - tgt_np) ** 2))

        pred_ortho_t = F.normalize(torch.from_numpy(pred_ortho).float(), p=2, dim=1)
        pred_full_t  = F.normalize(torch.from_numpy(pred_full ).float(), p=2, dim=1)
        tgt_norm     = F.normalize(torch.from_numpy(tgt_np   ).float(), p=2, dim=1)

        cos_ortho = F.cosine_similarity(pred_ortho_t, tgt_norm).mean().item()
        cos_full  = F.cosine_similarity(pred_full_t,  tgt_norm).mean().item()
        gap_pct = 100 * (mse_ortho - mse_full) / (mse_full + 1e-8)

        return {'mse_orthogonal': mse_ortho,
                'mse_full': mse_full,
                'gap_percent': float(gap_pct),
                'cos_orthogonal': cos_ortho,
                'cos_full': cos_full}

    @staticmethod
    @torch.no_grad()
    def mutual_knn_alignment(emb1, emb2, k=10):
        sim1 = emb1 @ emb1.T
        sim2 = emb2 @ emb2.T
        _, knn1 = torch.topk(sim1, k=k+1, dim=1)
        _, knn2 = torch.topk(sim2, k=k+1, dim=1)
        knn1 = knn1[:, 1:]
        knn2 = knn2[:, 1:]
        mutual_count = 0
        for i in range(len(emb1)):
            set1 = set(knn1[i].cpu().tolist())
            set2 = set(knn2[i].cpu().tolist())
            mutual_count += len(set1 & set2)
        return mutual_count / (len(emb1) * k)


# -----------------------------
# 8) Train setup
# -----------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

val_query_subset   = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset      = im_train_t_unique  # use unique images for DB

model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

# Learnable temperature (logit scale) initialized from TAU_INIT
logit_scale = nn.Parameter(torch.log(torch.tensor(1.0 / TAU_INIT, device=device)))

optimizer = torch.optim.AdamW(
    list(model.parameters()) + [logit_scale],
    lr=START_LR, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Training (RUN {RUN_ID}) ---")
print(f"Epochs={EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, SymInfoNCE={SYMMETRIC_INFO_NCE}, MixupWarmup={MIXUP_WARMUP_EPOCHS}â†’Î±={MIXUP_ALPHA}\n")

best_val_cos_r50 = 0.0  # early stop based on cosine R@50 (LB-like)
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_recon, total_kl = 0.0, 0.0

    for x_batch, y_batch, _img_idx in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # Mixup warmup
        use_mixup = (epoch > MIXUP_WARMUP_EPOCHS)
        if use_mixup:
            idx_shuffle = torch.randperm(x_batch.size(0), device=x_batch.device)
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            x_mix = lam * x_batch + (1 - lam) * x_batch[idx_shuffle]
            y_mix = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)
        else:
            x_mix, y_mix = x_batch, y_batch

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            # Forward (VAE)
            y_pred, mu, logvar = model(x_mix)   # (B,d)

            # Learnable temperature with a safe cap on scale
            scale = logit_scale.exp()
            scale = torch.clamp(scale, max=torch.tensor(MAX_LOGIT_SCALE, device=scale.device))
            logits = (y_pred @ y_mix.T) * scale

            labels = torch.arange(y_pred.size(0), device=device)

            # InfoNCE i->t
            loss_i2t = F.cross_entropy(logits, labels)
            # InfoNCE t->i (symmetric)
            if SYMMETRIC_INFO_NCE:
                loss_t2i = F.cross_entropy(logits.T, labels)
                loss_con = 0.5 * (loss_i2t + loss_t2i)
            else:
                loss_con = loss_i2t

            # Triplet (hard negatives within batch)
            with torch.no_grad():
                sims_no_tau = y_pred @ y_mix.T
                pos_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(pos_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_mix[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_mix, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)

            # KL (light)
            kl = kl_loss_fn(mu, logvar)
            loss = recon_loss + (KL_WEIGHT * kl)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon += recon_loss.item()
        total_kl += kl.item()

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()

    avg_recon = total_recon / len(train_loader)
    avg_kl = total_kl / len(train_loader)
    current_scale = float(scale.item())

    # ----- Validation -----
    # 1) Cosine (LB-like) on subset â€” used for early stopping
    rec_cos = recall_at_k_cosine(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    # 2) CSLS on subset â€” just to monitor
    rec_csls = recall_at_k_csls(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)

    # Early stop key (cosine R@50)
    cos_r50 = rec_cos['R@50']
    save_marker = ""
    if cos_r50 > best_val_cos_r50:
        best_val_cos_r50 = cos_r50
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best COS R@50! Saved."

    print(
        f"Epoch {epoch:03d}: ReconLoss={avg_recon:.4f} KL={avg_kl:.2f} "
        f"(LR={current_lr:.1E}) | temp_scale={current_scale:.2f} | "
        f"COS R@10={rec_cos['R@10']:.4f} R@50={rec_cos['R@50']:.4f} | "
        f"CSLS R@10={rec_csls['R@10']:.4f} R@50={rec_csls['R@50']:.4f} {save_marker}"
    )

    # Diagnostics (every 5 epochs + first)
    if epoch == 1 or (epoch % 5 == 0):
        with torch.no_grad():
            model.eval()
            diag_n = min(2000, len(val_query_subset))
            tx_sample = val_query_subset[:diag_n].to(device)
            tgt_sample = val_db_subset[val_indices_subset[:diag_n]].to(device)
            pred_sample = model(tx_sample)

            tx_cpu = tx_sample.detach().cpu()
            tgt_cpu = tgt_sample.detach().cpu()
            pred_cpu = pred_sample.detach().cpu()

            emb2emb = DiagnosticMetrics.emb2emb_linear_transfer(tx_cpu, tgt_cpu, model_pred_emb=pred_cpu, alpha=1.0)
            stitch = DiagnosticMetrics.stitching_penalty_monitor(tx_cpu, tgt_cpu)
            mk_align = DiagnosticMetrics.mutual_knn_alignment(
                F.normalize(tx_cpu, p=2, dim=1),
                F.normalize(tgt_cpu, p=2, dim=1),
                k=10
            )
            print(f"[Diagnostics @ epoch {epoch:03d}]")
            print("  emb2emb_linear_transfer:", emb2emb)
            print("  stitching_penalty_monitor:", stitch)
            print(f"  mutual_knn_alignment_fraction: {mk_align:.4f}")

print(f"\nðŸŽ¯ Training complete. Best model (COS R@50={best_val_cos_r50:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------
# 9) Inference + Submission
# -----------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH, map_location=device))
model.eval()
print("Best model loaded.")

with torch.no_grad():
    preds_list = []
    for i in tqdm(range(0, len(tx_test_t), 1024), desc="Generating Sub"):
        preds_list.append(model(tx_test_t[i:i+1024].to(device)))
    preds = torch.cat(preds_list, dim=0).cpu().numpy()

test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds]
})
out_csv = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(out_csv, index=False)
print(f"\nâœ… {out_csv} saved successfully.")

# -----------------------------
# 10) Final validation (full set)
# -----------------------------
print("Calculating final validation scores (cosine & CSLS) on FULL val set...")

rec_cos_full  = recall_at_k_cosine(model, tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)
rec_csls_full = recall_at_k_csls(model,   tx_val_t, im_train_t_unique, img_indices_val, chunk_size=1024)

print("\n--- Full Validation Results (COSINE) ---")
print(f"R@1:  {rec_cos_full['R@1']:<8.4f}")
print(f"R@5:  {rec_cos_full['R@5']:<8.4f}")
print(f"R@10: {rec_cos_full['R@10']:<8.4f}")
print(f"R@50: {rec_cos_full['R@50']:<8.4f}")

print("\n--- Full Validation Results (CSLS) ---")
print(f"R@1:  {rec_csls_full['R@1']:<8.4f}")
print(f"R@5:  {rec_csls_full['R@5']:<8.4f}")
print(f"R@10: {rec_csls_full['R@10']:<8.4f}")
print(f"R@50: {rec_csls_full['R@50']:<8.4f}")



Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight=1e-05) ***
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting Training (RUN giga_vae_model_1762505842) ---
Epochs=60, LR=0.0003, WD=0.0004, SymInfoNCE=True, MixupWarmup=3â†’Î±=0.3



                                                               

Epoch 001: ReconLoss=2.2446 KL=1267.33 (LR=3.0E-04) | temp_scale=50.00 | COS R@10=0.2992 R@50=0.5788 | CSLS R@10=0.3664 R@50=0.6396 <- Best COS R@50! Saved.
[Diagnostics @ epoch 001]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.23722252249717712, 'improvement': 0.384075790643692}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 002: ReconLoss=1.3479 KL=1820.64 (LR=3.0E-04) | temp_scale=50.00 | COS R@10=0.3842 R@50=0.6662 | CSLS R@10=0.4506 R@50=0.7220 <- Best COS R@50! Saved.


                                                               

Epoch 003: ReconLoss=1.1067 KL=2055.36 (LR=3.0E-04) | temp_scale=50.00 | COS R@10=0.4392 R@50=0.7134 | CSLS R@10=0.4986 R@50=0.7572 <- Best COS R@50! Saved.


                                                               

Epoch 004: ReconLoss=0.9274 KL=2184.35 (LR=3.0E-04) | temp_scale=50.00 | COS R@10=0.4646 R@50=0.7330 | CSLS R@10=0.5222 R@50=0.7704 <- Best COS R@50! Saved.


                                                               

Epoch 005: ReconLoss=0.8230 KL=2266.38 (LR=3.0E-04) | temp_scale=50.00 | COS R@10=0.4882 R@50=0.7514 | CSLS R@10=0.5384 R@50=0.7888 <- Best COS R@50! Saved.
[Diagnostics @ epoch 005]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.2600101828575134, 'improvement': 0.3612881302833557}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 006: ReconLoss=0.7487 KL=2365.08 (LR=2.9E-04) | temp_scale=50.00 | COS R@10=0.5008 R@50=0.7698 | CSLS R@10=0.5590 R@50=0.8016 <- Best COS R@50! Saved.


                                                               

Epoch 007: ReconLoss=0.6837 KL=2423.70 (LR=2.9E-04) | temp_scale=50.00 | COS R@10=0.5222 R@50=0.7802 | CSLS R@10=0.5734 R@50=0.8078 <- Best COS R@50! Saved.


                                                               

Epoch 008: ReconLoss=0.6415 KL=2494.48 (LR=2.9E-04) | temp_scale=50.00 | COS R@10=0.5346 R@50=0.7902 | CSLS R@10=0.5862 R@50=0.8152 <- Best COS R@50! Saved.


                                                               

Epoch 009: ReconLoss=0.5950 KL=2554.84 (LR=2.9E-04) | temp_scale=50.00 | COS R@10=0.5474 R@50=0.7982 | CSLS R@10=0.5952 R@50=0.8226 <- Best COS R@50! Saved.


                                                               

Epoch 010: ReconLoss=0.5300 KL=2598.86 (LR=2.8E-04) | temp_scale=50.00 | COS R@10=0.5700 R@50=0.8128 | CSLS R@10=0.6130 R@50=0.8354 <- Best COS R@50! Saved.
[Diagnostics @ epoch 010]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.27196529507637024, 'improvement': 0.3493330180644989}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 011: ReconLoss=0.4929 KL=2654.20 (LR=2.8E-04) | temp_scale=50.00 | COS R@10=0.5724 R@50=0.8178 | CSLS R@10=0.6158 R@50=0.8366 <- Best COS R@50! Saved.


                                                               

Epoch 012: ReconLoss=0.4707 KL=2658.54 (LR=2.8E-04) | temp_scale=50.00 | COS R@10=0.5802 R@50=0.8220 | CSLS R@10=0.6238 R@50=0.8390 <- Best COS R@50! Saved.


                                                               

Epoch 013: ReconLoss=0.4456 KL=2685.48 (LR=2.7E-04) | temp_scale=50.00 | COS R@10=0.5874 R@50=0.8208 | CSLS R@10=0.6288 R@50=0.8398 


                                                               

Epoch 014: ReconLoss=0.4319 KL=2693.87 (LR=2.7E-04) | temp_scale=50.00 | COS R@10=0.5896 R@50=0.8260 | CSLS R@10=0.6286 R@50=0.8440 <- Best COS R@50! Saved.


                                                               

Epoch 015: ReconLoss=0.4202 KL=2720.62 (LR=2.6E-04) | temp_scale=50.00 | COS R@10=0.5996 R@50=0.8270 | CSLS R@10=0.6354 R@50=0.8446 <- Best COS R@50! Saved.
[Diagnostics @ epoch 015]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.27919232845306396, 'improvement': 0.3421059846878052}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 016: ReconLoss=0.4020 KL=2742.55 (LR=2.6E-04) | temp_scale=50.00 | COS R@10=0.6054 R@50=0.8266 | CSLS R@10=0.6408 R@50=0.8460 


                                                               

Epoch 017: ReconLoss=0.3859 KL=2765.41 (LR=2.5E-04) | temp_scale=50.00 | COS R@10=0.6060 R@50=0.8302 | CSLS R@10=0.6386 R@50=0.8464 <- Best COS R@50! Saved.


                                                               

Epoch 018: ReconLoss=0.3663 KL=2783.01 (LR=2.4E-04) | temp_scale=50.00 | COS R@10=0.6086 R@50=0.8318 | CSLS R@10=0.6442 R@50=0.8474 <- Best COS R@50! Saved.


                                                               

Epoch 019: ReconLoss=0.3475 KL=2783.82 (LR=2.4E-04) | temp_scale=50.00 | COS R@10=0.6128 R@50=0.8346 | CSLS R@10=0.6460 R@50=0.8502 <- Best COS R@50! Saved.


                                                               

Epoch 020: ReconLoss=0.3317 KL=2815.23 (LR=2.3E-04) | temp_scale=50.00 | COS R@10=0.6142 R@50=0.8354 | CSLS R@10=0.6512 R@50=0.8540 <- Best COS R@50! Saved.
[Diagnostics @ epoch 020]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.28493013978004456, 'improvement': 0.3363681733608246}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 021: ReconLoss=0.3244 KL=2809.60 (LR=2.2E-04) | temp_scale=50.00 | COS R@10=0.6160 R@50=0.8376 | CSLS R@10=0.6480 R@50=0.8550 <- Best COS R@50! Saved.


                                                               

Epoch 022: ReconLoss=0.3189 KL=2829.47 (LR=2.2E-04) | temp_scale=50.00 | COS R@10=0.6218 R@50=0.8388 | CSLS R@10=0.6502 R@50=0.8554 <- Best COS R@50! Saved.


                                                               

Epoch 023: ReconLoss=0.3146 KL=2842.23 (LR=2.1E-04) | temp_scale=50.00 | COS R@10=0.6208 R@50=0.8370 | CSLS R@10=0.6518 R@50=0.8564 


                                                               

Epoch 024: ReconLoss=0.3074 KL=2828.30 (LR=2.0E-04) | temp_scale=50.00 | COS R@10=0.6268 R@50=0.8394 | CSLS R@10=0.6562 R@50=0.8580 <- Best COS R@50! Saved.


                                                               

Epoch 025: ReconLoss=0.3006 KL=2826.42 (LR=2.0E-04) | temp_scale=50.00 | COS R@10=0.6250 R@50=0.8414 | CSLS R@10=0.6580 R@50=0.8578 <- Best COS R@50! Saved.
[Diagnostics @ epoch 025]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.2878269851207733, 'improvement': 0.3334713280200958}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 026: ReconLoss=0.2951 KL=2836.11 (LR=1.9E-04) | temp_scale=50.00 | COS R@10=0.6272 R@50=0.8410 | CSLS R@10=0.6612 R@50=0.8576 


                                                               

Epoch 027: ReconLoss=0.2902 KL=2848.69 (LR=1.8E-04) | temp_scale=50.00 | COS R@10=0.6278 R@50=0.8408 | CSLS R@10=0.6602 R@50=0.8580 


                                                               

Epoch 028: ReconLoss=0.2847 KL=2845.76 (LR=1.7E-04) | temp_scale=50.00 | COS R@10=0.6292 R@50=0.8440 | CSLS R@10=0.6602 R@50=0.8584 <- Best COS R@50! Saved.


                                                               

Epoch 029: ReconLoss=0.2778 KL=2838.43 (LR=1.7E-04) | temp_scale=50.00 | COS R@10=0.6296 R@50=0.8428 | CSLS R@10=0.6608 R@50=0.8582 


                                                               

Epoch 030: ReconLoss=0.2750 KL=2847.14 (LR=1.6E-04) | temp_scale=50.00 | COS R@10=0.6318 R@50=0.8448 | CSLS R@10=0.6614 R@50=0.8574 <- Best COS R@50! Saved.
[Diagnostics @ epoch 030]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.28772658109664917, 'improvement': 0.33357173204421997}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 031: ReconLoss=0.2751 KL=2857.30 (LR=1.5E-04) | temp_scale=50.00 | COS R@10=0.6310 R@50=0.8420 | CSLS R@10=0.6632 R@50=0.8600 


                                                               

Epoch 032: ReconLoss=0.2724 KL=2857.60 (LR=1.4E-04) | temp_scale=50.00 | COS R@10=0.6342 R@50=0.8438 | CSLS R@10=0.6630 R@50=0.8588 


                                                               

Epoch 033: ReconLoss=0.2672 KL=2862.01 (LR=1.3E-04) | temp_scale=50.00 | COS R@10=0.6372 R@50=0.8438 | CSLS R@10=0.6624 R@50=0.8598 


                                                               

Epoch 034: ReconLoss=0.2666 KL=2860.44 (LR=1.3E-04) | temp_scale=50.00 | COS R@10=0.6380 R@50=0.8446 | CSLS R@10=0.6644 R@50=0.8600 


                                                               

Epoch 035: ReconLoss=0.2679 KL=2848.76 (LR=1.2E-04) | temp_scale=50.00 | COS R@10=0.6388 R@50=0.8432 | CSLS R@10=0.6664 R@50=0.8592 
[Diagnostics @ epoch 035]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.2886382043361664, 'improvement': 0.33266010880470276}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 036: ReconLoss=0.2638 KL=2873.48 (LR=1.1E-04) | temp_scale=50.00 | COS R@10=0.6358 R@50=0.8446 | CSLS R@10=0.6642 R@50=0.8596 


                                                               

Epoch 037: ReconLoss=0.2638 KL=2859.84 (LR=1.0E-04) | temp_scale=50.00 | COS R@10=0.6378 R@50=0.8426 | CSLS R@10=0.6652 R@50=0.8596 


                                                               

Epoch 038: ReconLoss=0.2589 KL=2852.34 (LR=9.6E-05) | temp_scale=50.00 | COS R@10=0.6380 R@50=0.8444 | CSLS R@10=0.6658 R@50=0.8612 


                                                               

Epoch 039: ReconLoss=0.2583 KL=2862.84 (LR=8.9E-05) | temp_scale=50.00 | COS R@10=0.6358 R@50=0.8444 | CSLS R@10=0.6652 R@50=0.8596 


                                                               

Epoch 040: ReconLoss=0.2574 KL=2851.60 (LR=8.2E-05) | temp_scale=50.00 | COS R@10=0.6360 R@50=0.8456 | CSLS R@10=0.6658 R@50=0.8590 <- Best COS R@50! Saved.
[Diagnostics @ epoch 040]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.28912270069122314, 'improvement': 0.332175612449646}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 041: ReconLoss=0.2581 KL=2862.47 (LR=7.5E-05) | temp_scale=50.00 | COS R@10=0.6380 R@50=0.8450 | CSLS R@10=0.6664 R@50=0.8596 


                                                               

Epoch 042: ReconLoss=0.2559 KL=2857.28 (LR=6.8E-05) | temp_scale=50.00 | COS R@10=0.6386 R@50=0.8468 | CSLS R@10=0.6648 R@50=0.8608 <- Best COS R@50! Saved.


                                                               

Epoch 043: ReconLoss=0.2535 KL=2852.78 (LR=6.2E-05) | temp_scale=50.00 | COS R@10=0.6382 R@50=0.8464 | CSLS R@10=0.6660 R@50=0.8606 


                                                               

Epoch 044: ReconLoss=0.2531 KL=2861.19 (LR=5.6E-05) | temp_scale=50.00 | COS R@10=0.6364 R@50=0.8448 | CSLS R@10=0.6668 R@50=0.8606 


                                                               

Epoch 045: ReconLoss=0.2538 KL=2873.38 (LR=5.0E-05) | temp_scale=50.00 | COS R@10=0.6388 R@50=0.8452 | CSLS R@10=0.6656 R@50=0.8610 
[Diagnostics @ epoch 045]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.2900967299938202, 'improvement': 0.33120158314704895}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 046: ReconLoss=0.2526 KL=2854.74 (LR=4.4E-05) | temp_scale=50.00 | COS R@10=0.6384 R@50=0.8454 | CSLS R@10=0.6656 R@50=0.8604 


                                                               

Epoch 047: ReconLoss=0.2516 KL=2854.32 (LR=3.9E-05) | temp_scale=50.00 | COS R@10=0.6384 R@50=0.8468 | CSLS R@10=0.6668 R@50=0.8614 


                                                               

Epoch 048: ReconLoss=0.2508 KL=2853.07 (LR=3.3E-05) | temp_scale=50.00 | COS R@10=0.6378 R@50=0.8456 | CSLS R@10=0.6660 R@50=0.8610 


                                                               

Epoch 049: ReconLoss=0.2515 KL=2856.53 (LR=2.9E-05) | temp_scale=50.00 | COS R@10=0.6390 R@50=0.8458 | CSLS R@10=0.6658 R@50=0.8616 


                                                               

Epoch 050: ReconLoss=0.2485 KL=2863.71 (LR=2.4E-05) | temp_scale=50.00 | COS R@10=0.6398 R@50=0.8460 | CSLS R@10=0.6660 R@50=0.8612 
[Diagnostics @ epoch 050]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.29009780287742615, 'improvement': 0.331200510263443}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 051: ReconLoss=0.2498 KL=2873.48 (LR=2.0E-05) | temp_scale=50.00 | COS R@10=0.6400 R@50=0.8460 | CSLS R@10=0.6656 R@50=0.8616 


                                                               

Epoch 052: ReconLoss=0.2483 KL=2870.58 (LR=1.6E-05) | temp_scale=50.00 | COS R@10=0.6394 R@50=0.8462 | CSLS R@10=0.6660 R@50=0.8618 


                                                               

Epoch 053: ReconLoss=0.2491 KL=2877.59 (LR=1.3E-05) | temp_scale=50.00 | COS R@10=0.6388 R@50=0.8458 | CSLS R@10=0.6670 R@50=0.8610 


                                                               

Epoch 054: ReconLoss=0.2484 KL=2875.00 (LR=1.0E-05) | temp_scale=50.00 | COS R@10=0.6390 R@50=0.8462 | CSLS R@10=0.6672 R@50=0.8608 


                                                               

Epoch 055: ReconLoss=0.2483 KL=2858.39 (LR=7.3E-06) | temp_scale=50.00 | COS R@10=0.6386 R@50=0.8456 | CSLS R@10=0.6670 R@50=0.8612 
[Diagnostics @ epoch 055]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.29018136858940125, 'improvement': 0.3311169445514679}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121


                                                               

Epoch 056: ReconLoss=0.2470 KL=2850.92 (LR=5.1E-06) | temp_scale=50.00 | COS R@10=0.6392 R@50=0.8458 | CSLS R@10=0.6672 R@50=0.8608 


                                                               

Epoch 057: ReconLoss=0.2483 KL=2861.25 (LR=3.3E-06) | temp_scale=50.00 | COS R@10=0.6394 R@50=0.8462 | CSLS R@10=0.6672 R@50=0.8610 


                                                               

Epoch 058: ReconLoss=0.2480 KL=2869.02 (LR=1.8E-06) | temp_scale=50.00 | COS R@10=0.6396 R@50=0.8460 | CSLS R@10=0.6674 R@50=0.8610 


                                                               

Epoch 059: ReconLoss=0.2483 KL=2870.19 (LR=8.2E-07) | temp_scale=50.00 | COS R@10=0.6392 R@50=0.8460 | CSLS R@10=0.6672 R@50=0.8610 


                                                               

Epoch 060: ReconLoss=0.2498 KL=2881.87 (LR=2.1E-07) | temp_scale=50.00 | COS R@10=0.6392 R@50=0.8460 | CSLS R@10=0.6672 R@50=0.8610 
[Diagnostics @ epoch 060]
  emb2emb_linear_transfer: {'transfer_mse': 0.0004931012517772615, 'transfer_cosine': 0.6212983131408691, 'model_cosine': 0.29012826085090637, 'improvement': 0.33117005228996277}
  stitching_penalty_monitor: {'mse_orthogonal': 0.0006547736120410264, 'mse_full': 0.00024529395159333944, 'gap_percent': 166.92746194586994, 'cos_orthogonal': 0.49713391065597534, 'cos_full': 0.7914635539054871}
  mutual_knn_alignment_fraction: 0.1121

ðŸŽ¯ Training complete. Best model (COS R@50=0.8468) saved as giga_vae_model_giga_vae_model_1762505842_best.pth
Loading BEST model from giga_vae_model_giga_vae_model_1762505842_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.


Generating Sub: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2/2 [00:00<00:00, 13.49it/s]



âœ… submission_giga_vae_run_giga_vae_model_1762505842.csv saved successfully.
Calculating final validation scores (cosine & CSLS) on FULL val set...


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.49 GiB. GPU 0 has a total capacity of 14.74 GiB of which 3.15 GiB is free. Process 3223 has 11.59 GiB memory in use. Of the allocated memory 10.87 GiB is allocated by PyTorch, and 601.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# ===========================================
# Giga-VAE++  â€” LB-focused, single-run recipe (NO EMA)
# ===========================================
# Upgrades:
# - Multi-Positive SupCon (t<->i) using image IDs
# - Grouped sampler (pack multiple captions per image)
# - Temp cap @ 35
# - Uniformity loss (Î»=0.01)
# - Whitening head (+ orthoreg 1e-4)
# - NoiseConditioningAugmentor (Ïƒ = 0.05 * input_std)
# - Inference: DynamicThresholdingClamp (99.5p), Ridge blend (Î±=0.1)
# - Optional: k-reciprocal re-ranking (toggle USE_RERANK)
#
# Notes:
# - Early-stops on COSINE R@50 (LB-like)
# - CSLS is reported for reference
# - No EMA anywhere per user request

import os, time, math, warnings
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, BatchSampler

from tqdm import tqdm
from sklearn.linear_model import Ridge

warnings.filterwarnings("ignore", message=".*Error displaying widget.*")

# -----------------------------
# 0) Device & basic setup
# -----------------------------
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# 1) Load data
# -----------------------------

tx_train = train_data["captions/embeddings"]      # (125000, 1024)
im_train = train_data["images/embeddings"]        # (25000, 1536)
tx_test  = test_data["captions/embeddings"]       # (N_test, 1024)

# Expand images 5x to align with captions count (assumption holds in your data)
repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------
# 2) Preprocess (mean-center + L2 norm)
# -----------------------------
cpu = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t  = F.normalize(tx_test_t  - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# Phase-2: NoiseConditioningAugmentor level
INPUT_STD = float(tx_train_t.std().item())
NOISE_LEVEL = 0.05  # effective sigma = NOISE_LEVEL * INPUT_STD

# -----------------------------
# 3) Model (VAETranslator + Whitening head)
# -----------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4,
                 use_whitening=True):
        super().__init__()
        self.latent_dim = latent_dim
        self.use_whitening = use_whitening
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        # Encoder: text -> (mu, logvar)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2)
        )

        # Decoder: latent -> image embedding
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

        if use_whitening:
            self.whiten = nn.Linear(output_dim, output_dim, bias=False)
            nn.init.orthogonal_(self.whiten.weight)
        else:
            self.whiten = None

    @staticmethod
    def _reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def _apply_head(self, y):
        if self.whiten is not None:
            y = self.whiten(y)
        return F.normalize(y, p=2, dim=1)

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu     = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]

        if self.training:
            z = self._reparameterize(mu, logvar)
            y = self.decoder(z)
            y = self._apply_head(y)
            return y, mu, logvar
        else:
            y_eval = self.decoder(mu)
            y_eval = self._apply_head(y_eval)
            return y_eval

# -----------------------------
# 4) Loss knobs & utilities
# -----------------------------
TAU_INIT = 0.01
MAX_LOGIT_SCALE = 35.0     # tighter cap for LB
KL_WEIGHT = 1e-5
UNIFORMITY_LAMBDA = 0.01   # small but effective
ORTHO_REG = 1e-4           # whitening head orthogonality penalty
MIXUP_ALPHA = 0.3
MIXUP_WARMUP_EPOCHS = 3

def kl_loss_fn(mu, logvar):
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return kl / mu.size(0)

def uniformity_loss(x, sample_pairs=1024):
    # x: (B, d) normalized
    if x.size(0) < 2:
        return x.new_tensor(0.0)
    B = x.size(0)
    idx1 = torch.randint(0, B, (sample_pairs,), device=x.device)
    idx2 = torch.randint(0, B, (sample_pairs,), device=x.device)
    diff = x[idx1] - x[idx2]
    sq = (diff * diff).sum(dim=1)
    return torch.exp(-2.0 * sq).mean()

def supcon_multi_positive(anchors, targets, labels, logit_scale):
    """
    Multi-positive supervised contrastive loss:
      - anchors: (B,d) L2-normalized
      - targets: (B,d) L2-normalized
      - labels:  (B,)  image ids; all items with same id are positives (excluding self)
      - logit_scale: scalar tensor
    Returns: scalar loss
    """
    sim = (anchors @ targets.T) * logit_scale  # (B,B)
    pos_mask = labels.unsqueeze(1).eq(labels.unsqueeze(0))
    pos_mask = pos_mask & (~torch.eye(len(labels), dtype=torch.bool, device=labels.device))

    # FP16-safe masking (no huge negative constants)
    neg_fill = torch.finfo(sim.dtype).min  # ~-65504 for float16
    pos_logits = sim.masked_fill(~pos_mask, neg_fill)

    logsumexp_pos = torch.logsumexp(pos_logits, dim=1)  # sum over positives
    logsumexp_all = torch.logsumexp(sim,       dim=1)   # sum over all
    pos_count = pos_mask.sum(dim=1).clamp(min=1)
    loss = -(logsumexp_pos - logsumexp_all) / pos_count
    return loss.mean()

def orthogonality_penalty(W):
    I = torch.eye(W.size(0), device=W.device, dtype=W.dtype)
    WT_W = W.T @ W
    return ((WT_W - I) ** 2).mean()

# -----------------------------
# 5) Split train/val + Grouped Sampler
# -----------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
perm = torch.randperm(N, device=cpu)
val_idx, train_idx = perm[:val_size], perm[val_size:]

# image ids (0..24999) aligned with 5 captions each
img_indices = torch.arange(N, device=cpu) // 5
img_indices_train = img_indices[train_idx]
img_indices_val   = img_indices[val_idx]

tx_val_t   = tx_train_t[val_idx]
im_val_t   = im_train_exp[val_idx]      # expanded images (aligned to captions)
tx_train_s = tx_train_t[train_idx]
im_train_s = im_train_exp[train_idx]

class GroupedByImageBatchSampler(BatchSampler):
    """
    Build batches containing M captions per image, for K images per batch.
    batch_size = M * K.
    Ensures multi-positive structure every step.
    """
    def __init__(self, img_ids_tensor, batch_size=512, m_per_image=4, drop_last=True):
        self.img_ids = img_ids_tensor.cpu().numpy()
        self.batch_size = batch_size
        self.m = m_per_image
        assert batch_size % m_per_image == 0, "batch_size must be divisible by m_per_image"
        self.k_images = batch_size // m_per_image
        self.drop_last = drop_last

        self.id_to_indices = {}
        for idx, iid in enumerate(self.img_ids):
            self.id_to_indices.setdefault(int(iid), []).append(idx)

        self.valid_ids = [iid for iid, idxs in self.id_to_indices.items() if len(idxs) >= 2]

    def __iter__(self):
        rng = np.random.default_rng()
        ids = self.valid_ids.copy()
        rng.shuffle(ids)
        batch = []
        for iid in ids:
            idxs = self.id_to_indices[iid]
            if len(idxs) >= self.m:
                chose = rng.choice(idxs, size=self.m, replace=False)
            else:
                chose = rng.choice(idxs, size=self.m, replace=True)
            batch.extend(chose.tolist())
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        approx = (len(self.valid_ids) // self.k_images)
        return approx if self.drop_last else approx + 1

train_ds = TensorDataset(tx_train_s, im_train_s, img_indices_train)
M_PER_IMAGE = 4
BATCH_SIZE = 512
batch_sampler = GroupedByImageBatchSampler(img_indices_train, batch_size=BATCH_SIZE, m_per_image=M_PER_IMAGE, drop_last=True)
train_loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=0)
print(f"Train pairs: {len(train_ds)}, Validation pairs: {len(val_idx)}")

# -----------------------------
# 6) Recall@K (cosine & CSLS)
# -----------------------------
@torch.no_grad()
def recall_at_k_cosine(model, tx_queries, im_database, gt_img_idx, chunk_size=1024, ks=(1,5,10,50)):
    model.eval()
    dev = next(model.parameters()).device
    Q = []
    for i in range(0, len(tx_queries), chunk_size):
        Q.append(model(tx_queries[i:i+chunk_size].to(dev)))
    Q = torch.cat(Q, dim=0)
    D = im_database.to(dev)
    sim = Q @ D.T
    top_idx = torch.argsort(sim, dim=1, descending=True)
    gt = gt_img_idx.to(dev)
    out = {}
    for k in ks:
        top_k = top_idx[:, :k]
        correct = (top_k == gt.unsqueeze(1)).any(dim=1).float().mean().item()
        out[f"R@{k}"] = correct
    return out

@torch.no_grad()
def recall_at_k_csls(model, tx_queries, im_database, gt_img_idx, chunk_size=1024, ks=(1,5,10,50), k_csls=10):
    model.eval()
    dev = next(model.parameters()).device
    Q = []
    for i in range(0, len(tx_queries), chunk_size):
        Q.append(model(tx_queries[i:i+chunk_size].to(dev)))
    Q = torch.cat(Q, dim=0)
    D = im_database.to(dev)
    # pass 1: mean_knn_d
    mean_knn_d_list = []
    for i in tqdm(range(0, len(D), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D[i:i+512] @ Q.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))
    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)
    out = {}
    gt = gt_img_idx.to(dev)
    for i in tqdm(range(0, len(Q), chunk_size), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q[i:i+chunk_size] @ D.T
        gt_chunk = gt[i:i+chunk_size]
        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)
        csls = 2 * sim_chunk - mean_knn_q - mean_knn_d
        top_idx = torch.argsort(csls, dim=1, descending=True)
        for k in ks:
            top_k = top_idx[:, :k]
            correct = (top_k == gt_chunk.unsqueeze(1)).any(dim=1).float().mean().item()
            out[f"R@{k}"] = out.get(f"R@{k}", 0) + correct * (len(sim_chunk)/len(Q))
    return out

# -----------------------------
# 7) Train setup
# -----------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

val_query_subset   = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset      = im_train_t_unique  # DB = unique images

model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536, use_whitening=True).to(device)

# Learnable temperature (logit scale) with cap
logit_scale = nn.Parameter(torch.log(torch.tensor(1.0 / TAU_INIT, device=device)))

optimizer = torch.optim.AdamW(
    list(model.parameters()) + [logit_scale],
    lr=START_LR, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Training (RUN {RUN_ID}) ---")
print(f"Epochs={EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, MultiPosSupCon=True, MixupWarmup={MIXUP_WARMUP_EPOCHS}â†’Î±={MIXUP_ALPHA}\n")

best_val_cos_r50 = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_main, total_kl, total_uni, total_ortho = 0.0, 0.0, 0.0, 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch, y_batch, img_ids = batch
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        img_ids = img_ids.to(device).long()

        # Mixup warmup + noise conditioning
        use_mixup = (epoch > MIXUP_WARMUP_EPOCHS)
        noise = torch.randn_like(x_batch) * (NOISE_LEVEL * INPUT_STD)
        x_aug = x_batch + noise

        if use_mixup:
            idx_shuffle = torch.randperm(x_aug.size(0), device=x_aug.device)
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            x_aug = lam * x_aug + (1 - lam) * x_aug[idx_shuffle]
            y_batch = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)
            # keep img_ids unchanged to preserve grouping

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            y_pred, mu, logvar = model(x_aug)   # (B,d)

            # temperature cap
            scale = logit_scale.exp()
            scale = torch.clamp(scale, max=torch.tensor(MAX_LOGIT_SCALE, device=scale.device))

            # Multi-Positive SupCon, symmetric
            loss_i2t = supcon_multi_positive(y_pred, y_batch, labels=img_ids, logit_scale=scale)
            loss_t2i = supcon_multi_positive(y_batch, y_pred, labels=img_ids, logit_scale=scale)
            main_loss = 0.5 * (loss_i2t + loss_t2i)

            # KL (light)
            kl = kl_loss_fn(mu, logvar)

            # Uniformity on both spaces (small)
            uni = 0.5 * (uniformity_loss(y_pred) + uniformity_loss(y_batch))

            # Ortho reg on whitening head
            ortho = model.whiten is not None and ORTHO_REG > 0
            ortho_loss = ORTHO_REG * orthogonality_penalty(model.whiten.weight) if ortho else 0.0

            loss = main_loss + KL_WEIGHT * kl + UNIFORMITY_LAMBDA * uni + (ortho_loss if ortho else 0.0)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_main += float(main_loss.item())
        total_kl   += float(kl.item())
        total_uni  += float(uni.item())
        total_ortho += (float(ortho_loss) if ortho else 0.0)

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_main = total_main / len(train_loader)
    avg_kl   = total_kl   / len(train_loader)
    avg_uni  = total_uni  / len(train_loader)
    avg_ortho= total_ortho/ len(train_loader)

    # ----- Validation (no EMA) -----
    rec_cos = recall_at_k_cosine(model, val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)
    rec_csls= recall_at_k_csls(model,   val_query_subset, val_db_subset, val_indices_subset, chunk_size=1024)

    cos_r50 = rec_cos['R@50']
    save_marker = ""
    if cos_r50 > best_val_cos_r50:
        best_val_cos_r50 = cos_r50
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best COS R@50! Saved."

    print(
        f"Epoch {epoch:03d}: Main={avg_main:.4f} KL={avg_kl:.2f} Uni={avg_uni:.4f} Ortho={avg_ortho:.6f} "
        f"(LR={current_lr:.1E}) | temp_scaleâ‰ˆ{float(torch.clamp(logit_scale.exp(), max=MAX_LOGIT_SCALE).item()):.2f} | "
        f"COS R@10={rec_cos['R@10']:.4f} R@50={rec_cos['R@50']:.4f} | "
        f"CSLS R@10={rec_csls['R@10']:.4f} R@50={rec_csls['R@50']:.4f} {save_marker}"
    )

print(f"\nðŸŽ¯ Training complete. Best model (COS R@50={best_val_cos_r50:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------
# 8) Inference helpers (clamp, ridge, re-rank)
# -----------------------------
def dynamic_thresholding_clamp(mat, q=99.5):
    """
    Clamp extreme activations per-feature to the +/- qth percentile (in-chunk).
    mat: (N,d) tensor (normalized downstream)
    """
    abs_mat = mat.abs()
    perc = torch.quantile(abs_mat, q/100.0, dim=0, keepdim=True)
    clamped = torch.clamp(mat, min=-perc, max=perc)
    return clamped

def fit_ridge_text2img(tx_cpu, im_cpu, alpha=1.0):
    ridge = Ridge(alpha=alpha, fit_intercept=False)
    ridge.fit(tx_cpu, im_cpu)
    return ridge

@torch.no_grad()
def encode_queries(model, X, chunk=1024):
    model.eval()
    dev = next(model.parameters()).device
    out = []
    for i in tqdm(range(0, len(X), chunk), desc="Encode queries"):
        chunk_x = X[i:i+chunk].to(dev)
        y = model(chunk_x)
        # DynamicThresholdingClamp (Phase-2)
        y = dynamic_thresholding_clamp(y, q=99.5)
        y = F.normalize(y, p=2, dim=1)
        out.append(y)
    return torch.cat(out, dim=0)

# Optional: k-reciprocal re-ranking (toggle if allowed/time permits)
USE_RERANK = False

def rerank_k_reciprocal(sim, topk=200, k1=20, k2=6, lambda_j=0.3):
    """
    Basic k-reciprocal re-ranking (Zhong et al., 2017).
    sim: (Q, I) cosine similarity matrix
    Returns new similarity matrix of same shape.
    """
    Q, I = sim.shape
    dist = (1 - sim).cpu().numpy()
    initial_rank = np.argsort(dist, axis=1)
    V = np.zeros_like(dist, dtype=np.float32)
    for i in range(Q):
        forward_k = initial_rank[i, :k1 + 1]
        reciprocal = []
        for j in forward_k:
            backward_k = initial_rank[j, :k1 + 1] if j < I else []
            if i in backward_k:
                reciprocal.append(j)
        reciprocal = np.array(reciprocal, dtype=np.int64)
        V[i, reciprocal] = np.exp(-dist[i, reciprocal])
        if k2 > 1:
            qe_idx = initial_rank[i, :k2]
            V[i, :] = (V[i, :] + V[qe_idx, :].mean(axis=0)) / 2.0
    V_norm = V / (V.sum(axis=1, keepdims=True) + 1e-12)
    jaccard = 1 - V_norm
    final_dist = (1 - lambda_j) * jaccard + lambda_j * dist
    final_sim = 1 - final_dist
    return torch.from_numpy(final_sim).to(sim.device)

# -----------------------------
# 9) Inference + Submission
# -----------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536, use_whitening=True).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH, map_location=device))
model.eval()
print("Best model loaded.")

# Fit ridge distiller once (CPU) for inference blending
print("Fitting ridge baseline (text->image) on full train (CPU) for inference blend Î±=0.1 ...")
ridge = fit_ridge_text2img(tx_train_t.numpy(), im_train_t_unique.numpy(), alpha=1.0)
ALPHA_BLEND = 0.10

# Encode queries (test)
Q = encode_queries(model, tx_test_t, chunk=1024)           # (Nt, d)
# Ridge predictions + blend (in chunks to save RAM)
with torch.no_grad():
    R = torch.from_numpy(ridge.predict(tx_test_t.numpy())).to(Q.device).float()
    R = F.normalize(R, p=2, dim=1)
    Q = F.normalize((1 - ALPHA_BLEND) * Q + ALPHA_BLEND * R, p=2, dim=1)

# Database = unique normalized image embeddings (ground-truth)
D = im_train_t_unique.to(Q.device)  # (Ni, d)

# Retrieve (cosine), optional re-ranking
print("Computing retrieval for submission (cosine)...")
sim = Q @ D.T  # (Nt, Ni)
if USE_RERANK:
    print("Applying k-reciprocal re-ranking (this may take a while)...")
    sim = rerank_k_reciprocal(sim, topk=200, k1=20, k2=6, lambda_j=0.3)

# Output submission embeddings (same format you used)
test_ids = test_data["captions/ids"].astype(int)
preds = Q  # final blended, normalized predictions
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in preds.cpu().numpy()]
})

out_csv = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(out_csv, index=False)
print(f"\nâœ… {out_csv} saved successfully.")

# -----------------------------
# 10) Final validation (cosine & CSLS) on FULL val set (same inference path)
# -----------------------------
print("Calculating final validation scores (cosine & CSLS) on FULL val set...")

# Build predictions for validation queries with clamp + blend, like test
Q_val = encode_queries(model, tx_val_t, chunk=1024)
with torch.no_grad():
    R_val = torch.from_numpy(ridge.predict(tx_val_t.numpy())).to(Q_val.device).float()
    R_val = F.normalize(R_val, p=2, dim=1)
    Q_val = F.normalize((1 - ALPHA_BLEND) * Q_val + ALPHA_BLEND * R_val, p=2, dim=1)

D_val = im_train_t_unique.to(Q_val.device)

# Cosine
sim_val = Q_val @ D_val.T
if USE_RERANK:
    sim_val = rerank_k_reciprocal(sim_val, topk=200, k1=20, k2=6, lambda_j=0.3)

top_idx = torch.argsort(sim_val, dim=1, descending=True)
gt = img_indices_val.to(Q_val.device)
def recall_from_top(top_idx, gt, ks=(1,5,10,50)):
    out = {}
    for k in ks:
        top_k = top_idx[:, :k]
        out[f"R@{k}"] = (top_k == gt.unsqueeze(1)).any(dim=1).float().mean().item()
    return out
rec_cos_full = recall_from_top(top_idx, gt)

# CSLS (recompute with current Q_val)
@torch.no_grad()
def csls_from_embeds(Q, D, gt, k_csls=10, chunk=1024, ks=(1,5,10,50)):
    mean_knn_d_list = []
    for i in tqdm(range(0, len(D), 512), desc="CSLS Val (Pass 1/2)", leave=False):
        sim_T_chunk = D[i:i+512] @ Q.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))
    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)
    out = {}
    for i in tqdm(range(0, len(Q), chunk), desc="CSLS Val (Pass 2/2)", leave=False):
        sim_chunk = Q[i:i+chunk] @ D.T
        gt_chunk = gt[i:i+chunk]
        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)
        csls = 2 * sim_chunk - mean_knn_q - mean_knn_d
        top_idx = torch.argsort(csls, dim=1, descending=True)
        for k in ks:
            top_k = top_idx[:, :k]
            correct = (top_k == gt_chunk.unsqueeze(1)).any(dim=1).float().mean().item()
            out[f"R@{k}"] = out.get(f"R@{k}", 0) + correct * (len(sim_chunk)/len(Q))
    return out

rec_csls_full = csls_from_embeds(Q_val, D_val, gt)

print("\n--- Full Validation Results (COSINE, clamp+blend) ---")
print(f"R@1:  {rec_cos_full['R@1']:<8.4f}")
print(f"R@5:  {rec_cos_full['R@5']:<8.4f}")
print(f"R@10: {rec_cos_full['R@10']:<8.4f}")
print(f"R@50: {rec_cos_full['R@50']:<8.4f}")

print("\n--- Full Validation Results (CSLS, clamp+blend) ---")
print(f"R@1:  {rec_csls_full['R@1']:<8.4f}")
print(f"R@5:  {rec_csls_full['R@5']:<8.4f}")
print(f"R@10: {rec_csls_full['R@10']:<8.4f}")
print(f"R@50: {rec_csls_full['R@50']:<8.4f}")


Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Train pairs: 112500, Validation pairs: 12500
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting Training (RUN giga_vae_model_1762509428) ---
Epochs=60, LR=0.0003, WD=0.0004, MultiPosSupCon=True, MixupWarmup=3â†’Î±=0.3



                                                               

Epoch 001: Main=0.7513 KL=1130.34 Uni=0.0293 Ortho=0.000000 (LR=3.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.2334 R@50=0.4974 | CSLS R@10=0.2948 R@50=0.5688 <- Best COS R@50! Saved.


                                                               

Epoch 002: Main=0.4483 KL=1496.93 Uni=0.0299 Ortho=0.000000 (LR=3.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3038 R@50=0.5950 | CSLS R@10=0.3796 R@50=0.6490 <- Best COS R@50! Saved.


                                                               

Epoch 003: Main=0.3816 KL=1624.92 Uni=0.0293 Ortho=0.000000 (LR=3.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3488 R@50=0.6492 | CSLS R@10=0.4224 R@50=0.7018 <- Best COS R@50! Saved.


                                                               

Epoch 004: Main=1.4339 KL=1684.87 Uni=0.0661 Ortho=0.000000 (LR=3.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.1692 R@50=0.3784 | CSLS R@10=0.2756 R@50=0.5362 


                                                               

Epoch 005: Main=1.3991 KL=1521.74 Uni=0.0460 Ortho=0.000000 (LR=3.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3106 R@50=0.6026 | CSLS R@10=0.4002 R@50=0.6696 


                                                               

Epoch 006: Main=1.2531 KL=1482.38 Uni=0.0334 Ortho=0.000000 (LR=2.9E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3644 R@50=0.6462 | CSLS R@10=0.4322 R@50=0.7052 


                                                               

Epoch 007: Main=1.2986 KL=1492.21 Uni=0.0342 Ortho=0.000000 (LR=2.9E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3688 R@50=0.6518 | CSLS R@10=0.4356 R@50=0.7064 <- Best COS R@50! Saved.


                                                               

Epoch 008: Main=1.0784 KL=1519.38 Uni=0.0301 Ortho=0.000000 (LR=2.9E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4038 R@50=0.6846 | CSLS R@10=0.4618 R@50=0.7332 <- Best COS R@50! Saved.


                                                               

Epoch 009: Main=1.2739 KL=1526.87 Uni=0.0307 Ortho=0.000000 (LR=2.9E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.3820 R@50=0.6734 | CSLS R@10=0.4514 R@50=0.7212 


                                                               

Epoch 010: Main=1.2527 KL=1510.43 Uni=0.0306 Ortho=0.000000 (LR=2.8E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4098 R@50=0.6932 | CSLS R@10=0.4718 R@50=0.7398 <- Best COS R@50! Saved.


                                                               

Epoch 011: Main=1.3290 KL=1494.21 Uni=0.0310 Ortho=0.000000 (LR=2.8E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4206 R@50=0.7034 | CSLS R@10=0.4786 R@50=0.7448 <- Best COS R@50! Saved.


                                                               

Epoch 012: Main=1.2606 KL=1464.50 Uni=0.0300 Ortho=0.000000 (LR=2.8E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4380 R@50=0.7154 | CSLS R@10=0.4926 R@50=0.7606 <- Best COS R@50! Saved.


                                                               

Epoch 013: Main=1.2669 KL=1451.67 Uni=0.0300 Ortho=0.000000 (LR=2.7E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4312 R@50=0.7202 | CSLS R@10=0.4910 R@50=0.7578 <- Best COS R@50! Saved.


                                                               

Epoch 014: Main=1.2634 KL=1436.63 Uni=0.0295 Ortho=0.000000 (LR=2.7E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4502 R@50=0.7352 | CSLS R@10=0.5122 R@50=0.7748 <- Best COS R@50! Saved.


                                                               

Epoch 015: Main=1.2189 KL=1425.04 Uni=0.0295 Ortho=0.000000 (LR=2.6E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4508 R@50=0.7416 | CSLS R@10=0.5142 R@50=0.7758 <- Best COS R@50! Saved.


                                                               

Epoch 016: Main=1.1800 KL=1434.34 Uni=0.0284 Ortho=0.000000 (LR=2.6E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4530 R@50=0.7448 | CSLS R@10=0.5090 R@50=0.7778 <- Best COS R@50! Saved.


                                                               

Epoch 017: Main=1.2574 KL=1428.02 Uni=0.0289 Ortho=0.000000 (LR=2.5E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4558 R@50=0.7408 | CSLS R@10=0.5184 R@50=0.7770 


                                                               

Epoch 018: Main=1.2578 KL=1404.78 Uni=0.0291 Ortho=0.000000 (LR=2.4E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4724 R@50=0.7508 | CSLS R@10=0.5246 R@50=0.7806 <- Best COS R@50! Saved.


                                                               

Epoch 019: Main=1.2099 KL=1411.67 Uni=0.0286 Ortho=0.000000 (LR=2.4E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4802 R@50=0.7632 | CSLS R@10=0.5364 R@50=0.7882 <- Best COS R@50! Saved.


                                                               

Epoch 020: Main=1.2353 KL=1413.35 Uni=0.0289 Ortho=0.000000 (LR=2.3E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4824 R@50=0.7586 | CSLS R@10=0.5390 R@50=0.7856 


                                                               

Epoch 021: Main=1.3040 KL=1395.51 Uni=0.0292 Ortho=0.000000 (LR=2.2E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4776 R@50=0.7590 | CSLS R@10=0.5334 R@50=0.7854 


                                                               

Epoch 022: Main=1.2568 KL=1378.64 Uni=0.0291 Ortho=0.000000 (LR=2.2E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4900 R@50=0.7662 | CSLS R@10=0.5432 R@50=0.7918 <- Best COS R@50! Saved.


                                                               

Epoch 023: Main=1.2611 KL=1372.03 Uni=0.0289 Ortho=0.000000 (LR=2.1E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.4980 R@50=0.7706 | CSLS R@10=0.5422 R@50=0.7942 <- Best COS R@50! Saved.


                                                               

Epoch 024: Main=1.2161 KL=1370.33 Uni=0.0286 Ortho=0.000000 (LR=2.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5060 R@50=0.7724 | CSLS R@10=0.5518 R@50=0.7966 <- Best COS R@50! Saved.


                                                               

Epoch 025: Main=1.3186 KL=1366.12 Uni=0.0286 Ortho=0.000000 (LR=2.0E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5016 R@50=0.7732 | CSLS R@10=0.5484 R@50=0.7994 <- Best COS R@50! Saved.


                                                               

Epoch 026: Main=1.2430 KL=1357.42 Uni=0.0286 Ortho=0.000000 (LR=1.9E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5106 R@50=0.7790 | CSLS R@10=0.5552 R@50=0.7984 <- Best COS R@50! Saved.


                                                               

Epoch 027: Main=1.2988 KL=1343.50 Uni=0.0287 Ortho=0.000000 (LR=1.8E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5116 R@50=0.7794 | CSLS R@10=0.5616 R@50=0.8016 <- Best COS R@50! Saved.


                                                               

Epoch 028: Main=1.2316 KL=1351.27 Uni=0.0283 Ortho=0.000000 (LR=1.7E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5136 R@50=0.7808 | CSLS R@10=0.5616 R@50=0.8046 <- Best COS R@50! Saved.


                                                               

Epoch 029: Main=1.1459 KL=1370.25 Uni=0.0280 Ortho=0.000000 (LR=1.7E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5140 R@50=0.7850 | CSLS R@10=0.5622 R@50=0.8036 <- Best COS R@50! Saved.


                                                               

Epoch 030: Main=1.2951 KL=1349.11 Uni=0.0282 Ortho=0.000000 (LR=1.6E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5146 R@50=0.7802 | CSLS R@10=0.5580 R@50=0.8070 


                                                               

Epoch 031: Main=1.3044 KL=1333.92 Uni=0.0283 Ortho=0.000000 (LR=1.5E-04) | temp_scaleâ‰ˆ35.00 | COS R@10=0.5150 R@50=0.7824 | CSLS R@10=0.5638 R@50=0.8066 


                                                               

KeyboardInterrupt: 

In [None]:
# ============================================================
# Giga-VAE (Champion rollback) + Low-risk upgrades, NO EMA
# ============================================================
# Upgrades kept (safe):
# - Mixup warmup (â†’ alpha=0.3)
# - Noise conditioning (sigma = 0.02 * input_std)
# - DynamicThresholdingClamp at inference (q=99.5)
# - Ridge distillation blend at inference AND validation (alpha=0.10)
# - CSLS + cosine recalls
# - Diagnostics: emb2emb transfer, stitching penalty, mutual kNN (safe to dims)
#
# Core loss: Asymmetric CE @ tau + hard-negative triplet (margin=0.06)
# Temp is FIXED via tau to avoid run-away scales.
# ============================================================

import os, time, math, warnings
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tqdm import tqdm
from sklearn.linear_model import Ridge
from scipy.linalg import orthogonal_procrustes

warnings.filterwarnings("ignore", message=".*Error displaying widget.*")

# -----------------------------
# 0) Device & setup
# -----------------------------
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# 1) Load data
# -----------------------------
tx_train = train_data["captions/embeddings"]      # (125000, 1024)
im_train = train_data["images/embeddings"]        # (25000, 1536)
tx_test  = test_data["captions/embeddings"]       # (N_test, 1024)

repeat_factor = len(tx_train) // len(im_train)
im_train_expanded = np.repeat(im_train, repeat_factor, axis=0)
print(f"Train shapes: {tx_train.shape}, {im_train.shape}, {im_train_expanded.shape}")
print(f"Test shape: {tx_test.shape}")

# -----------------------------
# 2) Preprocess (mean-center + L2)
# -----------------------------
cpu = torch.device("cpu")
tx_train_t = torch.as_tensor(tx_train, dtype=torch.float32, device=cpu)
im_train_t_unique = torch.as_tensor(im_train, dtype=torch.float32, device=cpu)
tx_test_t  = torch.as_tensor(tx_test,  dtype=torch.float32, device=cpu)

tx_mean = tx_train_t.mean(0, keepdim=True)
im_mean = im_train_t_unique.mean(0, keepdim=True)

tx_train_t = F.normalize(tx_train_t - tx_mean, p=2, dim=1)
im_train_t_unique = F.normalize(im_train_t_unique - im_mean, p=2, dim=1)
tx_test_t  = F.normalize(tx_test_t  - tx_mean, p=2, dim=1)

im_train_exp = torch.as_tensor(im_train_expanded, dtype=torch.float32, device=cpu)
im_train_exp = F.normalize(im_train_exp - im_mean, p=2, dim=1)
print("Data preprocessed and normalized (on CPU).")

# -----------------------------
# 3) Train/val split
# -----------------------------
N = len(tx_train_t)
val_size = int(0.1 * N)
perm = torch.randperm(N, device=cpu)
val_idx, train_idx = perm[:val_size], perm[val_size:]

# map each caption to image index (groups of 5)
img_indices = torch.arange(N, device=cpu) // 5
img_indices_train = img_indices[train_idx]
img_indices_val   = img_indices[val_idx]

tx_val_t   = tx_train_t[val_idx]
im_val_t   = im_train_exp[val_idx]
tx_train_s = tx_train_t[train_idx]
im_train_s = im_train_exp[train_idx]

train_dataset = TensorDataset(tx_train_s, im_train_s, img_indices_train)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
print(f"Train pairs: {len(train_dataset)}, Validation pairs: {len(val_idx)}")

# -----------------------------
# 4) Model (VAE translator)
# -----------------------------
class VAETranslator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=8192, latent_dim=1536, output_dim=1536, dropout=0.4):
        super().__init__()
        self.latent_dim = latent_dim
        print(f"Initializing Giga-VAE (hidden_dim={hidden_dim}, latent_dim={latent_dim})")

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, latent_dim * 2) # mu | logvar
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, output_dim)
        )

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu_logvar = self.encoder(x)
        mu     = mu_logvar[..., :self.latent_dim]
        logvar = mu_logvar[..., self.latent_dim:]
        if self.training:
            z = self.reparameterize(mu, logvar)
            y = self.decoder(z)
            return F.normalize(y, p=2, dim=1), mu, logvar
        else:
            y_eval = self.decoder(mu)
            return F.normalize(y_eval, p=2, dim=1)

# -----------------------------
# 5) Losses & helpers (Champion)
# -----------------------------
TAU = 0.01                 # fixed temperature (no runaway scales)
MARGIN = 0.06
LOSS_WEIGHT_CONTRASTIVE = 0.7
LOSS_WEIGHT_TRIPLET = 0.3
KL_WEIGHT = 1e-5

MIXUP_ALPHA = 0.3
MIXUP_WARMUP_EPOCHS = 3
INPUT_STD = float(tx_train_t.std().item())
NOISE_LEVEL = 0.02  # sigma = 0.02 * input_std

triplet_loss_fn = nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
    margin=MARGIN
)
print(f"Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight={KL_WEIGHT}) ***")

def kl_loss(mu, logvar):
    v = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return v / mu.size(0)

# -----------------------------
# 6) Recall@K (cosine & CSLS)
# -----------------------------
@torch.no_grad()
def recall_at_k_csls_from_embeds(Q_embed, D_embed, gt_idx, ks=(1,5,10,50), k_csls=10, chunk=1024):
    out = {}
    # Pass 1: mean knn in D against Q
    mean_knn_d_list = []
    for i in tqdm(range(0, len(D_embed), 512), desc="CSLS (Pass 1/2)", leave=False):
        sim_T_chunk = D_embed[i:i+512] @ Q_embed.T
        knn_d_chunk = torch.topk(sim_T_chunk, k=k_csls, dim=1).values
        mean_knn_d_list.append(knn_d_chunk.mean(1))
    mean_knn_d = torch.cat(mean_knn_d_list, dim=0).unsqueeze(0)

    for i in tqdm(range(0, len(Q_embed), chunk), desc="CSLS (Pass 2/2)", leave=False):
        sim_chunk = Q_embed[i:i+chunk] @ D_embed.T
        gt_chunk  = gt_idx[i:i+chunk]
        knn_q_chunk = torch.topk(sim_chunk, k=k_csls, dim=1).values
        mean_knn_q = knn_q_chunk.mean(1, keepdim=True)
        csls = 2 * sim_chunk - mean_knn_q - mean_knn_d
        top_idx = torch.argsort(csls, dim=1, descending=True)
        for k in ks:
            top_k = top_idx[:, :k]
            correct = (top_k == gt_chunk.unsqueeze(1)).any(dim=1).float().mean().item()
            out[f"R@{k}"] = out.get(f"R@{k}", 0.0) + correct * (len(sim_chunk)/len(Q_embed))
    return out

@torch.no_grad()
def recall_at_k_cos_from_embeds(Q_embed, D_embed, gt_idx, ks=(1,5,10,50)):
    sim = Q_embed @ D_embed.T
    top_idx = torch.argsort(sim, dim=1, descending=True)
    out = {}
    for k in ks:
        top_k = top_idx[:, :k]
        out[f"R@{k}"] = (top_k == gt_idx.unsqueeze(1)).any(dim=1).float().mean().item()
    return out

# -----------------------------
# 7) Diagnostics (dim-safe)
# -----------------------------
class DiagnosticMetrics:
    @staticmethod
    @torch.no_grad()
    def emb2emb_linear_transfer(src_emb, tgt_emb, model_pred_emb=None, alpha=1.0):
        src_np = (src_emb.detach().cpu().numpy()
                  if isinstance(src_emb, torch.Tensor) else src_emb)
        tgt_np = (tgt_emb.detach().cpu().numpy()
                  if isinstance(tgt_emb, torch.Tensor) else tgt_emb)

        ridge = Ridge(alpha=alpha, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred = ridge.predict(src_np)

        pred_t = F.normalize(torch.from_numpy(pred).float(), p=2, dim=1)
        tgt_t  = F.normalize(torch.from_numpy(tgt_np).float(), p=2, dim=1)

        transfer_mse = F.mse_loss(pred_t, tgt_t).item()
        transfer_cosine = F.cosine_similarity(pred_t, tgt_t).mean().item()

        if model_pred_emb is not None:
            if not isinstance(model_pred_emb, torch.Tensor):
                model_pred_emb = torch.from_numpy(model_pred_emb).float()
            model_cos = F.cosine_similarity(
                F.normalize(model_pred_emb.float(), p=2, dim=1), tgt_t
            ).mean().item()
            improvement = transfer_cosine - model_cos
            baseline_name = 'model_cosine'
        else:
            model_cos = 0.0
            improvement = transfer_cosine
            baseline_name = 'random_baseline'

        return {
            'transfer_mse': transfer_mse,
            'transfer_cosine': transfer_cosine,
            baseline_name: model_cos,
            'improvement': improvement
        }

    @staticmethod
    @torch.no_grad()
    def stitching_penalty_monitor(src_emb, tgt_emb):
        # Ridge map to tgt dim first (handles 1024->1536), then ortho vs ridge
        src_np = src_emb.detach().cpu().numpy()
        tgt_np = tgt_emb.detach().cpu().numpy()

        ridge = Ridge(alpha=0.01, fit_intercept=False)
        ridge.fit(src_np, tgt_np)
        pred_full = ridge.predict(src_np)  # (B, D_tgt)

        R, _ = orthogonal_procrustes(pred_full, tgt_np)
        pred_ortho = pred_full @ R

        mse_full = float(np.mean((pred_full  - tgt_np) ** 2))
        mse_orth = float(np.mean((pred_ortho - tgt_np) ** 2))
        gap_pct  = 100.0 * (mse_orth - mse_full) / (mse_full + 1e-8)

        pred_ortho_t = F.normalize(torch.from_numpy(pred_ortho).float(), p=2, dim=1)
        pred_full_t  = F.normalize(torch.from_numpy(pred_full ).float(), p=2, dim=1)
        tgt_norm     = F.normalize(torch.from_numpy(tgt_np   ).float(), p=2, dim=1)

        cos_orth = F.cosine_similarity(pred_ortho_t, tgt_norm).mean().item()
        cos_full = F.cosine_similarity(pred_full_t , tgt_norm).mean().item()

        return {
            'mse_orthogonal': mse_orth,
            'mse_full': mse_full,
            'gap_percent': gap_pct,
            'cos_orthogonal': cos_orth,
            'cos_full': cos_full
        }

    @staticmethod
    @torch.no_grad()
    def mutual_knn_alignment(emb1, emb2, k=10):
        sim1 = emb1 @ emb1.T
        sim2 = emb2 @ emb2.T
        _, knn1 = torch.topk(sim1, k=k+1, dim=1)
        _, knn2 = torch.topk(sim2, k=k+1, dim=1)
        knn1 = knn1[:, 1:]
        knn2 = knn2[:, 1:]
        mutual = 0
        for i in range(len(emb1)):
            set1 = set(knn1[i].cpu().tolist())
            set2 = set(knn2[i].cpu().tolist())
            mutual += len(set1 & set2)
        return mutual / (len(emb1) * k)

# -----------------------------
# 8) Clamp & Ridge helpers
# -----------------------------
def dynamic_thresholding_clamp(mat, q=99.5):
    # per-dimension clamp at percentile q
    with torch.no_grad():
        abs_mat = mat.abs()
        perc = torch.quantile(abs_mat, q/100.0, dim=0, keepdim=True)
    return torch.clamp(mat, min=-perc, max=perc)

def fit_ridge_text2img(tx_cpu_np, im_cpu_np, alpha=1.0):
    ridge = Ridge(alpha=alpha, fit_intercept=False)
    ridge.fit(tx_cpu_np, im_cpu_np)
    return ridge

@torch.no_grad()
def encode_queries(model, X, chunk=1024, clamp_q=None):
    model.eval()
    dev = next(model.parameters()).device
    outs = []
    for i in range(0, len(X), chunk):
        y = model(X[i:i+chunk].to(dev))
        if clamp_q is not None:
            y = dynamic_thresholding_clamp(y, q=clamp_q)
        y = F.normalize(y, p=2, dim=1)
        outs.append(y)
    return torch.cat(outs, dim=0)

# -----------------------------
# 9) Training setup
# -----------------------------
EPOCHS = 60
START_LR = 3e-4
WEIGHT_DECAY = 4e-4
RUN_ID = f"giga_vae_model_{int(time.time())}"
FINAL_SAVE_PATH = f"giga_vae_model_{RUN_ID}_best.pth"

val_query_subset   = tx_val_t[:5000]
val_indices_subset = img_indices_val[:5000]
val_db_subset      = im_train_t_unique  # DB = unique images

print("Fitting ridge baseline (text->image) for validation/test blending (alpha=1.0)...")
# Use expanded image embeddings so X and y have the same number of samples (125k).
ridge = fit_ridge_text2img(tx_train_t.numpy(), im_train_exp.numpy(), alpha=1.0)
ALPHA_BLEND = 0.10

model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=START_LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda')

print(f"\n--- Starting Training (RUN {RUN_ID}) ---")
print(f"Epochs={EPOCHS}, LR={START_LR}, WD={WEIGHT_DECAY}, MixupWarmup={MIXUP_WARMUP_EPOCHS}â†’Î±={MIXUP_ALPHA}\n")

best_csls_r50 = 0.0

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_recon, total_kl = 0.0, 0.0

    for x_batch, y_batch, img_ids in tqdm(train_loader, desc=f"Epoch {epoch:03d}/{EPOCHS}", leave=False):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        # Noise conditioning
        noise = torch.randn_like(x_batch) * (NOISE_LEVEL * INPUT_STD)
        x_aug = x_batch + noise

        # Mixup (after warmup)
        if epoch > MIXUP_WARMUP_EPOCHS:
            idx_shuffle = torch.randperm(x_aug.size(0), device=x_aug.device)
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            x_aug = lam * x_aug + (1 - lam) * x_aug[idx_shuffle]
            y_batch = F.normalize(lam * y_batch + (1 - lam) * y_batch[idx_shuffle], p=2, dim=1)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda'):
            y_pred, mu, logvar = model(x_aug)

            # Champion asymmetric CE
            sims_with_tau = (y_pred @ y_batch.T) / TAU
            labels = torch.arange(y_pred.size(0), device=device)
            loss_con = F.cross_entropy(sims_with_tau, labels)

            # Hard negative mining (no-tau space)
            with torch.no_grad():
                sims_no_tau = y_pred @ y_batch.T
                positive_mask = torch.eye(y_batch.size(0), dtype=torch.bool, device=device)
                sims_no_tau.masked_fill_(positive_mask, -float('inf'))
                hard_neg_idx = sims_no_tau.argmax(dim=1)
            y_hard_neg = y_batch[hard_neg_idx]
            loss_tri = triplet_loss_fn(y_pred, y_batch, y_hard_neg)

            recon_loss = (LOSS_WEIGHT_CONTRASTIVE * loss_con) + (LOSS_WEIGHT_TRIPLET * loss_tri)
            kl = kl_loss(mu, logvar)
            loss = recon_loss + (KL_WEIGHT * kl)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon += float(recon_loss.item())
        total_kl    += float(kl.item())

    current_lr = scheduler.get_last_lr()[0]
    scheduler.step()
    avg_recon = total_recon / len(train_loader)
    avg_kl    = total_kl    / len(train_loader)

    # ---------- Validation (QUICK, test-parity: clamp + ridge-blend) ----------
    with torch.no_grad():
        # Queries: encode, clamp, L2
        Q = encode_queries(model, val_query_subset, chunk=1024, clamp_q=99.5)

        # Ridge predictions + blend
        R = torch.from_numpy(ridge.predict(val_query_subset.numpy())).to(Q.device).float()
        R = F.normalize(R, p=2, dim=1)
        Qb = F.normalize((1 - ALPHA_BLEND) * Q + ALPHA_BLEND * R, p=2, dim=1)

        D = im_train_t_unique.to(Q.device)

        rec_cos  = recall_at_k_cos_from_embeds(Qb, D, val_indices_subset.to(Q.device))
        rec_csls = recall_at_k_csls_from_embeds(Qb, D, val_indices_subset.to(Q.device))

    save_marker = ""
    if rec_csls['R@50'] > best_csls_r50:
        best_csls_r50 = rec_csls['R@50']
        torch.save(model.state_dict(), FINAL_SAVE_PATH)
        save_marker = "<- Best (CSLS R@50, clamp+blend) Saved."

    print(
        f"\nEpoch {epoch:03d}: Recon={avg_recon:.4f} KL={avg_kl:.2f} (LR={current_lr:.1E}) | "
        f"COS R@10={rec_cos['R@10']:.4f} R@50={rec_cos['R@50']:.4f} | "
        f"CSLS R@10={rec_csls['R@10']:.4f} R@50={rec_csls['R@50']:.4f} {save_marker}"
    )

    # -------- Diagnostics (epoch 1 and every 5) --------
    if epoch == 1 or epoch % 5 == 0:
        with torch.no_grad():
            k = min(2000, len(tx_val_t))
            tx_cpu  = tx_val_t[:k].float().cpu()
            tgt_cpu = im_val_t[:k].float().cpu()
            # model preds for the same slice
            preds = []
            for i in range(0, k, 1024):
                preds.append(model(tx_cpu[i:i+1024].to(device)).cpu())
            pred_cpu = torch.cat(preds, dim=0)

            emb2emb = DiagnosticMetrics.emb2emb_linear_transfer(tx_cpu, tgt_cpu, model_pred_emb=pred_cpu, alpha=1.0)
            stitch  = DiagnosticMetrics.stitching_penalty_monitor(tx_cpu, tgt_cpu)
            mk_align= DiagnosticMetrics.mutual_knn_alignment(
                F.normalize(tx_cpu, p=2, dim=1),
                F.normalize(tgt_cpu, p=2, dim=1),
                k=10
            )
        print(f"[Diagnostics @ epoch {epoch:03d}]")
        print(f"  emb2emb_linear_transfer: {emb2emb}")
        print(f"  stitching_penalty_monitor: {stitch}")
        print(f"  mutual_knn_alignment_fraction: {mk_align:.4f}")

print(f"\nðŸŽ¯ Training complete. Best model (CSLS R@50 clamp+blend={best_csls_r50:.4f}) saved as {FINAL_SAVE_PATH}")

# -----------------------------
# 10) Inference + Submission (clamp + ridge blend)
# -----------------------------
print(f"Loading BEST model from {FINAL_SAVE_PATH} for inference...")
model = VAETranslator(dropout=0.4, hidden_dim=8192, latent_dim=1536, output_dim=1536).to(device)
model.load_state_dict(torch.load(FINAL_SAVE_PATH, map_location=device))
model.eval()
print("Best model loaded.")

# Encode test queries
Q_test = encode_queries(model, tx_test_t, chunk=1024, clamp_q=99.5)

# Ridge predictions + blend
R_test = torch.from_numpy(ridge.predict(tx_test_t.numpy())).to(Q_test.device).float()
R_test = F.normalize(R_test, p=2, dim=1)
Q_test = F.normalize((1 - ALPHA_BLEND) * Q_test + ALPHA_BLEND * R_test, p=2, dim=1)

# Build submission
test_ids = test_data["captions/ids"].astype(int)
submission = pd.DataFrame({
    "id": test_ids,
    "embedding": [list(map(float, row)) for row in Q_test.cpu().numpy()]
})
sub_name = f"submission_giga_vae_run_{RUN_ID}.csv"
submission.to_csv(sub_name, index=False)
print(f"\nâœ… {sub_name} saved successfully.")

# -----------------------------
# 11) Full validation (cosine + CSLS) with clamp+blend
# -----------------------------
print("Calculating final validation scores (COSINE & CSLS, clamp+blend)...")
Q_val = encode_queries(model, tx_val_t, chunk=1024, clamp_q=99.5)
R_val = torch.from_numpy(ridge.predict(tx_val_t.numpy())).to(Q_val.device).float()
R_val = F.normalize(R_val, p=2, dim=1)
Q_valb = F.normalize((1 - ALPHA_BLEND) * Q_val + ALPHA_BLEND * R_val, p=2, dim=1)

D_val = im_train_t_unique.to(Q_valb.device)
gt    = img_indices_val.to(Q_valb.device)

rec_cos_full  = recall_at_k_cos_from_embeds(Q_valb, D_val, gt)
rec_csls_full = recall_at_k_csls_from_embeds(Q_valb, D_val, gt)

print("\n--- Full Validation Results (COSINE, clamp+blend) ---")
print(f"R@1:  {rec_cos_full['R@1']:<8.4f}")
print(f"R@5:  {rec_cos_full['R@5']:<8.4f}")
print(f"R@10: {rec_cos_full['R@10']:<8.4f}")
print(f"R@50: {rec_cos_full['R@50']:<8.4f}")

print("\n--- Full Validation Results (CSLS, clamp+blend) ---")
print(f"R@1:  {rec_csls_full['R@1']:<8.4f}")
print(f"R@5:  {rec_csls_full['R@5']:<8.4f}")
print(f"R@10: {rec_csls_full['R@10']:<8.4f}")
print(f"R@50: {rec_csls_full['R@50']:<8.4f}")


Using device: cuda
Train shapes: (125000, 1024), (25000, 1536), (125000, 1536)
Test shape: (1500, 1024)
Data preprocessed and normalized (on CPU).
Train pairs: 112500, Validation pairs: 12500
Using CHAMPION ASYMMETRIC loss + *** VAE KL Loss (weight=1e-05) ***
Fitting ridge baseline (text->image) for validation/test blending (alpha=1.0)...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)

--- Starting Training (RUN giga_vae_model_1762511739) ---
Epochs=60, LR=0.0003, WD=0.0004, MixupWarmup=3â†’Î±=0.3



                                                               


Epoch 001: Recon=2.5464 KL=1241.21 (LR=3.0E-04) | COS R@10=0.3394 R@50=0.6114 | CSLS R@10=0.4118 R@50=0.6666 <- Best (CSLS R@50, clamp+blend) Saved.
[Diagnostics @ epoch 001]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.1311386376619339, 'improvement': 0.49270932376384735}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 002: Recon=1.4297 KL=1695.65 (LR=3.0E-04) | COS R@10=0.4036 R@50=0.6856 | CSLS R@10=0.4706 R@50=0.7238 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 003: Recon=1.1478 KL=1939.18 (LR=3.0E-04) | COS R@10=0.4630 R@50=0.7188 | CSLS R@10=0.5082 R@50=0.7584 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 004: Recon=0.9690 KL=2066.59 (LR=3.0E-04) | COS R@10=0.4732 R@50=0.7436 | CSLS R@10=0.5262 R@50=0.7716 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 005: Recon=0.8398 KL=2163.82 (LR=3.0E-04) | COS R@10=0.4968 R@50=0.7586 | CSLS R@10=0.5454 R@50=0.7850 <- Best (CSLS R@50, clamp+blend) Saved.
[Diagnostics @ epoch 005]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.16082946956157684, 'improvement': 0.4630184918642044}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 006: Recon=0.7511 KL=2253.60 (LR=2.9E-04) | COS R@10=0.5178 R@50=0.7738 | CSLS R@10=0.5598 R@50=0.8028 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 007: Recon=0.6765 KL=2334.25 (LR=2.9E-04) | COS R@10=0.5312 R@50=0.7818 | CSLS R@10=0.5802 R@50=0.8094 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 008: Recon=0.6095 KL=2406.21 (LR=2.9E-04) | COS R@10=0.5448 R@50=0.7880 | CSLS R@10=0.5900 R@50=0.8152 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 009: Recon=0.5659 KL=2491.30 (LR=2.9E-04) | COS R@10=0.5596 R@50=0.7964 | CSLS R@10=0.6024 R@50=0.8220 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 010: Recon=0.4896 KL=2547.76 (LR=2.8E-04) | COS R@10=0.5736 R@50=0.8068 | CSLS R@10=0.6118 R@50=0.8304 <- Best (CSLS R@50, clamp+blend) Saved.
[Diagnostics @ epoch 010]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.17442266643047333, 'improvement': 0.4494252949953079}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 011: Recon=0.4421 KL=2572.13 (LR=2.8E-04) | COS R@10=0.5810 R@50=0.8118 | CSLS R@10=0.6168 R@50=0.8320 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 012: Recon=0.4137 KL=2587.99 (LR=2.8E-04) | COS R@10=0.5862 R@50=0.8156 | CSLS R@10=0.6186 R@50=0.8362 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 013: Recon=0.3939 KL=2614.73 (LR=2.7E-04) | COS R@10=0.5872 R@50=0.8164 | CSLS R@10=0.6180 R@50=0.8342 


                                                               


Epoch 014: Recon=0.3757 KL=2626.68 (LR=2.7E-04) | COS R@10=0.5928 R@50=0.8190 | CSLS R@10=0.6260 R@50=0.8390 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 015: Recon=0.3598 KL=2648.87 (LR=2.6E-04) | COS R@10=0.5988 R@50=0.8230 | CSLS R@10=0.6300 R@50=0.8414 <- Best (CSLS R@50, clamp+blend) Saved.
[Diagnostics @ epoch 015]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.17873616516590118, 'improvement': 0.44511179625988007}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 016: Recon=0.3404 KL=2672.19 (LR=2.6E-04) | COS R@10=0.6024 R@50=0.8266 | CSLS R@10=0.6306 R@50=0.8412 


                                                               


Epoch 017: Recon=0.3253 KL=2680.88 (LR=2.5E-04) | COS R@10=0.6066 R@50=0.8294 | CSLS R@10=0.6340 R@50=0.8430 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 018: Recon=0.3119 KL=2698.49 (LR=2.4E-04) | COS R@10=0.6092 R@50=0.8268 | CSLS R@10=0.6374 R@50=0.8434 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 019: Recon=0.2907 KL=2721.32 (LR=2.4E-04) | COS R@10=0.6134 R@50=0.8286 | CSLS R@10=0.6410 R@50=0.8494 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 020: Recon=0.2712 KL=2723.42 (LR=2.3E-04) | COS R@10=0.6160 R@50=0.8312 | CSLS R@10=0.6406 R@50=0.8472 
[Diagnostics @ epoch 020]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.18533162772655487, 'improvement': 0.4385163336992264}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 021: Recon=0.2649 KL=2748.54 (LR=2.2E-04) | COS R@10=0.6192 R@50=0.8346 | CSLS R@10=0.6488 R@50=0.8472 


                                                               


Epoch 022: Recon=0.2584 KL=2739.24 (LR=2.2E-04) | COS R@10=0.6208 R@50=0.8366 | CSLS R@10=0.6446 R@50=0.8482 


                                                               


Epoch 023: Recon=0.2519 KL=2740.69 (LR=2.1E-04) | COS R@10=0.6238 R@50=0.8348 | CSLS R@10=0.6482 R@50=0.8526 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 024: Recon=0.2467 KL=2747.31 (LR=2.0E-04) | COS R@10=0.6244 R@50=0.8348 | CSLS R@10=0.6476 R@50=0.8498 


                                                               


Epoch 025: Recon=0.2407 KL=2754.46 (LR=2.0E-04) | COS R@10=0.6240 R@50=0.8360 | CSLS R@10=0.6500 R@50=0.8518 
[Diagnostics @ epoch 025]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.18753179907798767, 'improvement': 0.4363161623477936}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 026: Recon=0.2348 KL=2753.10 (LR=1.9E-04) | COS R@10=0.6236 R@50=0.8360 | CSLS R@10=0.6548 R@50=0.8538 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 027: Recon=0.2299 KL=2751.23 (LR=1.8E-04) | COS R@10=0.6280 R@50=0.8352 | CSLS R@10=0.6570 R@50=0.8520 


                                                               


Epoch 028: Recon=0.2255 KL=2759.82 (LR=1.7E-04) | COS R@10=0.6326 R@50=0.8402 | CSLS R@10=0.6582 R@50=0.8526 


                                                               


Epoch 029: Recon=0.2210 KL=2769.73 (LR=1.7E-04) | COS R@10=0.6280 R@50=0.8404 | CSLS R@10=0.6590 R@50=0.8538 


                                                               


Epoch 030: Recon=0.2163 KL=2751.50 (LR=1.6E-04) | COS R@10=0.6316 R@50=0.8394 | CSLS R@10=0.6602 R@50=0.8532 
[Diagnostics @ epoch 030]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.189003586769104, 'improvement': 0.43484437465667725}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 031: Recon=0.2145 KL=2785.71 (LR=1.5E-04) | COS R@10=0.6320 R@50=0.8424 | CSLS R@10=0.6598 R@50=0.8546 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 032: Recon=0.2090 KL=2768.89 (LR=1.4E-04) | COS R@10=0.6294 R@50=0.8418 | CSLS R@10=0.6608 R@50=0.8542 


                                                               


Epoch 033: Recon=0.2039 KL=2767.32 (LR=1.3E-04) | COS R@10=0.6324 R@50=0.8436 | CSLS R@10=0.6594 R@50=0.8526 


                                                               


Epoch 034: Recon=0.2016 KL=2770.03 (LR=1.3E-04) | COS R@10=0.6322 R@50=0.8434 | CSLS R@10=0.6630 R@50=0.8542 


                                                               


Epoch 035: Recon=0.2010 KL=2762.93 (LR=1.2E-04) | COS R@10=0.6312 R@50=0.8440 | CSLS R@10=0.6616 R@50=0.8542 
[Diagnostics @ epoch 035]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.19159509241580963, 'improvement': 0.4322528690099716}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 036: Recon=0.1938 KL=2773.52 (LR=1.1E-04) | COS R@10=0.6306 R@50=0.8426 | CSLS R@10=0.6618 R@50=0.8534 


                                                               


Epoch 037: Recon=0.1946 KL=2783.84 (LR=1.0E-04) | COS R@10=0.6350 R@50=0.8458 | CSLS R@10=0.6644 R@50=0.8562 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 038: Recon=0.1928 KL=2785.75 (LR=9.6E-05) | COS R@10=0.6350 R@50=0.8432 | CSLS R@10=0.6618 R@50=0.8560 


                                                               


Epoch 039: Recon=0.1869 KL=2779.84 (LR=8.9E-05) | COS R@10=0.6348 R@50=0.8444 | CSLS R@10=0.6610 R@50=0.8548 


                                                               


Epoch 040: Recon=0.1888 KL=2780.63 (LR=8.2E-05) | COS R@10=0.6352 R@50=0.8466 | CSLS R@10=0.6626 R@50=0.8548 
[Diagnostics @ epoch 040]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.19284290075302124, 'improvement': 0.43100506067276}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 041: Recon=0.1820 KL=2764.17 (LR=7.5E-05) | COS R@10=0.6352 R@50=0.8458 | CSLS R@10=0.6656 R@50=0.8556 


                                                               


Epoch 042: Recon=0.1836 KL=2801.22 (LR=6.8E-05) | COS R@10=0.6358 R@50=0.8460 | CSLS R@10=0.6656 R@50=0.8556 


                                                               


Epoch 043: Recon=0.1806 KL=2770.70 (LR=6.2E-05) | COS R@10=0.6372 R@50=0.8472 | CSLS R@10=0.6644 R@50=0.8558 


                                                               


Epoch 044: Recon=0.1773 KL=2779.94 (LR=5.6E-05) | COS R@10=0.6392 R@50=0.8478 | CSLS R@10=0.6662 R@50=0.8556 


                                                               


Epoch 045: Recon=0.1782 KL=2787.80 (LR=5.0E-05) | COS R@10=0.6402 R@50=0.8474 | CSLS R@10=0.6654 R@50=0.8550 
[Diagnostics @ epoch 045]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.1938338279724121, 'improvement': 0.43001413345336914}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 046: Recon=0.1768 KL=2777.08 (LR=4.4E-05) | COS R@10=0.6422 R@50=0.8466 | CSLS R@10=0.6674 R@50=0.8552 


                                                               


Epoch 047: Recon=0.1774 KL=2782.83 (LR=3.9E-05) | COS R@10=0.6434 R@50=0.8468 | CSLS R@10=0.6674 R@50=0.8568 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 048: Recon=0.1748 KL=2782.07 (LR=3.3E-05) | COS R@10=0.6410 R@50=0.8484 | CSLS R@10=0.6670 R@50=0.8566 


                                                               


Epoch 049: Recon=0.1739 KL=2774.45 (LR=2.9E-05) | COS R@10=0.6400 R@50=0.8472 | CSLS R@10=0.6674 R@50=0.8562 


                                                               


Epoch 050: Recon=0.1739 KL=2790.00 (LR=2.4E-05) | COS R@10=0.6402 R@50=0.8460 | CSLS R@10=0.6680 R@50=0.8566 
[Diagnostics @ epoch 050]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.1943666636943817, 'improvement': 0.42948129773139954}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 051: Recon=0.1717 KL=2764.16 (LR=2.0E-05) | COS R@10=0.6406 R@50=0.8478 | CSLS R@10=0.6672 R@50=0.8570 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 052: Recon=0.1718 KL=2785.36 (LR=1.6E-05) | COS R@10=0.6418 R@50=0.8482 | CSLS R@10=0.6688 R@50=0.8572 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 053: Recon=0.1715 KL=2778.67 (LR=1.3E-05) | COS R@10=0.6420 R@50=0.8486 | CSLS R@10=0.6680 R@50=0.8572 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 054: Recon=0.1728 KL=2779.73 (LR=1.0E-05) | COS R@10=0.6420 R@50=0.8474 | CSLS R@10=0.6680 R@50=0.8572 <- Best (CSLS R@50, clamp+blend) Saved.


                                                               


Epoch 055: Recon=0.1711 KL=2778.41 (LR=7.3E-06) | COS R@10=0.6424 R@50=0.8482 | CSLS R@10=0.6678 R@50=0.8572 
[Diagnostics @ epoch 055]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.1943126916885376, 'improvement': 0.42953526973724365}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223


                                                               


Epoch 056: Recon=0.1720 KL=2777.56 (LR=5.1E-06) | COS R@10=0.6428 R@50=0.8490 | CSLS R@10=0.6676 R@50=0.8568 


                                                               


Epoch 057: Recon=0.1715 KL=2774.95 (LR=3.3E-06) | COS R@10=0.6422 R@50=0.8490 | CSLS R@10=0.6684 R@50=0.8572 


                                                               


Epoch 058: Recon=0.1716 KL=2792.53 (LR=1.8E-06) | COS R@10=0.6416 R@50=0.8486 | CSLS R@10=0.6682 R@50=0.8572 


                                                               


Epoch 059: Recon=0.1706 KL=2782.46 (LR=8.2E-07) | COS R@10=0.6416 R@50=0.8486 | CSLS R@10=0.6682 R@50=0.8572 


                                                               


Epoch 060: Recon=0.1698 KL=2776.88 (LR=2.1E-07) | COS R@10=0.6418 R@50=0.8488 | CSLS R@10=0.6682 R@50=0.8572 
[Diagnostics @ epoch 060]
  emb2emb_linear_transfer: {'transfer_mse': 0.000489781319629401, 'transfer_cosine': 0.6238479614257812, 'model_cosine': 0.1943119466304779, 'improvement': 0.42953601479530334}
  stitching_penalty_monitor: {'mse_orthogonal': 0.00024280583602376282, 'mse_full': 0.00024280586512759328, 'gap_percent': -1.198596740844772e-05, 'cos_orthogonal': 0.7938115000724792, 'cos_full': 0.7938115000724792}
  mutual_knn_alignment_fraction: 0.1223

ðŸŽ¯ Training complete. Best model (CSLS R@50 clamp+blend=0.8572) saved as giga_vae_model_giga_vae_model_1762511739_best.pth
Loading BEST model from giga_vae_model_giga_vae_model_1762511739_best.pth for inference...
Initializing Giga-VAE (hidden_dim=8192, latent_dim=1536)
Best model loaded.

âœ… submission_giga_vae_run_giga_vae_model_1762511739.csv saved successfully.
Calculating final validation scores (COSINE & CSLS, clamp

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.49 GiB. GPU 0 has a total capacity of 14.74 GiB of which 2.74 GiB is free. Process 4234 has 12.00 GiB memory in use. Of the allocated memory 11.21 GiB is allocated by PyTorch, and 662.17 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)