# Remora Speedup Benchmark

This notebook demonstrates remora's Triton-accelerated VLM inference, comparing:

1. **Stock PyTorch** - baseline nn.Linear layers
2. **W8A16 Surgery** - int8 weights with fp16 activations (2x memory bandwidth savings)
3. **Full VLM Surgery** - W8A16 + fused vision projector

We also showcase **JaggedTensor** support for variable-length sequences without padding.

## Setup

Uncomment to install dependencies if needed.

In [None]:
# Optional: install remora with VLMEval extras.
# !pip install -e ".[eval]"
# Or minimal install without extras:
# !pip install -e .


## Load Model (Stock PyTorch Baseline)

First, we load SmolVLM-Base without any optimizations to establish a baseline.

In [None]:
import time
import copy
import torch

from remora.models import load_model_and_tokenizer
from remora import is_triton_available

# Configuration
PRESET = "smolvlm-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WARMUP_RUNS = 2
BENCHMARK_RUNS = 5
MAX_NEW_TOKENS = 32

print(f"Device: {DEVICE}")
print(f"Triton available: {is_triton_available()}")
print(f"\nLoading {PRESET}...")

# Load the model (we'll clone it for different experiments)
model, tokenizer = load_model_and_tokenizer(PRESET, device=DEVICE)
print(f"Model loaded: {model.__class__.__name__}")


## Benchmark Helper

Define a function to measure generation time and tokens per second.

In [None]:
def benchmark_generation(model, tokenizer, prompt, num_runs=5, warmup=2, max_new_tokens=32):
    """Benchmark model generation, returning avg time and tokens/sec."""
    # Prepare inputs
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Warmup runs
    for _ in range(warmup):
        with torch.inference_mode():
            _ = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
    
    # Synchronize before timing
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    
    # Timed runs
    times = []
    total_tokens = 0
    for _ in range(num_runs):
        start = time.perf_counter()
        with torch.inference_mode():
            output = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
        if DEVICE == "cuda":
            torch.cuda.synchronize()
        elapsed = time.perf_counter() - start
        times.append(elapsed)
        total_tokens += output.shape[-1] - inputs["input_ids"].shape[-1]
    
    avg_time = sum(times) / len(times)
    avg_tokens = total_tokens / len(times)
    tps = avg_tokens / avg_time
    
    return {
        "avg_time_ms": avg_time * 1000,
        "tokens_per_sec": tps,
        "avg_new_tokens": avg_tokens,
    }

# Test prompt
PROMPT = "Explain the benefits of quantized inference in three sentences."
print(f"Prompt: {PROMPT[:50]}...")


## Benchmark 1: Stock PyTorch (Baseline)

Run generation with the unmodified model.

In [None]:
print("Running stock PyTorch baseline...")
stock_results = benchmark_generation(
    model, tokenizer, PROMPT, 
    num_runs=BENCHMARK_RUNS, warmup=WARMUP_RUNS, max_new_tokens=MAX_NEW_TOKENS
)
print(f"  Avg time: {stock_results['avg_time_ms']:.1f} ms")
print(f"  Tokens/sec: {stock_results['tokens_per_sec']:.1f}")
print(f"  New tokens: {stock_results['avg_new_tokens']:.0f}")


## Benchmark 2: Remora W8A16 Surgery

Apply W8A16 quantization (int8 weights, fp16 activations) to all Linear layers.

In [None]:
from remora import hijack_model, is_triton_available
from remora.surgery import TritonBitLinear

print("Applying W8A16 surgery to all Linear layers...")
num_replaced = hijack_model(model, verbose=False)

# Diagnostic: Check a TritonBitLinear layer
triton_layers = [m for m in model.modules() if isinstance(m, TritonBitLinear)]
if triton_layers:
    layer = triton_layers[0]
    print(f"  Found {len(triton_layers)} TritonBitLinear layers")
    print(f"  Sample layer weight device: {layer.weight.device}")
    print(f"  Sample layer weight_int8 device: {layer.weight_int8.device}")
    print(f"  Triton available: {is_triton_available()}")
    print(f"  CUDA available: {torch.cuda.is_available()}")

