From 84f892ea0b8e2ae38afaf99faea8aa7f573cfa9d Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Tue, 2 Sep 2025 12:45:49 +0200 Subject: [PATCH 1/7] feat: err when unsupported attn impl is set w/ `--continuous_batching` --- src/transformers/commands/serving.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 622f50378dfd..c6ccff2e526b 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -483,6 +483,16 @@ def __init__(self, args: ServeArguments): # Store and process input arguments self.args = args self.use_continuous_batching = self.args.continuous_batching + supported_cb_attn_implementations = {"eager_paged", "sdpa_paged", "flash_attention_2"} + if self.use_continuous_batching: + # checking if attn_implementation is supported by continuous batching + if self.args.attn_implementation is None: + self.args.attn_implementation = "sdpa_paged" # default to sdpa_paged + if self.args.attn_implementation not in supported_cb_attn_implementations: + raise ValueError( + f"Continuous batching only supports {supported_cb_attn_implementations} as attn_implementation, got " + f"{self.args.attn_implementation}" + ) self.enable_cors = self.args.enable_cors if self.args.default_seed is not None: From 88bc48a136dd01130516256702544c8140ad59b5 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 3 Sep 2025 13:25:14 +0200 Subject: [PATCH 2/7] refactor: move defaults and support list to CB code --- src/transformers/commands/serving.py | 9 +++++---- .../generation/continuous_batching/continuous_api.py | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index c6ccff2e526b..49d1896e1f9d 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -483,14 +483,15 @@ def __init__(self, args: ServeArguments): # Store and process input arguments self.args = args self.use_continuous_batching = self.args.continuous_batching - supported_cb_attn_implementations = {"eager_paged", "sdpa_paged", "flash_attention_2"} if self.use_continuous_batching: + default_attn_impl = ContinuousBatchingManager.default_attention_implementation() # checking if attn_implementation is supported by continuous batching if self.args.attn_implementation is None: - self.args.attn_implementation = "sdpa_paged" # default to sdpa_paged - if self.args.attn_implementation not in supported_cb_attn_implementations: + self.args.attn_implementation = # default to sdpa_paged + supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations() + if self.args.attn_implementation not in supported_attn_impl: raise ValueError( - f"Continuous batching only supports {supported_cb_attn_implementations} as attn_implementation, got " + f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got " f"{self.args.attn_implementation}" ) self.enable_cors = self.args.enable_cors diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 1c63507abe93..48a10e13b0e9 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -595,6 +595,14 @@ def request_id_iter(self, request_id): if self.batch_processor is not None: request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id) + @staticmethod + def supported_attention_implementations() -> set[str]: + return {"eager_paged", "sdpa_paged", "flash_attention_2"} + + @staticmethod + def default_attention_implementation() -> str: + return "sdpa_paged" + @traced def warmup(self, batch_processor): stream = torch.cuda.Stream(device=self.model.device) From de490f6c255198f5baf87e973df0e899fea542da Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Wed, 3 Sep 2025 13:25:43 +0200 Subject: [PATCH 3/7] feat: add action item in error msg --- src/transformers/commands/serving.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 49d1896e1f9d..2514c528d901 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -493,6 +493,7 @@ def __init__(self, args: ServeArguments): raise ValueError( f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got " f"{self.args.attn_implementation}" + f"Try setting `--attn_implementation={default_attn_impl}`" ) self.enable_cors = self.args.enable_cors From 87817c390b0b454309b494a91a2efc083d8ec39a Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Fri, 5 Sep 2025 14:57:00 +0200 Subject: [PATCH 4/7] fix(serve): add default attn implementation --- src/transformers/commands/serving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 2514c528d901..7af3d60ae9ae 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -487,7 +487,7 @@ def __init__(self, args: ServeArguments): default_attn_impl = ContinuousBatchingManager.default_attention_implementation() # checking if attn_implementation is supported by continuous batching if self.args.attn_implementation is None: - self.args.attn_implementation = # default to sdpa_paged + self.args.attn_implementation = default_attn_impl # default to sdpa_paged supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations() if self.args.attn_implementation not in supported_attn_impl: raise ValueError( From a2b348d4a55facb2a3f308738fdea20c53176cd6 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Fri, 5 Sep 2025 14:57:59 +0200 Subject: [PATCH 5/7] feat(serve): add log when `attn_implementation` is `None` --- src/transformers/commands/serving.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 7af3d60ae9ae..6c5bbed3cfa4 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -488,6 +488,7 @@ def __init__(self, args: ServeArguments): # checking if attn_implementation is supported by continuous batching if self.args.attn_implementation is None: self.args.attn_implementation = default_attn_impl # default to sdpa_paged + logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}") supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations() if self.args.attn_implementation not in supported_attn_impl: raise ValueError( From 2421a8909b87e3281ac886c31b9e215c24b1b705 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 8 Sep 2025 16:16:47 +0200 Subject: [PATCH 6/7] feat: raise Exception when attn_implementation is not supported by CB --- .../generation/continuous_batching/continuous_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 48a10e13b0e9..eff36ee8da64 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -442,6 +442,10 @@ def __init__( max_queue_size: Maximum size of the request queue (0 = unlimited) streaming: Whether to stream tokens as they are generated """ + if model.config._attn_implementation not in self.supported_attention_implementations(): + raise ValueError( + f"Model attention implementation '{model.config._attn_implementation}' is not supported for continuous batching. Supported implementations: {self.supported_attention_implementations()}" + ) self.model = model.eval() generation_config = model.generation_config if generation_config is None else generation_config self.generation_config = generation_config From b69e38fc6eebfdff6b4fc5e48e652f4dddffcb64 Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 8 Sep 2025 16:22:38 +0200 Subject: [PATCH 7/7] revert: raise Exception when attn_implementation is not supported --- .../generation/continuous_batching/continuous_api.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index eff36ee8da64..48a10e13b0e9 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -442,10 +442,6 @@ def __init__( max_queue_size: Maximum size of the request queue (0 = unlimited) streaming: Whether to stream tokens as they are generated """ - if model.config._attn_implementation not in self.supported_attention_implementations(): - raise ValueError( - f"Model attention implementation '{model.config._attn_implementation}' is not supported for continuous batching. Supported implementations: {self.supported_attention_implementations()}" - ) self.model = model.eval() generation_config = model.generation_config if generation_config is None else generation_config self.generation_config = generation_config