## Cache-Augmented Generation
### 🔍 Overview
Retrieval-Augmented Generation (RAG) enhances language models by integrating external knowledge but faces challenges like retrieval latency, errors, and system complexity.

Cache-Augmented Generation (CAG) addresses these by preloading relevant data into the
model's context, leveraging modern LLMs' extended context windows and caching runtime parameters.

This eliminates real-time retrieval during inference, enabling direct response generation.

<img src="https://raw.githubusercontent.com/genieincodebottle/genaicodelab/main/cache_augumeted_generation/images/cag_diagram.png" width="300" height="400" alt="CAG">


### ✨ Advantages of CAG
* **Reduced Latency:** Faster inference by removing real-time retrieval.
* **Improved Reliability:** Avoids retrieval errors and ensures context relevance.
* **Simplified Design:** Offers a streamlined, low-complexity alternative to RAG with comparable or better performance.

### ⚠️ Limitations of CAG
* **Knowledge Size Limits:** Requires fitting all relevant data into the context window, unsuitable for extremely
large datasets.
* **Context Length Issues:** Performance may degrade with very long contexts.

### Step - 1 Install Required Packages

In [1]:
!pip install -qU transformers accelerate bitsandbytes sentence-transformers torch plotly python-dotenv

### Step - 2 Import Libraries

In [14]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from sentence_transformers import SentenceTransformer
import pandas as pd
import time

# Check if we can use GPU
print(f"GPU available: {torch.cuda.is_available()}")

GPU available: True


### Step - 3 Load the Model

* How to get your Hugging Face token and store it in Google Colab:
  
  * Visit the [Hugging Face Tokens Page](https://huggingface.co/settings/tokens)
  * Create a new token with read access

  * Copy the Huggigface toekn
  * In Google Colab, navigate to the Secret section, add a secret with the name HF_TOKEN, and paste the token you copied in the previous step.
  * <img src="https://raw.githubusercontent.com/genieincodebottle/genaicodelab/main/cache_augumeted_generation/images/google_cloab_secret_page.png" width="300" height="200" alt="Colab Secret Screenshot">

In [15]:
def load_model(model_name="meta-llama/Llama-3.2-1B-Instruct"):
    """Load the language model and tokenizer"""

    # Set up model settings based on available hardware
    model_settings = {
        "device_map": "auto",
        "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
    }

    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(model_name, **model_settings)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer

# Load the model
model, tokenizer = load_model()

### Step - 4 Download and Load Dataset

In [16]:
# Download the sample dataset
!wget https://raw.githubusercontent.com/genieincodebottle/genaicodelab/main/cache_augumeted_generation/datasets/sample_qa_dataset.csv

# Load the dataset
df = pd.read_csv("sample_qa_dataset.csv")
print("Dataset preview:")
display(df.head())

--2025-01-22 07:10:59--  https://raw.githubusercontent.com/genieincodebottle/genaicodelab/main/cache_augumeted_generation/datasets/sample_qa_dataset.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 43105 (42K) [text/plain]
Saving to: ‘sample_qa_dataset.csv.8’


2025-01-22 07:11:00 (4.56 MB/s) - ‘sample_qa_dataset.csv.8’ saved [43105/43105]

Dataset preview:


Unnamed: 0,topic,text,sample_question,sample_ground_truth
0,Setting Up a Mobile Device for Company Email,**Setting Up a Mobile Device for Company Email...,"""How do I set up my company email on my mobile...",To set up your company email on your mobile de...
1,Resetting a Forgotten PIN,**Resetting a Forgotten PIN**\n\nIf you have f...,"I forgot my PIN, how can I reset it?","Don't worry, I'm here to help To reset your fo..."
2,Configuring VPN Access for Remote Workers,**Configuring VPN Access for Remote Workers**\...,How do I set up VPN access on my laptop so I c...,To set up VPN access on your laptop and access...
3,Troubleshooting Issues with Microsoft Office,**Troubleshooting Issues with Microsoft Office...,"""My Microsoft Word keeps freezing every time I...",I'd be happy to help you troubleshoot the issu...
4,Setting Up a Conference Call on Cisco Webex,"To set up a conference call on Cisco Webex, fo...",How do I set up a conference call on Cisco Web...,To set up a conference call on Cisco Webex wit...


### Step - 5 Define Helper Functions

In [22]:
def create_prompt(question, context=""):
    """Create a prompt for the model"""
    if context:
        return f"""
        Context information:
        {context}

        Question: {question}
        Answer:"""
    return f"Question: {question}\nAnswer:"

def prepare_kv_cache(
        documents: str,
        instruction: str = None
    ) :
        """Prepare the KV cache for generation."""
        start_time = time.time()

        instruction = instruction or "Answer the question with a short answer."
        prompt = f"""
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an assistant for giving short answers based on given context.
        <|eot_id|>
        <|start_header_id|>user<|end_header_id|>
        Context information is below.
        ------------------------------------------------
        {documents}
        ------------------------------------------------
        {instruction}
        Question:
        """
        try:
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
            past_key_values = DynamicCache()

            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids,
                    past_key_values=past_key_values,
                    use_cache=True,
                    output_attentions=False,
                    output_hidden_states=False
                )

            if not outputs.past_key_values or len(outputs.past_key_values) == 0:
                raise ValueError("Empty KV cache generated")

            return outputs.past_key_values, time.time() - start_time

        except Exception as e:
            print(f"Error preparing KV cache: {str(e)}")
            return DynamicCache(), time.time() - start_time