print("\nRunning W8A16 benchmark...")
w8a16_results = benchmark_generation(
    model, tokenizer, PROMPT,
    num_runs=BENCHMARK_RUNS, warmup=WARMUP_RUNS, max_new_tokens=MAX_NEW_TOKENS
)
print(f"  Avg time: {w8a16_results['avg_time_ms']:.1f} ms")
print(f"  Tokens/sec: {w8a16_results['tokens_per_sec']:.1f}")

speedup = stock_results['avg_time_ms'] / w8a16_results['avg_time_ms']
print(f"\n  ‚ö° Speedup vs stock: {speedup:.2f}x")


## Results Summary

Compare all benchmarks side by side.


In [None]:
print("=" * 50)
print("BENCHMARK SUMMARY")
print("=" * 50)
print(f"Model: {PRESET}")
print(f"Device: {DEVICE}")
print(f"Max new tokens: {MAX_NEW_TOKENS}")
print()

results = [
    ("Stock PyTorch", stock_results),
    ("W8A16 Surgery", w8a16_results),
]

print(f"{'Method':<20} {'Time (ms)':<12} {'Tokens/s':<12} {'Speedup':<10}")
print("-" * 54)

baseline_time = stock_results['avg_time_ms']
for name, r in results:
    speedup = baseline_time / r['avg_time_ms']
    speedup_str = f"{speedup:.2f}x" if name != "Stock PyTorch" else "baseline"
    print(f"{name:<20} {r['avg_time_ms']:<12.1f} {r['tokens_per_sec']:<12.1f} {speedup_str:<10}")


## Bonus: JaggedTensor Demo

Remora supports variable-length sequences without padding using `JaggedTensor`. This is useful for mixed image+text batches where padding wastes compute.


In [None]:
from remora import JaggedTensor, pack_sequences, unpack_sequences, pad_jagged

# Simulate 3 sequences of different lengths (like image + varying text)
seq1 = torch.randn(576, 64, device=DEVICE)   # 576 image tokens
seq2 = torch.randn(128, 64, device=DEVICE)   # 128 text tokens  
seq3 = torch.randn(256, 64, device=DEVICE)   # 256 text tokens

# Pack into JaggedTensor - NO PADDING!
jagged = pack_sequences([seq1, seq2, seq3])

print("JaggedTensor properties:")
print(f"  Total tokens: {jagged.total_tokens} (vs {576 * 3}={576*3} if padded)")
print(f"  cu_seqlens: {jagged.cu_seqlens.tolist()}")
print(f"  Batch size: {jagged.batch_size}")
print(f"  Memory saved: {((576*3 - jagged.total_tokens) * 64 * 2) / 1024:.1f} KB")

# Can apply W8A16 operations directly on jagged data
from remora import w8a16_gemm, quantize_weight_per_channel

# Create a test projection
W = torch.randn(128, 64, device=DEVICE, dtype=torch.float16)
w_int8, scales = quantize_weight_per_channel(W)

# Apply to all tokens at once - no padding needed!
out = w8a16_gemm(jagged.data, w_int8, scales)
print(f"\nProjected jagged data: {jagged.data.shape} -> {out.shape}")


## What's Next?

Remora provides several optimizations beyond `torch.compile`:

| Feature | Description | torch.compile? |
|---------|-------------|----------------|
| **W8A16 GEMM** | int8 weights, fp16 activations | ‚ùå Struggles with int8 |
| **JaggedTensor** | No padding for variable sequences | ‚ùå Recompiles on shape change |
| **Fused GELU+Linear** | Activation fused with matmul | ‚ö†Ô∏è Unreliable |
| **Vision Projector** | Fused 2-layer MLP for VLMs | ‚ùå Breaks on shape transitions |

For full VLM optimization including vision projector fusion:
```python
from remora import full_vlm_surgery
full_vlm_surgery(model)  # Replaces projector + all Linear layers
```


In [None]:
# Cleanup
del model, tokenizer
if DEVICE == "cuda":
    torch.cuda.empty_cache()
print("Done! üéâ")
