In [0]:
%pip install 'databricks-sdk==0.61.0' 'databricks-connect==16.4.2' 'pyarrow<20' 'databricks-sdk[notebook]' 'databricks-agents==1.4.0' 'mlflow<=3.2' 'mlflow[databricks]' 'databricks-vectorsearch==0.57' 'langchain==0.3.27' 'langchain-mcp' 'langchain_core==0.3.74' 'databricks-langchain==0.7.1' 'bs4' 'dotenv' 'psycopg2-binary==2.9.9' 'pgvector==0.2.5' 'langgraph==0.3.4'
import os
if os.environ.get("DATABRICKS_RUNTIME_VERSION"):
    dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
%%writefile chain_postgres_genie.py
import functools
import os
import uuid
import json
from typing import Any, Generator, Literal, Optional, Dict, List

import mlflow
import pydantic
from mlflow.models import ModelConfig
from databricks.sdk import WorkspaceClient
from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
    DatabricksFunctionClient,
    set_uc_function_client
)
from databricks_langchain.genie import GenieAgent
from langchain_core.runnables import RunnableLambda
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent
from mlflow.langchain.chat_agent_langgraph import ChatAgentState
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from pydantic import BaseModel
from sqlalchemy import create_engine, text, event
from pgvector.psycopg2 import register_vector
from databricks.sdk import WorkspaceClient
from databricks_langchain import DatabricksEmbeddings
from databricks_langchain.chat_models import ChatDatabricks
from langchain.tools import Tool
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnableLambda
from mlflow.entities import SpanType, Document
from langchain_core.messages.ai import AIMessage

# Enable MLflow Tracing for LangChain
mlflow.autolog()
mlflow.langchain.autolog()

# Load chain configuration provided at logging/deployment time.
model_config: ModelConfig = mlflow.models.ModelConfig()

# Pydantic models for input validation
class Message(pydantic.BaseModel):
    role: str
    content: str
    name: Optional[str] = None

class Filters(pydantic.BaseModel):
    user_name: str  # Required
    chat_id: Optional[str] = None

class CustomInputs(pydantic.BaseModel):
    filters: Filters
    k: Optional[int] = None  # Optional, will default to model_config value

class ChatRequest(pydantic.BaseModel):
    messages: List[Message]
    custom_inputs: Optional[CustomInputs] = None

class ChatResponse(pydantic.BaseModel):
    messages: List[Message]
    finish_reason: Optional[str] = None


def _get_required_env(name: str) -> str:
    value = os.environ.get(name)
    if not value:
        raise RuntimeError(f"Missing required environment variable: {name}")
    return value


def get_postgres_connection(
    client: WorkspaceClient,
    db_instance_name: str,
    database_name: Optional[str] = "databricks_postgres",
) -> str:
    """
    Build a PostgreSQL SQLAlchemy URL (psycopg2) using Databricks Database credentials.

    Uses POSTGRES_GROUP env var as username if set; otherwise current user.
    Always enforces sslmode=require.
    """
    database = client.database.get_database_instance(db_instance_name)
    credentials = client.database.generate_database_credential(
        instance_names=[db_instance_name], request_id=str(uuid.uuid4())
    )

    postgres_group = os.getenv("POSTGRES_GROUP")
    username = (
        postgres_group if postgres_group else client.current_user.me().user_name
    )

    host = database.read_write_dns
    port = "5432"
    password = credentials.token
    db_name = database_name or "databricks_postgres"

    # SQLAlchemy URL with psycopg2 driver
    sqlalchemy_url = (
        f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}?sslmode=require"
    )
    return sqlalchemy_url


# --- Databricks Auth (required for both embeddings and DB credentials) ---
_DATABRICKS_HOST = _get_required_env("DATABRICKS_HOST")
_DATABRICKS_TOKEN = _get_required_env("DATABRICKS_TOKEN")

workspace_client = WorkspaceClient(host=_DATABRICKS_HOST, token=_DATABRICKS_TOKEN)


