In [13]:
import torch
from transformers import AutoTokenizer, AutoModel
from batch import inference
import time

# Load pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define the batch inference function
@inference.dynamically(
    batch_size=16, 
    timeout_ms=10.0,
)
def forward(features):
    with torch.no_grad():
        outputs = model(**features)
    return outputs.last_hidden_state

def batch_encode(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = forward(**inputs)
    
    # Return the last hidden states
    return outputs.cpu().numpy()

# Example usage
texts = [
    "Hello, how are you?",
    "The weather is nice today.",
    "I love programming in Python!",
    "Machine learning is fascinating.",
] * 25  # Repeat the list 25 times to create 100 items

# Measure time for batched inference
start_time = time.time()
batched_results = batch_encode(texts)
end_time = time.time()

print(f"Batched inference time: {end_time - start_time:.2f} seconds")
print(f"Number of processed items: {len(texts)}")
print(f"Shape of the first result: {batched_results[0].shape}")

# Print batch processing statistics
print("\nBatch processing statistics:")
print(forward.stats)


Batched inference time: 0.39 seconds
Number of processed items: 100
Shape of the first result: (8, 768)

Batch processing statistics:
BatchProcessorStats(queue_size=0, total_processed=100, total_batches=7, avg_batch_size=14.285714285714286, avg_processing_time=0.05156169618879046)
