In [None]:
import time
import statistics
import torch
from sentence_transformers import SentenceTransformer


In [None]:
MODELS = {
    "minilm": "sentence-transformers/all-MiniLM-L6-v2",
    "e5_small": "intfloat/e5-small-v2",
    "bge_small": "BAAI/bge-small-en-v1.5",
}


In [None]:
def prep_texts(model_key, texts, as_query=False):
    if model_key.startswith("e5"):
        prefix = "query: " if as_query else "passage: "
        return [prefix + t for t in texts]
    return texts


In [None]:
def bench(
    model_name,
    texts,
    device="cpu",
    batch_size=64,
    max_seq_length=256,
    runs=10,
    warmup=2,
):
    model = SentenceTransformer(model_name, device=device)
    model.max_seq_length = max_seq_length

    # Warmup
    for _ in range(warmup):
        _ = model.encode(
            texts[:batch_size],
            batch_size=batch_size,
            normalize_embeddings=False,
            show_progress_bar=False,
        )

    times = []
    for _ in range(runs):
        if device.startswith("cuda"):
            torch.cuda.synchronize()

        t0 = time.perf_counter()

        _ = model.encode(
            texts,
            batch_size=batch_size,
            normalize_embeddings=False,
            show_progress_bar=False,
        )

        if device.startswith("cuda"):
            torch.cuda.synchronize()

        t1 = time.perf_counter()
        times.append(t1 - t0)

    total = len(texts)
    median = statistics.median(times)
    p90 = statistics.quantiles(times, n=10)[8]

    return {
        "model": model_name,
        "device": device,
        "batch_size": batch_size,
        "max_seq_length": max_seq_length,
        "texts": total,
        "median_s": round(median, 4),
        "p90_s": round(p90, 4),
        "throughput_texts_per_s": round(total / median, 2),
        "latency_ms_per_text": round((median / total) * 1000, 3),
    }


In [None]:
base_text = "This is a sample paragraph used for embedding benchmarks. " * 60
texts = [f"{i}. {base_text}" for i in range(2000)]

len(texts), len(texts[0])


(2000, 3483)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device


'cpu'

In [None]:
results = []

for key, model_name in MODELS.items():
    prepared_texts = prep_texts(key, texts, as_query=False)
    res = bench(
        model_name,
        prepared_texts,
        device=device,
        batch_size=64,
        max_seq_length=256,
    )
    results.append(res)

results


[{'model': 'sentence-transformers/all-MiniLM-L6-v2',
  'device': 'cpu',
  'batch_size': 64,
  'max_seq_length': 256,
  'texts': 2000,
  'median_s': 23.9747,
  'p90_s': 24.7109,
  'throughput_texts_per_s': 83.42,
  'latency_ms_per_text': 11.987},
 {'model': 'intfloat/e5-small-v2',
  'device': 'cpu',
  'batch_size': 64,
  'max_seq_length': 256,
  'texts': 2000,
  'median_s': 46.1582,
  'p90_s': 47.037,
  'throughput_texts_per_s': 43.33,
  'latency_ms_per_text': 23.079},
 {'model': 'BAAI/bge-small-en-v1.5',
  'device': 'cpu',
  'batch_size': 64,
  'max_seq_length': 256,
  'texts': 2000,
  'median_s': 45.9232,
  'p90_s': 46.3682,
  'throughput_texts_per_s': 43.55,
  'latency_ms_per_text': 22.962}]

In [None]:
import pandas as pd

df = pd.DataFrame(results)
df


Unnamed: 0,model,device,batch_size,max_seq_length,texts,median_s,p90_s,throughput_texts_per_s,latency_ms_per_text
0,sentence-transformers/all-MiniLM-L6-v2,cpu,64,256,2000,23.9747,24.7109,83.42,11.987
1,intfloat/e5-small-v2,cpu,64,256,2000,46.1582,47.037,43.33,23.079
2,BAAI/bge-small-en-v1.5,cpu,64,256,2000,45.9232,46.3682,43.55,22.962


i want to use the bge small en v1.5 on my 'text' column in 'chunks' table in my scrape.db. i am using fastlite

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ ðŸ§ 

```python
from fastlite import database
```

```python
db = database("scrape.db")
db.t.chunks
```

In [None]:
from fastlite import database

In [None]:
db = database("scrape.db")
db.t.chunks

<Table chunks (does not exist yet)>

In [None]:
db.t.all()