<a href="https://colab.research.google.com/github/mshojaei77/RAG_CAG_SFT/blob/main/cag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q -U bitsandbytes

In [None]:
import torch
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoModelForCausalLM)

import bitsandbytes as bnb
from transformers.cache_utils import DynamicCache

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)

model_id  = "huihui-ai/Llama-3.2-3B-Instruct-abliterated"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model     = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map='auto')

In [None]:
from google.colab import files
# Upload knowledge base (.md or .txt) file directly to Colab
uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))
  knowledge = uploaded[fn].decode('utf-8') # Assuming it's text
  # Now you have the file content in the 'knowledge' variable

# Example: print first 100 characters of the uploaded file
print(knowledge[:80])


# Preloading Knowledge

In [None]:
def preprocess_knowledge(
    model,
    tokenizer,
    prompt: str) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.
    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt

    Returns:
        DynamicCache: KV Cache
    """
    embed_device = model.model.embed_tokens.weight.device # check which device are used
    input_ids    = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False)
    return outputs.past_key_values

In [None]:
def prepare_kvcache(documents, answer_instruction: str = None):
    # Prepare the knowledges kvcache

    if answer_instruction is None:
        answer_instruction = "Answer the question with a super short answer."

    knowledges = f"""
    <|begin_of_text|>
    <|start_header_id|>system<|end_header_id|>
    You are an ai assistant for giving short answers
    based on given documents.<|eot_id|>
    <|start_header_id|>user<|end_header_id|>
    Context information is bellow.
    ------------------------------------------------
    {documents}
    ------------------------------------------------
    {answer_instruction}
    Question:
    """
    # Get the knowledge cache
    kv = preprocess_knowledge(model, tokenizer, knowledges)
    kv_len = kv.key_cache[0].shape[-2]
    print("kvlen: ", kv_len)
    return kv, kv_len


knowledge_cache, kv_len  = prepare_kvcache(documents =knowledge)

In [None]:
def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]

In [None]:
def generate(
    model,
    input_ids: torch.Tensor,
    past_key_values,
    max_new_tokens: int = 300
) -> torch.Tensor:
    """
    Generate text with greedy decoding.

    Args:
        model: HuggingFace model with automatic device mapping
        input_ids: Input token ids
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate
    """

    embed_device = model.model.embed_tokens.weight.device

    origin_ids = input_ids
    input_ids = input_ids.to(embed_device)

    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(embed_device)

            past_key_values = outputs.past_key_values

            output_ids = torch.cat([output_ids, next_token], dim=1)

            if (next_token.item() in model.config.eos_token_id) and (_ > 0):
                break
    return output_ids[:, origin_ids.shape[-1]:]

In [None]:
query = 'What is Cache-Augmented Generation (CAG)?'

clean_up(knowledge_cache, kv_len)
input_ids = tokenizer.encode(query, return_tensors="pt").to(model.device)
output = generate(model, input_ids, knowledge_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
print(f"Response of the model:\n {generated_text}")