# GenAI with Azure Databricks - Developing RAG System

### Loading the csv file into the DBFS (Databricks File System)

In [None]:
 %sh
 rm -r /dbfs/rag_lab
 mkdir /dbfs/rag_lab
 wget -O /dbfs/rag_lab/diabetes_faq.csv https://raw.githubusercontent.com/kuljotSB/DatabricksUdemyCourse/refs/heads/main/GenAI/diabetes_treatment_faq.csv

### Loading the csv file into a dataframe

In [None]:
from pyspark.sql.functions import *

df = spark.read.load('/rag_lab/diabetes_faq.csv', format='csv', header=True)
display(df.limit(10))
df.printSchema()

### Creating an Azure OpenAI Client


In [None]:
from openai import AzureOpenAI
import json

openai_endpoint = "<YOUR-AZURE-OPENAI-ENDPOINT>"
openai_key = "<YOUR-AZURE-OPENAI-KEY>"

client = AzureOpenAI(
    api_key = openai_key,
    api_version = "2024-02-15-preview",
    azure_endpoint = openai_endpoint
)

### Saving the updated/new dataframe into ADLS as parquet storage

In [None]:
# Save the updated DataFrame as a Parquet file or table
df.write.mode("overwrite").parquet("/rag_lab/diabetes_faq.parquet")
df_with_embeddings.write.format("delta").mode("overwrite").saveAsTable("<catalog_name>.default.diabetes_faq_table")


### Installing the databricks vectorsearch SDK

In [None]:
%pip install databricks-vectorsearch

### Restarting our python environment

In [None]:
dbutils.library.restartPython()

### Developing the Cluster managed Vector index

In [None]:
from databricks.vector_search.client import VectorSearchClient

vector_client = VectorSearchClient()

vector_client.create_endpoint(
     name="vector_search_endpoint",
     endpoint_type="STANDARD"
 )

index = vector_client.create_delta_sync_index(
   endpoint_name="vector_search_endpoint",
   source_table_name="<catalog_name>.default.diabetes_faq_table",
   index_name="<catalog_name>.default.diabetes_faq_index",
   pipeline_type="TRIGGERED",
   primary_key="Topic",
   embedding_source_column="Description",
   embedding_model_endpoint_name="databricks-gte-large-en"
  )

### Triggering our Vector Index - Information Retriever

In [None]:
user_question = "what is diabetes?"

results_dict = vector_index.similarity_search(
            query_text = "{user_question}",
            columns = ["Topic", "Description"],
            num_results=1
          )

content = str(results_dict['result']['data_array'][0])
print(content)

### Developing the Generation Component of our RAG architecture


In [None]:
gpt_response = client.chat.completions.create(
                model="gpt-4", # model = "deployment_name".
                messages=[
                    {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
                    {"role": "user", "content": f"user query : {user_question} and supporting knowledge: {content}"}
                ]
            )
          print(gpt_response.choices[0].message.content)

### Developing the RAG model

In [None]:
import mlflow
from mlflow import pyfunc

class RAGModel(pyfunc.PythonModel):
      def __init__(self, vector_index):
          self.vector_index=vector_index
      
      def retrieve(self, query):
          results_dict = self.vector_index.similarity_search(
            query_text = query,
            columns = ["Topic", "Description"],
            num_results=1
          )

          return results_dict
        
      def chatCompletionsAPI(self, user_query, supporting_knowledge):
          response = client.chat.completions.create(
                model="YOUR_MODEL_NAME", # model = "deployment_name".
                messages=[
                    {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
                    {"role": "user", "content": f"user query : {user_query} and supporting knowledge: {supporting_knowledge}"}
                ]
            )
          return response.choices[0].message.content
      
      def predict(self, context, data):
          query = data["query"]
          retrieved_docs = self.retrieve(query)

           # Convert the dictionary to a DataFrame
          results = spark.createDataFrame([retrieved_docs['result']['data_array'][0]])

           # Extract the string values from the DataFrame column
          text_data = results.select("_2").rdd.flatMap(lambda x: x).collect()

          return self.chatCompletionsAPI(query, text_data)
          


      

### Logging the RAG model to MLflow

In [None]:
with mlflow.start_run():
    rag_model = RAGModel(index)  
    mlflow.pyfunc.log_model("rag_model", python_model=rag_model)
