# TPU vs GPU Inference Benchmark

**Purpose**: Compare inference cost/performance across hardware for investment thesis.

**Hardware Options** (Colab Dec 2024):
- **TPU**: v6e-1, v5e-1
- **GPU**: L4, T4, A100 (high VRAM switch available)

**Models Tested**:
- Embedding: Qwen3-Embedding-4B
- Reranking: Qwen3-Reranker-4B
- Inference: Mistral-7B, Llama-3-8B

In [None]:
# Setup - run this first
!pip install -q torch transformers accelerate sentencepiece
!pip install -q huggingface_hub

import torch
import time
import json
from datetime import datetime

# Detect hardware
if torch.cuda.is_available():
    device = "cuda"
    hw_name = torch.cuda.get_device_name(0)
    hw_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {hw_name} ({hw_mem:.1f}GB)")
elif 'TPU_NAME' in os.environ:
    device = "xla"
    hw_name = os.environ.get('TPU_NAME', 'TPU')
    hw_mem = 'N/A'
    print(f"TPU: {hw_name}")
else:
    device = "cpu"
    hw_name = "CPU"
    hw_mem = 'N/A'
    print("Warning: No GPU/TPU detected")

In [None]:
# Benchmark: Embedding throughput
from transformers import AutoTokenizer, AutoModel

MODEL_ID = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"  # Smaller for free tier

print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).to(device)
model.eval()

# Test data
texts = [
    "What is the capital of France?",
    "How do neural networks learn?",
    "Explain quantum computing in simple terms.",
] * 100  # 300 texts

# Warmup
with torch.no_grad():
    inputs = tokenizer(texts[:10], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    _ = model(**inputs)

# Benchmark
batch_size = 32
start = time.time()
total_tokens = 0

with torch.no_grad():
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        outputs = model(**inputs)
        total_tokens += inputs['input_ids'].numel()

elapsed = time.time() - start
tokens_per_sec = total_tokens / elapsed

result = {
    "timestamp": datetime.now().isoformat(),
    "hardware": hw_name,
    "hardware_mem_gb": hw_mem,
    "model": MODEL_ID,
    "task": "embedding",
    "texts_processed": len(texts),
    "total_tokens": total_tokens,
    "elapsed_seconds": round(elapsed, 2),
    "tokens_per_second": round(tokens_per_sec, 1),
    "texts_per_second": round(len(texts) / elapsed, 1),
}

print(json.dumps(result, indent=2))

In [None]:
# Save results for aggregation
import os

results_dir = "/content/benchmark_results"
os.makedirs(results_dir, exist_ok=True)

filename = f"{results_dir}/{hw_name.replace(' ', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(filename, 'w') as f:
    json.dump(result, f, indent=2)

print(f"Saved to {filename}")
print("\nTo download: Files -> benchmark_results/")

## Investment Thesis Data Points

After running on different hardware tiers, compare:

| Hardware | Type | VRAM | Cost/hr | Tokens/sec | Cost per 1M tokens |
|----------|------|------|---------|------------|--------------------|
| T4 | GPU | 16GB | $0 (free) | ? | ? |
| L4 | GPU | 24GB | ? | ? | ? |
| A100 | GPU | 40/80GB | ? | ? | ? |
| v5e-1 | TPU | - | ? | ? | ? |
| v6e-1 | TPU | - | ? | ? | ? |

**Key questions**:
1. At what scale does TPU beat GPU?
2. What's the break-even vs API pricing (OpenRouter ~$0.10/1M tokens)?
3. Which workloads favor which hardware?
4. L4 vs T4 - worth the upgrade?

## Cross-Hardware Matrix (The Complete Picture)

Current notebook measures **same-hardware** inference. But real investment thesis needs:

|  | **Serve GPU** | **Serve TPU** |
|--|---------------|---------------|
| **Train GPU** | Baseline (PyTorch) | Export to JAX |
| **Train TPU** | Export to ONNX | Native JAX/Flax |

**Next notebooks:**
- `cross_hardware_serving.ipynb` - conversion overhead
- `training_cost.ipynb` - train time comparison  
- `total_cost_calculator.ipynb` - optimal split for N requests