# litfit ‚Äî the shortest path from someone else's embedding to your task

litfit learns optimal linear projections from covariance statistics instead of gradient-based fine-tuning. One pass over your pairs, closed-form solution, done.

This notebook walks through the full pipeline on a text retrieval task (AskUbuntu duplicate detection). The same approach works for vision, multimodal, or any dense embeddings.

üíª [GitHub](https://github.com/b0nce/litfit) | üì¶ [PyPI](https://pypi.org/project/litfit/)

In [19]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"‚úÖ GPU: {gpu_name}")
else:
    print("‚ö†Ô∏è  No GPU detected ‚Äî litfit will still work, just slower.")
    print("   To enable GPU: Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

!pip install -q litfit

‚úÖ GPU: Tesla T4


## Step 1: Prepare your data

litfit needs three things:
- **embeddings** ‚Äî vectors from any model (text, vision, multimodal)
- **ids** ‚Äî a unique identifier per embedding (`torch.arange` is sometimes enough)
- **id_to_group** ‚Äî a dict saying which ids are "similar" (duplicates, same class, relevant pairs)

We'll use the built-in AskUbuntu dataset: StackOverflow questions grouped by duplicate clusters.

In [None]:
from litfit import encode_texts, load_askubuntu, split_data

all_ids, all_texts, id_to_group = load_askubuntu()
print(f"{len(all_ids)} questions, {len(set(id_to_group.values()))} duplicate groups")
print(f"\nExample: {all_texts[0][:120]}...")

# Encode with any HuggingFace model ‚Äî swap this for your own embeddings
embs = encode_texts("intfloat/e5-base-v2", all_texts)
print(f"Embedding shape: {embs.shape}")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'sentence-transformers/askubuntu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'sentence-transformers/askubuntu' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading AskUbuntu...
  AskUbuntu: 8213 items, 2952 groups
8213 questions, 2952 duplicate groups

Example: wireless network card not working ( trendnet tew-643pi )...
  Encoding 8213 texts with intfloat/e5-base-v2...


Loading weights:   0%|          | 0/199 [00:00<?, ?it/s]

BertModel LOAD REPORT from: intfloat/e5-base-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Batches:   0%|          | 0/129 [00:00<?, ?it/s]

Embedding shape: torch.Size([8213, 768])


In [21]:
# Group-aware split: all items in a group stay together (no leakage)
data = split_data(all_ids, all_texts, embs, id_to_group)

train_ids, _, train_embs, _ = data["train"]
val_ids,   _, val_embs,   _ = data["val"]
test_ids,  _, test_embs,  _ = data["test"]

print(f"Train: {len(train_ids)} | Val: {len(val_ids)} | Test: {len(test_ids)}")

    train: 4873 items, 1771 groups, emb=(4873, 768)
    val: 1674 items, 590 groups, emb=(1674, 768)
    test: 1666 items, 591 groups, emb=(1666, 768)
Train: 4873 | Val: 1674 | Test: 1666


## Step 2: Measure the baseline

How well do raw embeddings perform before litfit touches them?

In [22]:
from litfit import evaluate_retrieval_fast

baseline = evaluate_retrieval_fast(test_embs, test_ids, id_to_group)
print(f"Baseline ‚Äî R@1: {baseline['R@1']:.4f}  MAP@50: {baseline['MAP@50']:.4f}")

Baseline ‚Äî R@1: 0.5360  MAP@50: 0.5084


## Step 3: Compute projections with litfit

The core pipeline:
1. Compute covariance matrices from positive pairs (sufficient statistics)
2. Generate ~40 candidate projections (fast mode) using different methods
3. Evaluate on the validation set with explore-exploit scheduling
4. Pick the best one

All closed-form ‚Äî no iterative training.

In [23]:
import time

from litfit import (
    compute_stats,
    evaluate_projections,
    find_dim_range,
    generate_fast_projections,
)

t0 = time.time()

# Sufficient statistics from training pairs
st = compute_stats(train_embs, train_ids, id_to_group)

# Automatically find useful dimensionality range
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)

# Generate and evaluate ~40 projection candidates
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

print(f"Done in {time.time()-t0:.1f}s")

Scanning dimensions:   0%|          | 0/21 [00:00<?, ?it/s]


Dim range scan (baseline MAP@50=0.4986, peak=0.5540, delta=0.0555)
  dims      MAP@50   vs base
----------------------------
    38      0.4632   -0.0354
    76      0.5207   +0.0221
   114      0.5417   +0.0431
   152      0.5459   +0.0473
   190      0.5501   +0.0515
   228      0.5540   +0.0555
   266      0.5530   +0.0544
   304      0.5522   +0.0537
   342      0.5484   +0.0498
   380      0.5479   +0.0493
   418      0.5474   +0.0488
   456      0.5462   +0.0477
   494      0.5457   +0.0471
   532      0.5442   +0.0456
   570      0.5426   +0.0441
   608      0.5417   +0.0431
   646      0.5394   +0.0408
   684      0.5380   +0.0394
   722      0.5353   +0.0367
   760      0.5291   +0.0305
   768      0.5209   +0.0223

