# Create Agent Tools as UC Functions

In [None]:
from databricks.connect import DatabricksSession

spark = DatabricksSession.builder.remote(serverless=True).getOrCreate()

In [None]:
import os
import mlflow

# TODO make sure you update the config file before this

configs = mlflow.models.ModelConfig(development_config="./config.yml")
databricks_config = configs.get("databricks")
tools_config = configs.get("tools")

CATALOG = databricks_config["catalog"]
SCHEMA = databricks_config["schema"]
UC_TABLES = tools_config["tables"]

# Format the table list for SQL IN clause
table_list_sql = ", ".join([f"'{table}'" for table in UC_TABLES])

In [None]:
# load experiment

## Core Functions

### 1. List Available Tables Function

* Purpose: Returns all tables the agent is allowed to query with descriptions
* Why Essential: Enables dynamic table discovery without hardcoding table lists in prompts

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.list_available_tables()
    RETURNS TABLE
    COMMENT 'Returns all tables the agent is allowed to query with descriptions'
    LANGUAGE PYTHON
    AS $$
        from pyspark.sql import SparkSession
        from pyspark.sql.types import StructType, StructField, StringType
        
        spark = SparkSession.getActiveSession()
        
        # List of allowed tables from config
        allowed_tables = {UC_TABLES}
        
        results = []
        for table in allowed_tables:
            try:
                # Use DESCRIBE EXTENDED which is much faster than information_schema
                desc = spark.sql(f"DESCRIBE EXTENDED {{table}}")
                
                # Extract table comment from DESCRIBE EXTENDED output
                table_comment = None
                for row in desc.collect():
                    if row['col_name'] and 'Comment' in row['col_name']:
                        table_comment = row['data_type']
                        break
                
                # Get table type using DESCRIBE DETAIL (also fast)
                detail = spark.sql(f"DESCRIBE DETAIL {{table}}").collect()[0]
                table_type = detail['format'] if 'format' in detail.asDict() else 'TABLE'
                
                results.append({{
                    'full_table_name': table,
                    'table_name': table.split('.')[-1],
                    'description': table_comment,
                    'table_type': table_type
                }})
            except Exception as e:
                # Skip tables that don't exist or can't be accessed
                continue
        
        # Return as DataFrame
        schema = StructType([
            StructField('full_table_name', StringType(), True),
            StructField('table_name', StringType(), True),
            StructField('description', StringType(), True),
            StructField('table_type', StringType(), True)
        ])
        
        return spark.createDataFrame(results, schema)
    $$
    ;
    """
)

### 2. Get Table Schema Function
* Purpose: Returns detailed schema information for a specific table
* Why Essential: Provides column names, types, and descriptions for SQL generation

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_table_schema(
        table_name STRING COMMENT 'Fully qualified table name (catalog.schema.table)'
    )
    RETURNS TABLE
    COMMENT 'Returns detailed schema information for a table including column names, types, and descriptions'
    LANGUAGE PYTHON
    AS $$
        from pyspark.sql import SparkSession
        from pyspark.sql.types import StructType, StructField, StringType, IntegerType
        
        spark = SparkSession.getActiveSession()
        
        # Security check: only allow querying approved tables
        allowed_tables = {UC_TABLES}
        if table_name not in allowed_tables:
            raise ValueError(f"Access denied: table {{table_name}} is not in the allowed list")
        
        try:
            # Use DESCRIBE EXTENDED for fast metadata access
            desc_df = spark.sql(f"DESCRIBE EXTENDED {{table_name}}")
            
            results = []
            position = 1
            
            for row in desc_df.collect():
                col_name = row['col_name']
                data_type = row['data_type']
                comment = row['comment']
                
                # Skip metadata rows (they start with # or are empty)
                if not col_name or col_name.startswith('#') or col_name == '':
                    break
                    
                # Determine nullability (default to True if not specified)
                is_nullable = 'NOT NULL' not in str(data_type).upper()
                
                results.append({{
                    'column_name': col_name,
                    'data_type': data_type,
                    'comment': comment if comment else None,
                    'is_nullable': 'YES' if is_nullable else 'NO',
                    'ordinal_position': position
                }})
                
                position += 1
            
            # Return as DataFrame
            schema = StructType([
                StructField('column_name', StringType(), True),
                StructField('data_type', StringType(), True),
                StructField('comment', StringType(), True),
                StructField('is_nullable', StringType(), True),
                StructField('ordinal_position', IntegerType(), True)
            ])
            
            return spark.createDataFrame(results, schema)
            
        except Exception as e:
            raise ValueError(f"Failed to get schema for {{table_name}}: {{str(e)}}")
    $$
    ;
    """
)

