# RAG Evaluation: Original vs LeJEPA Fine-tuned

Compare retrieval performance:
- **Original**: Cosine similarity on normalized embeddings
- **LeJEPA**: Euclidean distance on isotropic Gaussian embeddings

In [None]:
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from ragcun import GaussianRetriever
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

print("✅ Imports successful")

## 1. Setup Models

In [None]:
# Original model (cosine similarity)
print("Loading original model...")
original_model = SentenceTransformer('google/embeddinggemma-300m', trust_remote_code=True)

# LeJEPA model (Euclidean distance)
print("Loading LeJEPA model...")
lejepa_retriever = GaussianRetriever(
    model_path='data/embeddings/gaussian_embeddinggemma_final.pt'
)

print("✅ Models loaded")

## 2. Create Test Dataset

We'll use a labeled QA dataset to measure retrieval accuracy.

In [None]:
from datasets import load_dataset

# Load MS MARCO dev set (query-doc pairs)
print("Loading test data...")
dataset = load_dataset('sentence-transformers/msmarco-hard-negatives', 'triplet', split='train', streaming=True)

# Take 500 examples
test_data = []
for i, example in enumerate(dataset):
    if i >= 500:
        break
    test_data.append({
        'query': example['query'],
        'positive': example['positive'],
        'negative': example['negative']
    })

print(f"✅ Loaded {len(test_data)} test examples")

## 3. Build Document Collections

In [None]:
# Extract all documents (positives + negatives)
all_docs = []
query_to_correct_doc_idx = {}  # Map query index to correct doc index

for i, example in enumerate(test_data):
    # Add positive doc
    pos_idx = len(all_docs)
    all_docs.append(example['positive'])
    query_to_correct_doc_idx[i] = pos_idx
    
    # Add negative doc
    all_docs.append(example['negative'])

print(f"Total documents: {len(all_docs)}")
print(f"Queries: {len(test_data)}")

# Add to LeJEPA retriever
print("\nIndexing with LeJEPA...")
lejepa_retriever.add_documents(all_docs, batch_size=32)

In [None]:
# Encode all docs with original model
print("Encoding with original model...")
original_doc_embs = original_model.encode(
    all_docs,
    batch_size=32,
    show_progress_bar=True,
    normalize_embeddings=True  # Normalize for cosine
)

print(f"✅ Original embeddings shape: {original_doc_embs.shape}")

## 4. Evaluate Retrieval Performance

In [None]:
def evaluate_retrieval(model, queries, doc_embs, query_to_correct, method='cosine', top_k=10):
    """
    Evaluate retrieval performance.
    
    Returns:
        - Recall@k
        - MRR (Mean Reciprocal Rank)
        - Separation (avg distance to pos vs neg)
    """
    recalls = []
    reciprocal_ranks = []
    pos_scores = []
    neg_scores = []
    
    # Encode queries
    if method == 'cosine':
        query_embs = model.encode(queries, show_progress_bar=True, normalize_embeddings=True)
    
    for i, query in enumerate(queries):
        correct_idx = query_to_correct[i]
        
        if method == 'cosine':
            # Cosine similarity
            query_emb = query_embs[i:i+1]
            similarities = cosine_similarity(query_emb, doc_embs)[0]
            
            # Rank by similarity (higher = better)
            ranked_indices = np.argsort(-similarities)
            
            # Scores for pos/neg
            pos_scores.append(similarities[correct_idx])
            neg_indices = [j for j in range(len(doc_embs)) if j != correct_idx]
            neg_scores.append(similarities[neg_indices].mean())
        else:
            # Euclidean distance (LeJEPA)
            results = lejepa_retriever.retrieve(query, top_k=top_k)
            ranked_indices = [all_docs.index(doc) for doc, _ in results]
            distances = [dist for _, dist in results]
            
            # Scores (lower = better for distance)
            correct_doc = all_docs[correct_idx]
            if correct_doc in [doc for doc, _ in results]:
                pos_idx = [doc for doc, _ in results].index(correct_doc)
                pos_scores.append(-distances[pos_idx])  # Negative for comparison
            else:
                # Not in top-k, estimate
                pos_scores.append(-100)  # Penalty
            
            neg_scores.append(-np.mean([d for d in distances if d != distances[pos_idx] if 'pos_idx' in locals()]))
        
        # Recall@k
        recall = 1 if correct_idx in ranked_indices[:top_k] else 0
        recalls.append(recall)
        
        # MRR
        if correct_idx in ranked_indices[:top_k]:
            rank = list(ranked_indices[:top_k]).index(correct_idx) + 1
            reciprocal_ranks.append(1.0 / rank)
        else:
            reciprocal_ranks.append(0.0)
    
    return {
        'recall@k': np.mean(recalls),
        'mrr': np.mean(reciprocal_ranks),
        'avg_pos_score': np.mean(pos_scores),
        'avg_neg_score': np.mean(neg_scores),
        'separation': np.mean(pos_scores) - np.mean(neg_scores)
    }

