### Install modules

In [0]:
#!pip install mlflow==2.9.0 langchain==0.1.5 databricks-vectorsearch==0.22 databricks-sdk==0.18.0 mlflow[databricks] langchain-community
!pip install mlflow langchain databricks-vectorsearch databricks-sdk mlflow[databricks] dotenv langchain-community

In [0]:
%python
dbutils.library.restartPython()

In [0]:
!python --version

Python 3.11.11


### Import modules

In [0]:

import os
from dotenv import load_dotenv

# langchain components ---------

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

# Databricks models and embeddings -------------

from langchain_community.chat_models import ChatDatabricks
from langchain_community.embeddings import DatabricksEmbeddings

# vector search ----------
from databricks.vector_search.client import VectorSearchClient
from langchain_community.vectorstores import DatabricksVectorSearch

# MLFLow -------------
from mlflow.models import infer_signature
import mlflow
import langchain



### Set env

In [0]:
host = "https://" + spark.conf.get("spark.databricks.workspaceUrl")
os.environ["DATABRICKS_HOST"] = host


In [0]:
load_dotenv()

### Databricks model and embeddings

In [0]:
chat_model = ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct", max_tokens = 200)

In [0]:
embedding_model = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

In [0]:
index_name="analytics.models.docs_idxs"
VECTOR_SEARCH_ENDPOINT_NAME="document_vector_endpoint"

### Create Vector DB retriever function

In [0]:

    
def get_retriever(persist_dir: str = None):
    os.environ["DATABRICKS_HOST"] = host
    #Get the vector search index
    vsc = VectorSearchClient(workspace_url=host, personal_access_token=os.environ["DATABRICKS_TOKEN"])
    vs_index = vsc.get_index(
        endpoint_name=VECTOR_SEARCH_ENDPOINT_NAME,
        index_name=index_name
    )

    # Create the retriever
    vectorstore = DatabricksVectorSearch(
        vs_index, text_column="text", embedding=embedding_model
    )
    return vectorstore.as_retriever()



### Test Langchain RetrivalQA

In [0]:
 
TEMPLATE = """You are an HR Assistant chatbot designed to provide accurate and up-to-date answers to employees HR-related queries. Your responses should be clear, concise, and directly based on the company's official policies. If the question is not related to one of these topics, kindly decline to answer. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.  Provide all answers only in English.
Use the following pieces of context to answer the question at the end:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=TEMPLATE, input_variables=["context", "question"])

chain = RetrievalQA.from_chain_type(
    llm=chat_model,
    chain_type="stuff",
    retriever=get_retriever(),
    chain_type_kwargs={"prompt": prompt}
)




In [0]:
question = {"query": "How do I do expenses when I travel?"}
answer = chain.invoke(question)
print(answer['result'])

### Model Publish

In [0]:



mlflow.set_registry_uri("databricks-uc")
model_name = "analytics.models.chatbot"




with mlflow.start_run(run_name="chatbot_hr",nested=True) as run:
    signature = infer_signature(question, answer)
    model_info = mlflow.langchain.log_model(
    chain,
    loader_fn=get_retriever,
    artifact_path="chain",
    registered_model_name=model_name,
    pip_requirements=[
        "mlflow==" + mlflow.__version__,
        "langchain==" + langchain.__version__,
        "databricks-vectorsearch",
        "langchain-community"

          
         ],
    input_example=question,
    signature=signature
    )

mlflow.end_run()





          