In [1]:
import warnings

warnings.filterwarnings("ignore", category=UserWarning)

from train import training_loop_
from utils.utils import *
import toml
import torch
import numpy as np
import random
import accelerate
from utils.collate import MultiencoderTokenizedDataset, TokenizedCollator
from utils.model_utils import get_sentence_embedding_dimension, load_encoder
from utils.streaming_utils import load_streaming_embeddings, process_batch
from translators.Discriminator import Discriminator
from torch.optim.lr_scheduler import LambdaLR
from utils.gan import LeastSquaresGAN, RelativisticGAN, VanillaGAN
from utils.eval_utils import EarlyStopper, eval_loop_
from utils.wandb_logger import Logger

In [2]:
cfg_dict = toml.load(f"configs/unsupervised.toml")
cfg = SimpleNamespace(**{k: v for d in cfg_dict.values() for k, v in d.items()})
cfg.num_points = 10
cfg.epochs = 10
use_val_set = hasattr(cfg, "val_size")
cfg

namespace(seed=5,
          sampling_seed=5,
          train_dataset_seed=10,
          val_dataset_seed=42,
          normalize_embeddings=True,
          mixed_precision='fp16',
          weight_init='kaiming',
          dataset='nq',
          max_seq_length=64,
          unsup_emb='stella',
          sup_emb='gte',
          n_embs_per_batch=1,
          finetune_mode=False,
          noise_level=0.0,
          style='res_mlp',
          norm_style='batch',
          depth=3,
          transform_depth=4,
          d_adapter=1024,
          d_hidden=1024,
          d_transform=1024,
          use_small_output_adapters=False,
          use_residual_adapters=True,
          gan_style='least_squares',
          disc_depth=5,
          disc_dim=1024,
          use_residual=True,
          bs=256,
          gradient_accumulation_steps=1,
          lr=2e-05,
          no_scheduler=True,
          max_grad_norm=1000.0,
          loss_coefficient_reverse_rec=0.0,
          loss_coefficient_

In [3]:
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)

In [4]:
mp = "no"
accelerator = accelerate.Accelerator(
    mixed_precision=mp,  # None „Å´„Åó„Å™„ÅÑ
    gradient_accumulation_steps=cfg.gradient_accumulation_steps,
)

# https://github.com/huggingface/transformers/issues/26548
accelerator.dataloader_config.dispatch_batches = False

if hasattr(cfg, "force_wandb_name") and cfg.force_wandb_name:
    save_dir = cfg.save_dir.format(cfg.wandb_name)
else:
    # unknown_cfg „Çí‰Ωø„Çè„Åö„ÄÅcfg ÂÜÖ„ÅÆÊó¢Â≠òÂÄ§„Å†„Åë„Åß wandb_name „ÇíÊ±∫ÂÆö
    if not hasattr(cfg, "wandb_name") or cfg.wandb_name is None:
        cfg.wandb_name = "default_run"  # ‰ªªÊÑè„ÅÆ„Éá„Éï„Ç©„É´„ÉàÂêç
    save_dir = cfg.save_dir.format(
        cfg.latent_dims if hasattr(cfg, "latent_dims") else cfg.wandb_name
    )

logger = Logger(
    project=cfg.wandb_project,
    name=cfg.wandb_name,
    dummy=(cfg.wandb_project is None) or not (cfg.use_wandb),
    config=cfg,
)

