diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index b5ad94ed3f11..7196dc994204 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -229,7 +229,9 @@ def batch_generate( use_cuda_graph=args.use_cuda_graph, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, - do_sample=False, + do_sample=True, + temperature=0.8, + top_p=0.9, num_blocks=args.num_blocks, max_batch_tokens=args.max_batch_tokens, ) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 0ddd5bde8968..1c63507abe93 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -631,16 +631,31 @@ def _process_logit(self, batch_data, logits): self.logit_processor.set_continuous_batching_context( batch_data["logits_indices"], batch_data["cu_seq_lens_q"] ) - return self.logit_processor(batch_data["input_ids"], logits) + + # Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size] + # but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size] + batch_size, seq_len, vocab_size = logits.shape + logits_2d = logits.view(batch_size * seq_len, vocab_size) + input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len) + + # Process with 2D tensors + processed_logits_2d = self.logit_processor(input_ids_2d, logits_2d) + + # Reshape back to 3D + return processed_logits_2d.view(batch_size, seq_len, vocab_size) @traced(span_name="sampling") def _sample(self, batch_processor: ContinuousBatchProcessor, probs): if self.do_sample: # sample probs = nn.functional.softmax(probs, dim=-1) - next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) + # probs[0] has shape [seq_len, vocab_size], multinomial returns [seq_len, 1] + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(-1) # Now [seq_len] + # Add batch dimension back to match argmax output + next_tokens = next_tokens.unsqueeze(0) # Now [1, seq_len] else: - next_tokens = torch.argmax(probs, dim=-1) - tokens = next_tokens.size(1) + next_tokens = torch.argmax(probs, dim=-1) # Already [1, seq_len] + + tokens = next_tokens.size(1) # Get seq_len dimension batch_processor.output_ids[:, :tokens].copy_(next_tokens) def _run_generation_loop(self): diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py index 352bc82a1e40..00836beabe13 100644 --- a/src/transformers/integrations/flash_paged.py +++ b/src/transformers/integrations/flash_paged.py @@ -51,7 +51,7 @@ def paged_attention_forward( k, v = cache.update(k, v, module.layer_idx, **kwargs) sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0) - if implementation is not None: + if implementation is not None and hasattr(implementation, "flash_attn_varlen_func"): flash_attn_varlen_func = implementation.flash_attn_varlen_func custom_kwargs = {"s_aux": kwargs.get("s_aux")} if "s_aux" in kwargs else {} attn_output = flash_attn_varlen_func( diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index d3241e237466..e7673f5f08cd 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -78,9 +78,72 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max ) for i, req_id in enumerate(batch_outputs): - generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip() + generated = self.tokenizer.decode( + batch_outputs[req_id].generated_tokens, skip_special_tokens=False + ).strip() expected = _EXPECTED_OUTPUTS[i].strip() self.assertTrue( generated.startswith(expected), msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}", ) + + @parameterized.expand( + [ + ("eager_paged", 64, 128, 64), + ("sdpa_paged", 32, 256, 128), + ("paged_attention", 16, 512, 256), + ("flex_paged", 64, 128, 64), + ] + ) + def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens): + """Test batch generation with do_sampling=True to verify sampling works correctly.""" + self.model.config.attn_implementation = attn_impl + + generation_config = GenerationConfig( + max_new_tokens=30, + do_sample=True, + top_k=50, + top_p=0.9, + temperature=0.8, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=num_blocks, + block_size=block_size, + max_batch_tokens=max_batch_tokens, + ) + + tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test + batch_inputs = list(tokenized["input_ids"]) + + start = time.time() + batch_outputs = self.model.generate_batch( + inputs=batch_inputs, + generation_config=generation_config, + ) + end = time.time() + print( + f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}" + ) + + # With sampling enabled, we can't check exact outputs, but we should verify: + # 1. All requests completed successfully + # 2. Generated text is non-empty + # 3. Generated text is different from greedy (demonstrating sampling is working) + self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed") + + for i, req_id in enumerate(batch_outputs): + generated = self.tokenizer.decode( + batch_outputs[req_id].generated_tokens, skip_special_tokens=False + ).strip() + self.assertTrue( + len(generated) > 0, + msg=f"[{attn_impl}] Empty output for request {i}", + ) + # Check that we got at least some tokens generated + generated_tokens = batch_outputs[req_id].generated_tokens + self.assertGreater( + len(generated_tokens), + 0, + msg=f"[{attn_impl}] No tokens generated for request {i}", + )