# Bulk table column description generator


Customization required
- Configure LLM prompt as required

Authors
- Scott Eade
- Sierra Yap

In [0]:
dbutils.widgets.text("Catalog", "", "Enter Catalog Name (Mandatory):")
dbutils.widgets.text("Schema", "", "Enter Schema Name (Optional):")
dbutils.widgets.text("Table", "", "Enter Table Name (Optional):")
dbutils.widgets.text("Output Path", "", "Enter Output Path (Mandatory):")
dbutils.widgets.text("Model Serving Endpoint Name", "databricks-meta-llama-3-3-70b-instruct", "Model Serving Endpoint Name (Mandatory):")
dbutils.widgets.text("Sample Data Limit", "5", "Sample Data Limit (Mandatory):")

In [0]:
catalog = dbutils.widgets.get("Catalog")
schema = dbutils.widgets.get("Schema")
table = dbutils.widgets.get("Table")
output_path = dbutils.widgets.get("Output Path")
endpoint_name = dbutils.widgets.get("Model Serving Endpoint Name")
data_limit = int(dbutils.widgets.get("Sample Data Limit"))

In [0]:
print(f"{catalog},{schema},{table},{output_path},{endpoint_name},{data_limit}")

In [0]:
from pyspark.sql import functions as F
from pyspark.sql import Row

def get_sample_data(catalog, limit, schema="", table=""):
    """
    Fetch up to `limit` rows from each table and return dict keyed by (catalog, schema, table).
    """
    sample_data = {}

    # Build list of tables
    query = f"""
        SELECT table_catalog, table_schema, table_name
        FROM system.information_schema.tables
        WHERE table_catalog = :catalog
    """
    if schema:
        query += " AND table_schema = :schema"
        if table:
            query += " AND table_name = :table"

    tables = spark.sql(query, args={"catalog": catalog, "schema": schema, "table": table}).collect()

    for t in tables:
        full_table_name = f"{t.table_catalog}.{t.table_schema}.{t.table_name}"
        data_query = f"SELECT * FROM {full_table_name} LIMIT {limit}"

        try:
            rows = spark.sql(data_query).toPandas().to_dict(orient="records")
            sample_data[(t.table_catalog, t.table_schema, t.table_name)] = rows
        except Exception as e:
            sample_data[(t.table_catalog, t.table_schema, t.table_name)] = [{"error": str(e)}]

    return sample_data
  

In [0]:
get_sample_data(catalog, data_limit, schema, table)

In [0]:
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType

def get_column_comments(catalog, limit, schema="", table=""):
    """
    Build AI prompts for each column including schema info + sample rows,
    then call ai_query for column description generation.
    Returns: PySpark DataFrame
    """
    # Step 1: get sample rows
    samples = get_sample_data(catalog, limit, schema, table)

    # Step 2: get column metadata
    query = f"""
      SELECT c.table_catalog, c.table_schema, c.table_name, c.column_name, c.data_type,
             c.ordinal_position,
             c.comment IS NULL or length(c.comment) == 0 AS replace_comment,
             c.comment AS existing_comment
      FROM system.information_schema.columns AS c
      JOIN system.information_schema.tables AS t 
           USING (table_catalog, table_schema, table_name)
      WHERE c.table_catalog = :catalog
    """
    if schema:
        query += " AND c.table_schema = :schema"
        if table:
            query += " AND c.table_name = :table"
    query += " ORDER BY c.table_catalog, c.table_schema, c.table_name, c.ordinal_position"

    columns = spark.sql(query, args={"catalog": catalog, "schema": schema, "table": table}).collect()

    # Step 3: Build prompt + call ai_query
    rows_out = []
    for col in columns:
        key = (col.table_catalog, col.table_schema, col.table_name)
        sample_rows = samples.get(key, [])
        sample_text = str(sample_rows[:3])  # include only first 3 rows

        prompt = (
            f'Generate a one-sentence description of the type of information in column "{col.column_name}" '
            f'from table "{col.table_name}" in schema "{col.table_schema}" within catalog "{col.table_catalog}". '
            f'The column data type is "{col.data_type}". '
            f'Here are some sample rows from the table: {sample_text}. '
            f'This will be used as a column description, so avoid mentioning schema or catalog.'
        )
        print(prompt)
        # Call ai_query dynamically
        ai_sql = f"SELECT ai_query('{endpoint_name}', :prompt) AS new_comment"
        new_comment = spark.sql(ai_sql, args={"prompt": prompt}).collect()[0].new_comment

        rows_out.append(Row(
            table_catalog=col.table_catalog,
            table_schema=col.table_schema,
            table_name=col.table_name,
            column_name=col.column_name,
            ordinal_position=col.ordinal_position,
            existing_comment=col.existing_comment,
            replace_comment=col.replace_comment,
            new_comment=new_comment
        ))
    
    schema_out = StructType([
    StructField("table_catalog", StringType(), True),
    StructField("table_schema", StringType(), True),
    StructField("table_name", StringType(), True),
    StructField("column_name", StringType(), True),
    StructField("ordinal_position", StringType(), True),
    StructField("existing_comment", StringType(), True),
    StructField("replace_comment", StringType(), True),
    StructField("new_comment", StringType(), True)
])
    # Step 4: Convert list of Rows -> PySpark DataFrame
    df_out = spark.createDataFrame(
    rows_out,
    schema=schema_out
)

    return df_out
  