print("Running Experiment:", cfg.wandb_name)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33minoue0426[0m ([33mSanderLab[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Weave is installed but not imported. Add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Running Experiment: unsupervised


In [5]:
sup_encs = {
    cfg.sup_emb: load_encoder(
        cfg.sup_emb,
        mixed_precision=(
            cfg.mixed_precision if hasattr(cfg, "mixed_precision") else None
        ),
    )
}
encoder_dims = {cfg.sup_emb: get_sentence_embedding_dimension(sup_encs[cfg.sup_emb])}
translator = load_n_translator(cfg, encoder_dims)

initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>


In [6]:
unsup_enc = {
    cfg.unsup_emb: load_encoder(
        cfg.unsup_emb,
        mixed_precision=(
            cfg.mixed_precision if hasattr(cfg, "mixed_precision") else None
        ),
    )
}
unsup_dim = {cfg.unsup_emb: get_sentence_embedding_dimension(unsup_enc[cfg.unsup_emb])}
translator.add_encoders(unsup_dim, overwrite_embs=[cfg.unsup_emb])

No sentence-transformers model found with name infgrad/stella-base-en-v2. Creating a new one with mean pooling.


initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>
initializing <class 'torch.nn.modules.linear.Linear'>


In [7]:
cfg.num_params = sum(x.numel() for x in translator.parameters())
num_workers = min(get_num_proc(), 8)
dset = load_streaming_embeddings(cfg.dataset)
print(f"Using {num_workers} workers and {len(dset)} datapoints")

dset_dict = dset.train_test_split(test_size=cfg.val_size, seed=cfg.val_dataset_seed)
dset = dset_dict["train"]
valset = dset_dict["test"]

assert hasattr(cfg, "num_points") or hasattr(cfg, "unsup_points")
dset = dset.shuffle(seed=cfg.train_dataset_seed)
if hasattr(cfg, "num_points"):
    assert cfg.num_points > 0 and cfg.num_points <= len(dset) // 2
    supset = dset.select(range(cfg.num_points))
    unsupset = dset.select(range(cfg.num_points, cfg.num_points * 2))
elif hasattr(cfg, "unsup_points"):
    unsupset = dset.select(range(min(cfg.unsup_points, len(dset))))
    supset = dset.select(
        range(min(cfg.unsup_points, len(dset)), len(dset) - len(unsupset))
    )

Using 8 workers and 5332023 datapoints


In [8]:
# --- „Éá„Éê„Ç§„Çπ„Å´Âøú„Åò„Å¶ pin_memory „ÇíË®≠ÂÆö ---
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
pin_memory_flag = True if use_cuda else False

print(
    f"[Init] Device: {'CUDA' if use_cuda else ('MPS' if use_mps else 'CPU')}, pin_memory={pin_memory_flag}"
)

# --- Dataset „ÅÆ‰ΩúÊàê ---
supset = MultiencoderTokenizedDataset(
    dataset=supset,
    encoders=sup_encs,
    n_embs_per_batch=cfg.n_embs_per_batch,
    batch_size=cfg.bs,
    max_length=cfg.max_seq_length,
    seed=cfg.sampling_seed,
)

unsupset = MultiencoderTokenizedDataset(
    dataset=unsupset,
    encoders=unsup_enc,
    n_embs_per_batch=1,
    batch_size=cfg.bs,
    max_length=cfg.max_seq_length,
    seed=cfg.sampling_seed,
)

# --- DataLoader „ÅÆ‰ΩúÊàê ---
sup_dataloader = DataLoader(
    supset,
    batch_size=cfg.bs,
    num_workers=num_workers // 2,
    shuffle=True,
    pin_memory=pin_memory_flag,
    prefetch_factor=None,
    collate_fn=TokenizedCollator(),
    drop_last=True,
)

unsup_dataloader = DataLoader(
    unsupset,
    batch_size=cfg.bs,
    num_workers=num_workers // 2,
    shuffle=True,
    pin_memory=pin_memory_flag,
    prefetch_factor=None,
    collate_fn=TokenizedCollator(),
    drop_last=True,
)

if use_val_set:
    valset = MultiencoderTokenizedDataset(
        dataset=valset,
        encoders={**unsup_enc, **sup_encs},
        n_embs_per_batch=2,
        batch_size=cfg.val_bs,
        max_length=cfg.max_seq_length,
        seed=cfg.sampling_seed,
    )

    valloader = DataLoader(
        valset,
        batch_size=cfg.val_bs if hasattr(cfg, "val_bs") else cfg.bs,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=pin_memory_flag,
        prefetch_factor=(8 if num_workers > 0 else None),
        collate_fn=TokenizedCollator(),
        drop_last=True,
    )
    valloader = accelerator.prepare(valloader)

[Init] Device: MPS, pin_memory=False


In [13]:
valloader.dataset

<utils.collate.MultiencoderTokenizedDataset at 0x33c58c6d0>

In [9]:
opt = torch.optim.Adam(
    translator.parameters(), lr=cfg.lr, fused=False, betas=(0.5, 0.999)
)

In [10]:
######################################################################################
disc = Discriminator(
    latent_dim=translator.in_adapters[cfg.unsup_emb].in_dim,
    discriminator_dim=cfg.disc_dim,
    depth=cfg.disc_depth,
    weight_init=cfg.weight_init,
)
disc_opt = torch.optim.Adam(
    disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps, betas=(0.5, 0.999)
)

cfg.num_disc_params = sum(x.numel() for x in disc.parameters())
# print(f"Number of discriminator parameters:", cfg.num_disc_params)
######################################################################################
sup_disc = Discriminator(
    latent_dim=translator.in_adapters[cfg.sup_emb].in_dim,
    discriminator_dim=cfg.disc_dim,
    depth=cfg.disc_depth,
)
sup_disc_opt = torch.optim.Adam(
    sup_disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps, betas=(0.5, 0.999)
)

cfg.num_sup_disc_params = sum(x.numel() for x in sup_disc.parameters())
# print(f"Number of supervised discriminator parameters:", cfg.num_sup_disc_params)
# print(sup_disc)
######################################################################################
latent_disc = Discriminator(
    latent_dim=cfg.d_adapter,
    discriminator_dim=cfg.disc_dim,
    depth=cfg.disc_depth,
    weight_init=cfg.weight_init,
)
latent_disc_opt = torch.optim.RMSprop(
    latent_disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps
)
cfg.num_latent_disc_params = sum(x.numel() for x in latent_disc.parameters())
# print(f"Number of latent discriminator parameters:", cfg.num_latent_disc_params)
# print(latent_disc)
latent_disc_opt = torch.optim.Adam(
    latent_disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps, betas=(0.5, 0.999)
)
######################################################################################
similarity_disc = Discriminator(
    latent_dim=cfg.bs,
    discriminator_dim=cfg.disc_dim,
    depth=cfg.disc_depth,
    weight_init=cfg.weight_init,
)
similarity_disc_opt = torch.optim.RMSprop(
    similarity_disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps
)
cfg.num_similarity_disc_params = sum(x.numel() for x in similarity_disc.parameters())
# print(f"Number of similarity discriminator parameters:", cfg.num_similarity_disc_params)
# print(similarity_disc)
similarity_disc_opt = torch.optim.Adam(
    similarity_disc.parameters(), lr=cfg.disc_lr, eps=cfg.eps, betas=(0.5, 0.999)
)
######################################################################################

In [11]:
max_num_epochs = int(np.ceil(cfg.epochs))
steps_per_epoch = len(supset) // cfg.bs
total_steps = steps_per_epoch * cfg.epochs / cfg.gradient_accumulation_steps
warmup_length = cfg.warmup_length if hasattr(cfg, "warmup_length") else 100


def lr_lambda(step):
    if step < warmup_length:
        return min(1, step / warmup_length)
    else:
        if hasattr(cfg, "no_scheduler") and cfg.no_scheduler:
            return 1
        return 1 - (step - warmup_length) / max(1, total_steps - warmup_length)


scheduler = LambdaLR(opt, lr_lambda=lr_lambda)
disc_scheduler = LambdaLR(disc_opt, lr_lambda=lr_lambda)
sup_disc_scheduler = LambdaLR(sup_disc_opt, lr_lambda=lr_lambda)
latent_disc_scheduler = LambdaLR(latent_disc_opt, lr_lambda=lr_lambda)
similarity_disc_scheduler = LambdaLR(similarity_disc_opt, lr_lambda=lr_lambda)

if cfg.finetune_mode:
    assert hasattr(cfg, "load_dir")
    print(f"Loading models from {cfg.load_dir}...")
    translator.load_state_dict(
        torch.load(cfg.load_dir + "model.pt", map_location="cpu"), strict=False
    )
    disc.load_state_dict(torch.load(cfg.load_dir + "disc.pt", map_location="cpu"))

translator, opt, scheduler = accelerator.prepare(translator, opt, scheduler)
disc, disc_opt, disc_scheduler = accelerator.prepare(disc, disc_opt, disc_scheduler)
sup_disc, sup_disc_opt, sup_disc_scheduler = accelerator.prepare(
    sup_disc, sup_disc_opt, sup_disc_scheduler
)
latent_disc, latent_disc_opt, latent_disc_scheduler = accelerator.prepare(
    latent_disc, latent_disc_opt, latent_disc_scheduler
)
similarity_disc, similarity_disc_opt, similarity_disc_scheduler = accelerator.prepare(
    similarity_disc, similarity_disc_opt, similarity_disc_scheduler
)
sup_dataloader, unsup_dataloader = accelerator.prepare(sup_dataloader, unsup_dataloader)

In [12]:
if cfg.gan_style == "vanilla":
    gan_cls = VanillaGAN
elif cfg.gan_style == "least_squares":
    gan_cls = LeastSquaresGAN
elif cfg.gan_style == "relativistic":
    gan_cls = RelativisticGAN
else:
    raise ValueError(f"Unknown GAN style: {cfg.gan_style}")
latent_gan = gan_cls(
    cfg=cfg,
    generator=translator,
    discriminator=latent_disc,
    discriminator_opt=latent_disc_opt,
    discriminator_scheduler=latent_disc_scheduler,
    accelerator=accelerator,
)
similarity_gan = gan_cls(
    cfg=cfg,
    generator=translator,
    discriminator=similarity_disc,
    discriminator_opt=similarity_disc_opt,
    discriminator_scheduler=similarity_disc_scheduler,
    accelerator=accelerator,
)
gan = gan_cls(
    cfg=cfg,
    generator=translator,
    discriminator=disc,
    discriminator_opt=disc_opt,
    discriminator_scheduler=disc_scheduler,
    accelerator=accelerator,
)
sup_gan = gan_cls(
    cfg=cfg,
    generator=translator,
    discriminator=sup_disc,
    discriminator_opt=sup_disc_opt,
    discriminator_scheduler=sup_disc_scheduler,
    accelerator=accelerator,
)

sup_iter = None
if hasattr(cfg, "unsup_points"):
    sup_iter = iter(sup_dataloader)

if hasattr(cfg, "val_size") and hasattr(cfg, "patience") and hasattr(cfg, "min_delta"):
    early_stopper = EarlyStopper(
        patience=cfg.patience, min_delta=cfg.min_delta, increase=False
    )
    early_stopping = True
else:
    early_stopping = False

In [13]:
def print_val_summary(epoch, val_res):
    print(f"\nüìä ===== Validation Summary (Epoch {epoch + 1}) =====")

    # ÂÜçÊßãÊàê (reconstruction) Á≥ª
    rec_gte = val_res.get("val/rec_gte_cos", None)
    rec_stella = val_res.get("val/rec_stella_cos", None)
    if rec_gte is not None or rec_stella is not None:
        print(f"  üîπ Reconstruction Cosine Similarity:")
        if rec_gte is not None:
            print(f"    - GTE self-cosine:     {rec_gte:.4f}")
        if rec_stella is not None:
            print(f"    - STELLA self-cosine:  {rec_stella:.4f}")

    # ÁøªË®≥ (translation) Á≥ª
    trans_cos = val_res.get("val/gte_stella_cos", None)
    trans_vsp = val_res.get("val/gte_stella_vsp", None)
    if trans_cos is not None or trans_vsp is not None:
        print(f"  üîπ GTE ‚Üí STELLA translation:")
        if trans_cos is not None:
            print(f"    - Cosine: {trans_cos:.4f}")
        if trans_vsp is not None:
            print(f"    - VSP:    {trans_vsp:.4f}")

    trans_cos2 = val_res.get("val/stella_gte_cos", None)
    trans_vsp2 = val_res.get("val/stella_gte_vsp", None)
    if trans_cos2 is not None or trans_vsp2 is not None:
        print(f"  üîπ STELLA ‚Üí GTE translation:")
        if trans_cos2 is not None:
            print(f"    - Cosine: {trans_cos2:.4f}")
        if trans_vsp2 is not None:
            print(f"    - VSP:    {trans_vsp2:.4f}")

    # Accuracy Á≥ª
    top1 = val_res.get("val/gte_stella_top_1_acc (avg. 4 batches)", None)
    top16 = val_res.get("val/gte_stella_top_16_acc (avg. 4 batches)", None)
    rank = val_res.get("val/gte_stella_rank (avg. 4 batches)", None)
    if top1 is not None or top16 is not None or rank is not None:
        print(f"  üîπ Retrieval metrics (GTE‚ÜíSTELLA):")
        if top1 is not None:
            print(f"    - Top-1 Acc:    {top1:.3f}")
        if top16 is not None:
            print(f"    - Top-16 Acc:   {top16:.3f}")
        if rank is not None:
            print(f"    - Avg. Rank:    {rank:.1f}")

    # Rank variance / SE
    rank_se = val_res.get("val/gte_stella_rank_se (avg. 4 batches)", None)
    if rank_se is not None:
        print(f"    - Rank StdErr:  {rank_se:.3f}")

    print("===========================================")

In [14]:
for epoch in range(max_num_epochs):
    print(f"\n===== Epoch {epoch + 1}/{max_num_epochs} =====")

    # --- Validation ---
    if use_val_set:
        print("[Eval] Running validation...")
        with torch.no_grad(), accelerator.autocast():
            translator.eval()
            val_res = {}
            recons, trans, heatmap_dict, _, _, _ = eval_loop_(
                cfg,
                translator,
                {**sup_encs, **unsup_enc},
                valloader,
                device=accelerator.device,
            )

            print(f"[Eval] Validation finished. Processing results...")
            for flag, res in recons.items():
                for k, v in res.items():
                    if k == "cos":
                        val_res[f"val/rec_{flag}_{k}"] = v

            for target_flag, d in trans.items():
                for flag, res in d.items():
                    for k, v in res.items():
                        if flag == cfg.unsup_emb and target_flag == cfg.unsup_emb:
                            continue
                        val_res[f"val/{flag}_{target_flag}_{k}"] = v

            if len(heatmap_dict) > 0:
                print(f"[Eval] Found {len(heatmap_dict)} heatmap metrics.")
                for k, v in heatmap_dict.items():
                    if "heatmap" in k and "top" not in k:
                        val_res[f"val/{k}"] = v
                    else:
                        val_res[f"val/{k} (avg. {cfg.top_k_batches} batches)"] = v

            translator.train()
            print("[Eval] Validation metrics collected.")

        # --- Early Stopping ---
        if epoch >= getattr(cfg, "min_epochs", 0) and early_stopping:
            score_values = [v for k, v in val_res.items() if "top_rank" in k]
            score = np.mean(score_values) if len(score_values) > 0 else 0.0
            print(f"[Eval] Validation score (mean top_rank): {score:.4f}")

            if early_stopper.early_stop(score):
                print("üõë Early stopping triggered!")
                break
            if early_stopper.counter == 0 and score < early_stopper.opt_val:
                print(
                    f"[Eval] Saving model ‚Äî new best score ({score:.4f} < {early_stopper.opt_val:.4f})"
                )
                save_everything(
                    cfg,
                    translator,
                    opt,
                    [gan, sup_gan, latent_gan, similarity_gan],
                    save_dir,
                )

    print_val_summary(epoch, val_res)

    # --- Training ---
    max_num_batches = None
    print(f"[Train] Starting training loop for epoch {epoch + 1}")
    if epoch + 1 >= max_num_epochs:
        max_num_batches = max(1, int((cfg.epochs - epoch) * len(supset) // cfg.bs))
        print(
            f"[Train] Final epoch detected. Setting max_num_batches = {max_num_batches}"
        )

    sup_iter = training_loop_(
        save_dir=save_dir,
        accelerator=accelerator,
        translator=translator,
        gan=gan,
        sup_gan=sup_gan,
        latent_gan=latent_gan,
        similarity_gan=similarity_gan,
        sup_dataloader=sup_dataloader,
        sup_iter=sup_iter,
        unsup_dataloader=unsup_dataloader,
        sup_encs=sup_encs,
        unsup_enc=unsup_enc,
        cfg=cfg,
        opt=opt,
        scheduler=scheduler,
        logger=logger,
        max_num_batches=max_num_batches,
    )

    print(f"[Train] Epoch {epoch + 1} completed.")
    print("-----------------------------------------")


===== Epoch 1/10 =====
[Eval] Running validation...
[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 1) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 1


Training: 0it [00:07, ?it/s]

[Train] Epoch 1 completed.
-----------------------------------------

===== Epoch 2/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 2) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 2


Training: 0it [00:07, ?it/s]

[Train] Epoch 2 completed.
-----------------------------------------

===== Epoch 3/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 3) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 3


Training: 0it [00:07, ?it/s]

[Train] Epoch 3 completed.
-----------------------------------------

===== Epoch 4/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 4) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 4


Training: 0it [00:07, ?it/s]

[Train] Epoch 4 completed.
-----------------------------------------

===== Epoch 5/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 5) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 5


Training: 0it [00:07, ?it/s]

[Train] Epoch 5 completed.
-----------------------------------------

===== Epoch 6/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 6) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 6


Training: 0it [00:07, ?it/s]

[Train] Epoch 6 completed.
-----------------------------------------

===== Epoch 7/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 7) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 7


