In [0]:
%pip install dspy pyyaml
%pip install --upgrade --force-reinstall databricks-vectorsearch 
dbutils.library.restartPython()

In [0]:
import json
import requests
from pyspark.sql import Row
import yaml

from databricks.vector_search.client import VectorSearchClient

import dspy
from dspy import Databricks
from dspy.retrieve.databricks_rm import DatabricksRM
import os
import pandas as pd

import mlflow
from mlflow import MlflowClient

In [0]:
# Load the configuration
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)  # Use safe_load to avoid arbitrary code execution
    return config

# Access the configuration
config = load_config("config.yaml")
secrets = load_config(".secrets")

# Access the configuration values
user_path = config["user"]["path"]
catalog_name = config["database"]["catalog_name"]
schema_name = config["database"]["schema_name"]
source_table_name = config["database"]["source_table_name"]

vector_search_endpoint_name = config["vector_search"]["endpoint_name"]
vs_index = config["vector_search"]["index"]

embedding_model_endpoint = config["models"]["embedding_model_endpoint"]
llm = config["models"]["llm"]

API_TOKEN = secrets["auth"]["api_token"]

registered_model_name = config["registered"]["model_name"]
registered_endpoint_name = config["registered"]["endpoint_name"]


API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get() 


### Create vector database from files

In [0]:
#get previous agents
with open(f"/Workspace/Users/{user_path}/ADAS_DSPy/script_archive/dspy_modules.json", "r") as file:
    data = json.load(file)
all_modules = []

for d in data:
    fp = d['file_path']
    for n, m in enumerate(d['modules']):
        all_modules.append((f"{fp}_{n}", m))

source_df = spark.createDataFrame([Row(id=item[0], module=item[1]) for item in all_modules])
display(source_df)

In [0]:
source_table_fullname = f"{catalog_name}.{schema_name}.{source_table_name}"
# Vector index
vs_index_fullname = f"{catalog_name}.{schema_name}.{vs_index}"

In [0]:
source_df.write.format("delta").option("delta.enableChangeDataFeed", "true").saveAsTable(source_table_fullname)

In [0]:
#building vector store
vsc = VectorSearchClient()

vsc.create_endpoint(
    name=vector_search_endpoint_name,
    endpoint_type="STANDARD"
)

endpoint = vsc.get_endpoint(
  name=vector_search_endpoint_name)

index = vsc.create_delta_sync_index(
  endpoint_name=vector_search_endpoint_name,
  source_table_name=source_table_fullname,
  index_name=vs_index_fullname,
  pipeline_type='TRIGGERED',
  primary_key="id",
  embedding_source_column="module",
  embedding_model_endpoint_name=embedding_model_endpoint
)
index.describe()


### Defining agent

In [0]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()

        # Define the retriever that fetches relevant documents from the Databricks Vector Search index
        self.retriever = DatabricksRM(
            databricks_endpoint=os.getenv("DATABRICKS_HOST"),
            databricks_index_name=vs_index_fullname,
            databricks_token=os.getenv("DATABRICKS_TOKEN"), 
            text_column_name="module",
            docs_id_column_name="id",
            k=2,
        )
        # Define the language model that will be used for response generation
        self.lm = dspy.LM(llm)

        # Define the program signature
        # The response generator will be provided with a "context" and a "request",
        # and will return a "response"
        signature = "example_modules, module_request -> new_module"

        # Define response generator
        self.response_generator = dspy.Predict(signature)

    def forward(self, request):

        # Obtain context by executing a Databricks Vector Search query
        retrieved_context = self.retriever(request)

        # Generate a response using the language model defined in the __init__ method
        with dspy.context(lm=self.lm):
            response = self.response_generator(
                example_modules=retrieved_context.docs, module_request=request
            )

        return response



### Deploying agent to endpoint

In [0]:
rag = RAG()
rag.forward("Program to do multi-hop reasoning")

In [0]:
# Start an MLflow run
with mlflow.start_run():
    # Log the model
    model_info = mlflow.dspy.log_model(
        dspy_model=rag,
        registered_model_name = registered_model_name,
        artifact_path="model",
        input_example="Program to do multi-hop reasoning",
        pip_requirements=["dspy"]
    )

In [0]:
# Set the name of the MLflow endpoint
endpoint_name = registered_endpoint_name

# Name of the registered MLflow model
model_name = registered_model_name

# Get the latest version of the MLflow model
mlflow_client = MlflowClient()
model_version = mlflow_client.search_model_versions(filter_string=f"name='{registered_model_name}'")[0].version

# Specify the type of compute (CPU, GPU_SMALL, GPU_LARGE, etc.)
workload_type = "CPU_SMALL" 

# Specify the scale-out size of compute (Small, Medium, Large, etc.)
workload_size = "Small" 

# Specify Scale to Zero(only supported for CPU endpoints)
scale_to_zero = False 

# Get the API endpoint and token for the current notebook context
serving_host = spark.conf.get("spark.databricks.workspaceUrl")

In [0]:
#Deploy the model to a model serving endpoint in databricks
data = {
    "name": endpoint_name,
        "served_entities": [
            {
                "entity_name": registered_model_name,
                "entity_version": model_version,
                "workload_size": workload_size,
                "scale_to_zero_enabled": scale_to_zero,
                "workload_type": workload_type,
                "environment_vars": {
                    "DATABRICKS_TOKEN": f"{API_TOKEN}",
                    "DATABRICKS_HOST": f"{API_ROOT}"
                }
            }
        ]
}

headers = {"Context-Type": "text/json", "Authorization": f"Bearer {API_TOKEN}"}

url = f"https://{serving_host}/api/2.0/serving-endpoints/{endpoint_name}/config"

response = requests.put(
    url=url, json=data, headers=headers
)

print(json.dumps(response.json(), indent=4)) 