In [0]:
commented_columns = get_column_comments(catalog, data_limit, schema, table)
display(commented_columns)

### In case user wants to update some column comments after reviewing

In [0]:
from pyspark.sql import functions as F

# Example mapping of updates you want to apply
# updates = [
#     (catalog, "fgac", "cust", "rec_id", "This is the new updated comment for cust's record_id.", True),
#     (catalog, "fgac", "cust", "ssn", "This is the new updated comment for SSN in cust.", True),
# ]

updates = []

if not updates:
    commented_columns_updated = commented_columns
else:
    # Create a DataFrame with the updates
    updates_df = spark.createDataFrame(
        updates,
        ["table_catalog", "table_schema", "table_name", "column_name", "updated_comment", "replace_comment_update"]
    )

    # Left join with the original DataFrame
    commented_columns_updated = (
        commented_columns
        .join(
            updates_df,
            on=["table_catalog", "table_schema", "table_name", "column_name"],
            how="left"
        )
        .withColumn(
            "new_comment",
            F.when(F.col("updated_comment").isNotNull(), F.col("updated_comment"))
             .otherwise(F.col("new_comment"))
        )
        .withColumn(
            "replace_comment",
            F.when(F.col("replace_comment_update").isNotNull(), F.col("replace_comment_update")) # use update value if provided
             .otherwise(F.col("replace_comment"))  # fallback to existing
        )
        .drop("updated_comment", "replace_comment_update")  # cleanup temp columns
    )

In [0]:
display(commented_columns_updated)

In [0]:
# Choose your desired file format
# commented_columns.coalesce(1).write.mode("overwrite").option("header", "true").csv(output_path + "/csv")
commented_columns_updated.coalesce(1).write.mode("overwrite").option("header", "true").json(output_path + "/json")

In [0]:
# Function to retrieve the table column comments for a given catalog, schema, table.
# def get_column_comments(catalog, schema="", table=""):
#   query = f"""
#     SELECT c.table_catalog, c.table_schema, c.table_name, c.column_name, c.ordinal_position, c.comment IS NULL or length(c.comment) == 0 AS replace_comment, c.comment AS existing_comment
#     , ai_query('{endpoint_name}', 'Generate a 1 sentence description of the type of information that the column "' || c.column_name || '" from the table "' || c.table_name || '" in schema "' || c.table_schema || '" within the catalog "' || c.table_catalog || '" would contain (the data type of the column is "' || c.data_type || '"). This will be used as a column description, so there is no need to mention that this is a column within a schema within a catalog.') AS new_comment
#     FROM system.information_schema.columns AS c
#     JOIN system.information_schema.tables AS t USING (table_catalog, table_schema, table_name)
#     WHERE table_catalog = :catalog
#     """
#   if schema:
#     query += " AND table_schema = :schema"
#     if table:
#       query += " AND table_name = :table"
#   query += " ORDER BY table_catalog, table_schema, table_name, ordinal_position"
#   # query += " LIMIT 5"
#   column_comments = spark.sql(query, args = {"catalog": catalog, "schema": schema, "table": table})
#   return column_comments