### 3. Get Sample Data Function
* Purpose: Returns sample rows from a table to help understand data patterns
* Why Essential: LLMs need concrete examples to understand data formats, value patterns, and data types

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_sample_data(
        table_name STRING COMMENT 'Fully qualified table name',
        num_rows INT DEFAULT 3 COMMENT 'Number of sample rows (default: 3, max: 10)'
    )
    RETURNS TABLE
    COMMENT 'Returns sample rows from a table to help understand data patterns'
    LANGUAGE SQL
    RETURN
        SELECT *
        FROM IDENTIFIER(table_name)
        LIMIT LEAST(num_rows, 10)
    ;
    """
)

### 4. Get Table Relationships Function
* Purpose: Returns foreign key relationships for a table
* Why Essential: Required for multi-table queries - agent learns correct JOIN syntax

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.get_table_relationships(
        table_name STRING COMMENT 'Fully qualified table name'
    )
    RETURNS TABLE
    COMMENT 'Returns foreign key relationships for a table to enable JOIN generation'
    LANGUAGE PYTHON
    AS $$
        from pyspark.sql import SparkSession
        from pyspark.sql.types import StructType, StructField, StringType
        
        spark = SparkSession.getActiveSession()
        
        results = []
        
        try:
            # Try to get constraint information using SHOW TBLPROPERTIES
            # This is faster than information_schema
            try:
                constraints = spark.sql(f"SHOW TBLPROPERTIES {{table_name}}")
                
                # Look for foreign key constraints in table properties
                for row in constraints.collect():
                    if 'constraint' in str(row['key']).lower() or 'foreign' in str(row['key']).lower():
                        # Parse constraint info if available
                        # Note: This is implementation-specific and may vary
                        pass
            except:
                pass
            
            # If no constraints found via properties, return empty result
            # Note: Unity Catalog constraints can also be queried via:
            # spark.sql(f"DESCRIBE DETAIL {{table_name}}") and checking properties
            
            schema = StructType([
                StructField('from_table', StringType(), True),
                StructField('foreign_key_column', StringType(), True),
                StructField('to_table', StringType(), True),
                StructField('referenced_column', StringType(), True),
                StructField('constraint_name', StringType(), True)
            ])
            
            return spark.createDataFrame(results, schema)
            
        except Exception as e:
            # Return empty DataFrame on error
            schema = StructType([
                StructField('from_table', StringType(), True),
                StructField('foreign_key_column', StringType(), True),
                StructField('to_table', StringType(), True),
                StructField('referenced_column', StringType(), True),
                StructField('constraint_name', StringType(), True)
            ])
            return spark.createDataFrame([], schema)
    $$
    ;
    """
)

