# Cross-Hardware Serving Benchmark

**Question**: What's the overhead of training on one hardware and serving on another?

## The Matrix

|  | **Serve GPU** | **Serve TPU** |
|--|---------------|---------------|
| **Train GPU** | Baseline | PyTorch→JAX |
| **Train TPU** | JAX→ONNX→TensorRT | Native |

## Conversion Paths

1. **PyTorch → JAX** (GPU train → TPU serve)
   - `torch` → `jax.numpy` weight conversion
   - Or: `torch` → ONNX → `jax_onnx`
   - Friction: High (architecture differences)

2. **JAX → ONNX → TensorRT** (TPU train → GPU serve)
   - `jax2onnx` or `flax` export
   - ONNX → TensorRT for GPU optimization
   - Friction: Medium (well-supported path)

3. **PyTorch → ONNX → TensorRT** (GPU train → GPU serve optimized)
   - Standard optimization path
   - Friction: Low (native ecosystem)

In [None]:
# Setup
!pip install -q torch transformers onnx onnxruntime-gpu
!pip install -q optimum[onnxruntime-gpu]

import torch
import time
import json
import os
from datetime import datetime

# Detect hardware
if torch.cuda.is_available():
    device = "cuda"
    hw_name = torch.cuda.get_device_name(0)
    print(f"GPU: {hw_name}")
else:
    device = "cpu"
    hw_name = "CPU"
    print("Warning: No GPU detected")

In [None]:
# Step 1: Load PyTorch model (simulating GPU-trained model)
from transformers import AutoTokenizer, AutoModel

MODEL_ID = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"

print(f"Loading {MODEL_ID} in PyTorch...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model_pt = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)

# Baseline PyTorch inference
test_text = "What is the meaning of life?"
inputs = tokenizer(test_text, return_tensors="pt")

model_pt.eval()
with torch.no_grad():
    start = time.time()
    for _ in range(10):
        _ = model_pt(**inputs)
    pt_time = (time.time() - start) / 10

print(f"PyTorch CPU baseline: {pt_time*1000:.1f}ms/inference")

In [None]:
# Step 2: Export to ONNX
from pathlib import Path

onnx_path = Path("/content/model.onnx")

print("Exporting to ONNX...")
export_start = time.time()

# Dynamic axes for variable sequence length
dynamic_axes = {
    'input_ids': {0: 'batch', 1: 'sequence'},
    'attention_mask': {0: 'batch', 1: 'sequence'},
}

torch.onnx.export(
    model_pt,
    (inputs['input_ids'], inputs['attention_mask']),
    onnx_path,
    input_names=['input_ids', 'attention_mask'],
    output_names=['last_hidden_state'],
    dynamic_axes=dynamic_axes,
    opset_version=14,
)

export_time = time.time() - export_start
onnx_size = onnx_path.stat().st_size / 1e9

print(f"ONNX export: {export_time:.1f}s, size: {onnx_size:.2f}GB")

In [None]:
# Step 3: ONNX Runtime inference (GPU)
import onnxruntime as ort

print("Loading ONNX model with GPU provider...")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
session = ort.InferenceSession(str(onnx_path), providers=providers)

# Check which provider is being used
print(f"Using provider: {session.get_providers()}")

# Prepare inputs
ort_inputs = {
    'input_ids': inputs['input_ids'].numpy(),
    'attention_mask': inputs['attention_mask'].numpy(),
}

# Warmup
for _ in range(3):
    _ = session.run(None, ort_inputs)

# Benchmark
start = time.time()
for _ in range(100):
    _ = session.run(None, ort_inputs)
onnx_time = (time.time() - start) / 100

print(f"ONNX Runtime GPU: {onnx_time*1000:.1f}ms/inference")
print(f"Speedup vs PyTorch CPU: {pt_time/onnx_time:.1f}x")

In [None]:
# Step 4: Results summary
results = {
    "timestamp": datetime.now().isoformat(),
    "model": MODEL_ID,
    "hardware": hw_name,
    "pytorch_cpu_ms": round(pt_time * 1000, 1),
    "onnx_export_seconds": round(export_time, 1),
    "onnx_size_gb": round(onnx_size, 2),
    "onnx_gpu_ms": round(onnx_time * 1000, 1),
    "speedup": round(pt_time / onnx_time, 1),
    "conversion_overhead": "one-time",
}

print("\n" + "="*50)
print("CROSS-HARDWARE SERVING RESULTS")
print("="*50)
print(json.dumps(results, indent=2))

# Investment insight
breakeven_inferences = export_time / (pt_time - onnx_time) if pt_time > onnx_time else float('inf')
print(f"\nBreak-even: {breakeven_inferences:.0f} inferences to recoup export cost")

## Investment Thesis Insights

**Conversion overhead is ONE-TIME cost.**

| Scenario | Export Time | Per-Inference Gain | Break-even |
|----------|-------------|--------------------|-----------|
| PT→ONNX→GPU | ? sec | ? ms | ? inferences |
| PT→JAX→TPU | ? sec | ? ms | ? inferences |
| JAX→ONNX→GPU | ? sec | ? ms | ? inferences |

**Key insight**: If serving >N requests, cross-hardware conversion PAYS OFF.

Fill in with your runs across T4, L4, A100, v5e, v6e.