This solution accelerator notebook is available at [Databricks Industry Solutions](https://github.com/databricks-industry-solutions/semantic-caching).

#Create and deploy a RAG chain with semantic caching

In this notebook, we will build a RAG chatbot with semantic caching. To do this, we first need to create and warm up our cache. We’ll use [Mosaic AI Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html) for semantic caching, taking advantage of its high-performance similarity search. In the following cells, we will create and warm the cache, build a chain with a semantic caching layer, log and register it using MLflow and Unity Catalog, and finally deploy it behind a [Databricks Mosaic AI Model Serving](https://docs.databricks.com/en/machine-learning/model-serving/index.html) endpoint.

## Cluster configuration
We recommend using a cluster with the following specifications to run this solution accelerator:
- Unity Catalog enabled cluster 
- Databricks Runtime 15.4 LTS ML or above
- Single-node cluster: e.g. `m6id.2xlarge` on AWS or `Standard_D8ds_v4` on Azure Databricks.

In [0]:
%pip install -r requirements.txt --quiet
dbutils.library.restartPython()

In [0]:
from config import Config
config = Config()

In [0]:
%run ./99_init $reset_all_data=false

Here, we define environmental variables `HOST` and `TOKEN` for our Model Serving endpoint to authenticate against our Vector Search index. 

In [0]:
import os

HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

os.environ['DATABRICKS_HOST'] = HOST
os.environ['DATABRICKS_TOKEN'] = TOKEN

## Create and warm a cache 

We instantiate a Vector Search client to interact with a Vector Search endpoint to create a cache. This will be an additional Vector Search Index, which, if the cache is hit, can immediately route the question to the answer in the cache.

In [0]:
from databricks.vector_search.client import VectorSearchClient
from cache import Cache

# Create a Vector Search Client
vsc = VectorSearchClient(
    workspace_url=HOST,
    personal_access_token=TOKEN,
    disable_notice=True,
    )

# Initialize the cache
semantic_cache = Cache(vsc, config)

We first delete the cache if it already exists.

In [0]:
semantic_cache.clear_cache()

We then create a cache.

In [0]:
semantic_cache.create_cache()

We finally load the cache with predefined Q&A pairs: i.e., `/data/synthetic_qa.txt`. This synthetic dataset contains a set of questions that have already been answered.

In [0]:
semantic_cache.warm_cache()

## Create and register a chain to MLflow 

The next cell defines our RAG chain with semantic cache using Langchain. When executed, it will write the content to the `chain/chain_cache.py` file, which will then be used to log the chain in MLflow.

In [0]:
%%writefile chain/chain_cache.py
from databricks.vector_search.client import VectorSearchClient
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.chat_models import ChatDatabricks
from operator import itemgetter
from datetime import datetime
from uuid import uuid4
import os
import mlflow
from cache import Cache
from config import Config


# Set up logging
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.getLogger("py4j.java_gateway").setLevel(logging.ERROR)
logging.getLogger("py4j.clientserver").setLevel(logging.ERROR)

# Enable MLflow Tracing
mlflow.langchain.autolog()

# Get configuration
config = Config()

# Connect to Vector Search
vsc = VectorSearchClient(
    workspace_url=os.environ['DATABRICKS_HOST'],
    personal_access_token=os.environ['DATABRICKS_TOKEN'],
    disable_notice=True,
)

# Get the Vector Search index
vs_index = vsc.get_index(
    index_name=config.VS_INDEX_FULLNAME,
    endpoint_name=config.VECTOR_SEARCH_ENDPOINT_NAME,
    )

# Instantiate a Cache object
semantic_cache = Cache(vsc, config)

# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    vs_index,
    text_column="content",
    columns=["id", "content", "url"],
).as_retriever(search_kwargs={"k": 3}) # Number of search results that the retriever returns

# Method to retrieve the context from the Vector Search index
def retrieve_context(qa):
    return vector_search_as_retriever.invoke(qa["question"])

# Enable the RAG Studio Review App and MLFlow to properly display track and display retrieved chunks for evaluation
mlflow.models.set_retriever_schema(primary_key="id", text_column="content", doc_uri="url")

# Method to format the docs returned by the retriever into the prompt (keep only the text from chunks)
def format_context(docs):
    chunk_contents = [f"Passage: {d.page_content}\n" for d in docs]
    return "".join(chunk_contents)

# Create a prompt template for response generation
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", f"{config.LLM_PROMPT_TEMPLATE}"),
        ("user", "{question}"),
    ]
)

# Define our foundation model answering the final prompt
model = ChatDatabricks(
    endpoint=config.LLM_MODEL_SERVING_ENDPOINT_NAME,
    extra_params={"temperature": 0.01, "max_tokens": 500}
)

