## 1. Imports

In [1]:
!pip install -U bitsandbytes 
!pip install PyMuPDF

import torch
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoModelForCausalLM)


import bitsandbytes as bnb
from transformers.cache_utils import DynamicCache

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.1-py3-none-manylinux_2_24_x86_64.whl.metadata (5.8 kB)
Downloading bitsandbytes-0.45.1-py3-none-manylinux_2_24_x86_64.whl (69.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.1
Collecting PyMuPDF
  Downloading pymupdf-1.25.2-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (3.4 kB)
Downloading pymupdf-1.25.2-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (20.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m78.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: PyMuPDF
Successfully installed PyMuPDF-1.25.2


In [2]:
# from huggingface_hub import notebook_login
# notebook_login()
HF_TOKEN = r"YOUR_HF_TOKEN_HERE"
MODEL_NAME = r"meta-llama/Llama-3.2-1B-Instruct"

## 2. Prepare knowledge

In [3]:
import fitz
PDF_PATH = r"Information Retrieval Implementing and Evaluating Search Engines.pdf"
document = fitz.open(PDF_PATH)

text = str()
for page_num in range(len(document)):
    page = document.load_page(page_num)
    text += page.get_text()

In [4]:
x = 50000
knowledge= text[x:x+131070]
print(knowledge)

n a single section from a long
technical manual.
1.3
Working with Electronic Text
13
Other document formats are proprietary, meaning they are associated with the products
of a single software manufacturer. These proprietary formats include Microsoft’s “doc” format.
Until recently, due to the market dominance of Microsoft Oﬃce, this format was widely used for
document exchange and collaboration. Although the technical speciﬁcations for such proprietary
formats are often available, they can be complex and may be modiﬁed substantially from version
to version, entirely at the manufacturer’s discretion. Microsoft and other manufacturers have
now shifted toward XML-based formats (such as the OpenDocument format or Microsoft’s
OOXML), which may ameliorate the complications of indexing.
In practice, HTML may share many of the problems of binary formats. Many HTML pages
include scripts in the JavaScript or Flash programming languages. These scripts may rewrite
the Web page in its entirety and d

## 3. Loading knowledge

In [5]:
def preprocess_knowledge(
    model,
    tokenizer,
    prompt: str) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt

    Returns:
        DynamicCache: KV Cache
    """
    embed_device = model.model.embed_tokens.weight.device # check which device are used 
    input_ids    = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False)
    return outputs.past_key_values

## 4. Preparing Knowledge and Creating Key-Value Cache Data


In [6]:
def prepare_kvcache(documents, answer_instruction: str = None):
    # Prepare the knowledges kvcache

    if answer_instruction is None:
        answer_instruction = "Answer the question with a super short answer."

    knowledges = f"""
    <|begin_of_text|>
    <|start_header_id|>system<|end_header_id|>
    You are an medical assistant for giving short answers 
    based on given reports.<|eot_id|>
    <|start_header_id|>user<|end_header_id|>
    Context information is bellow.
    ------------------------------------------------
    {documents}
    ------------------------------------------------
    {answer_instruction}
    Question:
    """
    # Get the knowledge cache
    kv = preprocess_knowledge(model, tokenizer, knowledges)
    kv_len = kv.key_cache[0].shape[-2]
    print("kvlen: ", kv_len)
    return kv, kv_len


# kvlen:  610


## 4.5 past key value cleaning

In [7]:
def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]


## 5. Query

In [8]:
def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 300
) -> torch.Tensor:
    """
    Generate text with greedy decoding.

    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate
    """

    embed_device = model.model.embed_tokens.weight.device

    origin_ids = input_ids # what ?
    input_ids = input_ids.to(embed_device)

    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token, 
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(embed_device)

            past_key_values = outputs.past_key_values

            output_ids = torch.cat([output_ids, next_token], dim=1)

            
            if (next_token.item() in model.config.eos_token_id) and (_ > 0):
                break
    return output_ids[:, origin_ids.shape[-1]:]

## 6. Run it

In [9]:
# Define quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # Load model in 4-bit precision
    bnb_4bit_quant_type="nf4",      # Normalize float 4 quantization
    bnb_4bit_compute_dtype=torch.float16,  # Compute dtype for 4-bit base matrices
    bnb_4bit_use_double_quant=True  # Use nested quantization
)


def load_quantized_model(model_name, hf_token=None):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=hf_token
    )

    # Load model with quantization
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        device_map="auto",          # Automatically choose best device
        trust_remote_code=True,     # Required for some models
        token=hf_token
    )

    return tokenizer, model

tokenizer, model = load_quantized_model(model_name=MODEL_NAME, hf_token=HF_TOKEN)


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [12]:
import time
start_time = time.time()

knowledge_cache, kv_len  = prepare_kvcache(documents =knowledge)
# query = 'which Patient experienced issues with blood glucose meter, what was the problem ?'
query = 'What is this about?.'
clean_up(knowledge_cache, kv_len)
input_ids = tokenizer.encode(query, return_tensors="pt").to(model.device)
output = generate(model, input_ids, knowledge_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
print(f"Response of the model:\n {generated_text}")

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

kvlen:  36016
Response of the model:
 assistant

This is about the cosine similarity and proximity ranking methods used in information retrieval.
Elapsed time: 35.316410303115845 seconds


## 6.5 Save

In [11]:
def write_kv_cache(kv: DynamicCache, path: str):
    """
    Write the KV Cache to a file.
    """
    torch.save(kv, path)

def read_kv_cache(path: str) -> DynamicCache:
    """
    Read the KV Cache from a file.
    """
    kv = torch.load(path, weights_only=True)
    return kv