Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about KV cache #1883

Closed
martinigoyanes opened this issue May 13, 2024 · 3 comments
Closed

Question about KV cache #1883

martinigoyanes opened this issue May 13, 2024 · 3 comments

Comments

@martinigoyanes
Copy link
Contributor

martinigoyanes commented May 13, 2024

In flash_causal_lm.py:

  • What is BLOCK_SIZE referring to? Why is it hardcoded to 16 and why is it used to scale all calculations?
  • What is exactly total_cache_size?
  • Why are max_batch_total_tokens computed by num_blocks * BLOCK_SIZE? Is it because one block of memory able to store 16 (block_size) tokens?
    def warmup(self, batch: FlashCausalLMBatch):
        # The warmup batch is the biggest batch we could ever receive
        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.empty_cache()
        elif IS_XPU_SYSTEM:
            torch.xpu.empty_cache()
        try:
            cache_manager = set_cache_manager(
                batch.blocks,
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.sliding_window is not None,
                self.dtype,
                self.device,
            )
            max_bt = batch.max_blocks
            max_s = max_bt * get_cache_manager().block_size
            _, batch, _ = self.generate_token(batch)
        except torch.cuda.OutOfMemoryError as e:
            raise RuntimeError(
                f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. "
                f"You need to decrease `--max-batch-prefill-tokens`"
            ) from e


        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            torch.cuda.synchronize(self.device)
        elif IS_XPU_SYSTEM:
            torch.xpu.synchronize(self.device)


        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
        # Calculate the number of blocks that can be allocated with the free memory
        dtype_size = torch.tensor([], dtype=self.dtype).element_size()
        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size


        if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
            total_free_memory, _ = torch.cuda.mem_get_info(self.device)
            total_gpu_memory = torch.cuda.get_device_properties(
                self.device
            ).total_memory


            free_memory = max(
                0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
            )
        elif IS_XPU_SYSTEM:
            total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
            free_memory = int(total_gpu_memory * 0.5)
        else:
            raise NotImplementedError("FlashModel is only available on GPU")


        num_blocks = (
            # Leave 5% for some wiggle room
            int((free_memory * 0.95) // total_cache_size)
            # Add batch.blocks as we allocated it above, so it is included in the peak memory.
            + cache_manager.num_blocks
        )


        del batch
        del cache_manager


        set_cache_manager(
            num_blocks,
            self.num_layers,
            self.num_kv_heads,
            self.head_size,
            self.sliding_window is not None,
            self.dtype,
            self.device,
        )


        if CUDA_GRAPHS:
            try:
                logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
                # Warmup cuda graphs
                for bs in CUDA_GRAPHS:
                    if self.speculate is None or self.speculate + 1 <= bs:
                        self.cuda_graph_warmup(bs, max_s, max_bt)
            except torch.cuda.OutOfMemoryError:
                logger.exception(f"Decode cuda graph warmup failed")
        else:
            logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")


        return int(num_blocks * BLOCK_SIZE)
@martinigoyanes
Copy link
Contributor Author

Hey @Venkat2811 , maybe this is of your interest!

@Venkat2811
Copy link

Venkat2811 commented May 15, 2024

#1863 (comment)

Yes, I am interested in this and have to spend some time with code & research papers to understand finer implementation details. Thanks for tagging me though :)

I'm not from HF team, just sharing my suggestion. Since this is more of QnA & not an actual issue, consider moving this to discussions ?

@martinigoyanes
Copy link
Contributor Author

You are right! I have created a discussion here: #1897

@martinigoyanes martinigoyanes closed this as not planned Won't fix, can't repro, duplicate, stale May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants