# Implementing Semantic Cache to Improve a RAG System with FAISS

In this example, we will explore a typical RAG solution where we will utilize an open-source model and the vector database Chroma DataBase. We will integrate a **semantic cache system** that will store various user queries and decide whether to generate the prompt enriched with information from the vector database or the cache.

A **semantic caching system** aims to identify similar or identical user requests. When a matching request is found, the system retrieves the corresponding information from the **cache**, reducing the need to fetch it from the *original source*.

As the comparison takes into account the semantic meansing of the requests, they do not have to be identical for the system to recognize them as the same question. They can be formulated differently or contain inaccuracies, be typographical or in the sentence structure, and we can still identify that the user is actually requesting the same information.

For example, if a user has three queries, like "What is the capital of France?", "Tell me the name of the capital of France?", and "What the capital of France is?", all convey the same intent and should be identified as the same question.

While the model's response may differ based on the request for a concise answer, the information retrieved from the vector database should be the same. Therefore, the cache system is inserted between the user and the vector database, instead of between the user and the LLM.
```
Documents --> ChromaDB <--> SemanticCache <-- UserQuery
                                |--> AugmentedPrompt --> LLM
```

To enhance the performance of RAG system in production, we may need one or multiple semantic caches. This cache retains the results of previous requests, and before resolving a new request, it checks if a similar one has been received before. If so, instead of re-executing the process, it retrieves the information from the cache.

In a RAG system, there are two points that are time consuming:
* Retrieving the information used to construct the enriched prompt
* Calling the LLM to obtain the response

In both points, a semantic cache system can be implemented, and we could even have two caches, one for each point. Note that placing the cache at the model's response point may lead to a loss of influence over the obtained response. For example, our cache system may consider "Explain the French Revolution in 10 words" and "Explain the French Resolution in 100 words" as the same query. If our cache system stores model responses, the resopnses may not follow the user instructions accurately.

In this example, we will place the semantic cache system between the user's request and the information retrieval from the vector database.

## Setups

In this example, we will need
* `sentence_transformers` - transform the sentences into fixed-length vectors, AKA, embeddings
* `xformers` - provide libraries to facilitate the work with `transformers` models.
* `chromadb` - vector database
* `accelerate` - run model in a GPU

In [None]:
!pip install -qU transformers accelerate sentence-transformers xformers chromadb datasets faiss-cpu torch

In [None]:
import numpy as np
import pandas as pd

## Load the Dataset

In [None]:
from datasets import load_dataset

data = load_dataset(
    'keivalya/MedQuad-MedicalQnADataset',
    split='train'
)

ChromaDB requires that the data has a unique identifier.

In [None]:
data = data.to_pandas()
data['id'] = data.index
data.head(5)

In [None]:
# some constants
MAX_ROWS = 15000
document = 'Answer'
TOPIC = 'qtype'

subset_data = data.head(MAX_ROWS)

## Import and configure the vector database

In [None]:
import chromadb

# set the path where the vector database will be stored
chroma_client = chromadb.PersistentClient(path='./chroma_db')

## Fill and query the chromaDB database

The data in ChromaDB is stored in collections. If the collection exists, we need to delete it.

In [None]:
collection_name = 'news_collection'

if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
    chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

Now we will add the data to the collection with the following information:
* `documents` to store the content of the `Answer` column in the dataset
* `metadatas` to store a list of topics. Here is the `qtype` column
* `id` to store unique identifiers for each row.

In [None]:
collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x} for x in range(MAX_ROWS)"]
)

Once we have the inforamtion in the database, we can query it and ask for data that matches our needs. The search is done inside the content of the document. The result will be based on the similarity between the search terms and the content of documents.

Metadata is not directly involved in the initial search process, but it can be used to filter or refine the results after retrieval, enabling further customization and precision.

In [None]:
def query_database(query_text, n_results=10):
    results = collection.query(
        query_texts=query_text,
        n_results=n_results
    )
    return results

## Create the semantic cache system

To implement the cache system, we will use FAISS, a library that allows storing embeddings in memory. It is similar to what Chroma does, but without its persistence. Here we will create a class called `SemanticCache` that works with its own encoder and provide the necessary functions for the user to perform queries.

In this class, we first query the cache implemented with FAISS that contains the previous petitions, and if the returned results are above a specified threshold, it will return the content of the cache. Otherwise, it will fetch the result from the Chroma database. The cache is stored in a JSON file.

In [None]:
import faiss
from sentence_transformers import SentenceTransformer
import time
import json

We will implement a `init_cache()` function, which employs the `FlatLS` index. We choose this index because it aligns well with the example. It can be used with vectors of high dimensions, consumes minimal memory, and performs well with small datasets.

There are other indexing options available with FAISS:
* `FlatL2` or `FlatIP` - well-suited for small datasets, it may not be the fastest, but its memory consumption is not excessive
* `LSH` - works effectively with small datasets and is recommended for use with vectors of up to 128 dimensions
* `HNSW` - very fast but demands a substantial amount of RAM
* `IVF` - works well with large datasets without consuming much memory or compromising performance

In [None]:
def init_cache():
    index = faiss.IndexFlatL2(768)
    if index.is_trained:
        print('Index trained')

    # Initialize sentence transformer model
    encoder = SentenceTransformer('all-mpnet-base-v2')

    return index, encoder

