Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,18 +187,20 @@ def batch_generate(
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
)
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
parser.add_argument("--slice-inputs", action="store_true", default=False)
parser.add_argument("--use-cuda-graph", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster
parser.add_argument("--use-cuda-graph", "-cg", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--samples", type=int, default=500)
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
parser.add_argument("--output-file", type=str, default=None)
parser.add_argument("--compare", action="store_true", default=False)
parser.add_argument("--metrics", action="store_true", default=False)
parser.add_argument("--compare", action="store_true")
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--profile", type=str, default=None)
args = parser.parse_args()

args.slice_inputs = not args.no_slice_inputs

# If turned on, we setup metrics
if args.metrics:
setup_metrics()
Expand Down
26 changes: 18 additions & 8 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(
# Add the inferred attributes to the class
self.num_blocks = num_blocks
self.max_batch_tokens = max_batch_tokens
logger.warning(
logger.info(
f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
f"{self.max_batch_tokens = } {num_attention_masks = }"
)
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_num_free_blocks(self) -> int:
return len(self._free_blocks)

@traced
def get_read_indices(
def extend_read_indices(
self, request_id: str, past_length: int, query_length: int, read_index: list[list[int]]
) -> None:
"""Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
Expand All @@ -264,7 +264,7 @@ def get_read_indices(
read_indices.extend(indices)

@traced
def get_write_indices(
def extend_write_indices(
self, request_id: str, past_length: int, query_length: int, write_index: list[list[int]]
) -> None:
"""Retrieve physical cache indices for writing new KV states to the cache across all layer groups. This method
Expand All @@ -274,6 +274,16 @@ def get_write_indices(
indices = cm.get_write_indices(request_id, past_length, query_length)
write_indices.extend(indices)

@traced
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> dict[str, int]:
"""Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
layer types to their corresponding key sequence lengths."""
seqlens_k = {}
for cm in self.group_cache_managers:
attn_type, seqlen_k = cm.get_seqlens_k(request_id, past_length, query_length)
seqlens_k[attn_type] = seqlen_k
return seqlens_k

@traced
def update(
self,
Expand Down Expand Up @@ -471,7 +481,7 @@ def compute_num_blocks_and_max_batch_tokens(
b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
c = -cache_memory
logger.info(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")

# Compute discriminant and greatest solution
discriminant = b**2 - 4 * a * c
Expand All @@ -485,11 +495,11 @@ def compute_num_blocks_and_max_batch_tokens(
num_pages = floor(greatest_solution)
num_blocks = num_pages // self.block_size
if num_blocks > self._upper_bound_num_blocks:
logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
num_blocks = self._upper_bound_num_blocks
max_batch_tokens = int(greatest_solution * m)
if max_batch_tokens > self._upper_bound_max_batch_tokens:
logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
max_batch_tokens = self._upper_bound_max_batch_tokens
return num_blocks, max_batch_tokens

Expand Down Expand Up @@ -517,7 +527,7 @@ def compute_max_batch_tokens(
# Compute max batch tokens and return
max_batch_tokens = floor(num / denum)
if max_batch_tokens > self._upper_bound_max_batch_tokens:
logger.warning(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
max_batch_tokens = self._upper_bound_max_batch_tokens
return max_batch_tokens

Expand Down Expand Up @@ -545,7 +555,7 @@ def compute_num_blocks(
num_pages = floor(num / denum)
num_blocks = num_pages // self.block_size
if num_blocks > self._upper_bound_num_blocks:
logger.warning(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
num_blocks = self._upper_bound_num_blocks
return num_blocks

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/generation/continuous_batching/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
"""Returns the physical indices of where to write request_id's cache in the cache tensor."""
pass

@abstractmethod
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
pass


class FullAttentionCacheAllocator(CacheAllocator):
"""Cache manager for a group of full attention layers."""
Expand Down Expand Up @@ -108,6 +113,11 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
physical_indices.append(physical_index)
return physical_indices

def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
seqlens_k = past_length + query_length
return "full_attention", seqlens_k


class SlidingAttentionCacheAllocator(CacheAllocator):
"""Cache manager for sliding window attention layers."""
Expand Down Expand Up @@ -191,6 +201,11 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
physical_indices = [-1] * padding_length + physical_indices
return physical_indices

def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
return "sliding_attention", seqlens_k


# TODO: test the impact of this
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
Expand Down
Loading