In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache, BitsAndBytesConfig

DynamicCache = KV 캐시 관리 유틸리티 클래스

✅ 자동 position 관리	position ID를 자동으로 이어붙임

✅ 캐시 append	새 입력이 들어오면 KV를 자동으로 확장

✅ 캐시 초기화	토큰 리셋, 잘라내기 등 지원 (메모리 관리에 유리)

✅ 모델 간 호환	다양한 LLM 구조에 맞춰 추상화되어 있음

# preloading: 문서 KV 캐시 생성

In [None]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
HF_TOKEN = ""

tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    token=HF_TOKEN
    )

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]



In [3]:
def preprocess_knowledge(
    model,
    tokenizer,
    prompt: str,
) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt

    Returns:
        DynamicCache: KV Cache
    """
    embed_device = model.model.embed_tokens.weight.device
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False
        )
    return outputs.past_key_values


In [4]:
doc_text = "Ronan Takizawa is a Colorado College computer science student, cybersecurity researcher, and tech content creator with over 100,000 followers across social media platforms. Ronan Takizawa has built a machine learning boxing analytics app (Punch Analytics), a zero-knowledge proof CI pipeline (Noname), a REST API for international schools, a website automation system for the Ireland-Japan Chamber of Commerce, and a text-to-speech Chrome extension (TeleSpeech) that won HackHarvard 2023. Ronan Takizawa has worked with technologies including Python, TypeScript, Rust, Java, Shell, SQL, React, NodeJS, MongoDB, Docker, Kubernetes, AWS, GCP, and tools like Firebase, OpenCV, and GraphQL."

answer_instruction = "Answer the question with a super short answer."

knowledges = f"""
    <|begin_of_text|>
    <|start_header_id|>system<|end_header_id|>
    You are an assistant for giving short answers based on given context.<|eot_id|>
    <|start_header_id|>user<|end_header_id|>
    Context information is bellow.
    ------------------------------------------------
    {doc_text}
    ------------------------------------------------
    {answer_instruction}
    Question:
    """

kv_cache = preprocess_knowledge(model, tokenizer, knowledges)
kv_len = kv_cache.key_cache[0].shape[-2]

In [5]:
print(f"len key_cache: {len(kv_cache.key_cache)}")
print(f"len value_cache: {len(kv_cache.value_cache)}")
print(f"key_cache.shape: {kv_cache.key_cache[0].shape}")
print(f"value_cache.shape: {kv_cache.value_cache[0].shape}")
print(kv_cache.key_cache[0])
print(kv_cache.value_cache[0])

len key_cache: 32
len value_cache: 32
key_cache.shape: torch.Size([1, 8, 210, 128])
value_cache.shape: torch.Size([1, 8, 210, 128])
tensor([[[[ 5.5859e-01,  9.8438e-01,  1.0635e+00,  ...,  1.3291e+00,
           -9.2102e-02,  4.1577e-01],
          [ 2.1406e+00,  1.8154e+00,  1.7656e+00,  ...,  1.3733e-01,
           -1.6738e+00, -1.3262e+00],
          [-3.7207e+00, -2.3789e+00, -1.4238e+00,  ..., -5.3418e-01,
           -1.5645e+00, -1.9434e+00],
          ...,
          [ 7.0781e+00,  3.8633e+00,  2.7891e+00,  ...,  2.6172e-01,
           -1.5244e+00, -1.8643e+00],
          [ 2.9961e+00,  2.8848e+00,  3.3633e+00,  ...,  1.0938e-01,
           -1.8770e+00, -1.3115e+00],
          [-2.3613e+00,  7.2754e-01,  1.0791e+00,  ..., -5.3418e-01,
           -1.5645e+00, -1.9434e+00]],

         [[ 1.7868e-02, -7.5317e-02,  1.9730e-02,  ..., -1.1904e+00,
            8.1836e-01,  6.0394e-02],
          [ 1.5137e-01, -9.2871e-01,  5.9424e-01,  ...,  1.5156e+00,
           -1.1816e+00, -6.6943e-

# Inference

In [6]:
def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]

In [7]:
def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 300
) -> torch.Tensor:
    """
    Generate text with greedy decoding.

    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate
    """

    embed_device = model.model.embed_tokens.weight.device

    origin_ids = input_ids
    input_ids = input_ids.to(embed_device)

    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :] # (batch_size, sequence_length, vocab_size), 가장 마지막 토큰의 출력(logits)
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(embed_device)

            past_key_values = outputs.past_key_values

            output_ids = torch.cat([output_ids, next_token], dim=1)

            if next_token.item() in model.config.eos_token_id:
                break
    return output_ids[:, origin_ids.shape[-1]:]

In [8]:
question = "Who is Ronan Takizawa?"
prompt = f"""
    {question}<|eot_id|>
    <|start_header_id|>assistant<|end_header_id|>
    """
clean_up(kv_cache, kv_len)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
output = generate(model, input_ids, kv_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
generated_text

' Ronan Takizawa is a Colorado College computer science student and cybersecurity researcher with over 100,000 social media followers.'

In [9]:
generated_text

' Ronan Takizawa is a Colorado College computer science student and cybersecurity researcher with over 100,000 social media followers.'

In [10]:
question = "What are his main projects?"
prompt = f"""
    {question}<|eot_id|>
    <|start_header_id|>assistant<|end_header_id|>
    """
clean_up(kv_cache, kv_len)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
output = generate(model, input_ids, kv_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
generated_text

' Punch Analytics, Noname, REST API for international schools, website automation system for the Ireland-Japan Chamber of Commerce, and TeleSpeech.'

In [11]:
generated_text

' Punch Analytics, Noname, REST API for international schools, website automation system for the Ireland-Japan Chamber of Commerce, and TeleSpeech.'

In [12]:
question = "What technologies has he worked with?"
prompt = f"""
    {question}<|eot_id|>
    <|start_header_id|>assistant<|end_header_id|>
    """
clean_up(kv_cache, kv_len)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
output = generate(model, input_ids, kv_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
generated_text

' Python, TypeScript, Rust, Java, Shell, SQL, React, NodeJS, MongoDB, Docker, Kubernetes, AWS, GCP, Firebase, OpenCV, GraphQL.'

In [13]:
generated_text

' Python, TypeScript, Rust, Java, Shell, SQL, React, NodeJS, MongoDB, Docker, Kubernetes, AWS, GCP, Firebase, OpenCV, GraphQL.'

In [14]:
tokenizer.decode(model.config.eos_token_id)

'<|end_of_text|><|eom_id|><|eot_id|>'