# --- Postgres Engine (pgvector) ---
def _build_engine() -> Any:
    # Allow configuration via model_config or environment variables
    db_instance_name = (
        os.environ.get("DATABASE_INSTANCE_NAME")
        or model_config.get("database_instance_name")
    )
    if not db_instance_name:
        raise RuntimeError(
            "A Postgres database instance name is required. Set env 'DATABASE_INSTANCE_NAME' "
            "or include 'database_instance_name' in the model_config."
        )

    postgres_database_name = (
        os.environ.get("POSTGRES_DATABASE_NAME")
        or model_config.get("postgres_database_name")
        or "databricks_postgres"
    )

    database_url = get_postgres_connection(
        workspace_client, db_instance_name, postgres_database_name
    )

    engine = create_engine(database_url, pool_pre_ping=True)

    @event.listens_for(engine, "connect")
    def _register_vector(dbapi_connection, connection_record):  # noqa: ANN001
        # Map Python lists to pgvector type for psycopg2
        register_vector(dbapi_connection)

    return engine


engine = _build_engine()


# --- Embeddings ---
embeddings = DatabricksEmbeddings(
    endpoint=model_config.get("embedding_model"),
    token=_DATABRICKS_TOKEN,
)

# --- Vector similarity search over Postgres (pgvector) ---
@mlflow.trace
def pg_vector_similarity_search(
    query_text: str,
    k: int = 3,
    filters: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Perform similarity search against message embeddings in Postgres (pgvector).

    Schema expectations:
    - message_embeddings(me: id, message_id, user_name, chat_id, embedding vector)
    - chat_history(ch: id, message_content, message_type, created_at, message_order)
    """
    filters = filters or {}

    # 1) Embed the query
    query_embedding = embeddings.embed_query(query_text)

    # 2) WHERE clause from filters
    where_conditions: List[str] = []
    params: Dict[str, Any] = {}

    if "user_name" in filters:
        where_conditions.append("me.user_name = :user_name")
        params["user_name"] = filters["user_name"]

    where_clause = ""
    if where_conditions:
        where_clause = "WHERE " + " AND ".join(where_conditions)

    # 3) Query using cosine distance operator (<=>) provided by pgvector
    sql = text(
        f"""
        SELECT
            ch.message_content,
            me.user_name,
            me.chat_id,
            ch.message_type,
            ch.created_at,
            ch.message_order,
            (me.embedding <=> CAST(:query_embedding AS vector)) AS distance
        FROM message_embeddings me
        JOIN chat_history ch ON me.message_id = ch.id
        {where_clause}
        ORDER BY me.embedding <=> CAST(:query_embedding AS vector)
        LIMIT :k
        """
    )

    span = mlflow.get_current_active_span()
    span.set_outputs([Document(page_content=sql)])

    with engine.connect() as conn:
        rows = conn.execute(
            sql, {"query_embedding": query_embedding, "k": k, **params}
        ).fetchall()

    passages = [f"Passage: {r.message_content}" for r in rows]
    return "\n".join(passages)
  

def create_context_aware_vector_search_tool(state, custom_k: Optional[int] = None):
    """Create a vector search tool that has access to user context from state"""
    
    def filtered_vector_search(query: str) -> str:
        # Extract user context from state
        user_context = state.get("user_context", {})
        filters = user_context.get("filters", {})
        
        # Use custom k if provided, otherwise fall back to model_config default
        k = custom_k if custom_k is not None else model_config.get('k')
        
        # Use your existing pg_vector_similarity_search with filters and custom k
        return pg_vector_similarity_search(
            query_text=query, 
            k=k, 
            filters=filters
        )
    
    return Tool(
        name="search_chat_history",
        description="Retrieve chat history from Postgres (pgvector) for the current user; use only if the immediate conversation context is insufficient. The input to this function should be the user message.",
        func=filtered_vector_search,
    )


# Marketing Policy Agent - Knowledge Assistant Integration
class MarketingPolicyAgent:
    """Agent for validating marketing policy compliance using Databricks Knowledge Assistant"""
    
    def __init__(self, endpoint_name: str, client: WorkspaceClient, description: str):
        self.endpoint_name = endpoint_name
        self.client = client
        self.description = description
        
    def invoke(self, state):
        """Invoke the marketing policy agent via Databricks serving endpoint"""
        try:
            messages = state.get("messages", [])
            
            # Build the request for the knowledge assistant
            payload = { 'input': messages}
            
            # Call the knowledge assistant endpoint
            response = self.client.api_client.do(
                method="POST",
                path=f"/serving-endpoints/{self.endpoint_name}/invocations",
                headers={"Content-Type": "application/json"},
                data=json.dumps(payload)
            )
            
            # Extract content from response
            content = ""
            if isinstance(response, dict):
                content = response['output'][0]["content"][0]['text']
            else:
                content = str(response)
            
            # Return in the expected format
            return {
                "messages": [AIMessage(content=content)]
            }
            
        except Exception as e:
            return {
                "messages": [{
                    "role": "assistant", 
                    "content": f"Error validating marketing policy compliance: {str(e)}"
                }]
            }


genie_agent_description = model_config.get('genie_agent_description')
general_assistant_description = model_config.get('general_assistant_description')
marketing_policy_agent_description = model_config.get('marketing_policy_agent_description')

genie_agent = GenieAgent(
    genie_space_id=model_config.get('genie_space_id'),
    genie_agent_name="Genie",
    description=genie_agent_description,
    client=workspace_client,
    include_context=True,
)

# Create Marketing Policy Agent
marketing_policy_agent = MarketingPolicyAgent(
    endpoint_name=model_config.get('marketing_policy_endpoint'),
    client=workspace_client,
    description=marketing_policy_agent_description
)

# Max number of interactions between agents
MAX_ITERATIONS = 3

worker_descriptions = {
    "Genie": genie_agent_description,
    "General": general_assistant_description,
    "MarketingPolicy": marketing_policy_agent_description,
}

formatted_descriptions = "\n".join(
    f"- {name}: {desc}" for name, desc in worker_descriptions.items()
)

system_prompt = f"""You are routing between specialized agents. Route to:
- Genie: For data queries requiring database access.
- General: To synthesize and present final answers when sufficient data is available.
- MarketingPolicy: To validate marketing compliance, policy adherence, and brand guidelines.
- FINISH: When a complete answer has been provided

Available agents:
{formatted_descriptions}"""

options = ["FINISH"] + list(worker_descriptions.keys())
FINISH = {"next_node": "FINISH"}
MARKETING_POLICY = {"next_node": "MarketingPolicy"}

# Our foundation model answering the final prompt
model = ChatDatabricks(
    endpoint=model_config.get("llm_model_serving_endpoint_name"),
    extra_params={"temperature": 0.01, "max_tokens": 500}
)

# Custom Static Tools
tools = []

def supervisor_agent(state):
    count = state.get("iteration_count", 0) + 1
    if count > MAX_ITERATIONS:
        return FINISH
    
    # Check if Genie just provided a data-rich response
    messages = state.get("messages", [])
    if messages:
        last_message = messages[-1] if messages else {}
        
        if (isinstance(last_message, dict) and 
            last_message.get("name") == "Genie" and 
            last_message.get("content", "").strip() and
            len(last_message.get("content", "")) > 50):  # Assume substantial data
            return FINISH
    
    class nextNode(BaseModel):
        next_node: Literal[tuple(options)]

    preprocessor = RunnableLambda(
        lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
    )
    supervisor_chain = preprocessor | model.with_structured_output(nextNode)
    next_node = supervisor_chain.invoke(state).next_node
    
    # if routed back to the same node, exit the loop
    if state.get("next_node") == next_node:
        return FINISH
    return {
        "iteration_count": count,
        "next_node": next_node
    }

#######################################
# Define our multiagent graph structure
#######################################


def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [
            {
                "role": "assistant",
                "content": result["messages"][-1].content,
                "name": name,
            }
        ]
    }


def final_answer(state):
    # Check if we have data-rich responses from Genie
    messages = state.get("messages", [])
    prompt = "Using only the content in the messages, respond to the previous user question using the answer given by the other assistant messages."
    
    preprocessor = RunnableLambda(
        lambda state: state["messages"] + [{"role": "user", "content": prompt}]
    )
    final_answer_chain = preprocessor | model
    return {"messages": [final_answer_chain.invoke(state)]}


def agent_node_with_context(state, agent, name, custom_k: Optional[int] = None):
    """Enhanced agent node that injects context-aware tools"""
    
    # Create the shared vector search tool with current state context and custom k
    vector_search_tool = create_context_aware_vector_search_tool(state, custom_k)
    
    if name == "Genie":
        # Genie already has its tools, just add vector search
        enhanced_agent = agent  # Genie agent already configured
        
    elif name == "MarketingPolicy":
        # Marketing Policy agent already configured
        enhanced_agent = agent
        
    elif name == "General":
        # Add vector search tool to General agent
        enhanced_tools = tools + [vector_search_tool]
        enhanced_agent = create_react_agent(model, tools=enhanced_tools)
        
    # Execute with enhanced agent
    result = enhanced_agent.invoke(state)
    return {
        "messages": [{
            "role": "assistant",
            "content": result["messages"][-1].content,
            "name": name,
        }]
    }

class AgentState(ChatAgentState):
    next_node: str
    iteration_count: int
    user_context: Optional[Dict[str, Any]] = None
    custom_k: Optional[int] = None

# Create enhanced agent nodes
def enhanced_genie_node(state):
    custom_k = state.get("custom_k")
    return agent_node_with_context(state, genie_agent, "Genie", custom_k)

def enhanced_general_node(state):
    custom_k = state.get("custom_k")
    return agent_node_with_context(state, None, "General", custom_k)

def enhanced_marketing_policy_node(state):
    custom_k = state.get("custom_k")
    return agent_node_with_context(state, marketing_policy_agent, "MarketingPolicy", custom_k)

workflow = StateGraph(AgentState)
# Agent States
workflow.add_node("Genie", enhanced_genie_node)
workflow.add_node("General", enhanced_general_node)
workflow.add_node("MarketingPolicy", enhanced_marketing_policy_node)
# Supervisor States
workflow.add_node("supervisor", supervisor_agent)
workflow.add_node("final_answer", final_answer)

workflow.set_entry_point("supervisor")
# We want our workers to ALWAYS "report back" to the supervisor when done
for worker in worker_descriptions.keys():
    workflow.add_edge(worker, "supervisor")

# Let the supervisor decide which next node to go
workflow.add_conditional_edges(
    "supervisor",
    lambda x: x["next_node"],
    {**{k: k for k in worker_descriptions.keys()}, "FINISH": "final_answer"},
)
workflow.add_edge("final_answer", END)
multi_agent = workflow.compile()

###################################
# Streaming LangGraph ChatAgent
###################################

class PostgresGenieChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent
    
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        """Non-streaming predict method for backward compatibility"""
        # Extract user context and custom_k from custom_inputs
        user_context = {}
        custom_k = None
        
        if custom_inputs:
            if "filters" in custom_inputs:
                user_context["filters"] = custom_inputs["filters"]
            custom_k = custom_inputs.get("k")
        
        agent_request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages],
            "user_context": user_context,
            "custom_k": custom_k
        }

        response_messages = []
        for event in self.agent.stream(agent_request, stream_mode="updates"):
            for node_data in event.values():
                response_messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        
        return ChatAgentResponse(messages=response_messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        """Streaming predict method - yields incremental responses as they're generated"""
        # Extract user context and custom_k from custom_inputs
        user_context = {}
        custom_k = None
        
        if custom_inputs:
            if "filters" in custom_inputs:
                user_context["filters"] = custom_inputs["filters"]
            custom_k = custom_inputs.get("k")
        
        agent_request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages],
            "user_context": user_context,
            "custom_k": custom_k
        }

        for event in self.agent.stream(agent_request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg})
                    for msg in node_data.get("messages", [])
                )

# Create the streaming model instance and set it for MLflow
streaming_model_instance = PostgresGenieChatAgent(multi_agent)
mlflow.models.set_model(model=streaming_model_instance)

Overwriting chain_postgres_genie.py


In [0]:
dbutils.widgets.text("embedding_model", "databricks-gte-large-en")
dbutils.widgets.text("database_instance_name", "adtech-series-do-not-delete")
dbutils.widgets.text("postgres_database_name", "databricks_postgres")
dbutils.widgets.text("llm_model_serving_endpoint_name", "databricks-claude-3-7-sonnet")
dbutils.widgets.text("target_catalog", "tanner_wendland")
dbutils.widgets.text("target_schema", "default")
dbutils.widgets.text("genie_space_id", "01f07c4ba44615aab8989b10e0a95420")
dbutils.widgets.text("marketing_policy_endpoint", "ka-c9af3fe4-endpoint")

In [0]:
embedding_model = dbutils.widgets.get("embedding_model")
database_instance_name = dbutils.widgets.get("database_instance_name")
postgres_database_name = dbutils.widgets.get("postgres_database_name")
llm_model_serving_endpoint_name = dbutils.widgets.get("llm_model_serving_endpoint_name")
target_catalog = dbutils.widgets.get("target_catalog")
target_schema = dbutils.widgets.get("target_schema")
genie_space_id = dbutils.widgets.get("genie_space_id")
marketing_policy_endpoint = dbutils.widgets.get("marketing_policy_endpoint")

