<a target="_parent" href="https://colab.research.google.com/github/gretelai/gretel-blueprints/blob/main/docs/notebooks/demo/dbrx-blog-demo/gretel-demo-chat-bot-rag.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Gretel x Databricks LLM Chatbot with RAG

This notebook can be used in tandem with the workflow notebook to work through the demo. 
[Original Source](https://notebooks.databricks.com/demos/llm-rag-chatbot/index.html#) 

In [None]:
%pip install -U --quiet databricks-sdk==0.28.0 databricks-agents mlflow-skinny mlflow mlflow[gateway] databricks-vectorsearch langchain==0.2.1 langchain_core==0.2.5 langchain_community==0.2.4
dbutils.library.restartPython()

In [None]:
from databricks.vector_search.client import VectorSearchClient
import time

VECTOR_SEARCH_ENDPOINT_NAME = input("Enter the name of the Vector Search endpoint to use:")

def endpoint_exists(client, endpoint_name):
    endpoints = client.list_endpoints()
    if 'endpoints' in endpoints:
        return any(endpoint['name'] == endpoint_name for endpoint in endpoints['endpoints'])
    else:
        return False

def wait_for_vs_endpoint_to_be_ready(client, endpoint_name, timeout=1200, interval=10):
    start_time = time.time()
    while time.time() - start_time < timeout:
        endpoint = client.get_endpoint(endpoint_name)
        if 'endpoint_status' in endpoint and endpoint['endpoint_status']['state']== 'ONLINE':
            return True
        time.sleep(interval)
    raise TimeoutError(f"Endpoint {endpoint_name} did not become ready within {timeout} seconds.")

vsc = VectorSearchClient(disable_notice=True)

if not endpoint_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME):
    vsc.create_endpoint(name=VECTOR_SEARCH_ENDPOINT_NAME, endpoint_type="STANDARD")

wait_for_vs_endpoint_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME)
print(f"Endpoint named {VECTOR_SEARCH_ENDPOINT_NAME} is ready.")

In [None]:
from databricks.sdk import WorkspaceClient
import databricks.sdk.service.catalog as c

## Should be where the synthetic data was written to in the previous notebook
catalog = input('Catalog to read from:')
db = input('Schema to read from:')

def index_exists(client, endpoint_name, index_name):
    indices = client.list_indexes(endpoint_name)
    if 'vector_indexes' in indices:
        return any(index['name'] == index_name for index in indices['vector_indexes'])
    else:
        return False

def wait_for_index_to_be_ready(client, endpoint_name, index_name, timeout=1200, interval=10):
    start_time = time.time()
    while time.time() - start_time < timeout:
        index = client.get_index(endpoint_name, index_name)
        if index.describe().get('status').get('detailed_state').startswith('ONLINE'):
            return True
        time.sleep(interval)
    raise TimeoutError(f"Endpoint {endpoint_name} did not become ready within {timeout} seconds.")


table_name = input('Table to index:')
#The table we'd like to index
source_table_fullname = f"{catalog}.{db}.{table_name}"
index_name = input('Index Name:')
# Where we want to store our index
vs_index_fullname = f"{catalog}.{db}.{index_name}"

if not index_exists(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname):
  print(f"Creating index {vs_index_fullname} on endpoint {VECTOR_SEARCH_ENDPOINT_NAME}...")
  vsc.create_delta_sync_index(
    endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
    index_name=vs_index_fullname,
    source_table_name=source_table_fullname,
    pipeline_type="TRIGGERED",
    primary_key="customer_id",
    embedding_source_column='customer_query', #The column containing our text
    embedding_model_endpoint_name='databricks-gte-large-en' #The embedding endpoint used to create the embeddings
  )
  #Let's wait for the index to be ready and all our embeddings to be created and indexed
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
else:
  #Trigger a sync to update our vs content with the new data saved in the table
  wait_for_index_to_be_ready(vsc, VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname)
  vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).sync()

print(f"index {vs_index_fullname} on table {source_table_fullname} is ready")

In [None]:
import mlflow.deployments
deploy_client = mlflow.deployments.get_deploy_client("databricks")

