# Evaluate Isotropy: Original vs Fine-Tuned

Compare embedding distributions:
- **Original EmbeddingGemma**: L2 normalized (spherical)
- **LeJEPA Fine-tuned**: Isotropic Gaussian N(0,I)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from ragcun import IsotropicGaussianEncoder

print("✅ Imports successful")

## 1. Load Models

In [None]:
# Original EmbeddingGemma (normalized)
print("Loading original EmbeddingGemma...")
original_model = SentenceTransformer('google/embeddinggemma-300m', trust_remote_code=True)

# Fine-tuned LeJEPA (Gaussian)
print("Loading fine-tuned model...")
gaussian_model = IsotropicGaussianEncoder.from_pretrained(
    'data/embeddings/gaussian_embeddinggemma_final.pt'
)

print("✅ Models loaded")

## 2. Generate Test Embeddings

In [None]:
# Sample diverse texts
from datasets import load_dataset

print("Loading test data...")
dataset = load_dataset('wikipedia', '20220301.en', split='train', streaming=True)

test_texts = []
for i, example in enumerate(dataset):
    if i >= 1000:  # 1000 samples
        break
    test_texts.append(example['text'][:200])  # First 200 chars

print(f"✅ Loaded {len(test_texts)} test texts")

In [None]:
# Encode with both models
print("Encoding with original model...")
original_emb = original_model.encode(
    test_texts, 
    show_progress_bar=True,
    convert_to_numpy=True
)

print("Encoding with Gaussian model...")
with torch.no_grad():
    gaussian_emb = gaussian_model.encode(
        test_texts,
        show_progress=True,
        convert_to_numpy=True
    )

print(f"\nOriginal shape: {original_emb.shape}")
print(f"Gaussian shape: {gaussian_emb.shape}")

## 3. Statistical Analysis

In [None]:
def analyze_distribution(embeddings, name):
    """Analyze embedding distribution statistics."""
    print(f"\n{'='*60}")
    print(f"{name}")
    print('='*60)
    
    # Mean and std
    mean = embeddings.mean(axis=0)
    std = embeddings.std(axis=0)
    
    print(f"\n1. Mean Statistics:")
    print(f"   Mean of means: {mean.mean():.6f} (want ~0 for Gaussian)")
    print(f"   Std of means:  {mean.std():.6f}")
    print(f"   Max |mean|:    {np.abs(mean).max():.6f}")
    
    print(f"\n2. Std Statistics:")
    print(f"   Mean of stds:  {std.mean():.6f} (want ~1 for Gaussian)")
    print(f"   Std of stds:   {std.std():.6f} (want ~0 for isotropy)")
    
    # Norms
    norms = np.linalg.norm(embeddings, axis=1)
    print(f"\n3. Norm Statistics:")
    print(f"   Mean norm:     {norms.mean():.6f}")
    print(f"   Std norm:      {norms.std():.6f}")
    print(f"   Min norm:      {norms.min():.6f}")
    print(f"   Max norm:      {norms.max():.6f}")
    
    # Covariance
    centered = embeddings - mean
    cov = (centered.T @ centered) / (embeddings.shape[0] - 1)
    
    # Diagonal vs off-diagonal
    diag = np.diag(cov)
    off_diag = cov.copy()
    np.fill_diagonal(off_diag, 0)
    
    print(f"\n4. Covariance Statistics:")
    print(f"   Mean diagonal:     {diag.mean():.6f} (want ~1)")
    print(f"   Std diagonal:      {diag.std():.6f} (want ~0)")
    print(f"   Mean |off-diag|:   {np.abs(off_diag).mean():.6f} (want ~0)")
    print(f"   Max |off-diag|:    {np.abs(off_diag).max():.6f}")
    
    # Isotropy score
    identity = np.eye(cov.shape[0])
    isotropy_error = np.linalg.norm(cov - identity, ord='fro')
    print(f"\n5. Isotropy Error: {isotropy_error:.4f} (want <5)")
    
    is_isotropic = (
        np.abs(mean).mean() < 0.1 and 
        isotropy_error < 10.0
    )
    
    print(f"\n✅ Is Isotropic: {is_isotropic}")
    
    return {
        'mean': mean,
        'std': std,
        'norms': norms,
        'cov': cov,
        'is_isotropic': is_isotropic
    }