In [0]:
import functools
import os
from typing import Any, Generator, Literal, Optional

In [0]:
import mlflow
mlflow.set_registry_uri("databricks-uc")
mlflow.autolog()
mlflow.langchain.autolog()



## Chain Config

In [0]:
chain_config = {
    "llm_model_serving_endpoint_name": llm_model_serving_endpoint_name,
    "embedding_model": embedding_model,
    "database_instance_name": database_instance_name,
    "postgres_database_name": postgres_database_name,
    "genie_space_id": genie_space_id,
    "marketing_policy_endpoint": marketing_policy_endpoint,
    "k": 3,
    "genie_agent_description": """This Genie agent is designed for marketing analysis, enabling detailed exploration and reporting on audience segments, and individual demographic profiles. The agent integrates campaign data, segment definitions, and audience census profiles. The agent allows to: Understand audience segment definitions and membership, Profiling the campaign audience on demographic dimensions such as age and gender, Developing custom ad-hoc queries for marketing analysts""",
    "general_assistant_description": "The General Assistant synthesizes information from data sources and other agents to provide confident, direct answers. When data is available from previous analysis, present it clearly without unnecessary hedging or disclaimers.",
    "marketing_policy_agent_description": """This agent is a knowledge assistant to help us adher to marketing policies. Use this agent to validate campaign recommendations, creative content, and data interpretations for policy compliance.""",
}

## Chain PY

In [0]:
from mlflow.models.resources import (
    DatabricksVectorSearchIndex,
    DatabricksServingEndpoint,
    DatabricksGenieSpace,
)

chain_file_path = os.path.join(os.getcwd(), "chain_postgres_genie.py")
if not os.path.exists(chain_file_path):
    raise FileNotFoundError(f"Chain file not found at {chain_file_path}")

workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
print(f"Workspace URL: {workspace_url}")
os.environ["DATABRICKS_HOST"] = f"https://{workspace_url}"
os.environ["DATABRICKS_TOKEN"] = (
    dbutils.entry_point.getDbutils().notebook().getContext().apiToken().get()
)

# Pydantic-based input example - schema will be auto-inferred from type hints
# Note: MLflow pyfunc expects List[InputType], so wrap in a list
input_example = {
    "messages": [{"role": "user", "content": "What was my chat history idea?"}]
}

model_config = mlflow.models.ModelConfig(development_config=chain_config)

# Use pyfunc log_model with the file path (models-from-code approach)
with mlflow.start_run(run_name="adtech_chat_history_agent_postgres_genie"):
    logged_chain_info = mlflow.pyfunc.log_model(
        python_model=chain_file_path,  # Path to the chain file
        model_config=chain_config,
        artifact_path="chat_agent",
        input_example=input_example,  # Schema auto-inferred from List[ChatRequest] type hint
        # Specify resources for automatic authentication passthrough
        resources=[
            DatabricksServingEndpoint(
                endpoint_name=model_config.get("llm_model_serving_endpoint_name")
            ),
            DatabricksGenieSpace(genie_space_id=model_config.get("genie_space_id")),
            DatabricksServingEndpoint(
                endpoint_name=model_config.get("marketing_policy_endpoint")
            ),
        ],
        pip_requirements=[
            "mlflow==3.2.0",
            "databricks-agents==1.4.0",
            "databricks-langchain==0.7.1",
            "langchain==0.3.27",
            "pgvector==0.2.5",
            "psycopg2-binary==2.9.9",
            "pydantic==2.11.7",
            "sqlalchemy==2.0.43",
            "tornado==6.3.2",
            "langgraph==0.3.4",
        ],
    )

model_name = "chat_history_agent_postgres_genie"
MODEL_NAME_FQN = f"{target_catalog}.{target_schema}.{model_name}"
# Register to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_chain_info.model_uri, name=MODEL_NAME_FQN
)



Workspace URL: e2-demo-field-eng.cloud.databricks.com


🔗 View Logged Model at: https://e2-demo-field-eng.cloud.databricks.com/ml/experiments/380662602719503/models/m-3645148db6e44ad7abb32cf10ca49e1a?o=1444828305810485
2025/09/19 17:41:17 INFO mlflow.pyfunc: Predicting on input example to validate output
 - pydantic (current: 2.11.9, required: pydantic==2.11.7)
 - tornado (current: 6.4.1, required: tornado==6.3.2)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.
 - pydantic (current: 2.11.9, required: pydantic==2.11.7)
 - tornado (current: 6.4.1, required: tornado==6.3.2)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.
