Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ def __init__(self, args: ServeArguments):
# Store and process input arguments
self.args = args
self.use_continuous_batching = self.args.continuous_batching
if self.use_continuous_batching:
default_attn_impl = ContinuousBatchingManager.default_attention_implementation()
# checking if attn_implementation is supported by continuous batching
if self.args.attn_implementation is None:
self.args.attn_implementation = default_attn_impl # default to sdpa_paged
logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}")
supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations()
if self.args.attn_implementation not in supported_attn_impl:
raise ValueError(
Comment on lines +492 to +494
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here if attn impelmeention not supported, you can try to map it to the correct ones?
Because for sdpa eager and flash, its easy to map: prefix with paged|
But up to you !

f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got "
f"{self.args.attn_implementation}"
f"Try setting `--attn_implementation={default_attn_impl}`"
)
Comment on lines +486 to +498
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdy about putting this directly in CB's api? would be better in general + we can automatically add paged_{attn_implememntation} no?

self.enable_cors = self.args.enable_cors

if self.args.default_seed is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,14 @@ def request_id_iter(self, request_id):
if self.batch_processor is not None:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)

@staticmethod
def supported_attention_implementations() -> set[str]:
return {"eager_paged", "sdpa_paged", "flash_attention_2"}

@staticmethod
def default_attention_implementation() -> str:
return "sdpa_paged"

@traced
def warmup(self, batch_processor):
stream = torch.cuda.Stream(device=self.model.device)
Expand Down