You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What is BLOCK_SIZE referring to? Why is it hardcoded to 16 and why is it used to scale all calculations?
What is exactly total_cache_size?
Why are max_batch_total_tokens computed by num_blocks * BLOCK_SIZE? Is it because one block of memory able to store 16 (block_size) tokens?
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.sliding_window is not None,
self.dtype,
self.device,
)
max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+ cache_manager.num_blocks
)
del batch
del cache_manager
set_cache_manager(
num_blocks,
self.num_layers,
self.num_kv_heads,
self.head_size,
self.sliding_window is not None,
self.dtype,
self.device,
)
if CUDA_GRAPHS:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
# Warmup cuda graphs
for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")
else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return int(num_blocks * BLOCK_SIZE)
The text was updated successfully, but these errors were encountered:
Yes, I am interested in this and have to spend some time with code & research papers to understand finer implementation details. Thanks for tagging me though :)
I'm not from HF team, just sharing my suggestion. Since this is more of QnA & not an actual issue, consider moving this to discussions ?
In
flash_causal_lm.py
:The text was updated successfully, but these errors were encountered: