In [0]:
%pip install unitycatalog-ai[databricks] databricks-connect

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
dbutils.library.restartPython()

In [0]:
from unitycatalog.ai.core.databricks import DatabricksFunctionClient

client = DatabricksFunctionClient()

In [0]:
# Example catalog and schema names
CATALOG = "mlops_pj"
SCHEMA = "rag_puneetjain"

In [0]:
from databricks.sdk import WorkspaceClient

workspace_client = WorkspaceClient()

secret_scope = "pj"  # Change me!

# Run this if you don't have the API key set to your secrets scope yet

# if secret_scope not in [scope.name for scope in workspace_client.secrets.list_scopes()]:
#     workspace_client.secrets.create_scope(secret_scope)

# my_secret = "Enter API key"

# workspace_client.secrets.put_secret(scope=secret_scope, key="token", string_value=my_secret)

In [0]:
 workspace_client.secrets.get_secret(scope=secret_scope, key="token").value

'ZGFwaTE5YWY4YmUwNDJmZjJiMTRkZjgwN2FlMzUzMjRiNTBi'

In [0]:
%sql
USE CATALOG mlops_pj;
USE SCHEMA rag_puneetjain;
CREATE OR REPLACE FUNCTION _genie_query(databricks_host STRING, 
                  databricks_token STRING,
                  space_id STRING,
                  question STRING,
                  contextual_history STRING)
