# BERT Reranker Inference on AWS Neuron

## Introduction

This notebook demonstrates how to compile and run the [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) BERT reranker model for accelerated inference on AWS Neuron (Inferentia2/Trainium).

This notebook was tested on an **inf2.8xlarge** instance.

## Setup

This notebook requires an Inferentia2 or Trainium instance with the following Neuron SDK packages:

- `torch-neuronx`
- `neuronx-cc`
- `transformers` (4.48 - 4.53)

For a step-by-step guide on launching an instance, see [Getting Started with Inferentia or Trainium](https://repost.aws/articles/ARgiH8VXXuQ22iSUmwX7ffiQ/getting-started-with-inferentia-or-trainium).

**Important**: The version of `transformers` affects compilation performance. Versions 4.48 through 4.53
produce an optimal TorchScript graph for `torch_neuronx.trace()`. Other versions (including the 4.56
shipped with the Neuron DLAMI) produce a ~20% slower compiled model. Pin the version before compiling:

```bash
pip install "transformers>=4.48,<=4.53.3"
```

If you are using VS Code with Remote SSH on the AWS Neuron DLAMI, you can make the pre-installed environment available as a kernel:

```bash
ln -s /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 ~/.venv
```

In [None]:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForSequenceClassification

print(f"PyTorch version: {torch.__version__}")
print(f"torch-neuronx version: {torch_neuronx.__version__}")

## Configuration

Set the model parameters. For best performance, we recommend:

- `--autocast matmult`: Enables BF16 for matrix multiplications (2x faster, minimal accuracy impact)
- `--optlevel 2`: Standard compiler optimizations

**DataParallel**: If using an instance with multiple Neuron cores (e.g., inf2.8xlarge has 2 cores),
you can use `torch_neuronx.DataParallel` to split batches across cores for higher throughput.

In [None]:
# Model configuration
MODEL_ID = "Alibaba-NLP/gte-multilingual-reranker-base"
SEQUENCE_LENGTH = 1024
BATCH_SIZE = 16
AUTOCAST = "matmult"  # Use 'matmult' for best performance, 'none' for full precision
OPTLEVEL = 2

print(f"Model: {MODEL_ID}")
print(f"Sequence Length: {SEQUENCE_LENGTH}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Autocast: {AUTOCAST}")
print(f"Optlevel: {OPTLEVEL}")

## Load Model and Tokenizer

In [None]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    torchscript=True,
    trust_remote_code=True
)
model.eval()

# Model info
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Layers: {model.config.num_hidden_layers}")

## Prepare Example Input

Create example inputs for tracing. The input shape during compilation must match inference.

In [None]:
# Create example input
queries = ["What is machine learning?"] * BATCH_SIZE
docs = ["Machine learning is a subset of artificial intelligence."] * BATCH_SIZE

encoded = tokenizer(
    queries,
    docs,
    padding="max_length",
    max_length=SEQUENCE_LENGTH,
    truncation=True,
    return_tensors="pt"
)

example_inputs = (encoded["input_ids"], encoded["attention_mask"])
print(f"Input shape: {example_inputs[0].shape}")

## Compile Model

Compile the model using `torch_neuronx.trace()` with optimized settings.

In [None]:
import time

# Build compiler arguments
compiler_args = [f"--optlevel={OPTLEVEL}"]
if AUTOCAST == "matmult":
    compiler_args.extend(["--auto-cast", "matmult"])

print(f"Compiler args: {compiler_args}")

# Compile
print("\nCompiling model...")
start = time.time()
model_neuron = torch_neuronx.trace(
    model,
    example_inputs,
    compiler_args=compiler_args
)
compile_time = time.time() - start

# Save
output_path = f"bert_reranker_seq{SEQUENCE_LENGTH}_batch{BATCH_SIZE}.pt"
torch.jit.save(model_neuron, output_path)

print(f"Compilation time: {compile_time:.1f}s")
print(f"Saved to: {output_path}")

## Verify Compilation

Run a quick test to verify the compiled model produces valid output.
The model is saved to disk so it can be loaded in future sessions without recompiling.

In [None]:
# Test inference with the compiled model
with torch.no_grad():
    outputs = model_neuron(*example_inputs)

# Get score
if isinstance(outputs, tuple):
    logits = outputs[0]
else:
    logits = outputs

score = torch.sigmoid(logits[0]).item()
print(f"Sample score: {score:.4f}")
print("Model compiled and running successfully")

## Benchmark Single-Core Performance

In [None]:
import numpy as np

WARMUP_ITERATIONS = 10
BENCHMARK_ITERATIONS = 50

print("Warming up...")
for _ in range(WARMUP_ITERATIONS):
    with torch.no_grad():
        _ = model_neuron(*example_inputs)

print(f"Benchmarking ({BENCHMARK_ITERATIONS} iterations)...")
latencies = []
for _ in range(BENCHMARK_ITERATIONS):
    start = time.perf_counter()
    with torch.no_grad():
        _ = model_neuron(*example_inputs)
    end = time.perf_counter()
    latencies.append((end - start) * 1000)

latencies = np.array(latencies)
p50 = np.percentile(latencies, 50)
p90 = np.percentile(latencies, 90)
p99 = np.percentile(latencies, 99)
throughput = (BATCH_SIZE * 1000) / np.mean(latencies)

print("\n=== Single-Core Results ===")
print(f"Latency p50: {p50:.2f} ms")
print(f"Latency p90: {p90:.2f} ms")
print(f"Latency p99: {p99:.2f} ms")
print(f"Throughput: {throughput:.2f} queries/second")

## Benchmark with DataParallel

`torch_neuronx.DataParallel` loads the same compiled model onto each available Neuron core
and runs them in parallel. Each core processes its own full batch independently, multiplying
total throughput by the number of cores.

On inf2.8xlarge (2 cores), this gives each core its own batch=16, doubling throughput.

In [None]:
import json
import subprocess

def get_neuron_core_count():
    """Detect the number of Neuron cores using neuron-ls."""
    try:
        result = subprocess.run(
            ["neuron-ls", "--json-output"],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode == 0:
            devices = json.loads(result.stdout)
            return sum(d["nc_count"] for d in devices)
    except Exception:
        pass
    return 1

num_cores = get_neuron_core_count()
print(f"Neuron cores detected: {num_cores}")

if num_cores > 1:
    # Load the same compiled model onto all cores
    model_dp = torch_neuronx.DataParallel(model_neuron)
    
    # Each core gets a full batch, so total input is BATCH_SIZE * num_cores
    dp_total_batch = BATCH_SIZE * num_cores
    dp_queries = ["What is machine learning?"] * dp_total_batch
    dp_docs = ["Machine learning is a subset of artificial intelligence."] * dp_total_batch
    dp_encoded = tokenizer(
        dp_queries, dp_docs,
        padding="max_length", max_length=SEQUENCE_LENGTH,
        truncation=True, return_tensors="pt"
    )
    dp_inputs = (dp_encoded["input_ids"], dp_encoded["attention_mask"])
    print(f"DataParallel input shape: {dp_inputs[0].shape}  ({BATCH_SIZE} per core x {num_cores} cores)")
    
    print("\nWarming up DataParallel...")
    for _ in range(WARMUP_ITERATIONS):
        with torch.no_grad():
            _ = model_dp(*dp_inputs)
    
    print(f"Benchmarking DataParallel ({BENCHMARK_ITERATIONS} iterations)...")
    latencies_dp = []
    for _ in range(BENCHMARK_ITERATIONS):
        start = time.perf_counter()
        with torch.no_grad():
            _ = model_dp(*dp_inputs)
        end = time.perf_counter()
        latencies_dp.append((end - start) * 1000)
    
    latencies_dp = np.array(latencies_dp)
    p50_dp = np.percentile(latencies_dp, 50)
    p90_dp = np.percentile(latencies_dp, 90)
    p99_dp = np.percentile(latencies_dp, 99)
    throughput_dp = (dp_total_batch * 1000) / np.mean(latencies_dp)
    
    print(f"\n=== DataParallel Results ({num_cores} cores) ===")
    print(f"Latency p50: {p50_dp:.2f} ms")
    print(f"Latency p90: {p90_dp:.2f} ms")
    print(f"Latency p99: {p99_dp:.2f} ms")
    print(f"Throughput: {throughput_dp:.2f} queries/second")
    print(f"\nSpeedup vs single-core: {throughput_dp/throughput:.2f}x")
else:
    print("Only 1 Neuron core available - skipping DataParallel benchmark.")
    print("Use an instance with multiple cores (e.g., inf2.8xlarge) for DataParallel.")

## Results Summary

Measured on **inf2.8xlarge** (torch-neuronx 2.9.0, neuronx-cc 2.22, transformers 4.53.0, sa-east-1):

| Configuration | Batch Size | Seq Length | Latency (p50) | Throughput |
|--------------|------------|------------|---------------|------------|
| Single-Core | 16 | 1024 | 214.51 ms | 74.59 qps |
| DataParallel (2 cores) | 16 per core | 1024 | 228.15 ms | 140.25 qps |

**Key Findings**:
- `--autocast matmult` provides 57-120% speedup vs `none` with negligible accuracy impact
- DataParallel provides ~1.88x throughput by running each core independently
- Model size with autocast=matmult: ~720 MB (vs ~830 MB without)
- The `transformers` version used at compile time significantly affects performance (see Setup section)