# Create Agent Tools as UC Functions

In [None]:
from databricks.connect import DatabricksSession

spark = DatabricksSession.builder.remote(serverless=True).getOrCreate()

In [None]:
import os
import mlflow

configs = mlflow.models.ModelConfig(development_config="./config.yml")
databricks_config = configs.get("databricks")
tools_config = configs.get("tools")

CATALOG = databricks_config["catalog"]
SCHEMA = databricks_config["schema"]
SQL_WAREHOUSE_ID = databricks_config["sql_warehouse_id"]
WORKSPACE_URL = databricks_config["workspace_url"]
SECRET_SCOPE_NAME = databricks_config["databricks_pat"]["secret_scope_name"]
SECRET_KEY_NAME = databricks_config["databricks_pat"]["secret_key_name"]

UC_TABLES = tools_config["tables"]
UC_CONNECTION = tools_config["uc_connection"]["name"]
UC_CONNECTION_SQL_EXEC = UC_CONNECTION + "_sql_exec"

table_list_sql = ", ".join([f"'{table}'" for table in UC_TABLES])

## Setup UC Connections

In [None]:
spark.sql(
    f"""
    CREATE CONNECTION IF NOT EXISTS {UC_CONNECTION} 
    TYPE HTTP
    OPTIONS (
        host '{WORKSPACE_URL}',
        port '443',
        base_path '/api/2.1/',
        bearer_token secret('{SECRET_SCOPE_NAME}', '{SECRET_KEY_NAME}'
        )
    )
    """
)

In [None]:
spark.sql(
    f"""
    CREATE CONNECTION IF NOT EXISTS {UC_CONNECTION_SQL_EXEC} 
    TYPE HTTP
    OPTIONS (
        host '{WORKSPACE_URL}',
        port '443',
        base_path '/api/2.0/',
        bearer_token secret('{SECRET_SCOPE_NAME}', '{SECRET_KEY_NAME}'
        )
    )
    """
)

## Core Functions

### Get Table Metadata
* Purpose: Returns comprehensive table-level metadata including owner, storage location, timestamps, and table type via Unity Catalog REST API
* Why Essential: Provides high-level table information that helps the agent understand what tables are available, who owns them, and when they were last updated - critical for discovery and context before schema introspection

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_table_metadata(
        table_name STRING COMMENT 'Fully qualified table name (catalog.schema.table)'
    )
    RETURNS STRUCT<
        name: STRING,
        catalog_name: STRING,
        schema_name: STRING,
        table_type: STRING,
        data_source_format: STRING,
        comment: STRING,
        owner: STRING,
        created_at: BIGINT,
        updated_at: BIGINT,
        storage_location: STRING,
        full_name: STRING
    >
    COMMENT 'Returns enhanced table metadata via Databricks REST API including owner, storage location, and timestamps'
    LANGUAGE SQL
    RETURN (
        WITH api_response AS (
            SELECT http_request(
                conn => '{UC_CONNECTION}',
                method => 'GET',
                path => concat('unity-catalog/tables/', table_name),
                headers => map('Accept', 'application/json')
            ) as response
        )
        SELECT 
            from_json(
                response.text,
                'name STRING, catalog_name STRING, schema_name STRING, table_type STRING, data_source_format STRING, comment STRING, owner STRING, created_at BIGINT, updated_at BIGINT, storage_location STRING, full_name STRING'
            ) as metadata
        FROM api_response
        WHERE response.status_code = 200
    );
    """
)

In [None]:
spark.sql(f"SELECT {CATALOG}.{SCHEMA}.get_table_metadata('{UC_TABLES[0]}')")

### Get Table Schema
* Purpose: Returns detailed column-level metadata including names, data types, descriptions, nullability, and position from Unity Catalog via REST API
* Why Essential: Core function that provides the schema blueprint the LLM needs to generate accurate SQL - includes column descriptions that are crucial for semantic understanding of what each field represents

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_table_schema(
        table_name STRING COMMENT 'Fully qualified table name (catalog.schema.table)'
    )
    RETURNS TABLE
    COMMENT 'Returns detailed schema information via REST API including column names, types, and descriptions'
    LANGUAGE SQL
    RETURN 
        SELECT 
            col.name as column_name,
            col.type_text as data_type,
            col.comment as column_description,
            col.nullable as is_nullable,
            col.position as ordinal_position
        FROM (
            SELECT 
                CASE 
                    WHEN table_name NOT IN ({table_list_sql}) THEN
                        NULL
                    WHEN response.status_code != 200 THEN
                        NULL
                    ELSE
                        from_json(
                            response.text,
                            'columns ARRAY<STRUCT<name:STRING, type_text:STRING, type_name:STRING, comment:STRING, nullable:BOOLEAN, position:INT>>'
                        ).columns
                END as columns_array
            FROM (
                SELECT http_request(
                    conn => '{UC_CONNECTION}',
                    method => 'GET',
                    path => CONCAT('unity-catalog/tables/', table_name),
                    headers => map('Accept', 'application/json')
                ) as response
            )
        )
        LATERAL VIEW explode(columns_array) exploded_table AS col
        WHERE col.name IS NOT NULL
        ORDER BY col.position
    """
)

In [None]:
result = spark.sql(
    f"SELECT * FROM {CATALOG}.{SCHEMA}.get_table_schema('{UC_TABLES[0]}')"
)

display(result.limit(5))

### Validate Query Function
* Purpose: Validates SQL syntax and ensures read-only operations
* Why Essential: Prevents write operations and SQL injection attempts

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.validate_query(query STRING)
    RETURNS STRUCT<is_valid: BOOLEAN, error_message: STRING>
    LANGUAGE PYTHON
    AS $$
    import re

    def validate_query(query):
        if not query or not query.strip():
            return (False, "Query is empty")
        
        query_upper = query.upper()
        
        # Block write operations
        write_keywords = [
            'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER',
            'TRUNCATE', 'MERGE', 'REPLACE', 'GRANT', 'REVOKE'
        ]
        
        for kw in write_keywords:
            if re.search(r'\b' + kw + r'\b', query_upper):
                return (False, "Operation not allowed: " + kw)
        
        # Ensure it's a SELECT or CTE (WITH)
        stripped = query_upper.strip()
        if not (stripped.startswith('SELECT') or stripped.startswith('WITH')):
            return (False, "Only SELECT queries and CTEs are allowed")
        
        # Check for SQL injection patterns
        injection_patterns = [
            r';.*?DROP',
            r';.*?DELETE',
            r';.*?INSERT',
            r'--.*?;',
        ]
        
        for pattern in injection_patterns:
            if re.search(pattern, query_upper, re.DOTALL):
                return (False, "Potential SQL injection detected")
        
        # Validate balanced parentheses
        if query.count('(') != query.count(')'):
            return (False, "Unbalanced parentheses")
        
        return (True, None)

    return validate_query(query)
    $$
    """
)

In [None]:
spark.sql(f"SELECT {CATALOG}.{SCHEMA}.validate_query('select * from balance_sheet')")

In [None]:
spark.sql(f"SELECT {CATALOG}.{SCHEMA}.validate_query('drop function some_function')")

### Execute Query Function (Python)
* Purpose: Executes SQL query with safety validation and row limits
* Why Essential: Safe query execution with automatic limits and error handling

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.execute_query(
        sql_query STRING COMMENT 'SQL query to execute (SELECT only)',
        row_limit INT DEFAULT 1000 COMMENT 'Maximum number of rows to return (default 1000, max 10000)'
    )
    RETURNS STRING
    COMMENT 'Executes validated SQL query via Statement Execution API with safety checks and row limits. Returns JSON string with results or error message.'
    LANGUAGE SQL
    RETURN (
        SELECT 
            CASE 
                -- Step 1: Validate the query first
                WHEN {CATALOG}.{SCHEMA}.validate_query(sql_query).is_valid = false THEN
                    CONCAT('{{"error": "Validation failed: ', {CATALOG}.{SCHEMA}.validate_query(sql_query).error_message, '"}}')
                
                -- Step 2: If valid, execute via API
                ELSE
                    (
                        SELECT 
                            CASE 
                                WHEN response.status_code BETWEEN 200 AND 299 THEN
                                    response.text
                                ELSE
                                    CONCAT(
                                        '{{"error": "Query execution failed", "status_code": ',
                                        CAST(response.status_code AS STRING),
                                        ', "details": ',
                                        COALESCE(response.text, '""'),
                                        '}}'
                                    )
                            END
                        FROM (
                            SELECT http_request(
                                conn => '{UC_CONNECTION_SQL_EXEC}',
                                method => 'POST',
                                path => 'sql/statements/',
                                headers => map('Accept', 'application/json', 'Content-Type', 'application/json'),
                                json => CONCAT(
                                    '{{"warehouse_id": "{SQL_WAREHOUSE_ID}", ',
                                    '"statement": "', 
                                    -- Add LIMIT if not present
                                    CASE 
                                        WHEN UPPER(sql_query) LIKE '%LIMIT%' THEN 
                                            REPLACE(REPLACE(sql_query, '"', '\\"'), '\n', ' ')
                                        ELSE 
                                            CONCAT(
                                                REPLACE(REPLACE(sql_query, '"', '\\"'), '\n', ' '),
                                                ' LIMIT ',
                                                CAST(LEAST(GREATEST(COALESCE(row_limit, 1000), 1), 10000) AS STRING)
                                            )
                                    END,
                                    '", "wait_timeout": "30s", "on_wait_timeout": "CANCEL", "format": "JSON_ARRAY"}}'
                                )
                            ) as response
                        )
                    )
            END
    )
"""
)

