# Vec2Vec: Colab Reproduction with WGAN Adversarial Loss (~50k Examples)

This notebook reproduces and extends the main vec2vec experiment from the paper:
**"Harnessing the Universal Geometry of Embeddings"** (Jha et al., 2025)

## Experiment Details
- **Source Model**: Stella (`stella` - unsupervised embedding model)
- **Target Model**: GTE (`gte` - supervised embedding model)
- **Dataset**: Natural Questions (NQ)
- **Training Size**: ~50k examples (25k per encoder)
- **Architecture**: ResNet MLP with adapters + adversarial training

## Adversarial Loss Comparison
We compare two adversarial training schemes:

1. **Baseline (GAN)**: Standard GAN with least-squares loss (original vec2vec approach)
2. **WGAN-GP**: Wasserstein GAN with gradient penalty
   - Treats discriminators as critics with unbounded outputs
   - Uses Wasserstein distance estimation
   - Enforces Lipschitz constraint via gradient penalty

The goal is to explore whether Wasserstein-style matching of embedding/latent distributions improves translation quality.

In [1]:
# Check GPU and environment
!nvidia-smi

import sys
print(f"\nPython version: {sys.version}")

# Check CUDA availability
try:
    import torch
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
except ImportError:
    print("PyTorch will be installed in the next step")

zsh:1: command not found: nvidia-smi

Python version: 3.11.14 (main, Oct 21 2025, 18:27:30) [Clang 20.1.8 ]
PyTorch version: 2.9.1
CUDA available: False


In [None]:
# Install dependencies
# Core packages for vec2vec
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers>=4.29.0 sentence-transformers>=2.2.0
!pip install -q datasets>=2.12.0 huggingface_hub>=0.15.0
!pip install -q accelerate>=0.20.0
!pip install -q wandb
!pip install -q scikit-learn scipy matplotlib seaborn
!pip install -q toml pandas
!pip install -q nltk
!pip install -q sentence-transformers
!pip install -q accelerate