question = "How can I close my account"

results = vsc.get_index(VECTOR_SEARCH_ENDPOINT_NAME, vs_index_fullname).similarity_search(
  query_text=question,
  columns=["interaction_type", "intent", "customer_query", "response"],
  num_results=1)
docs = results.get('result', {}).get('data_array', [])
docs

In [None]:
import yaml
rag_chain_config = {
    "databricks_resources": {
        "llm_endpoint_name": "databricks-dbrx-instruct",
        "vector_search_endpoint_name": VECTOR_SEARCH_ENDPOINT_NAME,
    },
    "input_example": {
        "messages": [{"content": "Sample user question", "role": "user"}]
    },
    "llm_config": {
        "llm_parameters": {"max_tokens": 1500, "temperature": 0.01},
        "llm_prompt_template": "You are a trusted AI assistant that helps answer questions based only on the provided information. If you do not know the answer to a question, you truthfully say you do not know. Here is the history of the current conversation you are having with your user: {chat_history}. And here is some context which may or may not help you answer the following question: {context}.  Answer directly, do not repeat the question, do not start with something like: the answer to the question, do not add AI in front of your answer, do not say: here is the answer, do not mention the context or the question. Based on this context, answer this question: {question}",
        "llm_prompt_template_variables": ["context", "chat_history", "question"],
    },
    "retriever_config": {
        "chunk_template": "Passage: {chunk_text}\n",
        "data_pipeline_tag": "poc",
        "parameters": {"k": 5, "query_type": "ann"},
        "schema": {"chunk_text": "customer_query", "primary_key": "customer_id"},
        "vector_search_index": f"{vs_index_fullname}",
    },
}
try:
    with open('rag_chain_config.yaml', 'w') as f:
        yaml.dump(rag_chain_config, f)
except:
    print('pass to work on build job')
model_config = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')

In [None]:
%%writefile chain.py
import os
import mlflow
from operator import itemgetter
from databricks.vector_search.client import VectorSearchClient
from langchain_community.chat_models import ChatDatabricks
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

## Enable MLflow Tracing
mlflow.langchain.autolog()

# Return the string contents of the most recent message from the user
def extract_user_query_string(chat_messages_array):
    return chat_messages_array[-1]["content"]

def extract_previous_messages(chat_messages_array):
    messages = "\n"
    for msg in chat_messages_array[:-1]:
        messages += (msg["role"] + ": " + msg["content"] + "\n")
    return messages

def combine_all_messages_for_vector_search(chat_messages_array):
    return extract_previous_messages(chat_messages_array) + extract_user_query_string(chat_messages_array)

#Get the conf from the local conf file
model_config = mlflow.models.ModelConfig(development_config='rag_chain_config.yaml')

databricks_resources = model_config.get("databricks_resources")
retriever_config = model_config.get("retriever_config")
llm_config = model_config.get("llm_config")

# Connect to the Vector Search Index
vs_client = VectorSearchClient(disable_notice=True)
vs_index = vs_client.get_index(
    endpoint_name=databricks_resources.get("vector_search_endpoint_name"),
    index_name=retriever_config.get("vector_search_index"),
)
vector_search_schema = retriever_config.get("schema")

# Turn the Vector Search index into a LangChain retriever
vector_search_as_retriever = DatabricksVectorSearch(
    vs_index,
    text_column=vector_search_schema.get("chunk_text"),
    columns=[
        vector_search_schema.get("primary_key"),
        vector_search_schema.get("chunk_text"),
        vector_search_schema.get("document_uri"),
    ],
).as_retriever(search_kwargs=retriever_config.get("parameters"))

# Required to:
# 1. Enable the RAG Studio Review App to properly display retrieved chunks
# 2. Enable evaluation suite to measure the retriever
mlflow.models.set_retriever_schema(
    primary_key=vector_search_schema.get("primary_key"),
    text_column=vector_search_schema.get("chunk_text"),
    doc_uri=vector_search_schema.get("document_uri")
)