Training: 0it [00:07, ?it/s]

[Train] Epoch 7 completed.
-----------------------------------------

===== Epoch 8/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 8) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 8


Training: 0it [00:09, ?it/s]

[Train] Epoch 8 completed.
-----------------------------------------

===== Epoch 9/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 9) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 9


Training: 0it [00:07, ?it/s]

[Train] Epoch 9 completed.
-----------------------------------------

===== Epoch 10/10 =====
[Eval] Running validation...





[Eval] Validation finished. Processing results...
[Eval] Found 18 heatmap metrics.
[Eval] Validation metrics collected.

üìä ===== Validation Summary (Epoch 10) =====
  üîπ Reconstruction Cosine Similarity:
    - GTE self-cosine:     0.0683
    - STELLA self-cosine:  0.0926
  üîπ GTE ‚Üí STELLA translation:
    - Cosine: 0.0614
    - VSP:    0.3430
  üîπ STELLA ‚Üí GTE translation:
    - Cosine: 0.0270
    - VSP:    0.1270
  üîπ Retrieval metrics (GTE‚ÜíSTELLA):
    - Top-1 Acc:    0.019
    - Top-16 Acc:   0.159
    - Avg. Rank:    203.0
    - Rank StdErr:  3.406
[Train] Starting training loop for epoch 10
[Train] Final epoch detected. Setting max_num_batches = 1


Training: 0it [00:07, ?it/s]

[Train] Epoch 10 completed.
-----------------------------------------