Useful range: dims 190-304 (fracs 0.25-0.40)
Returned dim_fractions: (0.049479166666666664, 0.1484375, 0.24739583333333334, 0.2838541666666667, 0.3216145833333333, 0.359375, 0.3958333333333333, 1.0)


Generating projections (fast):   0%|          | 0/38 [00:00<?, ?it/s]

Total: 38 projections, 0 failed
Device: cuda


Evaluating projections:   0%|          | 0/38 [00:00<?, ?it/s]


VAL SET
Method                            5%(38)    15%(114)    25%(190)    28%(218)    32%(247)    36%(276)    40%(304)   full(768)   configs
--------------------------------------------------------------------------------------------------------------------------------------
Ray‚ÜíAsymRef‚ÜíMSE                   0.4942      0.5555      0.5625      0.5633      0.5627      0.5631      0.5631      0.5648    8/8  
Ray‚ÜíAsymRef                       0.4910      0.5521      0.5512      0.5527      0.5546      0.5532      0.5527      0.5205    6/6  
Ray‚ÜíMSE‚ÜíAsymRef                   0.4885      0.5439      0.5502      0.5506      0.5511      0.5510      0.5512      0.5514    8/8  
SplitRankRay                      0.4412      0.5241      0.5422      0.5445      0.5464      0.5484      0.5489      0.5210    8/8  
Rayleigh                          0.4408      0.5239      0.5418      0.5440      0.5460      0.5483      0.5479      0.5209    4/4  
Ray‚ÜíMSE                           0.474

## Step 4: Extract the best projection and compare

In [24]:
import torch

from litfit import evaluate_retrieval_fast

# Find best config by validation MAP@50
flat = []
for key, dim_dict in results.items():
    for n_dims, scores in dim_dict.items():
        flat.append((key, n_dims, scores["MAP@50"]))
flat.sort(key=lambda x: x[2], reverse=True)
best_key, best_n_dims, best_val_score = flat[0]

# Apply projection to test embeddings
W = all_W[best_key]
projected = test_embs @ W
if best_n_dims is not None:
    projected = projected[:, :best_n_dims]

improved = evaluate_retrieval_fast(projected, test_ids, id_to_group)

dims_str = best_n_dims or test_embs.shape[1]
cfg_str = ", ".join(best_key[1:]) if len(best_key) > 1 else "default"

print(f"Best method: {best_key[0]} ({cfg_str})")
print(f"Dimensions:  {test_embs.shape[1]} ‚Üí {dims_str}")
print()
print(f"{'':15s} {'R@1':>8s}  {'MAP@50':>8s}")
print(f"{'Baseline':15s} {baseline['R@1']:>8.4f}  {baseline['MAP@50']:>8.4f}")
print(f"{'+ litfit':15s} {improved['R@1']:>8.4f}  {improved['MAP@50']:>8.4f}")
print(f"{'Œî':15s} {improved['R@1']-baseline['R@1']:>+8.4f}  {improved['MAP@50']-baseline['MAP@50']:>+8.4f}")

Best method: Ray‚ÜíAsymRef‚ÜíMSE (reg=0.01, reg_mse=1.0, reg_refine=0.1)
Dimensions:  768 ‚Üí 768

                     R@1    MAP@50
Baseline          0.5360    0.5084
+ litfit          0.5816    0.5771
Œî                +0.0456   +0.0687


## Step 5 (optional): Retrain on train+val for a final boost

Since litfit uses a closed-form solution, merging val into training data is safe ‚Äî no risk of overfitting through early stopping leakage.

In [25]:
tv_embs = torch.cat([train_embs, val_embs], dim=0)
tv_ids = train_ids + val_ids

st_full = compute_stats(tv_embs, tv_ids, id_to_group)
all_W_full = generate_fast_projections(st_full, verbose=False)

W_full = all_W_full[best_key]
projected_full = test_embs @ W_full
if best_n_dims is not None:
    projected_full = projected_full[:, :best_n_dims]

final = evaluate_retrieval_fast(projected_full, test_ids, id_to_group)

print(f"{'':15s} {'R@1':>8s}  {'MAP@50':>8s}")
print(f"{'Baseline':15s} {baseline['R@1']:>8.4f}  {baseline['MAP@50']:>8.4f}")
print(f"{'Train only':15s} {improved['R@1']:>8.4f}  {improved['MAP@50']:>8.4f}")
print(f"{'Train+val':15s} {final['R@1']:>8.4f}  {final['MAP@50']:>8.4f}")

                     R@1    MAP@50
Baseline          0.5360    0.5084
Train only        0.5816    0.5771
Train+val         0.5822    0.5788


## Step 6: Export for inference

The projection is just a matrix multiply ‚Äî export it as `torch.nn.Linear` and plug it into any pipeline.