# Call the foundation model
def call_model(prompt):
    response = model.invoke(prompt)
    semantic_cache.store_in_cache(
        question = prompt.dict()['messages'][1]['content'], 
        answer = response.content
    )
    return response

# Return the string contents of the most recent messages: [{...}] from the user to be used as input question
def extract_user_query_string(chat_messages_array):
    return chat_messages_array[-1]["content"]

# Router to determine which subsequent step to be executed
def router(qa):
    if qa["answer"] == "":
        return rag_chain
    else:
        return (qa["answer"])

# RAG chain
rag_chain = (
    {
        "question": lambda x: x["question"],
        "context": RunnablePassthrough()
        | RunnableLambda(retrieve_context)
        | RunnableLambda(format_context),
    }
    | prompt
    | RunnableLambda(call_model)
)

# Full chain with cache
full_chain = (
    itemgetter("messages")
    | RunnableLambda(extract_user_query_string)
    | RunnableLambda(semantic_cache.get_from_cache)
    | RunnableLambda(router)
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=full_chain)

In this cell, we log the chain to MLflow. Note that this time we are passing `cache.py` and `utils.py` along with `config.py` as dependencies, allowing the chain to also load custom classes and functions needed to another compute environment or to a Model Serving endpoint. MLflow returns a trace of the inference that shows the detail breakdown of the latency and the input/output from each step in the chain.

In [0]:
# Log the model to MLflow
config_file_path = "config.py"
cache_file_path = "cache.py"
utils_file_path = "utils.py"

with mlflow.start_run(run_name=f"rag_chatbot"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model=os.path.join(os.getcwd(), 'chain/chain_cache.py'),  # Chain code file e.g., /path/to/the/chain.py 
        artifact_path="chain",  # Required by MLflow
        input_example=config.INPUT_EXAMPLE,  # MLflow will execute the chain before logging & capture it's output schema.
        code_paths = [cache_file_path, config_file_path, utils_file_path],
    )

# Test the chain locally
chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
chain.invoke(config.INPUT_EXAMPLE)

Let's ask a question to the chain that we know a similar question has not been asked before therefore doesn't exist in the cache. We see in the trace that the entire chain is indeed executed.

In [0]:
chain.invoke({'messages': [{'content': "How does Databricks' feature Genie automate feature engineering for machine learning models?", 'role': 'user'}]})

If we reformulate the question without changing the meaning, we get the response from the cache, as the question has been upserted into the cache. We see this in the trace and the execution time is less than half.

In [0]:
chain.invoke({'messages': [{'content': "What is the role of Databricks' feature Genie in automating feature engineering for machine learning models?", 'role': 'user'}]})

Where to set the similarity threshold -`0.01` in this demo defined in `config.py`- is arguably the most important design decision you need to make for your solution. A threshold that is too high will reduce the hit rate and undermine the effect of semantic caching, but a threshold too low could generate many false positives. There is a fine balance you would need to strike. To make an informed decision, refer to the exploratory data analysis performed in the `00_introduction` notebook.

If we are happy with the chain, we will go ahead and register the chain in Unity Catalog.

In [0]:
# Register to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=config.MODEL_FULLNAME_CACHE)

## Deploy the chain to a Model Serving endpoint

We deploy the chain using custom functions defined in the `utils.py` script.

In [0]:
import utils
utils.deploy_model_serving_endpoint(
  spark, 
  config.MODEL_FULLNAME_CACHE,
  config.CATALOG_CACHE,
  config.LOGGING_SCHEMA_CACHE,
  config.ENDPOINT_NAME_CACHE,
  HOST,
  TOKEN,
  )

Wait until the endpoint is ready. This may take some time (~15 minutes), so grab a coffee!

In [0]:
utils.wait_for_model_serving_endpoint_to_be_ready(config.ENDPOINT_NAME_CACHE)

Once the endpoint is up and running, let's send a request and see how it responds. If the following cell fails with 404 Not Found error, take a minute and try re-running the cell.

In [0]:
import utils
data = {
    "inputs": {
        "messages": [
            {
                "content": "What is Model Serving?",
                "role": "user"
            }
        ]
    }
}
# Now, call the function with the correctly formatted data
utils.send_request_to_endpoint(config.ENDPOINT_NAME_CACHE, data)

In this notebook, we built a RAG chatbot with semantic caching. In the next `04_evaluate` notebook, we will compare the two chains we built. 

© 2024 Databricks, Inc. All rights reserved. The source in this notebook is provided subject to the Databricks License.