## General batching

### Sentence Transformers

In [1]:
from sentence_transformers import SentenceTransformer
from concurrent.futures import ThreadPoolExecutor
# Load your model
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
model.encode("Hello, world!")

  from tqdm.autonotebook import tqdm, trange


array([ 0.3059689 ,  0.79072964,  0.00980721, ...,  0.0644002 ,
       -0.45898244, -0.01831866], dtype=float32)

In [2]:
def simple_encode(text: list[str]):
    return model.encode(text)

In [3]:
import timeit

def run_benchmark():
    with ThreadPoolExecutor(max_workers=128) as executor:
        list(executor.map(simple_encode, [["Hello, world!"] for _ in range(1000)]))

# Run the benchmark
execution_time = timeit.timeit(run_benchmark, number=1)
print(f"Execution time: {execution_time:.2f} seconds")

Execution time: 28.56 seconds


In [5]:
import batch
import timeit

dynamic_encode = batch.dynamically(simple_encode)

def run_benchmark():
    with ThreadPoolExecutor(max_workers=128) as executor:
        list(executor.map(dynamic_encode, [["Hello, world!"] for _ in range(1000)]))


# Run the benchmark
execution_time = timeit.timeit(run_benchmark, number=1)
print(f"Execution time: {execution_time:.2f} seconds")

Execution time: 1.85 seconds


### Ofen

In [13]:
from ofen.models import TextEncoder

text_encoder = TextEncoder("mixedbread-ai/mxbai-embed-large-v1")
text_encoder.encode("Hello, world!")



EncodingResult(embeddings=array([[[ 0.01729077,  0.04468535,  0.00055422, ...,  0.00363935,
         -0.0259378 , -0.00103522],
        [-0.00385489,  0.02841953,  0.0072808 , ..., -0.00641318,
         -0.00887358,  0.0145831 ],
        [ 0.02255955,  0.03204089,  0.01209312, ...,  0.00168586,
         -0.03093472,  0.01115614],
        [-0.01648994,  0.0239604 , -0.01059136, ..., -0.01044504,
         -0.04768613,  0.02494775],
        [ 0.00466856,  0.04424845,  0.00212991, ...,  0.01685553,
         -0.01741113,  0.01215234],
        [ 0.00623544,  0.03078754,  0.00834455, ...,  0.00311966,
         -0.0418628 , -0.02152618]]], dtype=float32), total_tokens=6)

In [14]:
def simple_encode(text: list[str]):
    return text_encoder.encode(text).embeddings

def run_benchmark():
    with ThreadPoolExecutor(max_workers=128) as executor:
        list(executor.map(simple_encode, [["Hello, world!"] for _ in range(1000)]))


# Run the benchmark
execution_time = timeit.timeit(run_benchmark, number=1)
print(f"Execution time: {execution_time:.2f} seconds")

Execution time: 23.13 seconds


In [15]:
import batch
import timeit

dynamic_encode = batch.dynamically(simple_encode)

def run_benchmark():
    with ThreadPoolExecutor(max_workers=128) as executor:
        list(executor.map(dynamic_encode, [["Hello, world!"] for _ in range(1000)]))

# Run the benchmark
execution_time = timeit.timeit(run_benchmark, number=1)
print(f"Execution time: {execution_time:.2f} seconds")

Execution time: 1.82 seconds


## Inference batching

In [16]:
from ofen.models import TextEncoder

text_encoder = TextEncoder("mixedbread-ai/mxbai-embed-large-v1")
text_encoder.encode("Hello, world!")

EncodingResult(embeddings=array([[[ 0.01729077,  0.04468535,  0.00055422, ...,  0.00363935,
         -0.0259378 , -0.00103522],
        [-0.00385489,  0.02841953,  0.0072808 , ..., -0.00641318,
         -0.00887358,  0.0145831 ],
        [ 0.02255955,  0.03204089,  0.01209312, ...,  0.00168586,
         -0.03093472,  0.01115614],
        [-0.01648994,  0.0239604 , -0.01059136, ..., -0.01044504,
         -0.04768613,  0.02494775],
        [ 0.00466856,  0.04424845,  0.00212991, ...,  0.01685553,
         -0.01741113,  0.01215234],
        [ 0.00623544,  0.03078754,  0.00834455, ...,  0.00311966,
         -0.0418628 , -0.02152618]]], dtype=float32), total_tokens=6)

In [17]:
from batch import inference
from concurrent.futures import ThreadPoolExecutor
import timeit

simple_forward = text_encoder.forward

dynamic_forward = inference.dynamically(lambda features: simple_forward(**features)["embeddings"])
text_encoder.forward = lambda **features: {"embeddings": dynamic_forward(features)}

def run_benchmark():
    with ThreadPoolExecutor(max_workers=128) as executor:
        list(executor.map(text_encoder.encode, [["Hello, world!"] for _ in range(1000)]))

# Run the benchmark
execution_time = timeit.timeit(run_benchmark, number=1)
print(f"Execution time: {execution_time:.2f} seconds")


Execution time: 2.18 seconds