In [26]:
import torch.nn as nn

out_dim = best_n_dims or W_full.shape[1]
layer = nn.Linear(W_full.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W_full[:, :out_dim].T.cpu().float())

# Verify: should match train+val results
with torch.no_grad():
    check = layer(test_embs.cpu().float())
check_metrics = evaluate_retrieval_fast(check, test_ids, id_to_group)
print(f"Exported layer ‚Äî R@1: {check_metrics['R@1']:.4f}  MAP@50: {check_metrics['MAP@50']:.4f}")
print(f"Layer shape: {layer.weight.shape} (in={W_full.shape[0]}, out={out_dim})")

Exported layer ‚Äî R@1: 0.5822  MAP@50: 0.5788
Layer shape: torch.Size([768, 768]) (in=768, out=768)


## Bonus: Full sweep (800+ projections)

`generate_fast_projections` tests ~40 configs. If you have a GPU (or a bit of patience on CPU), `generate_all_projections` tries 800+ ‚Äî including negative-pair statistics and more regularization variants. This can squeeze out extra performance.

In [27]:
from litfit import compute_all_stats, generate_all_projections

t0 = time.time()

# Compute both positive and negative pair statistics
st_all, neg = compute_all_stats(train_embs, train_ids, id_to_group)

# Generate 800+ projection candidates
all_W_full_sweep = generate_all_projections(st_all, neg, include_neg_methods=True)
print(f"Generated {len(all_W_full_sweep)} projections in {time.time()-t0:.1f}s")

# Evaluate with the same dim range
results_full, summary_full = evaluate_projections(
    all_W_full_sweep, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

# Best result from full sweep
flat_full = []
for key, dim_dict in results_full.items():
    for n_dims, scores in dim_dict.items():
        flat_full.append((key, n_dims, scores["MAP@50"]))
flat_full.sort(key=lambda x: x[2], reverse=True)
best_key_full, best_n_dims_full, _ = flat_full[0]

W_sweep = all_W_full_sweep[best_key_full]
proj_sweep = test_embs @ W_sweep
if best_n_dims_full is not None:
    proj_sweep = proj_sweep[:, :best_n_dims_full]

sweep = evaluate_retrieval_fast(proj_sweep, test_ids, id_to_group)

dims_full_str = best_n_dims_full or test_embs.shape[1]

print(f"\nBest method: {best_key_full[0]} ‚Üí {dims_full_str} dims")
print(f"\n{'':15s} {'R@1':>8s}  {'MAP@50':>8s}")
print(f"{'Baseline':15s} {baseline['R@1']:>8.4f}  {baseline['MAP@50']:>8.4f}")
print(f"{'Fast (~40)':15s} {improved['R@1']:>8.4f}  {improved['MAP@50']:>8.4f}")
print(f"{'Full (~800+)':15s} {sweep['R@1']:>8.4f}  {sweep['MAP@50']:>8.4f}")
print(f"\nTotal full sweep time: {time.time()-t0:.1f}s")

Generating projections:   0%|          | 0/861 [00:00<?, ?it/s]

Total: 861 projections, 0 failed
Device: cuda
Generated 861 projections in 95.8s


Evaluating projections:   0%|          | 0/861 [00:00<?, ?it/s]


VAL SET
Method                            5%(38)    15%(114)    25%(190)    28%(218)    32%(247)    36%(276)    40%(304)   full(768)   configs
--------------------------------------------------------------------------------------------------------------------------------------
Ray‚ÜíAsymRef‚ÜíMSE                   0.4950      0.5555      0.5625      0.5633      0.5627      0.5631      0.5632      0.5649   48/48 
Ray‚ÜíMSE‚ÜíAsymRef                   0.4922      0.5558      0.5621      0.5630      0.5633      0.5632      0.5628      0.5628   48/48 
Ray‚ÜíAsymRef                       0.4913      0.5534      0.5512      0.5527      0.5546      0.5532      0.5527      0.5210   16/16 
Ray‚ÜíMSE                           0.4788      0.5406      0.5464      0.5475      0.5484      0.5485      0.5479      0.5494   16/16 
Asym‚ÜíMSE                          0.4788      0.5406      0.5464      0.5475      0.5484      0.5485      0.5479      0.5494   16/16 
SplitRankRay                      0.4

## Using litfit on your own data

Replace the dataset loading with your own embeddings and group labels:
```python
import numpy as np

# Your embeddings ‚Äî any source, any modality
embs = np.load("my_embeddings.npy")  # shape (n, d)

# Your group labels ‚Äî which items should be similar
ids = list(range(len(embs)))
id_to_group = {0: "A", 1: "A", 2: "B", 3: "B", 4: "C", ...}

# Then run the same pipeline from Step 3 onwards
```

litfit works with any dense vectors: text (e5, bge, MiniLM, OpenAI), vision (SigLIP, CLIP, DINOv2), multimodal, or your own custom model.