# Extract queries
queries = [ex['query'] for ex in test_data]

print("\nEvaluating Original (Cosine Similarity)...")
original_results = evaluate_retrieval(
    original_model, queries, original_doc_embs, 
    query_to_correct_doc_idx, method='cosine', top_k=10
)

print("\nEvaluating LeJEPA (Euclidean Distance)...")
lejepa_results = evaluate_retrieval(
    None, queries, None,
    query_to_correct_doc_idx, method='euclidean', top_k=10
)

print("✅ Evaluation complete!")

## 5. Results Comparison

In [None]:
# Create comparison table
comparison = pd.DataFrame({
    'Metric': ['Recall@10', 'MRR', 'Avg Pos Score', 'Avg Neg Score', 'Separation'],
    'Original (Cosine)': [
        f"{original_results['recall@k']:.4f}",
        f"{original_results['mrr']:.4f}",
        f"{original_results['avg_pos_score']:.4f}",
        f"{original_results['avg_neg_score']:.4f}",
        f"{original_results['separation']:.4f}"
    ],
    'LeJEPA (Euclidean)': [
        f"{lejepa_results['recall@k']:.4f}",
        f"{lejepa_results['mrr']:.4f}",
        f"{lejepa_results['avg_pos_score']:.4f}",
        f"{lejepa_results['avg_neg_score']:.4f}",
        f"{lejepa_results['separation']:.4f}"
    ],
    'Improvement': [
        f"{(lejepa_results['recall@k'] - original_results['recall@k']):.4f}",
        f"{(lejepa_results['mrr'] - original_results['mrr']):.4f}",
        f"-",
        f"-",
        f"{(lejepa_results['separation'] - original_results['separation']):.4f}"
    ]
})

print("\n" + "="*80)
print("RAG RETRIEVAL PERFORMANCE COMPARISON")
print("="*80)
print(comparison.to_string(index=False))
print("="*80)

# Save
comparison.to_csv('data/processed/rag_performance_comparison.csv', index=False)
print("\n✅ Saved to data/processed/rag_performance_comparison.csv")

## 6. Visualization

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Recall comparison
metrics = ['Recall@10', 'MRR']
original_vals = [original_results['recall@k'], original_results['mrr']]
lejepa_vals = [lejepa_results['recall@k'], lejepa_results['mrr']]

x = np.arange(len(metrics))
width = 0.35

axes[0].bar(x - width/2, original_vals, width, label='Original (Cosine)', alpha=0.8)
axes[0].bar(x + width/2, lejepa_vals, width, label='LeJEPA (Euclidean)', alpha=0.8)
axes[0].set_ylabel('Score')
axes[0].set_title('Retrieval Performance')
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics)
axes[0].legend()
axes[0].grid(axis='y', alpha=0.3)

# Separation comparison
separations = [original_results['separation'], lejepa_results['separation']]
models = ['Original\n(Cosine)', 'LeJEPA\n(Euclidean)']

axes[1].bar(models, separations, alpha=0.8, color=['blue', 'green'])
axes[1].set_ylabel('Pos-Neg Separation')
axes[1].set_title('Score Separation (Higher = Better)')
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('data/processed/rag_performance.png', dpi=150)
plt.show()

print("\n✅ Visualization saved to data/processed/rag_performance.png")

## 7. Summary

**Key Findings:**

1. **Recall@10**: How often is the correct document in top 10?
2. **MRR**: Average reciprocal rank (1/position of correct doc)
3. **Separation**: Distance between positive and negative scores

**Expected Results:**
- LeJEPA should show **higher separation** (better discrimination)
- Similar or better Recall/MRR
- More calibrated scores for downstream ranking