In [None]:
import torch
import time
from transformers import AutoTokenizer, AutoModel

# Konfiguracja
MODELS = {
    "ModernBERT-base": "answerdotai/ModernBERT-base", # ~149M params
    "BGE-v1.5-small": "BAAI/bge-small-en-v1.5",       # ~33M params
}

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
BATCH_SIZE = 1 # Kluczowe dla RL (latency)
SEQ_LEN = 128  # Przykładowa długość tekstu stanu/akcji
ITERATIONS = 1

def benchmark():
    print(f"Benchmarking on {DEVICE} ({DTYPE})...\n")
    
    for name, model_id in MODELS.items():
        # Ładowanie
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModel.from_pretrained(model_id, torch_dtype=DTYPE).to(DEVICE).eval()
        print(f"Loaded {name} model.")
        
        # Przygotowanie danych (dummy input)
        dummy_text = ["This is a sample sentence for RL state representation."] * BATCH_SIZE
        inputs = tokenizer(dummy_text, return_tensors="pt", 
                           padding="max_length", truncation=True, 
                           max_length=SEQ_LEN).to(DEVICE)
        
        # Warmup
        with torch.no_grad():
            for _ in range(1):
                _ = model(**inputs)
        
        # Pomiar czasu
        torch.cuda.synchronize() if DEVICE == "cuda" else None
        start_time = time.perf_counter()
        
        with torch.no_grad():
            for _ in range(ITERATIONS):
                # Ekstrakcja embeddingu (pobieramy last_hidden_state i robimy mean pooling)
                outputs = model(**inputs)
                embeddings = outputs.last_hidden_state.mean(dim=1)
                
        torch.cuda.synchronize() if DEVICE == "cuda" else None
        end_time = time.perf_counter()
        
        total_time = end_time - start_time
        avg_latency = (total_time / ITERATIONS) * 1000 # ms
        
        print(f"[{name}]")
        print(f"  Avg Latency: {avg_latency:.2f} ms")
        print(f"  Params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f} M")
        print("-" * 30)

if __name__ == "__main__":
    benchmark()

Benchmarking on cpu (torch.float32)...