2025/09/19 17:41:42 INFO mlflow.models.model: Found the following environment variables used during model inference: [DATABRICKS_HOST, DATABRICKS_TOKEN]. Please check if you need

Downloading artifacts:   0%|          | 0/13 [00:00<?, ?it/s]

Uploading artifacts:   0%|          | 0/14 [00:00<?, ?it/s]

🔗 Created version '55' of model 'tanner_wendland.default.chat_history_agent_postgres_genie': https://e2-demo-field-eng.cloud.databricks.com/explore/data/models/tanner_wendland/default/chat_history_agent_postgres_genie/version/55?o=1444828305810485


In [0]:
AGENT = mlflow.pyfunc.load_model(logged_chain_info.model_uri)

Downloading artifacts:   0%|          | 0/13 [00:00<?, ?it/s]

 - pydantic (current: 2.11.9, required: pydantic==2.11.7)
 - tornado (current: 6.4.1, required: tornado==6.3.2)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.


## Test Document Retreival

In [0]:
# Test with custom k parameter using pydantic structure
# Note: Wrap in list because MLflow pyfunc expects List[ChatRequest]
test_input_custom_k = [{
    "messages": [{"role": "user", "content": "What was my chat history idea?"}],
    "custom_inputs": {
        "filters": {"user_name": "tanner.wendland@databricks.com"},
        "k": 10  # Request 10 records instead of default 3
    }
}]

answer = AGENT.predict(test_input_custom_k)
print("=== Test with custom k=10 ===")
print(answer)

# Test with default k (no custom_inputs provided)
test_input_default = [{
    "messages": [{"role": "user", "content": "What campaign has the most dog owners?"}]
    # No custom_inputs - should use defaults
}]

answer_default = AGENT.predict(test_input_default)
print("\n=== Test with no custom_inputs (should use defaults) ===")
print(answer_default)

# Test with required user_name but no k (should use default k)
test_input_required_user = [{
    "messages": [{"role": "user", "content": "What policies do we have regarding tracking dog ownership?"}],
    "custom_inputs": {
        "filters": {"user_name": "tanner.wendland@databricks.com"}
        # No k specified - should use default from model_config
    }
}]

answer_required = AGENT.predict(test_input_required_user)
print("\n=== Test with required user_name but default k ===")
print(answer_required)