In [None]:
spark.sql(
    f"SELECT {CATALOG}.{SCHEMA}.execute_query('select * from {UC_TABLES[0]} limit 1')"
)

### Get Sample Data Function
* Purpose: Returns sample rows from a table to help understand data patterns
* Why Essential: LLMs need concrete examples to understand data formats, value patterns, and data types

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_sample_data(
        table_name STRING COMMENT 'Fully qualified table name (catalog.schema.table)',
        num_rows INT DEFAULT 3 COMMENT 'Number of sample rows to return (default 3, max 10)'
    )
    RETURNS STRING
    COMMENT 'Returns sample rows from a table to help understand data patterns and formats. Validates table allowlist and enforces row limits.'
    LANGUAGE SQL
    RETURN (
        SELECT 
            CASE 
                -- Validate table is in allowlist
                WHEN table_name NOT IN ({table_list_sql}) THEN
                    CONCAT('{{"error": "Table not in allowlist: ', table_name, '"}}')
                
                -- Validate num_rows is within bounds
                WHEN num_rows < 1 OR num_rows > 10 THEN
                    '{{"error": "num_rows must be between 1 and 10"}}'
                
                -- Execute query via execute_query function
                ELSE
                    {CATALOG}.{SCHEMA}.execute_query(
                        CONCAT('SELECT * FROM ', table_name),
                        LEAST(GREATEST(COALESCE(num_rows, 3), 1), 10)
                    )
            END
    )
    """
)

In [None]:
spark.sql(f"SELECT {CATALOG}.{SCHEMA}.get_sample_data('{UC_TABLES[0]}')")

## Future TODOs

### Get Table Relationships Function
* Purpose: Returns foreign key relationships for a table
* Why Essential: Required for multi-table queries - agent learns correct JOIN syntax

### Other Advanced Functions

* Business Glossary Supporting Table
    * Purpose: Maps business terminology to technical column names
    * Why Advanced: Critical for enterprise deployments where business users don't know technical schema
* Get Business Terms Function
    * Purpose: Returns business terminology mappings for domain-specific queries
    * Why Advanced: Enables non-technical users to query using familiar business language
* Query History Table
    * Purpose: Stores successful query patterns for few-shot learning
    * Why Advanced: Enables continuous improvement from successful patterns
    * WE WILL DO THIS VIA **INFERENCE TABLE**
* Get Similar Queries Function
    * Purpose: Returns similar successful queries for few-shot learning
    * Why Advanced: Improves accuracy by learning from past successful patterns
    * WE WILL DO THIS VIA **VECTOR SEARCH INDEX**
* Format Query Results Function (Python)
    * Purpose: Executes query and formats results in requested format
    * Why Advanced: Flexible output formatting for different consumption patterns
* Log Query Execution Function
    * Purpose: Logs query execution for audit and learning
    * Why Advanced: Enables continuous improvement and compliance
    * WE WILL DO THIS VIA **INFERENCE TABLE**