Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def batch_generate(
use_cuda_graph=args.use_cuda_graph,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
do_sample=True,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
max_batch_tokens=args.max_batch_tokens,
)
Expand Down
23 changes: 19 additions & 4 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,16 +631,31 @@ def _process_logit(self, batch_data, logits):
self.logit_processor.set_continuous_batching_context(
batch_data["logits_indices"], batch_data["cu_seq_lens_q"]
)
return self.logit_processor(batch_data["input_ids"], logits)

# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
batch_size, seq_len, vocab_size = logits.shape
logits_2d = logits.view(batch_size * seq_len, vocab_size)
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)

# Process with 2D tensors
processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d)

# Reshape back to 3D
return processed_logits_2d.view(batch_size, seq_len, vocab_size)

@traced(span_name="sampling")
def _sample(self, batch_processor: ContinuousBatchProcessor, probs):
if self.do_sample: # sample
probs = nn.functional.softmax(probs, dim=-1)
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1)
# probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1]
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len]
# Add batch dimension back to match argmax output
next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len]
else:
next_tokens = torch.argmax(probs, dim=-1)
tokens = next_tokens.size(1)
next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len]

tokens = next_tokens.size(1) # Get seq_len dimension
batch_processor.output_ids[:, :tokens].copy_(next_tokens)

def _run_generation_loop(self):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/flash_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def paged_attention_forward(
k, v = cache.update(k, v, module.layer_idx, **kwargs)

sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
if implementation is not None:
if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"):
flash_attn_varlen_func = implementation.flash_attn_varlen_func
custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {}
attn_output = flash_attn_varlen_func(
Expand Down
65 changes: 64 additions & 1 deletion tests/generation/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,72 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
)

for i, req_id in enumerate(batch_outputs):
generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip()
generated = self.tokenizer.decode(
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
).strip()
expected = _EXPECTED_OUTPUTS[i].strip()
self.assertTrue(
generated.startswith(expected),
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}",
)

@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
]
)
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
self.model.config.attn_implementation = attn_impl

generation_config = GenerationConfig(
max_new_tokens=30,
do_sample=True,
top_k=50,
top_p=0.9,
temperature=0.8,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=False,
num_blocks=num_blocks,
block_size=block_size,
max_batch_tokens=max_batch_tokens,
)

tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
batch_inputs = list(tokenized["input_ids"])

start = time.time()
batch_outputs = self.model.generate_batch(
inputs=batch_inputs,
generation_config=generation_config,
)
end = time.time()
print(
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
)

# With sampling enabled, we can't check exact outputs, but we should verify:
# 1. All requests completed successfully
# 2. Generated text is non-empty
# 3. Generated text is different from greedy (demonstrating sampling is working)
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")

for i, req_id in enumerate(batch_outputs):
generated = self.tokenizer.decode(
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
).strip()
self.assertTrue(
len(generated) > 0,
msg=f"[{attn_impl}] Empty output for request {i}",
)
# Check that we got at least some tokens generated
generated_tokens = batch_outputs[req_id].generated_tokens
self.assertGreater(
len(generated_tokens),
0,
msg=f"[{attn_impl}] No tokens generated for request {i}",
)