[31mERROR: Could not find a version that satisfies the requirement torch (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for torch[0m[31m
[0mzsh:1: 4.29.0 not found
zsh:1: 2.12.0 not found
zsh:1: 0.20.0 not found


  from .autonotebook import tqdm as notebook_tqdm



Installed versions:
  PyTorch: 2.9.1
  Transformers: 4.57.1
  Sentence-Transformers: 5.1.2
  Accelerate: 1.11.0
  CUDA available: False


In [2]:
# Download NLTK data
import nltk
nltk.download('punkt', quiet=True)

# Verify installation
import torch
import transformers
import sentence_transformers
import accelerate

print(f"\nInstalled versions:")
print(f"  PyTorch: {torch.__version__}")
print(f"  Transformers: {transformers.__version__}")
print(f"  Sentence-Transformers: {sentence_transformers.__version__}")
print(f"  Accelerate: {accelerate.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm



Installed versions:
  PyTorch: 2.9.1
  Transformers: 4.57.1
  Sentence-Transformers: 5.1.2
  Accelerate: 1.11.0
  CUDA available: False


## Data & Embedding Preparation

Vec2vec uses the **Natural Questions (NQ)** dataset from the BeIR benchmark. The embeddings are generated on-the-fly during training using the source (Stella) and target (GTE) embedding models.

The data loading pipeline:
1. Loads NQ corpus from HuggingFace datasets
2. Splits into train/validation sets
3. Creates tokenized batches for both encoders
4. Generates embeddings during forward pass

We'll use **25,000 samples per encoder** (50k total), which is sufficient to demonstrate the approach while keeping training time reasonable.

In [None]:
# !pip install vec2text
# !pip install -U typing_extensions
# !pip install -q torch torchvision torchaudio
# !pip install -q transformers>=4.29.0 sentence-transformers>=2.2.0
# !pip install -q datasets
# !pip install beir

Collecting beir
  Downloading beir-2.2.0-py3-none-any.whl.metadata (28 kB)
Collecting pytrec-eval-terrier (from beir)
  Downloading pytrec_eval_terrier-0.5.10-cp311-cp311-macosx_10_9_universal2.whl.metadata (1.1 kB)
Downloading beir-2.2.0-py3-none-any.whl (77 kB)
Downloading pytrec_eval_terrier-0.5.10-cp311-cp311-macosx_10_9_universal2.whl (136 kB)
Installing collected packages: pytrec-eval-terrier, beir
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [beir]
[1A[2KSuccessfully installed beir-2.2.0 pytrec-eval-terrier-0.5.10


In [10]:
# Test data loading to ensure everything is properly set up
import sys
sys.path.insert(0, '../vec2vec')

from utils.streaming_utils import load_streaming_embeddings

# Load the NQ dataset
print("Loading NQ dataset...")
dset = load_streaming_embeddings("nq")
print(f"Dataset loaded: {len(dset)} examples")
print(f"\nSample entry keys: {list(dset[0].keys())}")

# Show a sample
sample = dset[0]
text_key = 'text' if 'text' in sample else list(sample.keys())[0]
print(f"\nSample text (truncated): {sample[text_key][:200]}...")

# Confirm we have enough data
TRAIN_SIZE = 25000  # per encoder
VAL_SIZE = 4096
REQUIRED = TRAIN_SIZE * 2 + VAL_SIZE

if len(dset) >= REQUIRED:
    print(f"\n Dataset has {len(dset)} examples (need {REQUIRED} for this run)")
else:
    print(f"\n Dataset has {len(dset)} examples, adjusting train size...")
    TRAIN_SIZE = (len(dset) - VAL_SIZE) // 2
    print(f"  New train size: {TRAIN_SIZE} per encoder")

Loading NQ dataset...


Generating train split: 100%|██████████| 5332023/5332023 [00:03<00:00, 1714157.84 examples/s]
Generating dev split: 100%|██████████| 849508/849508 [00:00<00:00, 1672897.71 examples/s]


Dataset loaded: 5332023 examples

Sample entry keys: ['text']

Sample text (truncated): to a short list of finalists. Ties can occur if the panel decides both entries show equal merit, however they are encouraged to choose a single winner. The judges are selected from a public applicatio...

 Dataset has 5332023 examples (need 54096 for this run)


## Adversarial Loss Implementation

### Original vec2vec GAN Loss
The original implementation uses **Least Squares GAN** (LSGAN) with:
- Discriminator loss: `0.5 * (D(real)^2 + (D(fake) - 1)^2)`
- Generator loss: `0.5 * D(fake)^2`

This is applied at:
- **Embedding level**: D1 (unsup space), D2 (sup space)
- **Latent level**: D_latent (shared latent space)

### WGAN-GP (Wasserstein GAN with Gradient Penalty)
We implement WGAN-GP with:
- **Critic loss**: `E[D(fake)] - E[D(real)] + lambda_gp * gradient_penalty`
- **Generator loss**: `-E[D(fake)]`
- **Gradient penalty**: `E[(||grad_D(x_interp)||_2 - 1)^2]`

Key differences:
1. Discriminators are treated as **critics** (no sigmoid, unbounded outputs)
2. Uses Wasserstein distance instead of JS divergence
3. Gradient penalty enforces 1-Lipschitz constraint
4. Generally more stable training dynamics

We add a new `gan_style="wgan"` option to toggle this behavior.

In [11]:
# Add WGAN-GP implementation to the codebase
import os

# Read the current gan.py
gan_path = '../vec2vec/utils/gan.py'
with open(gan_path, 'r') as f:
    gan_code = f.read()

# WGAN-GP implementation to add
wgan_code = '''

class WassersteinGAN(VanillaGAN):
    """Wasserstein GAN with Gradient Penalty (WGAN-GP).
    
    Uses Wasserstein distance estimation with gradient penalty
    to enforce Lipschitz constraint on the critic.
    """
    
    def compute_wgan_gradient_penalty(self, real_data: torch.Tensor, fake_data: torch.Tensor) -> torch.Tensor:
        """Compute gradient penalty for WGAN-GP.
        
        Interpolates between real and fake samples, computes critic output,
        and penalizes deviation of gradient norm from 1.
        """
        batch_size = real_data.size(0)
        device = real_data.device
        
        # Random interpolation coefficient
        epsilon = torch.rand(batch_size, 1, device=device)
        
        # Interpolate between real and fake
        interpolated = epsilon * real_data + (1 - epsilon) * fake_data
        interpolated = interpolated.requires_grad_(True)
        
        # Get critic output for interpolated samples
        d_interpolated = self.discriminator(interpolated)
        
        # Compute gradients
        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(d_interpolated),
            create_graph=True,
            retain_graph=True,
        )[0]
        
        # Compute gradient penalty: (||grad||_2 - 1)^2
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()
        
        return gradient_penalty
    
    def _step_discriminator(self, real_data: torch.Tensor, fake_data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, float, float]:
        """WGAN critic update step.
        
        Critic loss = E[D(fake)] - E[D(real)] + lambda_gp * gradient_penalty
        """
        real_data = real_data.detach()
        fake_data = fake_data.detach()
        
        # Critic outputs (no sigmoid - raw scores)
        d_real = self.discriminator(real_data)
        d_fake = self.discriminator(fake_data)
        
        # Wasserstein distance estimate (negative because we want to maximize)
        # Critic wants: D(real) high, D(fake) low
        # So critic loss = E[D(fake)] - E[D(real)]
        wasserstein_dist = d_real.mean() - d_fake.mean()
        critic_loss = -wasserstein_dist  # Minimize negative = maximize distance
        
        # Gradient penalty
        gp_lambda = getattr(self.cfg, 'gp_lambda', 10.0)  # Default lambda=10
        gradient_penalty = self.compute_wgan_gradient_penalty(real_data, fake_data)
        
        # Total critic loss
        total_critic_loss = critic_loss + gp_lambda * gradient_penalty
        
        # "Accuracy" metrics (for compatibility with logging)
        # In WGAN, we use the sign of the output as a proxy
        disc_acc_real = (d_real > 0).float().mean().item()
        disc_acc_fake = (d_fake < 0).float().mean().item()
        
        # Backward pass
        self.generator.train()
        self.discriminator_opt.zero_grad()
        self.accelerator.backward(total_critic_loss * self.cfg.loss_coefficient_disc)
        self.accelerator.clip_grad_norm_(
            self.discriminator.parameters(),
            self.cfg.max_grad_norm
        )
        self.discriminator_opt.step()
        self.discriminator_scheduler.step()
        
        return gradient_penalty.detach(), critic_loss.detach(), disc_acc_real, disc_acc_fake
    
    def _step_generator(self, real_data: torch.Tensor, fake_data: torch.Tensor) -> tuple[torch.Tensor, float]:
        """WGAN generator update step.
        
        Generator loss = -E[D(fake)]
        Generator wants critic to think fake samples are real (high scores).
        """
        # Get critic score for fake samples
        d_fake = self.discriminator(fake_data)
        
        # Generator wants to maximize D(fake), so minimize -D(fake)
        gen_loss = -d_fake.mean()
        
        # "Accuracy" metric (proxy)
        gen_acc = (d_fake > 0).float().mean().item()
        
        return gen_loss, gen_acc
'''

# Check if WGAN is already added
if 'class WassersteinGAN' not in gan_code:
    # Add WGAN-GP implementation
    with open(gan_path, 'a') as f:
        f.write(wgan_code)
    print("Added WassersteinGAN class to utils/gan.py")
else:
    print("WassersteinGAN already exists in utils/gan.py")

# Now update train.py to support the wgan style
train_path = '../vec2vec/train.py'
with open(train_path, 'r') as f:
    train_code = f.read()

# Check if we need to add WGAN import and handling
if 'WassersteinGAN' not in train_code:
    # Update the imports
    old_import = 'from utils.gan import LeastSquaresGAN, RelativisticGAN, VanillaGAN'
    new_import = 'from utils.gan import LeastSquaresGAN, RelativisticGAN, VanillaGAN, WassersteinGAN'
    train_code = train_code.replace(old_import, new_import)
    
    # Update the GAN style selection
    old_selection = '''    if cfg.gan_style == "vanilla":
        gan_cls = VanillaGAN
    elif cfg.gan_style == "least_squares":
        gan_cls = LeastSquaresGAN
    elif cfg.gan_style == "relativistic":
        gan_cls = RelativisticGAN
    else:
        raise ValueError(f"Unknown GAN style: {cfg.gan_style}")'''
    
    new_selection = '''    if cfg.gan_style == "vanilla":
        gan_cls = VanillaGAN
    elif cfg.gan_style == "least_squares":
        gan_cls = LeastSquaresGAN
    elif cfg.gan_style == "relativistic":
        gan_cls = RelativisticGAN
    elif cfg.gan_style == "wgan":
        gan_cls = WassersteinGAN
    else:
        raise ValueError(f"Unknown GAN style: {cfg.gan_style}")'''
    
    train_code = train_code.replace(old_selection, new_selection)
    
    with open(train_path, 'w') as f:
        f.write(train_code)
    print("Updated train.py to support gan_style='wgan'")
else:
    print("train.py already supports WassersteinGAN")

print("\nCodebase modifications complete!")

WassersteinGAN already exists in utils/gan.py
train.py already supports WassersteinGAN

Codebase modifications complete!


## Training Configuration

We'll use the `unsupervised.toml` config as the base and run two experiments:

### Run 1: Baseline GAN (Least Squares)
- `gan_style = "least_squares"` (original vec2vec)
- Standard discriminator with MSE-based loss

### Run 2: WGAN-GP
- `gan_style = "wgan"`
- Wasserstein distance with gradient penalty
- `gp_lambda = 10` (gradient penalty coefficient)

### Shared Hyperparameters for Colab:
- **num_points**: 25,000 (samples per encoder, 50k total)
- **epochs**: 8 (reduced for faster iteration)
- **batch_size**: 128 (reduced for GPU memory)
- **learning_rate**: 2e-5 (default)
- **mixed_precision**: fp16 (for memory efficiency)

All other components (architecture, reconstruction loss, VSP loss, etc.) remain identical.

In [12]:
# Display and define configurations for both runs
import toml
import os

# Load base config
config_path = '../vec2vec/configs/unsupervised.toml'
with open(config_path, 'r') as f:
    config = toml.load(f)

print("=" * 70)
print("BASE CONFIGURATION (unsupervised.toml)")
print("=" * 70)

# Key settings
print(f"\n[General]")
print(f"  Dataset: {config['general']['dataset']}")
print(f"  Unsup Model: {config['general']['unsup_emb']}")
print(f"  Sup Model: {config['general']['sup_emb']}")

print(f"\n[Translator]")
print(f"  Style: {config['translator']['style']}")
print(f"  Adapter Dim: {config['translator']['d_adapter']}")

print(f"\n[Discriminator]")
print(f"  GAN Style: {config['discriminator']['gan_style']}")
print(f"  Depth: {config['discriminator']['disc_depth']}")

# Define shared Colab settings
SHARED_CONFIG = {
    'num_points': 25000,       # 25k per encoder (50k total)
    'epochs': 8,               # Reduced for Colab
    'bs': 128,                 # Reduced batch size for T4
    'val_size': 4096,          # Validation set size
    'use_wandb': False,        # Disable W&B for simplicity
    'min_epochs': 4,           # Lower minimum epochs
    'patience': 3,             # Earlier stopping
}

print("\n" + "=" * 70)
print("SHARED COLAB SETTINGS")
print("=" * 70)
for k, v in SHARED_CONFIG.items():
    print(f"  --{k} {v}")

print("\n" + "=" * 70)
print("EXPERIMENT CONFIGURATIONS")
print("=" * 70)

print("\n[Run 1: Baseline GAN (Least Squares)]")
print("  --gan_style least_squares")
print("  Output: outputs/vec2vec_colab_50k_gan/")

print("\n[Run 2: WGAN-GP]")
print("  --gan_style wgan")
print("  --gp_lambda 10")
print("  Output: outputs/vec2vec_colab_50k_wgan/")

print("\n" + "=" * 70)

BASE CONFIGURATION (unsupervised.toml)

[General]
  Dataset: nq
  Unsup Model: stella
  Sup Model: gte

[Translator]
  Style: res_mlp
  Adapter Dim: 1024

[Discriminator]
  GAN Style: least_squares
  Depth: 5

SHARED COLAB SETTINGS
  --num_points 25000
  --epochs 8
  --bs 128
  --val_size 4096
  --use_wandb False
  --min_epochs 4
  --patience 3

EXPERIMENT CONFIGURATIONS

[Run 1: Baseline GAN (Least Squares)]
  --gan_style least_squares
  Output: outputs/vec2vec_colab_50k_gan/

[Run 2: WGAN-GP]
  --gan_style wgan
  --gp_lambda 10
  Output: outputs/vec2vec_colab_50k_wgan/



In [13]:
# Run 1: Train with baseline GAN (Least Squares)
import os
os.chdir('../vec2vec')

# Create output directory
OUTPUT_DIR_GAN = '../vec2vec/outputs/vec2vec_colab_50k_gan'
os.makedirs(OUTPUT_DIR_GAN, exist_ok=True)

# Build training command
CMD_GAN = f"""
python train.py unsupervised \\
    --num_points 25000 \\
    --epochs 8 \\
    --bs 128 \\
    --val_size 4096 \\
    --use_wandb false \\
    --min_epochs 4 \\
    --patience 3 \\
    --gan_style least_squares \\
    --save_dir '{OUTPUT_DIR_GAN}/{{}}' \\
    --force_wandb_name true \\
    --wandb_name gan_baseline
"""

print("=" * 70)
print("RUN 1: BASELINE GAN (Least Squares)")
print("=" * 70)
print("\nTraining command:")
print(CMD_GAN)
print("\nStarting training... (this may take 30-60 minutes on a T4)")
print("=" * 70 + "\n")

# Run training
!{CMD_GAN}

RUN 1: BASELINE GAN (Least Squares)

Training command:

python train.py unsupervised \
    --num_points 25000 \
    --epochs 8 \
    --bs 128 \
    --val_size 4096 \
    --use_wandb false \
    --min_epochs 4 \
    --patience 3 \
    --gan_style least_squares \
    --save_dir '../vec2vec/outputs/vec2vec_colab_50k_gan/{}' \
    --force_wandb_name true \
    --wandb_name gan_baseline


Starting training... (this may take 30-60 minutes on a T4)

^C
object address  : 0x3420bf220
object refcount : 2
object type     : 0x102d509b8
object type name: KeyboardInterrupt
object repr     : KeyboardInterrupt()
lost sys.stderr


In [None]:
# Run 2: Train with WGAN-GP
import os
os.chdir('../vec2vec')

# Create output directory
OUTPUT_DIR_WGAN = '../vec2vec/outputs/vec2vec_colab_50k_wgan'
os.makedirs(OUTPUT_DIR_WGAN, exist_ok=True)

# Build training command
CMD_WGAN = f"""
python train.py unsupervised \\
    --num_points 25000 \\
    --epochs 8 \\
    --bs 128 \\
    --val_size 4096 \\
    --use_wandb false \\
    --min_epochs 4 \\
    --patience 3 \\
    --gan_style wgan \\
    --gp_lambda 10 \\
    --save_dir '{OUTPUT_DIR_WGAN}/{{}}' \\
    --force_wandb_name true \\
    --wandb_name wgan_gp
"""

print("=" * 70)
print("RUN 2: WGAN-GP")
print("=" * 70)
print("\nTraining command:")
print(CMD_WGAN)
print("\nStarting training... (this may take 30-60 minutes on a T4)")
print("=" * 70 + "\n")

# Run training
!{CMD_WGAN}

## Evaluation

We evaluate both trained translators on the same held-out test data using the following metrics:

1. **Mean Cosine Similarity**: Average cosine similarity between translated embeddings and true target embeddings (higher is better)
2. **Top-1 Accuracy**: Percentage of samples where the translated embedding's nearest neighbor is the correct target (higher is better)
3. **Mean Rank**: Average rank of the correct target among all candidates (lower is better)
4. **VSP (Vector Space Preservation)**: How well pairwise similarities are preserved after translation (lower is better)

The comparison will show whether WGAN-GP improves translation quality over the standard least-squares GAN.

In [None]:
# Evaluate both models and compare results
import os
import sys
import glob
import json
import torch
import numpy as np
import pandas as pd
from types import SimpleNamespace

sys.path.insert(0, '/content/vec2vec')
os.chdir('/content/vec2vec')

import toml
import accelerate
from torch.utils.data import DataLoader

from utils.collate import MultiencoderTokenizedDataset, TokenizedCollator
from utils.eval_utils import eval_loop_, create_heatmap
from utils.model_utils import get_sentence_embedding_dimension, load_encoder
from utils.utils import load_n_translator, get_num_proc
from utils.streaming_utils import load_streaming_embeddings, process_batch

def find_checkpoint(base_dir):
    """Find the model checkpoint directory."""
    checkpoint_dirs = glob.glob(f"{base_dir}/**/model.pt", recursive=True)
    if checkpoint_dirs:
        return os.path.dirname(sorted(checkpoint_dirs, key=os.path.getmtime)[-1])
    return None

def evaluate_model(checkpoint_path, test_size=4096):
    """Evaluate a trained model and return metrics."""
    if not checkpoint_path:
        return None
    
    # Load config
    config_file = os.path.join(checkpoint_path, 'config.toml')
    if not os.path.exists(config_file):
        print(f"Config not found: {config_file}")
        return None
    
    cfg = SimpleNamespace(**toml.load(config_file))
    
    # Setup accelerator
    accelerator = accelerate.Accelerator(
        mixed_precision=cfg.mixed_precision if hasattr(cfg, 'mixed_precision') else None
    )
    accelerator.dataloader_config.dispatch_batches = False
    
    # Load encoders
    sup_encs = {cfg.sup_emb: load_encoder(cfg.sup_emb, mixed_precision=getattr(cfg, 'mixed_precision', None))}
    unsup_enc = {cfg.unsup_emb: load_encoder(cfg.unsup_emb, mixed_precision=getattr(cfg, 'mixed_precision', None))}
    
    # Load translator
    encoder_dims = {cfg.sup_emb: get_sentence_embedding_dimension(sup_encs[cfg.sup_emb])}
    translator = load_n_translator(cfg, encoder_dims)
    unsup_dim = {cfg.unsup_emb: get_sentence_embedding_dimension(unsup_enc[cfg.unsup_emb])}
    translator.add_encoders(unsup_dim, overwrite_embs=[cfg.unsup_emb])
    
    # Load weights
    translator.load_state_dict(torch.load(os.path.join(checkpoint_path, 'model.pt'), map_location='cpu'), strict=False)
    translator = accelerator.prepare(translator)
    translator.eval()
    
    # Load test data
    dset = load_streaming_embeddings(cfg.dataset)
    dset_dict = dset.train_test_split(test_size=test_size, seed=cfg.val_dataset_seed)
    testset = dset_dict["test"]
    
    num_workers = min(get_num_proc(), 4)
    evalset = MultiencoderTokenizedDataset(
        dataset=testset,
        encoders={**unsup_enc, **sup_encs},
        n_embs_per_batch=2,
        batch_size=cfg.val_bs if hasattr(cfg, 'val_bs') else 256,
        max_length=cfg.max_seq_length,
        seed=cfg.sampling_seed,
    )
    evalloader = DataLoader(
        evalset,
        batch_size=cfg.val_bs if hasattr(cfg, 'val_bs') else 256,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=True,
        collate_fn=TokenizedCollator(),
        drop_last=True,
    )
    evalloader = accelerator.prepare(evalloader)
    
    # Run evaluation
    with torch.no_grad():
        recons, trans, heatmap_dict, _, _, _ = eval_loop_(
            cfg, translator, {**sup_encs, **unsup_enc}, evalloader, device=accelerator.device
        )
    
    # Extract metrics
    metrics = {
        'gan_style': cfg.gan_style,
        'train_size': cfg.num_points,
        'test_size': test_size,
    }
    
    # Get translation metrics (unsup -> sup)
    trans_key = f"{cfg.unsup_emb}_{cfg.sup_emb}"
    if cfg.sup_emb in trans and cfg.unsup_emb in trans[cfg.sup_emb]:
        t = trans[cfg.sup_emb][cfg.unsup_emb]
        metrics['cosine'] = t.get('cos', 0)
        metrics['vsp'] = t.get('vsp', 0)
    
    # Get top-1 and rank from heatmap_dict
    for k, v in heatmap_dict.items():
        if f"{cfg.unsup_emb}_{cfg.sup_emb}_top_1_acc" in k:
            metrics['top1'] = v
        elif f"{cfg.unsup_emb}_{cfg.sup_emb}_rank" in k and 'var' not in k:
            metrics['rank'] = v
    
    return metrics

# Find checkpoints
gan_checkpoint = find_checkpoint('/content/vec2vec/outputs/vec2vec_colab_50k_gan')
wgan_checkpoint = find_checkpoint('/content/vec2vec/outputs/vec2vec_colab_50k_wgan')

print("=" * 70)
print("EVALUATING TRAINED MODELS")
print("=" * 70)

results = []

# Evaluate GAN baseline
if gan_checkpoint:
    print(f"\nEvaluating GAN baseline: {gan_checkpoint}")
    gan_metrics = evaluate_model(gan_checkpoint)
    if gan_metrics:
        gan_metrics['adv_type'] = 'GAN (LS)'
        results.append(gan_metrics)
        print(f"  Done!")
else:
    print("\nGAN baseline checkpoint not found")

# Evaluate WGAN
if wgan_checkpoint:
    print(f"\nEvaluating WGAN-GP: {wgan_checkpoint}")
    wgan_metrics = evaluate_model(wgan_checkpoint)
    if wgan_metrics:
        wgan_metrics['adv_type'] = 'WGAN-GP'
        results.append(wgan_metrics)
        print(f"  Done!")
else:
    print("\nWGAN-GP checkpoint not found")

print("\n" + "=" * 70)

In [None]:
# Create comparison table
import pandas as pd

if results:
    # Create DataFrame
    df = pd.DataFrame(results)
    
    # Reorder columns
    cols = ['adv_type', 'train_size', 'test_size', 'cosine', 'top1', 'rank', 'vsp']
    df = df[[c for c in cols if c in df.columns]]
    
    # Rename for display
    df.columns = ['Adv Type', 'Train Size', 'Test Size', 'Cosine', 'Top-1', 'Rank', 'VSP']
    
    print("\n" + "=" * 70)
    print("COMPARISON RESULTS")
    print("=" * 70)
    print("\nModel pair: stella -> gte (unsupervised -> supervised)")
    print("Dataset: Natural Questions (NQ)")
    print("\n")
    
    # Format the table
    pd.set_option('display.float_format', lambda x: '%.4f' % x if abs(x) < 100 else '%.1f' % x)
    print(df.to_string(index=False))
    
    print("\n")
    print("Metrics explanation:")
    print("  Cosine: Mean cosine similarity (higher is better)")
    print("  Top-1:  Nearest neighbor accuracy (higher is better)")
    print("  Rank:   Mean rank of correct target (lower is better)")
    print("  VSP:    Vector space preservation error (lower is better)")
    print("\n" + "=" * 70)
    
    # Save results
    results_path = '/content/vec2vec/outputs/comparison_results.csv'
    df.to_csv(results_path, index=False)
    print(f"\nResults saved to: {results_path}")
else:
    print("\nNo results to compare. Please run both training cells first.")

In [None]:
# Create comparison plots
import matplotlib.pyplot as plt
import numpy as np

if results and len(results) >= 2:
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Extract data
    labels = [r['adv_type'] for r in results]
    colors = ['#2ecc71', '#3498db']  # Green for GAN, Blue for WGAN
    
    # Plot 1: Cosine Similarity
    if 'cosine' in results[0]:
        values = [r.get('cosine', 0) for r in results]
        bars = axes[0].bar(labels, values, color=colors)
        axes[0].set_ylabel('Cosine Similarity')
        axes[0].set_title('Cosine Similarity (higher is better)')
        axes[0].set_ylim([min(values) * 0.95, max(values) * 1.02])
        for bar, val in zip(bars, values):
            axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                        f'{val:.4f}', ha='center', va='bottom', fontsize=10)
    
    # Plot 2: Top-1 Accuracy
    if 'top1' in results[0]:
        values = [r.get('top1', 0) for r in results]
        bars = axes[1].bar(labels, values, color=colors)
        axes[1].set_ylabel('Top-1 Accuracy')
        axes[1].set_title('Top-1 Accuracy (higher is better)')
        axes[1].set_ylim([0, max(values) * 1.2])
        for bar, val in zip(bars, values):
            axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{val:.4f}', ha='center', va='bottom', fontsize=10)
    
    # Plot 3: Mean Rank
    if 'rank' in results[0]:
        values = [r.get('rank', 0) for r in results]
        bars = axes[2].bar(labels, values, color=colors)
        axes[2].set_ylabel('Mean Rank')
        axes[2].set_title('Mean Rank (lower is better)')
        axes[2].set_ylim([0, max(values) * 1.3])
        for bar, val in zip(bars, values):
            axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                        f'{val:.1f}', ha='center', va='bottom', fontsize=10)
    
    # Plot 4: VSP Error
    if 'vsp' in results[0]:
        values = [r.get('vsp', 0) for r in results]
        bars = axes[3].bar(labels, values, color=colors)
        axes[3].set_ylabel('VSP Error')
        axes[3].set_title('VSP Error (lower is better)')
        axes[3].set_ylim([0, max(values) * 1.3])
        for bar, val in zip(bars, values):
            axes[3].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                        f'{val:.4f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.savefig('/content/vec2vec/outputs/comparison_plots.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nPlots saved to: /content/vec2vec/outputs/comparison_plots.png")
else:
    print("Need results from both runs to create comparison plots.")

## Summary & Discussion

### What This Notebook Reproduces

This notebook trains and compares two vec2vec translators:

1. **Baseline (GAN)**: Original vec2vec with Least Squares GAN loss
2. **WGAN-GP**: Vec2vec with Wasserstein GAN + Gradient Penalty

Both use:
- **50k training examples** (25k per encoder)
- **Stella -> GTE** model pair (unsupervised -> supervised)
- **ResNet MLP architecture** with adapters
- Identical reconstruction, VSP, and cross-chain losses

Only the adversarial component differs:
- GAN: `L_D = 0.5 * (D(real)^2 + (D(fake)-1)^2)`, `L_G = 0.5 * D(fake)^2`
- WGAN: `L_C = E[D(fake)] - E[D(real)] + 10 * GP`, `L_G = -E[D(fake)]`

---

### Interpreting Results

Key questions to consider:

1. **Did WGAN-GP improve cosine similarity?**
   - Higher cosine = better alignment with target embeddings

2. **Did Top-1 accuracy change?**
   - Measures exact retrieval performance

3. **Did mean rank improve?**
   - Lower rank = better overall ranking quality

4. **Was training more stable?**
   - WGAN-GP typically provides smoother gradients

---

### Tuning WGAN-GP

If WGAN-GP underperforms, try:

```python
# Adjust gradient penalty coefficient
--gp_lambda 1     # Less regularization
--gp_lambda 100   # More regularization

# Adjust critic learning rate
--disc_lr 5e-5    # Higher for WGAN

# Multiple critic steps per generator step (not implemented here)
# Would require modifying training loop
```

---

### Scaling Up

To increase training scale:

```python
# More training data
--num_points 100000

# More epochs
--epochs 30

# Larger batch size (if GPU memory allows)
--bs 256
```

---

### Changing Configurations

**Different model pairs:**
```python
--unsup_emb gte --sup_emb gtr    # GTE to GTR
--unsup_emb e5 --sup_emb gte     # E5 to GTE
```

**Different datasets:**
```python
--dataset fineweb
--dataset msmarco-corpus
```

**Toggle adversarial modes:**
```python
--gan_style vanilla         # Standard BCE GAN
--gan_style least_squares   # LSGAN (default)
--gan_style relativistic    # Relativistic GAN
--gan_style wgan            # WGAN-GP
```

---

### Citation

```bibtex
@misc{jha2025harnessinguniversalgeometryembeddings,
      title={Harnessing the Universal Geometry of Embeddings}, 
      author={Rishi Jha and Collin Zhang and Vitaly Shmatikov and John X. Morris},
      year={2025},
      eprint={2505.12540},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2505.12540}, 
}
```