We also need a `retrieve_cache` function to retrieve a JSON file from disk in case there is a need to reuse the cache across sessions.

In [None]:
def retrieve_cache(json_file):
    try:
        with open(json_file, 'r') as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {'question': [], 'embeddings': [], 'answers': [], 'response_text':[]}

    return cache

The `store_cache` function saves the file containing the cache data to disk.

In [None]:
def store_cache(json_file, cache):
    with open(json_file, 'w') as file:
        json.dump(cache, file)

In [None]:
class SemanticCache:
    def __init__(self, json_file='cache_file.json', threshold=0.35, max_repsonse=100, eviction_policy=None):
        """Initialize the semantic cache

        Parameters
        ----------
        json_file: str
            The name of the JSON file where the cache is stored
        threshold: float
            The threshold for the Euclidean distance to determine if a question is similar
        max_response: int
            The maximum number of responses the cache can store
        eviction_policy: str
            The policy for evicting items from the cache.
            This can be any policy, but 'FIFO' (First In First Out) has been implemented for now.
            If None, no eviction policy will be applied
        """
        # Initialize FAISS index with Eucliean distance
        self.index, self.encoder = init_cache()

        # Set Eucliean distance threshold
        # A distance of 0 means identical sentences
        # We only return from cache sentences under this threshold
        self.euclidean_threshold = threshold

        self.json_file = json_file
        self.cache = retrieve_cache(self.json_file)
        self.max_response = max_response
        self.eviction_policy = eviction_policy

    def evict(self):
        """Evict an item from the cache based on the eviction policy"""
        if self.eviction_policy and len(self.cache['questions']) > self.max_size:
            for _ in range((len(self.cache['questions']) - self.max_response)):
                if self.eviction_policy = 'FIFO':
                    self.cache['questions'].pop(0)
                    self.cache['embeddings'].pop(0)
                    self.cache['answers'].pop(0)
                    self.cache['response_text'].pop(0)

    def ask(self, question: str) -> str:
        # Method to retrieve an answer from the cache or generate a new one
        state_time = time.time()
        try:
            # First we obtain the embeddings corresponding to the user question
            embedding = self.encoder.encode([question])

            # Search for the nearest neighbor in the index
            self.index.nprobe = 8
            D, I = self.index.search(embedding, 1)

            if D[0] >= 0:
                if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                    row_id = int(I[0][0])

                    print('Answer recovered from Cache.')
                    print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
                    print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
                    print(f"response_text: {self.cache['response_text'][row_id]}")

                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print(f"TIme taken: {elapsed_time:.3f} seconds")

                    return self.cache['response_test'][row_id]

            # Handle the case when there are not enough results
            # or Euclidean distance is not met, asking to chromaDB
            answer = query_database([question], 1)
            response_text = answer['documents'][0][0]

            self.cache['questions'].append(question)
            self.cache['embeddings'].append(embedding[0].tolist())
            self.cache['answers'].append(answer)
            self.cache['response_text'].append(response_text)

            print('Answer recovered from ChromaDB.')
            print(f"response_text: {response_text}")

            self.index.add(embedding)
            self.evict()

            store_cache(self.json_file, self.cache)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"TIme taken: {elapsed_time:.3f} seconds")

            return response_text

        except Exception as e:
            raise RuntimeError(f"Error during 'ask' method: {e}")

Test the `SemanticCache` class

In [None]:
cache = SemanticCache('4cache.json')

In [None]:
results = cache.ask('How do vaccines work?')

If we send a second question that is quite different, the response should also be retreived from ChromaDB. This is because the question stored previously is so dissimilar that it would surpass the specified threshold in terms of Euclidean distance.

In [None]:
results = cache.ask('Explain briefly what is a Sydenham chorea')

Now if we test it with a question very similar to the one we just asked. The response should come directly from the cache without the need to access the ChromaDB database.

In [None]:
results = cache.ask('Briefly explain me what is a Sydenham chorea')

The previous two questions are so similar that their Euclidean distance is truly minimal, almost as if they were identical.

If we ask a more distinct question,

In [None]:
question_ref = 'Write in 20 words what is a Sydenham chorea'
results = cache.ask(question_ref)

We see that the Euclidean distance has increased, but it still remains within the specified threshold.

## Load the model and create the prompt

In [None]:
from torch import cuda, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = f"cuda:{cuda.current_device()}" if cuda.is_avilable() else 'cpu'

model_id = 'google/gemma-2b-it'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='cuda',
    torch_dtype=torch.bfloat16
)

## Create the extended prompt

To create the prompt, we use the result from query the `SemanticCache` class and the question introduced by the user.

The prompt have two parts, the **relevant context** that is the information recovered from the database and the **user's question**. We only need to put the two parts together to create the prompt then send it to the model.

In [None]:
prompt_template = f"Relevant context: {results}\n\nThe user's question: {question_ref}"
prompt_template

In [None]:
input_ids = tokenizer(
    prompt_template,
    return_tensors='pt'
).to('cuda')

outputs = model.generate(
    **input_ids,
    max_new_tokens=256
)
print(tokenizer.decode(outputs[0]))