# Bulk table description generator


Customization required
- Configure LLM prompt as required


In [0]:
dbutils.widgets.text("Catalog", "", "Enter Catalog Name (Mandatory):")
dbutils.widgets.text("Schema", "", "Enter Schema Name (Optional):")
dbutils.widgets.text("Table", "", "Enter Table Name (Optional):")
dbutils.widgets.text("Output Path", "", "Enter Output Path (Mandatory):")
dbutils.widgets.text("Model Serving Endpoint Name", "", "Model Serving Endpoint Name (Mandatory):")

In [0]:
catalog = dbutils.widgets.get("Catalog")
schema = dbutils.widgets.get("Schema")
table = dbutils.widgets.get("Table")
output_path = dbutils.widgets.get("Output Path")
endpoint_name = dbutils.widgets.get("Model Serving Endpoint Name")

In [0]:
print(f"{catalog},{schema},{table},{output_path},{endpoint_name}")

In [0]:
from pyspark.sql.utils import AnalysisException

def get_table_columns(catalog: str, schema: str = None, table: str = None):
    """
    Get all columns and datatypes for one or more tables in Databricks,
    returning a dict with keys (catalog, schema, table) and values as
    comma-separated "col_name col_type" strings.
    """
    results = {}

    try:
        # Case 1: Specific catalog/schema/table provided
        if catalog and schema and table:
            df = spark.table(f"{catalog}.{schema}.{table}")
            schema_fields = [f"{f.name} {f.dataType.simpleString()}" for f in df.schema.fields]
            results[(catalog, schema, table)] = ", ".join(schema_fields)
            return results

        # Case 2: Need to loop through all catalogs/schemas/tables
        catalogs = [catalog] if catalog else [row.catalog_name for row in spark.sql("SHOW CATALOGS").collect()]

        for cat in catalogs:
            schemas = [schema] if schema else [row.databaseName for row in spark.sql(f"SHOW SCHEMAS IN {cat}").collect()]
            
            for sch in schemas:
                try:
                    tables = [table] if table else [row.tableName for row in spark.sql(f"SHOW TABLES IN {cat}.{sch}").collect()]
                except AnalysisException:
                    # Skip schema if not accessible
                    continue

                for tbl in tables:
                    try:
                        df = spark.table(f"{cat}.{sch}.{tbl}")
                        schema_fields = [f"{f.name} {f.dataType.simpleString()}" for f in df.schema.fields]
                        results[(cat, sch, tbl)] = ", ".join(schema_fields)
                    except AnalysisException:
                        # Skip table if not accessible
                        continue

    except Exception as e:
        print(f"Error while fetching columns: {e}")

    return results


In [0]:
# Function to retrieve the table column comments for a given catalog, schema, table.
def get_table_comments(catalog, table_column_details, schema=None, table=None):
    """
    Generate AI-assisted comments for tables, dynamically including schema details
    from the table_column_details dict for each row.
    """
    if table_column_details is None:
        raise ValueError("table_column_details dict must be provided")

    # Convert the dict into a DataFrame for joining
    rows = [
        (cat, sch, tbl, col_str)
        for (cat, sch, tbl), col_str in table_column_details.items()
    ]
    schema_str_df = spark.createDataFrame(rows, ["catalog", "schema", "table", "schema_str"])
    schema_str_df.createOrReplaceTempView("table_columns_view")

    # Build the query with a LEFT JOIN to dynamically get schema_str per row
    query = f"""
        SELECT 
            t.table_catalog,
            t.table_schema,
            t.table_name,
            t.comment IS NULL OR length(t.comment) == 0 AS replace_comment,
            t.comment AS existing_comment,
            ai_query(
                '{endpoint_name}',
                'Generate a paragraph of description of the type of information that the table "' ||
                    t.table_name ||
                    '" in schema "' ||
                    t.table_schema ||
                    '" within the catalog "' ||
                    t.table_catalog ||
                    '" would contain (based on the name and datatypes of the columns: ' ||
                    COALESCE(c.schema_str, '') ||
                    '). This will be used as a table description, so there is no need to mention that this is a table within a schema within a catalog.'
            ) AS new_comment
        FROM system.information_schema.tables AS t
        LEFT JOIN table_columns_view AS c
          ON t.table_catalog = c.catalog
         AND t.table_schema = c.schema
         AND t.table_name = c.table
        WHERE t.table_catalog = :catalog
    """

    if schema:
        query += " AND t.table_schema = :schema"
    if table:
        query += " AND t.table_name = :table"

    query += " ORDER BY t.table_catalog, t.table_schema, t.table_name"

    table_comments = spark.sql(query, args={"catalog": catalog, "schema": schema, "table": table})
    return table_comments


In [0]:
table_column_details=get_table_columns(catalog, schema, table)

In [0]:
commented_tables = get_table_comments(catalog, table_column_details, schema, table)
display(commented_tables)


### In case user wants to update some table comments after reviewing

In [0]:
from pyspark.sql import functions as F

# Example mapping of updates you want to apply
# updates = [
#     (catalog, "fgac", "customer", "This is the updated comment for customer."),
#     (catalog, "fgac", "customer_pii_data_parquet", "This is the updated comment for customer_pii_data_parquet."),
# ]

updates = []

if not updates:
  commented_tables_updated=commented_tables
else:
  # Create a DataFrame with the updates
  updates_df = spark.createDataFrame(updates, ["table_catalog", "table_schema", "table_name", "updated_comment"])

  # Left join with the original DataFrame
  commented_tables_updated = (
      commented_tables
      .join(
          updates_df,
          on=["table_catalog", "table_schema", "table_name"],
          how="left"
      )
      .withColumn(
          "new_comment",
          F.when(F.col("updated_comment").isNotNull(), F.col("updated_comment"))
          .otherwise(F.col("new_comment"))
      )
      .drop("updated_comment")  # cleanup temp column
  )

In [0]:
display(commented_tables_updated)

In [0]:
# Choose your desired file format
# commented_columns.coalesce(1).write.mode("overwrite").option("header", "true").csv(output_path + "/csv")
commented_tables_updated.coalesce(1).write.mode("overwrite").option("header", "true").json(output_path + "/json")