# Analyze both
original_stats = analyze_distribution(original_emb, "Original EmbeddingGemma")
gaussian_stats = analyze_distribution(gaussian_emb, "LeJEPA Fine-tuned")

## 4. Visualizations

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Original vs LeJEPA: Distribution Comparison', fontsize=16)

# Row 1: Original
axes[0, 0].hist(original_stats['norms'], bins=50, alpha=0.7, color='blue')
axes[0, 0].axvline(original_stats['norms'].mean(), color='red', linestyle='--', label=f"Mean={original_stats['norms'].mean():.3f}")
axes[0, 0].set_title('Original: Norm Distribution')
axes[0, 0].legend()

axes[0, 1].hist(original_emb[:, 0], bins=50, alpha=0.7, color='blue')
axes[0, 1].set_title('Original: First Dimension')

axes[0, 2].imshow(original_stats['cov'][:50, :50], cmap='RdBu', vmin=-1, vmax=1)
axes[0, 2].set_title('Original: Covariance (50x50)')

# Row 2: Gaussian
axes[1, 0].hist(gaussian_stats['norms'], bins=50, alpha=0.7, color='green')
axes[1, 0].axvline(gaussian_stats['norms'].mean(), color='red', linestyle='--', label=f"Mean={gaussian_stats['norms'].mean():.3f}")
axes[1, 0].set_title('LeJEPA: Norm Distribution')
axes[1, 0].legend()

axes[1, 1].hist(gaussian_emb[:, 0], bins=50, alpha=0.7, color='green')
axes[1, 1].axvline(0, color='red', linestyle='--', label='N(0,1)')
axes[1, 1].set_title('LeJEPA: First Dimension')
axes[1, 1].legend()

axes[1, 2].imshow(gaussian_stats['cov'][:50, :50], cmap='RdBu', vmin=-1, vmax=1)
axes[1, 2].set_title('LeJEPA: Covariance (50x50)')

plt.tight_layout()
plt.savefig('data/processed/isotropy_comparison.png', dpi=150)
plt.show()

print("\n✅ Visualization saved to data/processed/isotropy_comparison.png")

## 5. Summary Comparison

In [None]:
import pandas as pd

comparison = pd.DataFrame({
    'Metric': [
        'Mean of Means',
        'Mean of Stds',
        'Mean Norm',
        'Std of Norms',
        'Mean Diagonal (Cov)',
        'Mean |Off-Diagonal|',
        'Is Isotropic?'
    ],
    'Original': [
        f"{original_stats['mean'].mean():.6f}",
        f"{original_stats['std'].mean():.6f}",
        f"{original_stats['norms'].mean():.6f}",
        f"{original_stats['norms'].std():.6f}",
        f"{np.diag(original_stats['cov']).mean():.6f}",
        f"{np.abs(original_stats['cov'] - np.diag(np.diag(original_stats['cov']))).mean():.6f}",
        f"{original_stats['is_isotropic']}"
    ],
    'LeJEPA': [
        f"{gaussian_stats['mean'].mean():.6f}",
        f"{gaussian_stats['std'].mean():.6f}",
        f"{gaussian_stats['norms'].mean():.6f}",
        f"{gaussian_stats['norms'].std():.6f}",
        f"{np.diag(gaussian_stats['cov']).mean():.6f}",
        f"{np.abs(gaussian_stats['cov'] - np.diag(np.diag(gaussian_stats['cov']))).mean():.6f}",
        f"{gaussian_stats['is_isotropic']}"
    ],
    'Target (Gaussian)': [
        '~0.0',
        '~1.0',
        'Variable',
        '>0',
        '~1.0',
        '~0.0',
        'True'
    ]
})

print("\n" + "="*80)
print("ISOTROPY COMPARISON SUMMARY")
print("="*80)
print(comparison.to_string(index=False))
print("="*80)

# Save
comparison.to_csv('data/processed/isotropy_comparison.csv', index=False)
print("\n✅ Saved to data/processed/isotropy_comparison.csv")