def generate_answer(model, tokenizer, prompt, past_key_values, max_tokens=300):
    """Generate an answer using the model"""
    # Encode the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    origin_ids = input_ids
    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        for _ in range(max_tokens):
            outputs = model(
              input_ids=next_token,
              past_key_values=past_key_values,
              use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
            next_token = next_token.to(model.device)
            past_key_values = outputs.past_key_values
            output_ids = torch.cat([output_ids, next_token], dim=1)

            if next_token.item() in model.config.eos_token_id:
                break

    output = output_ids[:, origin_ids.shape[-1]:]
    return tokenizer.decode(output[0], skip_special_tokens=True)

def calculate_similarity(text1, text2, similarity_model):
    """Calculate similarity between two texts"""
    # Encode both texts
    embedding1 = similarity_model.encode(text1, convert_to_tensor=True)
    embedding2 = similarity_model.encode(text2, convert_to_tensor=True)

    # Calculate similarity
    similarity = torch.nn.functional.cosine_similarity(embedding1.unsqueeze(0),
                                                     embedding2.unsqueeze(0))
    return similarity.item()

### Step - 6 Test the System

In [23]:
# Initialize similarity model
similarity_model = SentenceTransformer('all-MiniLM-L6-v2')

# Get a sample of questions (adjust number as needed)
num_questions = 5
test_questions = df['sample_question'].iloc[:num_questions].tolist()
ground_truths = df['sample_ground_truth'].iloc[:num_questions].tolist()

# Prepare context (combine all relevant texts)
context = '\n\n'.join(df["text"].tolist())

knowledge_cache, prep_time = prepare_kv_cache(context)

# Store results
results = []

# Process each question
for question, ground_truth in zip(test_questions, ground_truths):
    # Start timing
    start_time = time.time()

    # Create prompt and generate answer
    prompt = create_prompt(question, context)
    answer = generate_answer(model, tokenizer, prompt, knowledge_cache)

    # Calculate generation time
    generation_time = time.time() - start_time

    # Calculate similarity with ground truth
    similarity = calculate_similarity(answer, ground_truth, similarity_model)

    # Store results
    results.append({
        'Question': question,
        'Generated Answer': answer,
        'Ground Truth': ground_truth,
        'Similarity': similarity,
        'Generation Time': generation_time
    })

# Convert results to DataFrame
results_df = pd.DataFrame(results)
display(results_df)

Unnamed: 0,Question,Generated Answer,Ground Truth,Similarity,Generation Time
0,"""How do I set up my company email on my mobile...","""To set up your company email on your mobile ...",To set up your company email on your mobile de...,0.815589,4.682934
1,"I forgot my PIN, how can I reset it?","To reset your PIN, follow the steps outlined ...","Don't worry, I'm here to help To reset your fo...",0.790417,5.030656
2,How do I set up VPN access on my laptop so I c...,"To set up VPN access on your laptop, follow t...",To set up VPN access on your laptop and access...,0.892379,26.6961
3,"""My Microsoft Word keeps freezing every time I...","To fix the issue, try the following steps:\n\...",I'd be happy to help you troubleshoot the issu...,0.771017,32.753245
4,How do I set up a conference call on Cisco Web...,To set up a conference call on Cisco Webex wi...,To set up a conference call on Cisco Webex wit...,0.9752,38.280641


### Step - 7 Visualize Results

In [24]:
import plotly.graph_objects as go

# Create performance visualization
fig = go.Figure()

# Add generation time trace
fig.add_trace(go.Scatter(
    x=list(range(len(results))),
    y=results_df['Generation Time'],
    name='Generation Time',
    mode='lines+markers'
))

# Add similarity trace
fig.add_trace(go.Scatter(
    x=list(range(len(results))),
    y=results_df['Similarity'],
    name='Answer Similarity',
    mode='lines+markers',
    yaxis='y2'
))

# Update layout
fig.update_layout(
    title='Performance Metrics',
    xaxis_title='Question Number',
    yaxis_title='Generation Time (seconds)',
    yaxis2=dict(
        title='Similarity Score',
        overlaying='y',
        side='right'
    )
)

fig.show()

# Print summary statistics
print("\nSummary Statistics:")
print(f"Average Generation Time: {results_df['Generation Time'].mean():.2f} seconds")
print(f"Average Similarity Score: {results_df['Similarity'].mean():.2f}")


Summary Statistics:
Average Generation Time: 21.49 seconds
Average Similarity Score: 0.85