# Method to format the docs returned by the retriever into the prompt
def format_context(docs):
    chunk_template = retriever_config.get("chunk_template")
    chunk_contents = [
        chunk_template.format(
            chunk_text=d.page_content,
        )
        for d in docs
    ]
    return "".join(chunk_contents)

# Prompt Template for generation
prompt = PromptTemplate(
    template=llm_config.get("llm_prompt_template"),
    input_variables=llm_config.get("llm_prompt_template_variables"),
)

# FM for generation
model = ChatDatabricks(
    endpoint=databricks_resources.get("llm_endpoint_name"),
    extra_params=llm_config.get("llm_parameters"),
)

# RAG Chain
chain = (
    {
        "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
        "context": itemgetter("messages")
        | RunnableLambda(combine_all_messages_for_vector_search)
        | vector_search_as_retriever
        | RunnableLambda(format_context),
        "chat_history": itemgetter("messages") | RunnableLambda(extract_previous_messages)
    }
    | prompt
    | model
    | StrOutputParser()
)

# Tell MLflow logging where to find your chain.
mlflow.models.set_model(model=chain)

# COMMAND ----------

# chain.invoke(model_config.get("input_example")

In [None]:
import os
with mlflow.start_run(run_name=f"dbdemos_rag_quickstart"):
    logged_chain_info = mlflow.langchain.log_model(
        lc_model=os.path.join(os.getcwd(), 'chain.py'),  # Chain code file e.g., /path/to/the/chain.py 
        model_config='rag_chain_config.yaml',  # Chain configuration 
        artifact_path="chain",  # Required by MLflow
        input_example=model_config.get("input_example"),  # Save the chain's input schema.  MLflow will execute the chain before logging & capture it's output schema.
    )

# Test the chain locally
chain = mlflow.langchain.load_model(logged_chain_info.model_uri)
chain.invoke(model_config.get("input_example"))

In [None]:
from databricks import agents
MODEL_NAME = input("Enter the name of the model to be reviewed (e.g., dbdemos_rag_demo):")
MODEL_NAME_FQN = f"{catalog}.{db}.{MODEL_NAME}"

In [None]:
instructions_to_reviewer = f"""### Instructions for Testing the our Databricks Documentation Chatbot assistant

Your inputs are invaluable for the development team. By providing detailed feedback and corrections, you help us fix issues and improve the overall quality of the application. We rely on your expertise to identify any gaps or areas needing enhancement.

1. **Variety of Questions**:
   - Please try a wide range of questions that you anticipate the end users of the application will ask. This helps us ensure the application can handle the expected queries effectively.

2. **Feedback on Answers**:
   - After asking each question, use the feedback widgets provided to review the answer given by the application.
   - If you think the answer is incorrect or could be improved, please use "Edit Answer" to correct it. Your corrections will enable our team to refine the application's accuracy.

3. **Review of Returned Documents**:
   - Carefully review each document that the system returns in response to your question.
   - Use the thumbs up/down feature to indicate whether the document was relevant to the question asked. A thumbs up signifies relevance, while a thumbs down indicates the document was not useful.

Thank you for your time and effort in testing our assistant. Your contributions are essential to delivering a high-quality product to our end users."""

def wait_for_model_serving_endpoint_to_be_ready(endpoint_name, timeout=1200, interval=10):
    client = WorkspaceClient()
    start_time = time.time()
    while time.time() - start_time < timeout:
        endpoint = client.serving_endpoints.get(name=endpoint_name)
        if('ready' in dir(endpoint.state)):
            return True
        time.sleep(interval)
    raise TimeoutError(f"Endpoint {endpoint_name} did not become ready within {timeout} seconds.")

# Register the chain to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_chain_info.model_uri, name=MODEL_NAME_FQN)

# Deploy to enable the Review APP and create an API endpoint
deployment_info = agents.deploy(model_name=MODEL_NAME_FQN, model_version=uc_registered_model_info.version, scale_to_zero=True)

# Add the user-facing instructions to the Review App
agents.set_review_instructions(MODEL_NAME_FQN, instructions_to_reviewer)

wait_for_model_serving_endpoint_to_be_ready(deployment_info.endpoint_name)