In [0]:
# Databricks notebook source
%pip install langchain-core langchain-openai langchain-community pyyaml

# COMMAND ----------

open_ai_key = ""

table_address_default = "vanhack.mobi_data.silver_stations"
data_contract_default_path = "/Workspace/Users/jazz@jazzgrewal.com/mobi_agent_starter/datacontract.yml"

# Create widgets
dbutils.widgets.text("table_address", table_address_default)
dbutils.widgets.text("data_contract_path", data_contract_default_path)

# Read widget values
table_address = dbutils.widgets.get("table_address")
data_contract_path = dbutils.widgets.get("data_contract_path")

# COMMAND ----------

# ------------------------------------------------------
# 0. SETUP: Load OpenAI key from Databricks Secrets
# ------------------------------------------------------
open_ai_key = open_ai_key #dbutils.secrets.get("openai", "api_key")

from openai import OpenAI
import json
import yaml
import re


def read_contract(file_path: str) -> dict:
    """Load a YAML data contract from disk."""
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            content = f.read()
    except FileNotFoundError as exc:
        raise FileNotFoundError(
            f"Data contract file not found at `{file_path}`."
        ) from exc

    contract = yaml.safe_load(content) or {}
    if not isinstance(contract, dict):
        raise ValueError("Data contract must deserialize to a dictionary.")

    return contract


def extract_column_descriptions(contract: dict) -> dict:
    """Recursively collect column descriptions from a data contract."""

    descriptions = {}

    def _extract_from_columns(columns_obj):
        if not isinstance(columns_obj, dict):
            return
        for column_name, payload in columns_obj.items():
            if isinstance(payload, dict):
                description = payload.get("description") or payload.get("comment")
                if isinstance(description, str) and description.strip():
                    descriptions[str(column_name)] = description.strip()
            elif isinstance(payload, str) and payload.strip():
                descriptions[str(column_name)] = payload.strip()

    def _walk(node):
        if isinstance(node, dict):
            for key, value in node.items():
                if key.lower() == "columns":
                    _extract_from_columns(value)
                else:
                    _walk(value)
        elif isinstance(node, list):
            for item in node:
                _walk(item)

    _walk(contract)
    return descriptions


def generate_placeholder_description(column_name: str, data_type: str) -> str:
    """Produce a human-style one-line fallback description."""

    words = column_name.replace("_", " ")
    words = re.sub(r"\s+", " ", words).strip()
    if not words:
        words = column_name

    # Simple heuristic for ID fields
    if re.search(r"id$", column_name, re.IGNORECASE):
        return f"Unique identifier for the {words.lower()} record.".capitalize()

    return f"{words.capitalize()} value stored as {data_type}."

client = OpenAI(api_key=open_ai_key)

data_contract = read_contract(data_contract_path)
baseline_contract_columns = extract_column_descriptions(data_contract)

# data_contract_json = json.dumps(data_contract, indent=2)
data_contract_json = json.dumps(data_contract, indent=2, default=str)
contract_column_guidance_json = json.dumps(baseline_contract_columns, indent=2, default=str)
# contract_column_guidance_json = json.dumps(baseline_contract_columns, indent=2)

def escape_table_address(table_address: str):
    parts = table_address.split(".")
    if len(parts) != 3:
        raise ValueError(f"Invalid UC table address: {table_address}")

    catalog, schema, table = parts
    return f"`{catalog}`.`{schema}`.`{table}`"


def escape_identifier(identifier: str) -> str:
    """Backtick-escape Unity Catalog identifiers."""
    if "`" in identifier:
        raise ValueError(f"Identifier contains illegal backtick: {identifier}")
    return f"`{identifier}`"


def sql_escape_literal(value: str) -> str:
    """Escape single quotes for safe SQL literal usage."""
    return value.replace("'", "''")


# ======================================================
# 1. METADATA UPDATER TOOL ‚Äî REAL UNITY CATALOG UPDATES
# ======================================================
def metadata_creator(table_address: str, data_contract: dict):
    """
    Updates Unity Catalog table & column comments based on a data contract.

    Example contract:
    {
        "table_comment": "Table description",
        "columns": {
            "col1": "Comment1",
            "col2": "Comment2"
        }
    }
    """

    summary = []
    
    # --------------------------------------------------
    # Validate table exists
    # --------------------------------------------------
    try:
        table = spark.table(table_address)
    except Exception as e:
        return f"‚ùå Table not found: {table_address}\n{str(e)}"

    schema_fields = {field.name: field for field in table.schema.fields}
    schema_cols = list(schema_fields.keys())
    escaped_table_address = escape_table_address(table_address)


    # --------------------------------------------------
    # 1. Update table comment
    # --------------------------------------------------
    if "table_comment" in data_contract and data_contract["table_comment"]:
        comment = str(data_contract["table_comment"]).strip()
        sql = f"COMMENT ON TABLE {escaped_table_address} IS '{sql_escape_literal(comment)}'"
        print(sql)
        spark.sql(sql)
        summary.append(f"‚úÖ Updated table comment ‚Üí {comment}")


    # --------------------------------------------------
    # 2. Update column-level comments
    # --------------------------------------------------
    columns_payload = {}
    if "columns" in data_contract and isinstance(data_contract["columns"], dict):
        columns_payload = {k: v for k, v in data_contract["columns"].items() if isinstance(k, str)}

    unknown_columns = sorted(set(columns_payload.keys()) - set(schema_cols))
    if unknown_columns:
        summary.append(
            "‚ö†Ô∏è Ignored unknown columns " + ", ".join(f"`{col}`" for col in unknown_columns)
        )

    for col in schema_cols:
        field = schema_fields[col]
        base_comment = columns_payload.get(col)
        comment_text = str(base_comment).strip() if base_comment is not None else ""

        if not comment_text:
            baseline_comment = baseline_contract_columns.get(col)
            if baseline_comment:
                comment_text = baseline_comment.strip()
                summary.append(f"‚ÑπÔ∏è Reused baseline contract comment for `{col}` ‚Üí {comment_text}")
            else:
                comment_text = generate_placeholder_description(col, field.dataType.simpleString())
                summary.append(f"‚ÑπÔ∏è Generated placeholder comment for `{col}` ‚Üí {comment_text}")

        column_ref = f"{escaped_table_address}.{escape_identifier(col)}"
        sql = f"COMMENT ON COLUMN {column_ref} IS '{sql_escape_literal(comment_text)}'"
        print(sql)
        spark.sql(sql)
        summary.append(f"üîπ Updated column `{col}` ‚Üí {comment_text}")


    # --------------------------------------------------
    # 3. Done
    # --------------------------------------------------
    return "\n".join(summary)