RETURNS STRING
LANGUAGE PYTHON
COMMENT 'This is a agent that you can converse with to get answers to questions. Try to provide simple questions and provide history if you had prior conversations.'
AS
$$
    import json
    import os
    import time
    from dataclasses import dataclass
    from datetime import datetime
    from typing import Optional
    
    import pandas as pd
    import requests
    
    
    @dataclass
    class GenieResult:
        space_id: str
        conversation_id: str
        question: str
        content: Optional[str]
        sql_query: Optional[str] = None
        sql_query_description: Optional[str] = None
        sql_query_result: Optional[pd.DataFrame] = None
        error: Optional[str] = None
    
        def to_json_results(self):
            result = {
                "space_id": self.space_id,
                "conversation_id": self.conversation_id,
                "question": self.question,
                "content": self.content,
                "sql_query": self.sql_query,
                "sql_query_description": self.sql_query_description,
                "sql_query_result": self.sql_query_result.to_dict(
                    orient="records") if self.sql_query_result is not None else None,
                "error": self.error,
            }
            jsonified_results = json.dumps(result)
            return f"Genie Results are: {jsonified_results}"
    
        def to_string_results(self):
            results_string = self.sql_query_result.to_dict(orient="records") if self.sql_query_result is not None else None
            return ("Genie Results are: \n"
                    f"Space ID: {self.space_id}\n"
                    f"Conversation ID: {self.conversation_id}\n"
                    f"Question That Was Asked: {self.question}\n"
                    f"Content: {self.content}\n"
                    f"SQL Query: {self.sql_query}\n"
                    f"SQL Query Description: {self.sql_query_description}\n"
                    f"SQL Query Result: {results_string}\n"
                    f"Error: {self.error}")
    
    class GenieClient:
    
        def __init__(self, *,
                     host: Optional[str] = None,
                     token: Optional[str] = None,
                     api_prefix: str = "/api/2.0/genie/spaces"):
            self.host = host or os.environ.get("DATABRICKS_HOST")
            self.token = token or os.environ.get("DATABRICKS_TOKEN")
            assert self.host is not None, "DATABRICKS_HOST is not set"
            assert self.token is not None, "DATABRICKS_TOKEN is not set"
            self._workspace_client = requests.Session()
            self._workspace_client.headers.update({"Authorization": f"Bearer {self.token}"})
            self._workspace_client.headers.update({"Content-Type": "application/json"})
            self.api_prefix = api_prefix
            self.max_retries = 300
            self.retry_delay = 1
            self.new_line = "\r\n"
    
        def _make_url(self, path):
            return f"{self.host.rstrip('/')}/{path.lstrip('/')}"
    
        def start(self, space_id: str, start_suffix: str = "") -> str:
            path = self._make_url(f"{self.api_prefix}/{space_id}/start-conversation")
            resp = self._workspace_client.post(
                url=path,
                headers={"Content-Type": "application/json"},
                json={"content": "starting conversation" if not start_suffix else f"starting conversation {start_suffix}"},
            )
            resp = resp.json()
            print(resp)
            return resp["conversation_id"]
    
        def ask(self, space_id: str, conversation_id: str, message: str) -> GenieResult:
            path = self._make_url(f"{self.api_prefix}/{space_id}/conversations/{conversation_id}/messages")
            # TODO: cleanup into a separate state machine
            resp_raw = self._workspace_client.post(
                url=path,
                headers={"Content-Type": "application/json"},
                json={"content": message},
            )
            resp = resp_raw.json()
            message_id = resp.get("message_id", resp.get("id"))
            if message_id is None:
                print(resp, resp_raw.url, resp_raw.status_code, resp_raw.headers)
                return GenieResult(content=None, error="Failed to get message_id")
    
            attempt = 0
            query = None
            query_description = None
            content = None
    
            while attempt < self.max_retries:
                resp_raw = self._workspace_client.get(
                    self._make_url(f"{self.api_prefix}/{space_id}/conversations/{conversation_id}/messages/{message_id}"),
                    headers={"Content-Type": "application/json"},
                )
                resp = resp_raw.json()
                status = resp["status"]
                if status == "COMPLETED":
                    try:
    
                        query = resp["attachments"][0]["query"]["query"]
                        query_description = resp["attachments"][0]["query"].get("description", None)
                        content = resp["attachments"][0].get("text", {}).get("content", None)
                    except Exception as e:
                        return GenieResult(
                            space_id=space_id,
                            conversation_id=conversation_id,
                            question=message,
                            content=resp["attachments"][0].get("text", {}).get("content", None)
                        )
                    break
    
                elif status == "EXECUTING_QUERY":
                    self._workspace_client.get(
                        self._make_url(
                            f"{self.api_prefix}/{space_id}/conversations/{conversation_id}/messages/{message_id}/query-result"),
                        headers={"Content-Type": "application/json"},
                    )
                elif status in ["FAILED", "CANCELED"]:
                    return GenieResult(
                        space_id=space_id,
                        conversation_id=conversation_id,
                        question=message,
                        content=None,
                        error=f"Query failed with status {status}"
                    )
                elif status != "COMPLETED" and attempt < self.max_retries - 1:
                    time.sleep(self.retry_delay)
                else:
                    return GenieResult(
                        space_id=space_id,
                        conversation_id=conversation_id,
                        question=message,
                        content=None,
                        error=f"Query failed or still running after {self.max_retries * self.retry_delay} seconds"
                    )
                attempt += 1
            resp = self._workspace_client.get(
                self._make_url(
                    f"{self.api_prefix}/{space_id}/conversations/{conversation_id}/messages/{message_id}/query-result"),
                headers={"Content-Type": "application/json"},
            )
            resp = resp.json()
            columns = resp["statement_response"]["manifest"]["schema"]["columns"]
            header = [str(col["name"]) for col in columns]
            rows = []
            output = resp["statement_response"]["result"]
            if not output:
                return GenieResult(
                    space_id=space_id,
                    conversation_id=conversation_id,
                    question=message,
                    content=content,
                    sql_query=query,
                    sql_query_description=query_description,
                    sql_query_result=pd.DataFrame([], columns=header),
                )
            for item in resp["statement_response"]["result"]["data_typed_array"]:
                row = []
                for column, value in zip(columns, item["values"]):
                    type_name = column["type_name"]
                    str_value = value.get("str", None)
                    if str_value is None:
                        row.append(None)
                        continue
                    match type_name:
                        case "INT" | "LONG" | "SHORT" | "BYTE":
                            row.append(int(str_value))
                        case "FLOAT" | "DOUBLE" | "DECIMAL":
                            row.append(float(str_value))
                        case "BOOLEAN":
                            row.append(str_value.lower() == "true")
                        case "DATE":
                            row.append(datetime.strptime(str_value, "%Y-%m-%d").date())
                        case "TIMESTAMP":
                            row.append(datetime.strptime(str_value, "%Y-%m-%d %H:%M:%S"))
                        case "BINARY":
                            row.append(bytes(str_value, "utf-8"))
                        case _:
                            row.append(str_value)
                rows.append(row)
    
            query_result = pd.DataFrame(rows, columns=header)
            return GenieResult(
                space_id=space_id,
                conversation_id=conversation_id,
                question=message,
                content=content,
                sql_query=query,
                sql_query_description=query_description,
                sql_query_result=query_result,
            )
    
    
    assert databricks_host is not None, "host is not set"
    assert databricks_token is not None, "token is not set"
    assert space_id is not None, "space_id is not set"
    assert question is not None, "question is not set"
    assert contextual_history is not None, "contextual_history is not set"
    client = GenieClient(host=databricks_host, token=databricks_token)
    conversation_id = client.start(space_id)
    formatted_message = f"""Use the contextual history to answer the question. The history may or may not help you. Use it if you find it relevant.
    
    Contextual History: {contextual_history}
    
    Question to answer: {question}
    """
    
    result = client.ask(space_id, conversation_id, formatted_message)
    
    return result.to_string_results()

$$;


In [0]:
%sql
CREATE OR REPLACE FUNCTION ask_forecasting_questions(question STRING COMMENT "The question to ask about the updates and adjusted forecast", contextual_history STRING COMMENT "provide relevant history to be able to answer this question, assume genie doesn\'t keep track of history. Use \'no relevant history\' if there is nothing relevant to answer the question.")
RETURNS STRING
LANGUAGE SQL
COMMENT 'This Agent interacts with the Genie space API to provide answers to questions about the lastest updates and adjusted forecasts'  
RETURN SELECT _genie_query(
  "https://adb-984752964297111.11.azuredatabricks.net/",
  secret('pj', 'token'),
  '01efeaca45c813aba0f6cd461d2c9c9f',
  question, -- retrieved from function
  contextual_history -- retrieved from function
);


