In [0]:
%pip install databricks-vectorsearch
%pip install langchain langchain_community
dbutils.library.restartPython()

In [0]:
dbutils.widgets.text('catalog_name','geekcoders_dev')
catalog_name=dbutils.widgets.get('catalog_name')


In [0]:
vs_endpoint_name='databricks_llm_geekcoders'
vs_index_name=f'{catalog_name}.gold.master_dim_patient_info_index'

In [0]:
from databricks.vector_search.client import VectorSearchClient
vsc=VectorSearchClient(disable_notice=True)
vs_index=vsc.get_index(endpoint_name=vs_endpoint_name,index_name=vs_index_name)
vs_index.similarity_search(columns=["merged_desc"],query_text='P_258, Brooke,Davis,132 Main St, City 31 ',num_results=3)

In [0]:
from langchain.vectorstores import DatabricksVectorSearch
from langchain.embeddings import DatabricksEmbeddings
from databricks.vector_search.client import VectorSearchClient
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_community.llms import Databricks
import json
import os
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec
import numpy as np
import pandas as pd
from mlflow.deployments import get_deploy_client
from mlflow.models import validate_serving_input

import mlflow

# Initialize Vector Search Client
vsc = VectorSearchClient()

# Define your endpoint and index names
vector_search_endpoint_name=vs_endpoint_name
vector_index_name=vs_index_name

# Define catalog and schema for Unity Catalog
UC_CATALOG = f"{catalog_name}"
UC_SCHEMA = "gold"
MODEL_NAME = "geekcoders_chatbot_model"

# Define retriever function
def get_retriever():
    vs_index = vsc.get_index(
        endpoint_name=vector_search_endpoint_name,
        index_name=vector_index_name
    )
    
    vectorstore = DatabricksVectorSearch(
        vs_index, 
        text_column="merged_desc"
    )
    
    return vectorstore.as_retriever(search_kwargs={"k": 3})

# Create the retriever
retriever = get_retriever()

def transform_input(**request):
    messages = [
        {"role": "system", "content": "You are an AI assistant that extracts information from text."},
        {"role": "user", "content": request["prompt"]}
    ]
    
    request["messages"] = messages
    del request["prompt"]
    return request

# Initialize the LLM
llm = Databricks(
    endpoint_name="databricks-meta-llama-3-1-405b-instruct",
    transform_input_fn=transform_input
)

def build_qa_chain():



    template = """You are an AI assistant that extracts information from text. 

### Instruction:
Extract the relevant details from the provided context. Parse the `merged_desc` field, which contains values joined with the `|` character. If no relevant information is found, respond with: "Not available."

return only the final answer with meaningfull sentence

### Context:
{context}

### Question:
{question}

### Response:
"""

    
    prompt = PromptTemplate(input_variables=['context', 'question'], template=template)
    
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=get_retriever(),
        return_source_documents=True,
        chain_type_kwargs={
            "verbose": True,
            "prompt": prompt
        }
    )
    
    return qa_chain

class ChatBotModel(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.chain = None

    def load_context(self, context):
        # Initialize the chain when the model is loaded
        self.chain = build_qa_chain()
        
    def predict(self, context, model_input):
        # Process the input data consistently
        if isinstance(model_input, pd.DataFrame):
            question = model_input["question"].iloc[0]
        else:
            # Handle direct string input or other formats
            question = model_input
            
        # Get the result
        result = self.chain({"query": question})
        
        # Consistently extract and format the answer
        answer = result["result"]
       
        return answer
    
    
        


In [0]:
workspace_url = dbutils.notebook.entry_point.getDbutils().notebook().getContext().browserHostName().get()
token=dbutils.secrets.get(scope='geekcoders_llm',key='databricks_token')



In [0]:
def save_model_to_uc():
    print("Saving model to Unity Catalog...")
    
    # Define model signature
    input_schema = Schema([ColSpec("string", "question")])
    output_schema = Schema([ColSpec("string", "answer")])
    signature = ModelSignature(inputs=input_schema, outputs=output_schema)
   
    
    # Set tracking URI to use UC
    mlflow.set_registry_uri("databricks-uc")
    
    # Start a new MLflow run
    with mlflow.start_run() as run:
        # Log model to Unity Catalog
        model_info = mlflow.pyfunc.log_model(
            artifact_path="model",
            python_model=ChatBotModel(),
            signature=signature,
            registered_model_name=f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}",
            pip_requirements=["langchain", "databricks-vectorsearch", "pandas", "numpy", "mlflow","langchain_community", "databricks-sdk"]
        )
        
        print(f"Model saved to: {model_info.model_uri}")
        return (model_info.model_uri ,model_info)
    
def deploy_to_serving_endpoint(model_uri,model_version, endpoint_name="geekcoders-chatbot-model"):
    from databricks.sdk import WorkspaceClient
    from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput
    mlflow.set_registry_uri("databricks-uc")
    client = get_deploy_client("databricks")

    try:
        if(client.get_endpoint("geekcoders-chatbot-model")['state']['ready']=='READY'):
             endpoint=client.update_endpoint_config("geekcoders-chatbot-model", config={
            "served_entities": [
                {
                    "entity_name": f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}",
                    "entity_version": model_version,
                    "workload_size": "Small",
                    "workload_type": "CPU",
                    "scale_to_zero_enabled": True,
                    "environment_vars": {"DATABRICKS_HOST": workspace_url,
         "DATABRICKS_TOKEN": token}
                }
                
            ], 
           
        }
    )
    except Exception as e:
        endpoint = client.create_endpoint(
            name=endpoint_name,
            config={
                "served_entities": [
                    {
                        "entity_name": f"{UC_CATALOG}.{UC_SCHEMA}.{MODEL_NAME}",
                        "entity_version": model_version,
                        "workload_size": "Small",
                        "workload_type": "CPU",
                        "scale_to_zero_enabled": True,
                        "environment_vars": {"DATABRICKS_HOST": workspace_url,
         "DATABRICKS_TOKEN": token}
                    }
                    
                ], 
               
            }
        )
    
    print(f"Model deployed to endpoint: {endpoint_name}")
    return endpoint_name

#testing  
# def answer_question(question):
#     qa_chain = build_qa_chain()
#     result = qa_chain({"query": question})
#     return result["result"]
    

def main():
    
    
    # Save to Unity Catalog
    model_uri = save_model_to_uc()
    
   
    
    # Deploy to serving endpoint
    endpoint_name = deploy_to_serving_endpoint(model_uri[0],model_uri[1].registered_model_version)
    
    
   

if __name__ == "__main__":
    main()
    


In [0]:
status = None
client = get_deploy_client("databricks")
while True:
    state = client.get_endpoint("geekcoders-chatbot-model")['state']['ready']
    print(state)
    if state == 'READY':
        break