### 5. Validate Query Function (Python)
* Purpose: Validates SQL syntax and ensures read-only operations
* Why Essential: Prevents write operations and SQL injection attempts

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.validate_query(
        query STRING COMMENT 'SQL query to validate'
    )
    RETURNS STRUCT<is_valid: BOOLEAN, error_message: STRING>
    COMMENT 'Validates SQL syntax and ensures read-only operations (no INSERT/UPDATE/DELETE)'
    LANGUAGE PYTHON
    AS $$
        import re

        # Normalize query for checking
        query_upper = query.upper()
        query_stripped = query.strip()

        # Block write operations
        write_keywords = [
            'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE',
            'ALTER', 'TRUNCATE', 'MERGE', 'GRANT', 'REVOKE',
            'REPLACE', 'RENAME', 'COPY', 'LOAD', 'UNLOAD'
        ]

        for keyword in write_keywords:
            # Use word boundary to avoid false positives
            if re.search(rf'\\b{{keyword}}\\b', query_upper):
                return {{
                    'is_valid': False,
                    'error_message': f'Operation {{keyword}} is not allowed. Only SELECT queries are permitted.'
                }}

        # Ensure query starts with SELECT or WITH (for CTEs)
        if not re.match(r'^(SELECT|WITH)\\b', query_stripped, re.IGNORECASE):
            return {{
                'is_valid': False,
                'error_message': 'Only SELECT queries (and WITH/CTEs) are allowed.'
            }}

        # Check for SQL injection patterns
        injection_patterns = [
            r';.*?DROP',
            r';.*?DELETE',
            r'--.*?(DROP|DELETE|UPDATE)',
            r'/\\*.*?(DROP|DELETE|UPDATE).*?\\*/'
        ]

        for pattern in injection_patterns:
            if re.search(pattern, query_upper):
                return {{
                    'is_valid': False,
                    'error_message': 'Potentially malicious SQL pattern detected.'
                }}

        # Basic syntax validation (check for balanced parentheses)
        if query.count('(') != query.count(')'):
            return {{
                'is_valid': False,
                'error_message': 'Syntax error: Unbalanced parentheses.'
            }}

        return {{'is_valid': True, 'error_message': None}}
    $$
    ;
    """
)

### 6. Execute Query Function (Python)
* Purpose: Executes SQL query with safety validation and row limits
* Why Essential: Safe query execution with automatic limits and error handling

In [None]:
spark.sql(
    f"""
    CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.execute_query(
        sql_query STRING COMMENT 'SQL query to execute',
        row_limit INT DEFAULT 1000 COMMENT 'Maximum rows to return (default: 1000)'
    )
    RETURNS TABLE
    COMMENT 'Executes SQL query with safety validation, row limits, and timeout protection'
    LANGUAGE PYTHON
    AS $$
        import re
        from pyspark.sql import SparkSession

        spark = SparkSession.getActiveSession()

        # Validate query is read-only using our validation function
        validation = spark.sql(f\"\"\"
            SELECT {CATALOG}.{SCHEMA}.validate_query('{{sql_query.replace("'", "''")}}') as result
        \"\"\").collect()[0]['result']

        if not validation['is_valid']:
            raise ValueError(f"Query validation failed: {{validation['error_message']}}")

        # Add row limit for safety (cap at 10,000)
        safe_limit = min(row_limit, 10000)

        # Check if query already has LIMIT clause
        if re.search(r'\\bLIMIT\\s+\\d+', sql_query, re.IGNORECASE):
            safe_query = sql_query
        else:
            safe_query = f"SELECT * FROM ({{sql_query}}) LIMIT {{safe_limit}}"

        try:
            result = spark.sql(safe_query)
            return result
        except Exception as e:
            raise ValueError(f"Query execution failed: {{str(e)}}")
    $$
    ;
    """
)

## Enhanced Functions (High Value)

### 7. Search Tables by Keyword Function
* Purpose: Searches for tables and columns matching keyword
* Why High Value: Enables semantic schema discovery - agent can find relevant tables for "revenue", "customers", etc.

## Advanced Functions (Future Enhancements)

* Business Glossary Supporting Table
    * Purpose: Maps business terminology to technical column names
    * Why Advanced: Critical for enterprise deployments where business users don't know technical schema
* Get Business Terms Function
    * Purpose: Returns business terminology mappings for domain-specific queries
    * Why Advanced: Enables non-technical users to query using familiar business language
* Query History Table
    * Purpose: Stores successful query patterns for few-shot learning
    * Why Advanced: Enables continuous improvement from successful patterns
    * WE WILL DO THIS VIA **INFERENCE TABLE**
* Get Similar Queries Function
    * Purpose: Returns similar successful queries for few-shot learning
    * Why Advanced: Improves accuracy by learning from past successful patterns
    * WE WILL DO THIS VIA **VECTOR SEARCH INDEX**
* Format Query Results Function (Python)
    * Purpose: Executes query and formats results in requested format
    * Why Advanced: Flexible output formatting for different consumption patterns
* Log Query Execution Function
    * Purpose: Logs query execution for audit and learning
    * Why Advanced: Enables continuous improvement and compliance
    * WE WILL DO THIS VIA **INFERENCE TABLE**