# Train LeJEPA Isotropic Gaussian Embeddings

This notebook trains EmbeddingGemma-300M with LeJEPA loss to produce isotropic Gaussian embeddings for improved RAG retrieval.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ctn/ragcun/blob/main/notebooks/lejepa_training.ipynb)

## 1. Setup GPU and Install Packages

In [1]:
# Check GPU
!nvidia-smi

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

Thu Nov 13 15:55:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   54C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Install dependencies
!pip install -q transformers>=4.45.0 sentence-transformers>=3.0.0 datasets
!pip install -q faiss-gpu accelerate
!pip install -q lejepa || pip install -q git+https://github.com/rbalestr-lab/lejepa.git

print("✅ All packages installed!")

[31mERROR: Could not find a version that satisfies the requirement faiss-gpu (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for faiss-gpu[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement lejepa (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for lejepa[0m[31m
[0m  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for lejepa (pyproject.toml) ... [?25l[?25hdone
✅ All packages installed!


## 2. Imports

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sentence_transformers import SentenceTransformer
import lejepa
import numpy as np
from tqdm.auto import tqdm
import json
from pathlib import Path

print("✅ Imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

✅ Imports successful!
PyTorch: 2.8.0+cu126
Device: cuda


## 3. Model: Gaussian EmbeddingGemma

This model wraps EmbeddingGemma-300M and adds a projection layer to produce **unnormalized isotropic Gaussian embeddings**.

In [4]:
class GaussianEmbeddingGemma(nn.Module):
    """
    EmbeddingGemma with LeJEPA-trained projection to isotropic Gaussian space.

    Key features:
    - Starts with EmbeddingGemma-300M (state-of-the-art for RAG)
    - Projects to unnormalized Gaussian space (NO L2 normalization)
    - Trained with LeJEPA SIGReg loss for isotropy
    - Uses Euclidean distance for retrieval (not cosine similarity)
    """

    def __init__(self, output_dim=512, freeze_early_layers=True):
        super().__init__()

        print("Loading EmbeddingGemma-300M...")
        self.base = SentenceTransformer(
            'google/embeddinggemma-300m',
            trust_remote_code=True
        )

        # Projection: 768 (normalized) → output_dim (Gaussian)
        # NO normalization layers!
        self.projection = nn.Sequential(
            nn.Linear(768, 768 * 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(768 * 2, output_dim)
        )

        # Make base trainable
        for param in self.base.parameters():
            param.requires_grad = True

        # Optionally freeze early layers
        if freeze_early_layers:
            frozen = 0
            for name, param in self.base.named_parameters():
                if any(f'encoder.layer.{i}.' in name for i in range(4)):
                    param.requires_grad = False
                    frozen += 1
            print(f"Froze {frozen} parameters in early layers")

        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total params: {total:,}")
        print(f"Trainable: {trainable:,} ({100*trainable/total:.1f}%)")

        self.output_dim = output_dim

    def encode(self, texts, batch_size=32, show_progress=False):
        """Encode texts to Gaussian embeddings (NOT normalized)"""
        base_emb = self.base.encode(
            texts,
            batch_size=batch_size,
            convert_to_tensor=True,
            show_progress_bar=show_progress,
            normalize_embeddings=True  # Base does L2 norm
        )

        # Project to Gaussian space (undoes normalization)
        gaussian_emb = self.projection(base_emb)
        return gaussian_emb

    def forward(self, texts):
        return self.encode(texts, show_progress=False)

print("✅ Model class defined!")

✅ Model class defined!


## 4. Load Training Data

In [None]:
from datasets import load_dataset
from tqdm.auto import tqdm # Import tqdm for progress bar

# Load MS MARCO dataset (query-positive-negative triplets)
print("Loading training data...")

# --- NEW: Load MS MARCO queries and passages for text mapping ---
print("Loading MS MARCO v1.1 train dataset for queries. This may take a moment and consume memory...")
msmarco_v11_train_dataset = load_dataset("ms_marco", "v1.1", split='train')

query_id_to_text = {}

print("Populating query ID to text mapping...")
# Iterate over the msmarco_v11_train_dataset to populate query map
for example in tqdm(msmarco_v11_train_dataset, desc="Processing MS MARCO v1.1 queries"):
    query_id_to_text[str(example['query_id'])] = example['query']

print(f"Loaded {len(query_id_to_text)} MS MARCO queries.")

print("Loading full MS MARCO passage collection. This may take a moment and consume significant memory...")
# Load the dedicated MS MARCO passage collection, e.g., from Tevatron
# This dataset contains all passages and their IDs.
passage_collection_dataset = load_dataset('Tevatron/msmarco-passage', 'v1', split='train')
passage_id_to_text = {str(example['pid']): example['text'] for example in tqdm(passage_collection_dataset, desc="Populating passage map")}
print(f"Loaded {len(passage_id_to_text)} MS MARCO passages.")
# --- END NEW ---

# Load MS MARCO hard negatives (query-positive-negative triplets of IDs)
# This dataset is loaded in streaming mode as we only take a subset
hard_negatives_dataset = load_dataset(
    'sentence-transformers/msmarco-hard-negatives',
    'default', # Changed 'triplet' to 'default' in previous turn
    split='train',
    streaming=True
)

# Take 5000 samples for quick training
num_samples = 5000
data = []
collected_triplets = 0

# --- MODIFIED: Use mappings to get actual texts ---
for i, example in enumerate(hard_negatives_dataset):
    if collected_triplets >= num_samples:
        break

    # Get IDs from the hard_negatives_dataset (based on kernel state example)
    query_id = str(example['qid'])
    positive_id = str(example['pos'][0]) # Assuming 'pos' is a list and we take the first

    # Assuming 'neg' is a dict and we take the first from 'bm25' key
    negative_id = None
    if 'bm25' in example['neg'] and example['neg']['bm25']:
        negative_id = str(example['neg']['bm25'][0])

    # Lookup texts using the maps
    query_text = query_id_to_text.get(query_id)
    positive_text = passage_id_to_text.get(positive_id)
    negative_text = passage_id_to_text.get(negative_id)

    if query_text and positive_text and negative_text: # Only add if all texts are found
        data.append({
            'query': query_text,
            'positive': positive_text,
            'negative': negative_text
        })
        collected_triplets += 1

    if (i + 1) % 1000 == 0:
        print(f"  Processed {i + 1} hard negative examples from stream, collected {collected_triplets} full triplets...")

# Train/val split
train_size = int(0.9 * len(data))
train_data = data[:train_size]
val_data = data[train_size:]

print(f"✅ Train: {len(train_data)}, Val: {len(val_data)}")

Loading training data...
Loading MS MARCO v1.1 train dataset. This may take a moment and consume memory...
Populating query and passage ID to text mappings...


Processing MS MARCO v1.1 train data:   0%|          | 0/82326 [00:00<?, ?it/s]

Loaded 82326 MS MARCO queries.
Loaded 0 MS MARCO passages.


Repo card metadata block was not found. Setting CardData to empty.


  Processed 1000 hard negative examples from stream, collected 0 full triplets...
  Processed 2000 hard negative examples from stream, collected 0 full triplets...
  Processed 3000 hard negative examples from stream, collected 0 full triplets...
  Processed 4000 hard negative examples from stream, collected 0 full triplets...
  Processed 5000 hard negative examples from stream, collected 0 full triplets...
  Processed 6000 hard negative examples from stream, collected 0 full triplets...
  Processed 7000 hard negative examples from stream, collected 0 full triplets...
  Processed 8000 hard negative examples from stream, collected 0 full triplets...
  Processed 9000 hard negative examples from stream, collected 0 full triplets...
  Processed 10000 hard negative examples from stream, collected 0 full triplets...
  Processed 11000 hard negative examples from stream, collected 0 full triplets...
  Processed 12000 hard negative examples from stream, collected 0 full triplets...
  Processed 1

## 5. Create DataLoaders

In [1]:
class TripletDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    return {
        'queries': [item['query'] for item in batch],
        'positives': [item['positive'] for item in batch],
        'negatives': [item['negative'] for item in batch]
    }

batch_size = 16  # Adjust based on GPU (T4: 8-16, A100: 32-64)

train_loader = DataLoader(
    TripletDataset(train_data),
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    TripletDataset(val_data),
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)

print(f"✅ DataLoaders: {len(train_loader)} train, {len(val_loader)} val batches")

NameError: name 'Dataset' is not defined

## 6. Initialize Model and LeJEPA Loss

In [None]:
# Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GaussianEmbeddingGemma(output_dim=512).to(device)

# LeJEPA loss (SIGReg)
print("\nInitializing LeJEPA SIGReg...")
sigreg = lejepa.multivariate.SlicingUnivariateTest(
    univariate_test=lejepa.univariate.EppsPulley(num_points=17),
    num_slices=1024
).to(device)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-5,
    weight_decay=0.05
)

# Scheduler
num_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=num_epochs * len(train_loader)
)

# Loss weights (from LeJEPA paper)
lambda_contrastive = 1.0
lambda_isotropy = 0.01

print("✅ Training setup complete!")

## 7. Training Utilities

In [None]:
def check_isotropy(embeddings, verbose=True):
    """Verify embeddings are isotropic Gaussian N(0, I)"""
    embeddings = embeddings.detach().cpu()

    mean = embeddings.mean(dim=0)
    centered = embeddings - mean
    cov = (centered.T @ centered) / (embeddings.shape[0] - 1)

    mean_norm = torch.norm(mean).item()
    cov_error = torch.norm(cov - torch.eye(cov.shape[0]), p='fro').item()
    diag_mean = torch.diag(cov).mean().item()

    off_diag = cov.clone()
    off_diag.fill_diagonal_(0)
    off_diag_mean = off_diag.abs().mean().item()

    if verbose:
        print(f"  Mean norm: {mean_norm:.4f} (want ~0)")
        print(f"  Cov error: {cov_error:.4f} (want <5)")
        print(f"  Diag mean: {diag_mean:.4f} (want ~1)")
        print(f"  Off-diag: {off_diag_mean:.4f} (want ~0)")

    return {
        'mean_norm': mean_norm,
        'cov_error': cov_error,
        'is_isotropic': mean_norm < 0.5 and cov_error < 10.0
    }

def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)
    print(f"✅ Saved: {path}")

print("✅ Utilities defined!")

## 8. Training Loop

In [None]:
def train_epoch(model, loader, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0
    total_contrastive = 0
    total_isotropy = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in pbar:
        # Encode triplets
        q_emb = model(batch['queries'])
        pos_emb = model(batch['positives'])
        neg_emb = model(batch['negatives'])

        # Euclidean contrastive loss
        pos_dist = torch.norm(q_emb - pos_emb, dim=1)
        neg_dist = torch.norm(q_emb - neg_emb, dim=1)
        contrastive_loss = torch.relu(pos_dist - neg_dist + 1.0).mean()

        # LeJEPA isotropy loss
        all_emb = torch.cat([q_emb, pos_emb, neg_emb], dim=0)
        isotropy_loss = sigreg(all_emb)

        # Combined
        loss = lambda_contrastive * contrastive_loss + lambda_isotropy * isotropy_loss

        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        total_contrastive += contrastive_loss.item()
        total_isotropy += isotropy_loss.item()

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'contr': f'{contrastive_loss.item():.4f}',
            'iso': f'{isotropy_loss.item():.4f}'
        })

    return {
        'loss': total_loss / len(loader),
        'contrastive': total_contrastive / len(loader),
        'isotropy': total_isotropy / len(loader)
    }

def validate(model, loader):
    model.eval()
    total_loss = 0
    all_embeddings = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            q_emb = model(batch['queries'])
            pos_emb = model(batch['positives'])
            neg_emb = model(batch['negatives'])

            pos_dist = torch.norm(q_emb - pos_emb, dim=1)
            neg_dist = torch.norm(q_emb - neg_emb, dim=1)
            contrastive_loss = torch.relu(pos_dist - neg_dist + 1.0).mean()

            all_emb = torch.cat([q_emb, pos_emb, neg_emb], dim=0)
            isotropy_loss = sigreg(all_emb)

            loss = lambda_contrastive * contrastive_loss + lambda_isotropy * isotropy_loss
            total_loss += loss.item()
            all_embeddings.append(all_emb)

    all_embeddings = torch.cat(all_embeddings, dim=0)
    print("\nIsotropy check:")
    isotropy_metrics = check_isotropy(all_embeddings)

    return {
        'loss': total_loss / len(loader),
        'isotropy_metrics': isotropy_metrics
    }

print("✅ Training functions defined!")

## 9. Run Training

In [None]:
Path('checkpoints').mkdir(exist_ok=True)
best_val_loss = float('inf')

print("="*60)
print("Starting Training!")
print("="*60)

for epoch in range(num_epochs):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print('='*60)

    # Train
    train_metrics = train_epoch(model, train_loader, optimizer, scheduler, epoch)
    print(f"\nTrain Loss: {train_metrics['loss']:.4f}")

    # Validate
    val_metrics = validate(model, val_loader)
    print(f"Val Loss: {val_metrics['loss']:.4f}")
    print(f"Isotropic: {val_metrics['isotropy_metrics']['is_isotropic']}")

    # Save best
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        save_checkpoint(model, optimizer, epoch, 'checkpoints/best_model.pt')
        print("  ⭐ New best model!")

print("\n" + "="*60)
print("✅ Training complete!")
print("="*60)

## 10. Test Inference

In [None]:
# Load best model
checkpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Test
test_texts = [
    "What is machine learning?",
    "Machine learning is a branch of AI",
    "How to cook pasta",
    "Python programming tutorial"
]

with torch.no_grad():
    embeddings = model(test_texts)
    print(f"\nShape: {embeddings.shape}")
    print(f"Mean norm: {embeddings.norm(dim=1).mean():.4f}")
    print(f"Std norm: {embeddings.norm(dim=1).std():.4f}")

check_isotropy(embeddings)

## 11. Save for RAGCUN

Save the trained model to use in the RAGCUN retriever.

In [None]:
# Save final model
final_path = 'gaussian_embeddinggemma_final.pt'
torch.save(model.state_dict(), final_path)
print(f"✅ Model saved: {final_path}")

# Download to local machine
from google.colab import files
files.download(final_path)

print("\n📥 Download complete!")
print("Next step: Add this model to /Users/ctn/src/ctn/ragcun/data/embeddings/")