#### fraud detection agent

this notebook creates an agent with tools to detect and flag fraudulent transactions

#### Tool & View Registration

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.ai;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_transaction_details(tid STRING COMMENT 'transaction id of the transaction')
RETURNS TABLE (
  body STRING COMMENT 'Body of the event',
  event_type STRING COMMENT 'The type of event',
  transaction_id STRING COMMENT 'The transaction id',
  ts STRING COMMENT 'The timestamp of the event',
  branch STRING COMMENT 'the branch where the transaction occurred'
)
COMMENT 'Returns all events associated with the transaction id (tid)'
RETURN
  SELECT ae.body, ae.event_type, ae.transaction_id, ae.ts, br.name as branch
  FROM ${CATALOG}.lakeflow.all_events ae
  LEFT JOIN ${CATALOG}.simulator.branches br ON ae.branch_id = br.branch_id
  WHERE ae.transaction_id = tid;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_transaction_processing_time(tid STRING COMMENT 'transaction id of the transaction')
RETURNS TABLE (
  transaction_id STRING COMMENT 'The transaction id',
  creation_time TIMESTAMP COMMENT 'The timestamp of the first event for the transaction',
  completion_time TIMESTAMP COMMENT 'The timestamp of the last event for the transaction',
  duration_seconds FLOAT COMMENT 'The total duration from the first to the last event in seconds'
)
COMMENT 'Returns the first event time, last event time, and total duration for a given transaction id.'
RETURN
  WITH MinMaxTimestamps AS (
    SELECT
      MIN(try_to_timestamp(ts)) as first_event_time,
      MAX(try_to_timestamp(ts)) as last_event_time
    FROM
      ${CATALOG}.lakeflow.all_events
    WHERE
      transaction_id = tid
  )
  SELECT
    tid as transaction_id,
    first_event_time AS creation_time,
    last_event_time AS completion_time,
    CAST(
      (UNIX_TIMESTAMP(last_event_time) - UNIX_TIMESTAMP(first_event_time)) AS FLOAT
    ) AS duration_seconds
  FROM
    MinMaxTimestamps;

In [None]:
%sql
CREATE OR REPLACE VIEW ${CATALOG}.ai.transaction_patterns_per_branch_view AS
WITH transaction_times AS (
  SELECT
    ae.transaction_id,
    br.name as branch,
    MAX(CASE WHEN ae.event_type = 'transaction_created' THEN try_to_timestamp(ae.ts) END) AS transaction_created_time,
    MAX(CASE WHEN ae.event_type = 'transaction_completed' THEN try_to_timestamp(ae.ts) END) AS transaction_completed_time,
    MAX(CASE WHEN ae.event_type = 'fraud_check_completed' THEN try_to_timestamp(ae.ts) END) AS fraud_check_time
  FROM
    ${CATALOG}.lakeflow.all_events ae
  LEFT JOIN ${CATALOG}.simulator.branches br ON ae.branch_id = br.branch_id
  WHERE
    try_to_timestamp(ae.ts) >= CURRENT_TIMESTAMP() - INTERVAL 1 DAY
  GROUP BY
    ae.transaction_id,
    br.name
),
total_transaction_times AS (
  SELECT
    transaction_id,
    branch,
    (UNIX_TIMESTAMP(transaction_completed_time) - UNIX_TIMESTAMP(transaction_created_time)) / 60 AS total_transaction_time_minutes,
    (UNIX_TIMESTAMP(fraud_check_time) - UNIX_TIMESTAMP(transaction_created_time)) AS fraud_check_duration_seconds
  FROM
    transaction_times
  WHERE
    transaction_created_time IS NOT NULL
    AND transaction_completed_time IS NOT NULL
)
SELECT
  branch,
  PERCENTILE(total_transaction_time_minutes, 0.50) AS P50_minutes,
  PERCENTILE(total_transaction_time_minutes, 0.75) AS P75_minutes,
  PERCENTILE(total_transaction_time_minutes, 0.99) AS P99_minutes,
  AVG(fraud_check_duration_seconds) AS avg_fraud_check_seconds
FROM
  total_transaction_times
GROUP BY
  branch

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_branch_patterns(br STRING COMMENT 'Branch name as a string')
RETURNS TABLE (
  branch STRING COMMENT 'Branch name',
  P50_minutes FLOAT COMMENT '50th percentile transaction time in minutes',
  P75_minutes FLOAT COMMENT '75th percentile transaction time in minutes',
  P99_minutes FLOAT COMMENT '99th percentile transaction time in minutes',
  avg_fraud_check_seconds FLOAT COMMENT 'Average fraud check duration in seconds'
)
COMMENT 'Returns transaction timing patterns and fraud check durations for a specific branch'
RETURN
  SELECT branch, P50_minutes, P75_minutes, P99_minutes, avg_fraud_check_seconds
  FROM ${CATALOG}.ai.transaction_patterns_per_branch_view AS tpb
  WHERE tpb.branch = br;

#### Model

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

In [0]:
CATALOG = dbutils.widgets.get("CATALOG")
LLM_MODEL = dbutils.widgets.get("LLM_MODEL")

In [0]:
import re
from IPython.core.magic import register_cell_magic

@register_cell_magic
def writefilev(line, cell):
    """
    %%writefilev file.py
    Allows {{var}} substitutions while leaving normal {} intact.
    """
    filename = line.strip()

    def replacer(match):
        expr = match.group(1)
        return str(eval(expr, globals(), locals()))

    # Replace only double braces {{var}}
    content = re.sub(r"\{\{(.*?)\}\}", replacer, cell)

    with open(filename, "w") as f:
        f.write(content)
    print(f"Wrote file with substitutions: {filename}")

In [None]:
%%writefilev agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = f"{{LLM_MODEL}}"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are FraudDetectorGPT, a banking security agent responsible for analyzing transactions for potential fraud.

    You can call tools to gather the information you need. Start with a `transaction_id`.

    Instructions:
    1. Call `get_transaction_details(transaction_id)` first to get event history and confirm the id is valid and the transaction was processed.
    2. Analyze the transaction processing time by calling `get_transaction_processing_time(transaction_id)`.
    3. Extract the branch information (either directly or from the event body).
    4. Call `get_branch_patterns(branch)` to get typical transaction patterns for that branch.
    5. Compare this transaction's characteristics to typical patterns to identify anomalies.

    Fraud indicators to look for:
    - Unusually fast processing times (bypassing normal checks)
    - Transactions outside normal branch patterns
    - Multiple rapid transactions
    - Unusual amounts or frequencies

    Output a single-line JSON with these fields:
    - `fraud_risk_score` (float 0.0-1.0, where 1.0 is highest risk),
    - `fraud_classification` ("none" | "low" | "medium" | "high"),
    - `reason` (short human explanation of why this transaction is flagged or cleared)

    You must return only the JSON. No extra text or markdown."""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

uc_tool_names = [f"{{CATALOG}}.ai.get_transaction_details", 
                 f"{{CATALOG}}.ai.get_branch_patterns",
                 f"{{CATALOG}}.ai.get_transaction_processing_time"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

In [None]:
# get an actual transaction_id
sample_transaction_id = spark.sql(f"""
    SELECT transaction_id 
    FROM {CATALOG}.lakeflow.all_events 
    WHERE event_type='transaction_completed'
    LIMIT 1
""").collect()[0]['transaction_id']

In [None]:
assert sample_transaction_id is not None
print(sample_transaction_id)

In [None]:
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": f"{sample_transaction_id}"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="fraud_agent_v2",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

mlflow.set_active_model(model_id = logged_agent_info.model_id)

#### eval

In [None]:
# sample 10 transaction_ids
fraud_queries = [
    row['transaction_id'] for row in spark.sql(f"""
        SELECT transaction_id 
        FROM {CATALOG}.lakeflow.all_events 
        WHERE event_type='transaction_completed'
        LIMIT 10
    """).collect()
]

# wrap in correct input schema
data = []
for query in fraud_queries:
    data.append(
        {
            "inputs": {
                "messages": [
                    {
                        "role": "user",
                        "content": query,
                    }
                ]
            },
        }
    )

print(data)

In [None]:
# create guideline, run evals

from mlflow.genai.scorers import Guidelines
import mlflow
import sys
import os
sys.path.append(os.getcwd())

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
project_directory = os.path.dirname(notebook_path)

# Add the project directory to the system path
sys.path.append(project_directory)

from agent import AGENT

fraud_detection_guideline = Guidelines(
    name="fraud_detection_reason",
    guidelines=["Fraud risk assessments must be based on transaction patterns and timing anomalies, not on customer identity or demographics."]
)


results = mlflow.genai.evaluate(
    data=data,
    scorers=[fraud_detection_guideline],
    predict_fn = lambda messages: AGENT.predict({"messages": messages})
)

#### log fraud detector to `UC`

In [None]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = f"{CATALOG}.ai.fraud_detector"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

#### deploy the agent to model serving

In [None]:
from databricks import agents
deployment_info = agents.deploy(
    model_name=UC_MODEL_NAME, 
    model_version=uc_registered_model_info.version, 
    scale_to_zero=False,
    endpoint_name=f"{dbutils.widgets.get("FRAUD_AGENT_ENDPOINT_NAME")}")

In [0]:
print(deployment_info)

##### record model in state

In [None]:
# Also add to UC-state
import sys
sys.path.append('../utils')
from uc_state import add

add(dbutils.widgets.get("CATALOG"), "endpoints", deployment_info)