# ------------------------------------------------------
# Build schema overview for prompt engineering context
# ------------------------------------------------------
try:
    table = spark.table(table_address)
except Exception as exc:
    raise RuntimeError(
        f"Unable to load table `{table_address}` to build schema overview."
    ) from exc

schema_overview = []
for field in table.schema.fields:
    comment = ""
    if field.metadata and "comment" in field.metadata:
        comment = field.metadata["comment"]
    schema_overview.append(
        {
            "name": field.name,
            "type": field.dataType.simpleString(),
            "nullable": field.nullable,
            "existing_comment": comment,
        }
    )

schema_overview_json = json.dumps(schema_overview, indent=2)


# ======================================================
# 2. TOOL SCHEMA FOR OPENAI FUNCTION CALLING
# ======================================================
tools = [
    {
        "type": "function",
        "function": {
            "name": "metadata_creator",
            "description": "Applies Unity Catalog table and column comments using the provided data contract.",
            "parameters": {
                "type": "object",
                "additionalProperties": False,
                "properties": {
                    "table_address": {
                        "type": "string",
                        "description": "Unity Catalog table address in the form catalog.schema.table."
                    },
                    "data_contract": {
                        "type": "object",
                        "additionalProperties": False,
                        "properties": {
                            "table_comment": {
                                "type": "string",
                                "description": "Optional table-level description"
                            },
                            "columns": {
                                "type": "object",
                                "description": "Mapping of column name to desired comment.",
                                "minProperties": 1,
                                "additionalProperties": {"type": "string"}
                            }
                        },
                        "required": ["columns"]
                    }
                },
                "required": ["table_address", "data_contract"]
            }
        }
    }
]



# ======================================================
# 3. CALL THE AGENT
# ======================================================
response = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=[
        {
            "role": "system",
            "content": (
                "You are a meticulous Data Governance Agent working in Databricks. "
                "Generate column-level Unity Catalog metadata updates. "
                "Always reply by calling the provided tool with valid JSON arguments, "
                "include every column exactly once, and avoid hallucinating new columns."
            )
        },
        {
            "role": "user",
            "content": f"""
Please craft concise, high-quality descriptions for every column in the Unity Catalog table "{table_address}".

Table schema (JSON):
{schema_overview_json}

Data contract (YAML ‚Üí JSON):
{data_contract_json}

Column guidance extracted from contract:
{contract_column_guidance_json}

Requirements:
- Respond by invoking the `metadata_creator` tool only.
- `data_contract.columns` must include all columns listed above with helpful descriptions.
- Reuse or refine `existing_comment` and data contract guidance when informative; otherwise write a new, human-friendly, one-sentence description that clarifies the column's business meaning.
- Omit `table_comment` unless you have a strong justification.
- Do not add or remove columns, and do not include commentary outside the tool call.
            """
        }
    ],
    tools=tools,
    tool_choice="auto"
)

msg = response.choices[0].message



# ======================================================
# 4. HANDLE TOOL CALL (Execute Python function)
# ======================================================
if msg.tool_calls:
    for call in msg.tool_calls:
        name = call.function.name
        args = json.loads(call.function.arguments)

        if name == "metadata_creator":
            result = metadata_creator(**args)
            print("=== METADATA UPDATE RESULT ===")
            print(result)

else:
    print("LLM response (no tool call):")
    print(msg.content)


COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`station_id` IS 'Unique identifier for the station id record.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`name` IS 'Name value stored as string.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`lat` IS 'Lat value stored as double.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`lon` IS 'Lon value stored as double.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`is_virtual_station` IS 'Is virtual station value stored as boolean.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`capacity` IS 'Capacity value stored as bigint.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`vehicle_type_capacity` IS 'Vehicle type capacity value stored as struct<1:bigint,2:bigint,3:bigint>.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`is_valet_station` IS 'Is valet station value stored as boolean.'
COMMENT ON COLUMN `vanhack`.`mobi_data`.`silver_stations`.`is_char