-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Description
Generation overhead: 3.25 GPU syncs per token + PyTorch dispatch overhead
System Info
transformersversion: 5.0.0.dev0 (main branch)- Platform: Linux
- Python version: 3.12
- PyTorch version: 2.x with CUDA
- GPU: NVIDIA (tested)
Who can help?
Information
- My own modified scripts
Tasks
- My own task or dataset (give details below)
Reproduction
We benchmarked generation overhead using a tiny model (hidden_size=16, 1 layer, vocab_size=256) to isolate framework overhead from actual compute.
Benchmark Script
import time
import warnings
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
warnings.filterwarnings("ignore")
class AlwaysPassStoppingCriteria(StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
class ExtraSoftmaxLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids, scores):
return torch.log(F.softmax(scores, dim=-1) + 1e-10)
# Tiny model to minimize compute, expose overhead
config = LlamaConfig(
vocab_size=256, hidden_size=16, intermediate_size=16,
num_hidden_layers=1, num_attention_heads=1, num_key_value_heads=1,
max_position_embeddings=2048, use_cache=True,
)
for device in ["cpu", "cuda"]:
model = LlamaForCausalLM(config).to(device).eval()
input_ids = torch.tensor([[1]], device=device)
attention_mask = torch.ones_like(input_ids)
stopping_criteria = StoppingCriteriaList([AlwaysPassStoppingCriteria()])
logits_processor = LogitsProcessorList([ExtraSoftmaxLogitsProcessor()])
# Warmup
with torch.no_grad():
for _ in range(3):
model.generate(input_ids, attention_mask=attention_mask,
min_new_tokens=64, max_new_tokens=64,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
do_sample=False, pad_token_id=0)
if device == "cuda":
torch.cuda.synchronize()
# Benchmark
times = []
with torch.no_grad():
for _ in range(10):
start = time.perf_counter()
model.generate(input_ids, attention_mask=attention_mask,
min_new_tokens=64, max_new_tokens=64,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
do_sample=False, pad_token_id=0)
if device == "cuda":
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
mean_ms = sum(times) / len(times) * 1000
print(f"{device.upper()}: {mean_ms:.1f} ms for 64 tokens ({mean_ms/64:.2f} ms/token)")Results
| Device | Time (64 tokens) | Per Token |
|---|---|---|
| CPU | ~226 ms | ~3.5 ms |
| GPU | ~578 ms | ~9.0 ms |
GPU is 2.5x slower than CPU on this tiny model because overhead dominates.
Root Cause Analysis
1. CPU Slowness: PyTorch Dispatch Overhead (~3.5 ms/token)
Even on CPU, each token takes ~3.5ms despite trivial compute. Profiling shows:
| Operation | Time per call |
|---|---|
| Raw matmul (16x16) | 0.002 ms |
nn.Linear(16,16) |
0.24 ms |
| Full forward pass | ~3.0 ms |
The model has ~8 linear layers per forward. Each nn.Linear call incurs ~0.24ms PyTorch dispatch overhead (Python→C++ transition, tensor metadata, op dispatch). This is unavoidable Python/PyTorch overhead.
2. GPU Slowness: 3.25 GPU→CPU Syncs Per Token (~5.5 ms/token overhead)
Profiling with torch.profiler reveals 26 aten::_local_scalar_dense calls for 8 tokens (the actual GPU→CPU sync operation):
aten::is_nonzero 26 calls 41ms total ~1.5ms each
aten::item 26 calls 41ms total ~1.5ms each
26 syncs / 8 tokens = 3.25 syncs per token
Tracing the call sites:
| Location | Calls per Token | Purpose |
|---|---|---|
utils.py:2674 _has_unfinished_sequences |
1 | Check if generation done |
utils.py:522 _cache_dependant_input_preparation |
1 | Check cache state |
masking_utils.py:253 _ignore_causal_mask_sdpa |
1 | Check mask skip condition |
Each sync forces GPU→CPU data transfer and pipeline stall, costing ~1.5ms each.
Total GPU overhead: 3.25 syncs × 1.5ms = ~5ms per token
Expected behavior
Generation should have minimal per-token overhead, especially for:
- Small/medium models where compute doesn't dominate
- Latency-sensitive applications
- Edge deployment scenarios
Potential Improvements
-
Reduce sync frequency - Check stopping criteria every N tokens instead of every token (see PR Skip attention_mask.all() GPU-CPU sync during generation #43088)
-
Async stopping criteria - Run sync-causing checks in a separate CUDA stream so they don't block the main compute stream (see PR Add async_stopping_criteria flag to reduce GPU-CPU syncs during generation #43085)
-
Batch boolean checks - Combine multiple boolean tensor checks into a single sync point
-
Lazy evaluation - Defer
_ignore_causal_mask_sdpaand similar checks when not needed -
Compile-friendly paths -
torch.compilecould potentially fuse operations and reduce sync points
Impact
For real models, this overhead gets amortized by actual compute. But for:
- Speculative decoding with small draft models
- On-device/edge models
- High-throughput serving with small models
- Distilled/quantized models
...this 3.25 syncs/token overhead becomes significant. Reducing it would improve generation latency across the board.
Related PRs
- Add async_stopping_criteria flag to reduce GPU-CPU syncs during generation #43085 - Async stopping criteria to reduce GPU-CPU syncs
- Skip attention_mask.all() GPU-CPU sync during generation #43088 - Reduce stopping criteria check frequency