# 3 - RAG Retriever
* Notebook by Adam Lang
* Notebook was adopted from the Databricks webinar in June 2024 that streamed on the Databricks YouTube channel.
* This is the 3rd notebook and the next step after creating the delta tables, vector endpoint, and vector index table. 
* Date: 4/30/2025

## 1. Install Dependencies

In [0]:
%pip install mlflow==2.10.1 langchain==0.1.5 databricks-vectorsearch==0.22 databricks-sdk==0.18.0 mlflow[databricks]

## restart kernel after installing
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


## 2. Set Workspace Parameters
* In order to do this you will need to creat a secret scope. The steps to do this using the Databricks CLI or SDK are:

1. Create a Secret Scope: First, create a secret scope if you don't already have one. You can do this using the Databricks CLI or the Databricks SDK for Python.

Using Databricks CLI:
```
%bash
databricks secrets create-scope --scope demo
```

2. Add a Secret to the Scope: Add your secret token to the created scope. You can do this using the Databricks CLI.

Using Databricks CLI:
```
%bash
databricks secrets put --scope demo --key azure3-token --string-value <your-secret-token>
```

3. Access token in notebook with code below. 

## 3. Build Retriever

In [0]:
from databricks.vector_search.client import VectorSearchClient
from langchain_community.vectorstores import DatabricksVectorSearch
from langchain_community.embeddings import DatabricksEmbeddings

# 0. Set the Databricks host environment variable
os.environ["DATABRICKS_HOST"] = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()

## 1. init embedding model -- same embedding model as vector index
embedding_model = DatabricksEmbeddings(endpoint="databricks-gte-large-en")

## vector search endpoint name
VECTOR_SEARCH_ENDPOINT_NAME = "doc_vector_endpoint"
INDEX_NAME='workspace.llm_rag_demos.docs_idx' ## name of delta table with vector index

## 2. Retriever function
# Define the retriever function
def get_retriever():
    vsc = VectorSearchClient(workspace_url=os.environ["DATABRICKS_HOST"], disable_notice=True)
    vsc_index = vsc.get_index(endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME, index_name=INDEX_NAME)
    vectorstore = DatabricksVectorSearch(vsc_index, text_column="text", embedding=embedding_model)
    return vectorstore.as_retriever()
    


## 4. Create RAG using LangChain
* Hugging Face model card for the databricks model: https://huggingface.co/databricks/dbrx-instruct
* Note, the model is not considered "multilingual" as it was trained on mostly english text.

In [0]:
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatDatabricks

## 1. define LLM chat model
llm = ChatDatabricks(endpoint="databricks-dbrx-instruct",
                     max_tokens=400)

## 2. Chat Template
CHAT_TEMPLATE = """You are an expert assistant for home appliance users. You are tasked with answering hot to perform maintenance and troubleshooting questions about home appliances that you will have data on. If the question you are given is not related to one of these topics, please decline to answer. If you don't know the answer, simply say "I don't know", do not try to make up an answer without context. If the question appears to be for an appliance you don't have data on, please state "I do not have any data to answer that question." Keep your answer as concise as possible. Provide all answers in English only. If the question is in another language such as spanish than translate it and answer in English.
Use the following context to answer the question at the end.
{context}
Question: {question}
Answer:
"""
## 3. Set up PromptTemplate
prompt = PromptTemplate(template=CHAT_TEMPLATE, input_variables=["context", "question"])

## 4. Create RAG chain 
chain = RetrievalQA.from_chain_type(llm=llm,
                                  chain_type="stuff",
                                  retriever=get_retriever(),
                                  chain_type_kwargs={"prompt": prompt})



## 5. Test LangChain Retriever


In [0]:
## test query
question = {"query": "What does a SUDS message mean?"}
answer = chain.run(question)
print(answer)

The "SUDS" message on a washing machine indicates that there are too many suds in the machine. This can be caused by using too much detergent or the wrong type of detergent. To fix this issue, you can try running a rinse cycle to remove the excess suds, and then adjust the amount or type of detergent you are using. If the problem persists, you may need to consult the user manual or contact a service center for further assistance.


## 6. Register Chain as a Model in the Databricks Unity Catalog

In [0]:
%python
from mlflow.models import infer_signature
import mlflow
import langchain

## setup registry
mlflow.set_registry_uri("databricks-uc")
model_name="workspace.llm_rag_demos.rag_model" ## use same directory as data


with mlflow.start_run(run_name="appliance_chatbot_run") as run:
  signature = infer_signature(question, answer)
  model_info = mlflow.langchain.log_model(
    chain,
    loader_fn=lambda: get_retriever, ## pass token to loader function
    artifact_path="chain",
    registered_model_name=model_name,
    pip_requirements=[
      "mlflow==" + mlflow.__version__,
      "langchain==" + langchain.__version__,
      "databricks-vectorsearch",
    ],
    input_example=question,
    signature=signature
  )

## show model_info
display(model_info)

2025/04/30 19:39:15 INFO mlflow.models.utils: We convert input dictionaries to pandas DataFrames such that each key represents a column, collectively constituting a single row of data. If you would like to save data as multiple rows, please convert your data to a pandas DataFrame before passing to input_example.
Registered model 'workspace.llm_rag_demos.rag_model' already exists. Creating a new version of this model...


[0;31m---------------------------------------------------------------------------[0m
[0;31mMlflowException[0m                           Traceback (most recent call last)
File [0;32m<command-5282819306779848>, line 12[0m
[1;32m     10[0m [38;5;28;01mwith[39;00m mlflow[38;5;241m.[39mstart_run(run_name[38;5;241m=[39m[38;5;124m"[39m[38;5;124mappliance_chatbot_run[39m[38;5;124m"[39m) [38;5;28;01mas[39;00m run:
[1;32m     11[0m   signature [38;5;241m=[39m infer_signature(question, answer)
[0;32m---> 12[0m   model_info [38;5;241m=[39m mlflow[38;5;241m.[39mlangchain[38;5;241m.[39mlog_model(
[1;32m     13[0m     chain,
[1;32m     14[0m     loader_fn[38;5;241m=[39m[38;5;28;01mlambda[39;00m: get_retriever, [38;5;66;03m## pass token to loader function[39;00m
[1;32m     15[0m     artifact_path[38;5;241m=[39m[38;5;124m"[39m[38;5;124mchain[39m[38;5;124m"[39m,
[1;32m     16[0m     registered_model_name[38;5;241m=[39mmodel_name,
[1;32m     17[