In [0]:
import os
import requests
import numpy as np
import pandas as pd
import json


import json
import requests
from typing import Any, Dict

def score_model(
    item_id: int,
    location_id: int,
    date: str,
    current_base_forecast: float,
    social_media_index: float,
    local_events_count: int,
    location_weather: str,
    adjusted_demand_forecast: float,
    lag1_pos_sales: float,
    lag2_pos_sales: float,
    secrets : str
) -> str:
    """This is a function which can give updated shipment forecasts for a given item in a given location on a given date.

    Args:
        item_id (int): The unique identifier for the item.
        location_id (int): The unique identifier for the location.
        date (str): The forecast date in 'YYYY-MM-DD' format.
        current_base_forecast (float): The current base forecast value.
        social_media_index (float): The social media index value.
        local_events_count (int): The number of local events.
        location_weather (str): A string describing the location's weather.
        adjusted_demand_forecast (float): The adjusted demand forecast value.
        lag1_pos_sales (float): The previous period's positive sales (lag 1).
        lag2_pos_sales (float): The sales from two periods ago (lag 2).

    Returns:
        Dict[str, Any]: The JSON response from the scoring endpoint.
    
    Raises:
        Exception: If the remote request returns a status code other than 200.
    """
    import json
    import requests
    from typing import Any, Dict

    url = 'https://adb-984752964297111.11.azuredatabricks.net/serving-endpoints/ship_str_ds_forecast/invocations'
    headers = {
        'Authorization': f"Bearer {secrets}",
        'Content-Type': 'application/json'
    }

    query_dict = {
        "index": [0],
        "columns": [
            "item_id", "location_id", "date", "current_base_forecast",
            "social_media_index", "local_events_count", "location_weather",
            "adjusted_demand_forecast", "lag1_pos_sales", "lag2_pos_sales"
        ],
        "data": [[
            item_id, location_id, date, current_base_forecast,
            social_media_index, local_events_count, location_weather,
            adjusted_demand_forecast, lag1_pos_sales, lag2_pos_sales
        ]]
    }
    query_dict = {"dataframe_split": query_dict}
    data_json = json.dumps(query_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    
    return response.json()

In [0]:
client.create_python_function(func=score_model, catalog=CATALOG, schema=SCHEMA, replace=True)

  check_docstring_signature_consistency(docstring_info.params, params_in_signature, func_name)
  check_function_info(created_function_info)


FunctionInfo(browse_only=None, catalog_name='mlops_pj', comment='This is a function which can give updated shipment forecasts for a given item in a given location on a given date.', created_at=1739570303590, created_by='puneet.jain@databricks.com', data_type=<ColumnTypeName.STRING: 'STRING'>, external_language='Python', external_name=None, full_data_type='STRING', full_name='mlops_pj.rag_puneetjain.score_model', function_id='ebf6a451-b276-4ca1-9564-f21fd2c09136', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='item_id', type_text='bigint', type_name=<ColumnTypeName.LONG: 'LONG'>, position=0, comment='The unique identifier for the item.', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{"name":"item_id","type":"long","nullable":true,"metadata":{"comment":"The unique identifier for the item."}}', type_precision=0, type_scale=0), FunctionParameterInfo(name='location_id', typ

In [0]:
%sql
CREATE OR REPLACE FUNCTION mlops_pj.rag_puneetjain.get_updated_shipment_forecast(
  item_id INTEGER COMMENT "The unique identifier for the item.",
  location_id INTEGER COMMENT "The unique identifier for the location",
  date  STRING COMMENT "The forecast date in 'YYYY-MM-DD' format",
  current_base_forecast FLOAT COMMENT "The current base forecast value",
  social_media_index FLOAT COMMENT "The social media index value.",
  local_events_count INTEGER COMMENT "The number of local events.",
  location_weather STRING COMMENT "A string describing the location's weather.",
  adjusted_demand_forecast float COMMENT "The adjusted demand forecast value.",
  lag1_pos_sales FLOAT COMMENT "The previous period's positive sales (lag 1)",
  lag2_pos_sales FLOAT COMMENT "The sales from two periods ago (lag 2)")
RETURNS STRING
LANGUAGE SQL
COMMENT 'This is a function which can give updated shipment forecasts for a given item in a given location on a given date'  
RETURN SELECT mlops_pj.rag_puneetjain.score_model(
   item_id,
   location_id,
   date,
   current_base_forecast,
   social_media_index,
   local_events_count,
   location_weather,
   adjusted_demand_forecast,
   lag1_pos_sales,
   lag2_pos_sales,
   secret('pj', 'token')
);


In [0]:
score_model(
   item_id = 1,
    location_id =1,
    date= "2024-10-12",
    current_base_forecast=101,
    social_media_index =105,
    local_events_count =1,
    location_weather= "cloudy",
    adjusted_demand_forecast= 103,
    lag1_pos_sales= 100,
    lag2_pos_sales= 97
)

{'predictions': [100.1899873703072]}