In [1]:
import torch, math, time, contextlib, os
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention

In [2]:
# Define Model and datatype for benchmarking
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
DTYPE    = torch.bfloat16

In [3]:
@contextlib.contextmanager
def timed_cuda(label: str):
    torch.cuda.synchronize()
    t0 = time.time()
    try:
        yield
    finally:
        torch.cuda.synchronize()
        print(f"{label:<25}: {(time.time() - t0)*1e3:8.2f} ms")

In [4]:
class LlamaAttentionKV(LlamaAttention):
    """Same attention, but keeps at most `MAX_CACHE` tokens in the returned cache."""
    MAX_CACHE = 2048

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        **kwargs,
    ):

        attn_output, attn_weights, present_kv = super().forward(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            **kwargs,
        )

        # crop the cache if it exists and is too large
        if present_kv is not None and self.MAX_CACHE > 0:
            if hasattr(present_kv, 'key_cache') and hasattr(present_kv, 'value_cache'):
                # DynamicCache or StaticCache format
                if present_kv.key_cache[0].size(2) > self.MAX_CACHE:
                    present_kv.key_cache[0] = present_kv.key_cache[0][:, :, -self.MAX_CACHE:, :]
                    present_kv.value_cache[0] = present_kv.value_cache[0][:, :, -self.MAX_CACHE:, :]
            elif isinstance(present_kv, tuple) and len(present_kv) == 2:
                k, v = present_kv
                if hasattr(k, 'size') and k.size(2) > self.MAX_CACHE:
                    k = k[:, :, -self.MAX_CACHE:, :]
                    v = v[:, :, -self.MAX_CACHE:, :]
                    present_kv = (k, v)
            else:
                if hasattr(present_kv, '__len__'):
                    print(f"Length: {len(present_kv)}")
                    if len(present_kv) > 0:
                        print(f"First element type: {type(present_kv[0])}")

        return attn_output, attn_weights, present_kv

In [5]:
def main():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype=DTYPE, device_map="auto"
    ).eval()

    for layer in model.model.layers:
        layer.self_attn.__class__ = LlamaAttentionKV
    print("Custom KV cache initialized")

    prompt = "Why do cats purr?"
    def run(use_cache: bool):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.inference_mode():
            return model.generate(
                **inputs,
                max_new_tokens=256,
                use_cache=use_cache,
                do_sample=False,     # deterministic
                temperature=None, top_p=None,
                pad_token_id=tokenizer.eos_token_id,
            )

    with timed_cuda("No‑cache (baseline)"):
        _ = run(False)

    with timed_cuda("With KV cache"):
        _ = run(True)

In [6]:
main()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Custom KV cache initialized
No‑cache (baseline)      :  6471.98 ms
With KV cache            :  5155.16 ms