=== Test with custom k=10 ===
{'messages': [{'role': 'assistant', 'content': 'Based on your chat history, I don\'t see a specific "chat history idea" that you previously mentioned. I can see that you\'ve had conversations about:\n\n1. Marketing campaigns data (Terrific Tacos, Best Burgers, Moving Movie, Fresh Flowers, Super Savings, Cool Car, Happy Dogs)\n2. Census data related to these campaigns\n3. Queries about unique people counts and demographic information\n\nIf you\'re referring to a specific idea about chat history functionality or something else that isn\'t appearing in these results, could you please provide more details about what you\'re looking for? Perhaps it was from an earlier conversation or you\'re referring to something specific that needs additional context.', 'name': 'General', 'id': '651044fa-b2c9-4040-b374-c97d77e38c7e'}, {'role': 'assistant', 'content': 'I don\'t see any previous assistant messages in our conversation history that contain an answer to your quest

[Trace(trace_id=tr-7166ff915f5b65444e9448cf059d79da), Trace(trace_id=tr-bbeea7e0910a097519e3430398b75432), Trace(trace_id=tr-32b8f465221b824b74a76de314a128ca)]

## Test with Predict Stream

In [0]:
test_input_custom_k = [{
    "messages": [{"role": "user", "content": "What was my chat history idea?"}],
    "custom_inputs": {
        "filters": {"user_name": "tanner.wendland@databricks.com"},
        "k": 10  # Request 10 records instead of default 3
    }
}]

answer = AGENT.predict_stream(test_input_custom_k)
for a in answer:
  print(a)

{'delta': {'role': 'assistant', 'content': 'Based on searching your chat history, I don\'t see any specific "idea" you mentioned about chat history itself. What I can see is that you\'ve previously discussed:\n\n1. Marketing campaigns data including:\n   - Six campaigns: Terrific Tacos, Best Burgers, Moving Movie, Fresh Flowers, Super Savings, and Cool Car\n   - A Happy Dogs campaign was also mentioned\n   - Statistics about unique people counts across campaigns\n   - Demographics data like senior male counts\n\nIf you were referring to a different idea related to chat history, could you provide more details or context about what specific idea you\'re trying to recall?', 'name': 'General', 'id': 'e1f93135-bca7-4b8d-9a77-f61be91c7485'}}
{'delta': {'role': 'assistant', 'content': 'I don\'t see any previous assistant messages in our conversation history that contain an answer to your question about a "chat history idea." Our conversation has only included discussion about marketing campai

Trace(trace_id=tr-250dc2071a08df0bdbb76204ff635674)

## Deploy

In [0]:
dbutils.widgets.text("secert_scope", "adtech-series", "Secret Scope")
dbutils.widgets.text("secret_key", "app-secret", "Secret Key")
dbutils.widgets.text("permission_group", "Adtech Series DB Access Role", "Permission Group")

In [0]:
secret_scope = dbutils.widgets.get("secert_scope")
secret_key = dbutils.widgets.get("secret_key")
permission_group = dbutils.widgets.get("permission_group")

secret_value = dbutils.secrets.get(scope=secret_scope, key=secret_key)

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
    EndpointCoreConfigInput,
    ServedEntityInput,
    AiGatewayConfig,
    ServingEndpointAccessControlRequest,
    ServingEndpointPermissionLevel,
)
from databricks import agents

workspace_client = WorkspaceClient()

version = uc_registered_model_info.version
serving_endpoint_name = model_name.replace(".", "-")

workspace_url = spark.conf.get("spark.databricks.workspaceUrl")

config = {
    "served_entities": [
        {
            "name": serving_endpoint_name,
            "entity_name": MODEL_NAME_FQN,
            "entity_version": version,
            "workload_size": "Small",
            "scale_to_zero_enabled": True,
            "environment_vars": {
                "DATABRICKS_HOST": workspace_url,
                "DATABRICKS_TOKEN": secret_value,
            },
        }
    ]
}

ai_gateway_config = {
    "inference_table_config": {
        "enabled": True,
        "catalog_name": target_catalog,
        "schema_name": target_schema,
        "table_name": "chat_history_agent_postgres_genie_inference",
    }
}


def does_endpoint_exists(endpoint_name):
    try:
        workspace_client.serving_endpoints.get(endpoint_name)
        return True
    except:
        return False


print(f"Creating endpoint {serving_endpoint_name}...")
deployment = agents.deploy(
    model_name=MODEL_NAME_FQN,
    model_version=version,
    scale_to_zero=True,
    environment_vars={
        "DATABRICKS_HOST": workspace_url,
        "DATABRICKS_TOKEN": secret_value,
    },
    workload_size="Small",
    endpoint_name=serving_endpoint_name,
)

# Grant permissions to the specified group after endpoint deployment
print(f"Granting permissions to group: {permission_group}")
try:
    agents.set_permissions(model_name=MODEL_NAME_FQN, users=[permission_group], permission_level="CAN_QUERY")
    print(f"Successfully granted CAN_QUERY permission to group: {permission_group}")
except Exception as e:
    print(f"Failed to grant permissions to group {permission_group}: {str(e)}")
    print("Please manually grant permissions to the endpoint if needed.")

Creating endpoint chat_history_agent_postgres_genie...


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

[0;31m---------------------------------------------------------------------------[0m
[0;31mBadRequest[0m                                Traceback (most recent call last)
File [0;32m<command-4640557381755985>, line 53[0m
[1;32m     49[0m         [38;5;28;01mreturn[39;00m [38;5;28;01mFalse[39;00m
[1;32m     52[0m [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124mCreating endpoint [39m[38;5;132;01m{[39;00mserving_endpoint_name[38;5;132;01m}[39;00m[38;5;124m...[39m[38;5;124m"[39m)
[0;32m---> 53[0m deployment [38;5;241m=[39m agents[38;5;241m.[39mdeploy(
[1;32m     54[0m     model_name[38;5;241m=[39mMODEL_NAME_FQN,
[1;32m     55[0m     model_version[38;5;241m=[39mversion,
[1;32m     56[0m     scale_to_zero[38;5;241m=[39m[38;5;28;01mTrue[39;00m,
[1;32m     57[0m     environment_vars[38;5;241m=[39m{
[1;32m     58[0m         [38;5;124m"[39m[38;5;124mDATABRICKS_HOST[39m[38;5;124m"[39m: workspace_url,
[1;32m     59[0m       