# BERT Reranker Inference on AWS Trainium2

## Introduction

This notebook demonstrates how to compile and run the [Alibaba-NLP/gte-multilingual-reranker-base](https://huggingface.co/Alibaba-NLP/gte-multilingual-reranker-base) BERT reranker model for accelerated inference on AWS Trainium2.

This notebook was tested on a **trn2.3xlarge** instance using **LNC=1** (Logical NeuronCore=1),
which provides 8 logical cores from 8 physical NeuronCores.

## Setup

This notebook requires a Trainium2 instance with the following Neuron SDK packages:

- `torch-neuronx`
- `neuronx-cc`
- `transformers` (4.48 - 4.53)

For a step-by-step guide on launching an instance, see [Getting Started with Inferentia or Trainium](https://repost.aws/articles/ARgiH8VXXuQ22iSUmwX7ffiQ/getting-started-with-inferentia-or-trainium).

**Important**: The version of `transformers` affects compilation performance. Versions 4.48 through 4.53
produce an optimal TorchScript graph for `torch_neuronx.trace()`. Other versions (including the 4.56
shipped with the Neuron DLAMI) produce a ~20% slower compiled model. Pin the version before compiling:

```bash
pip install "transformers>=4.48,<=4.53.3"
```

**LNC Configuration**: This notebook uses LNC=1. Before running, set the runtime environment variable:

```bash
export NEURON_LOGICAL_NC_CONFIG=1
```

If you are using VS Code with Remote SSH on the AWS Neuron DLAMI, you can make the pre-installed environment available as a kernel:

```bash
ln -s /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 ~/.venv
```

In [1]:
import os
os.environ["NEURON_LOGICAL_NC_CONFIG"] = "1"

import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForSequenceClassification

print(f"PyTorch version: {torch.__version__}")
print(f"torch-neuronx version: {torch_neuronx.__version__}")
print(f"NEURON_LOGICAL_NC_CONFIG: {os.environ.get('NEURON_LOGICAL_NC_CONFIG')}")

PyTorch version: 2.9.0+cu128
torch-neuronx version: 2.9.0.2.11.19912+e48cd891
NEURON_LOGICAL_NC_CONFIG: 1


## Configuration

Set the model parameters. For best performance on Trainium2, we recommend:

- `--target trn2`: Targets the Trainium2 hardware
- `--lnc 1`: Uses LNC=1, giving 8 logical cores on trn2.3xlarge (1 physical core per logical core)
- `--auto-cast matmult`: Enables BF16 for matrix multiplications (~60% faster, negligible accuracy impact)
- `--optlevel 2`: Standard compiler optimizations

### Why LNC=1?

Trainium2 supports two LNC modes:
- **LNC=1**: Each physical NeuronCore is one logical core (8 logical cores on trn2.3xlarge)
- **LNC=2**: Two physical NeuronCores form one logical core (4 logical cores on trn2.3xlarge)

For this model (~306M parameters), LNC=1 is significantly more efficient. LNC=2 doubles the
hardware per core without a proportional performance gain, resulting in lower per-physical-core
efficiency.

**DataParallel**: With LNC=1 on trn2.3xlarge, `torch_neuronx.DataParallel` loads the compiled model
onto all 8 cores. Each core runs its own full batch=16 independently.

In [2]:
# Model configuration
MODEL_ID = "Alibaba-NLP/gte-multilingual-reranker-base"
SEQUENCE_LENGTH = 1024
BATCH_SIZE = 16
AUTOCAST = "matmult"  # Use 'matmult' for best performance, 'none' for full precision
OPTLEVEL = 2
LNC = 1

print(f"Model: {MODEL_ID}")
print(f"Sequence Length: {SEQUENCE_LENGTH}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Autocast: {AUTOCAST}")
print(f"Optlevel: {OPTLEVEL}")
print(f"LNC: {LNC}")

Model: Alibaba-NLP/gte-multilingual-reranker-base
Sequence Length: 1024
Batch Size: 16
Autocast: matmult
Optlevel: 2
LNC: 1


## Load Model and Tokenizer

In [3]:
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    torchscript=True,
    trust_remote_code=True
)
model.eval()

# Model info
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Layers: {model.config.num_hidden_layers}")

Parameters: 305,959,681
Hidden size: 768
Layers: 12


## Prepare Example Input

Create example inputs for tracing. The input shape during compilation must match inference.

In [4]:
# Create example input
queries = ["What is machine learning?"] * BATCH_SIZE
docs = ["Machine learning is a subset of artificial intelligence."] * BATCH_SIZE

encoded = tokenizer(
    queries,
    docs,
    padding="max_length",
    max_length=SEQUENCE_LENGTH,
    truncation=True,
    return_tensors="pt"
)

example_inputs = (encoded["input_ids"], encoded["attention_mask"])
print(f"Input shape: {example_inputs[0].shape}")

Input shape: torch.Size([16, 1024])


## Compile Model

Compile the model using `torch_neuronx.trace()` with Trainium2-specific settings.

The `--target trn2` and `--lnc 1` flags are required for Trainium2 compilation.
A model compiled with a given LNC value must be run with the matching
`NEURON_LOGICAL_NC_CONFIG` environment variable.

In [5]:
import time

# Build compiler arguments
compiler_args = [
    f"--optlevel={OPTLEVEL}",
    "--target", "trn2",
    "--lnc", str(LNC),
]
if AUTOCAST == "matmult":
    compiler_args.extend(["--auto-cast", "matmult"])

print(f"Compiler args: {compiler_args}")

# Compile
print("\nCompiling model (this takes ~8 minutes)...")
start = time.time()
model_neuron = torch_neuronx.trace(
    model,
    example_inputs,
    compiler_args=compiler_args
)
compile_time = time.time() - start

# Save
output_path = f"bert_reranker_trn2_lnc{LNC}_seq{SEQUENCE_LENGTH}_batch{BATCH_SIZE}.pt"
torch.jit.save(model_neuron, output_path)

import os
model_size_mb = os.path.getsize(output_path) / (1024 * 1024)
print(f"Compilation time: {compile_time:.1f}s")
print(f"Model size: {model_size_mb:.1f} MB")
print(f"Saved to: {output_path}")

Compiler args: ['--optlevel=2', '--target', 'trn2', '--lnc', '1', '--auto-cast', 'matmult']

Compiling model (this takes ~8 minutes)...


.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

Completed run_backend_driver.



Compiler status PASS


Compilation time: 434.2s
Model size: 717.7 MB
Saved to: bert_reranker_trn2_lnc1_seq1024_batch16.pt


## Verify Compilation

Run a quick test to verify the compiled model produces valid output.
The model is saved to disk so it can be loaded in future sessions without recompiling.

In [6]:
# Test inference with the compiled model
with torch.no_grad():
    outputs = model_neuron(*example_inputs)

# Get score
if isinstance(outputs, tuple):
    logits = outputs[0]
else:
    logits = outputs

score = torch.sigmoid(logits[0]).item()
print(f"Sample score: {score:.4f}")
print("Model compiled and running successfully on Trainium2")

Sample score: 0.9479
Model compiled and running successfully on Trainium2


## Benchmark Single-Core Performance

In [7]:
import numpy as np

WARMUP_ITERATIONS = 10
BENCHMARK_ITERATIONS = 50

print("Warming up...")
for _ in range(WARMUP_ITERATIONS):
    with torch.no_grad():
        _ = model_neuron(*example_inputs)

print(f"Benchmarking ({BENCHMARK_ITERATIONS} iterations)...")
latencies = []
for _ in range(BENCHMARK_ITERATIONS):
    start = time.perf_counter()
    with torch.no_grad():
        _ = model_neuron(*example_inputs)
    end = time.perf_counter()
    latencies.append((end - start) * 1000)

latencies = np.array(latencies)
p50 = np.percentile(latencies, 50)
p90 = np.percentile(latencies, 90)
p99 = np.percentile(latencies, 99)
throughput = (BATCH_SIZE * 1000) / np.mean(latencies)

print("\n=== Single-Core Results ===")
print(f"Latency p50: {p50:.2f} ms")
print(f"Latency p90: {p90:.2f} ms")
print(f"Latency p99: {p99:.2f} ms")
print(f"Throughput: {throughput:.2f} queries/second")

Warming up...


Benchmarking (50 iterations)...



=== Single-Core Results ===
Latency p50: 191.28 ms
Latency p90: 191.34 ms
Latency p99: 191.38 ms
Throughput: 83.64 queries/second


## Benchmark with DataParallel

`torch_neuronx.DataParallel` loads the same compiled model onto each available logical NeuronCore
and runs them in parallel. Each core processes its own full batch independently, multiplying
total throughput by the number of cores.

On trn2.3xlarge with LNC=1 (8 logical cores), each core gets its own batch=16.

**Important**: When consecutive device IDs are used (the default), `DataParallel` loads all cores
via a single batch call and sets `num_workers=2`. This limits the thread pool to 2 concurrent
dispatches, serializing execution across the remaining cores. Set `model_dp.num_workers` to at
least the number of cores for full parallel throughput.

In [8]:
import json
import subprocess

def get_neuron_core_count():
    """Detect the number of Neuron logical cores using neuron-ls."""
    try:
        result = subprocess.run(
            ["neuron-ls", "--json-output"],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode == 0:
            devices = json.loads(result.stdout)
            return sum(d["nc_count"] for d in devices)
    except Exception:
        pass
    return 1

num_cores = get_neuron_core_count()
print(f"Neuron logical cores detected: {num_cores}")

if num_cores > 1:
    # Load the same compiled model onto all cores
    model_dp = torch_neuronx.DataParallel(model_neuron)
    
    # Fix: default num_workers=2 serializes across cores; set to num_cores for full parallelism
    model_dp.num_workers = num_cores
    print(f"DataParallel num_workers set to: {model_dp.num_workers}")
    
    # Each core gets a full batch, so total input is BATCH_SIZE * num_cores
    dp_total_batch = BATCH_SIZE * num_cores
    dp_queries = ["What is machine learning?"] * dp_total_batch
    dp_docs = ["Machine learning is a subset of artificial intelligence."] * dp_total_batch
    dp_encoded = tokenizer(
        dp_queries, dp_docs,
        padding="max_length", max_length=SEQUENCE_LENGTH,
        truncation=True, return_tensors="pt"
    )
    dp_inputs = (dp_encoded["input_ids"], dp_encoded["attention_mask"])
    print(f"DataParallel input shape: {dp_inputs[0].shape}  ({BATCH_SIZE} per core x {num_cores} cores)")
    
    print("\nWarming up DataParallel...")
    for _ in range(WARMUP_ITERATIONS):
        with torch.no_grad():
            _ = model_dp(*dp_inputs)
    
    print(f"Benchmarking DataParallel ({BENCHMARK_ITERATIONS} iterations)...")
    latencies_dp = []
    for _ in range(BENCHMARK_ITERATIONS):
        start = time.perf_counter()
        with torch.no_grad():
            _ = model_dp(*dp_inputs)
        end = time.perf_counter()
        latencies_dp.append((end - start) * 1000)
    
    latencies_dp = np.array(latencies_dp)
    p50_dp = np.percentile(latencies_dp, 50)
    p90_dp = np.percentile(latencies_dp, 90)
    p99_dp = np.percentile(latencies_dp, 99)
    throughput_dp = (dp_total_batch * 1000) / np.mean(latencies_dp)
    
    print(f"\n=== DataParallel Results ({num_cores} cores) ===")
    print(f"Latency p50: {p50_dp:.2f} ms")
    print(f"Latency p90: {p90_dp:.2f} ms")
    print(f"Latency p99: {p99_dp:.2f} ms")
    print(f"Throughput: {throughput_dp:.2f} queries/second")
    print(f"\nSpeedup vs single-core: {throughput_dp/throughput:.2f}x")
else:
    print("Only 1 Neuron core available - skipping DataParallel benchmark.")
    print("Use an instance with multiple cores for DataParallel.")

Neuron logical cores detected: 8


DataParallel num_workers set to: 8
DataParallel input shape: torch.Size([128, 1024])  (16 per core x 8 cores)

Warming up DataParallel...


Benchmarking DataParallel (50 iterations)...



=== DataParallel Results (8 cores) ===
Latency p50: 201.98 ms
Latency p90: 202.22 ms
Latency p99: 202.32 ms
Throughput: 633.83 queries/second

Speedup vs single-core: 7.58x


## Results Summary

Measured on **trn2.3xlarge** with LNC=1 (torch-neuronx 2.9.0, neuronx-cc 2.22, transformers 4.53.3):

| Configuration | Batch Size | Seq Length | Latency (p50) | Throughput |
|--------------|------------|------------|---------------|------------|
| Single-Core | 16 | 1024 | 191.2 ms | 83.7 qps |
| DataParallel (8 cores) | 16 per core | 1024 | 202.2 ms | 633.4 qps |

**Compiler flags explained**:
- `--target trn2`: Generates code optimized for Trainium2 hardware
- `--lnc 1`: Maps one physical NeuronCore to one logical core (8 cores on trn2.3xlarge)
- `--auto-cast matmult`: Casts matrix multiplications to BF16, providing ~60% speedup with negligible accuracy impact (<0.0004)
- `--optlevel 2`: Enables standard compiler optimizations (optlevel 3 provides no additional benefit for this model)

**Key Findings**:
- LNC=1 is significantly more efficient than LNC=2 for this model size (~306M parameters)
- DataParallel provides 7.6x throughput scaling across 8 cores (requires setting `num_workers`)
- Default `DataParallel` `num_workers=2` limits scaling to ~2x; set `num_workers >= num_cores`
- Model size with autocast=matmult: ~718 MB
- The `transformers` version used at compile time significantly affects performance (see Setup section)