# Catalog Table Sizes
This notebook lists the total size of every table across all schemas in a selected catalog using [DiscoverX](https://github.com/databrickslabs/discoverx).

Use the widget below to select one or more catalogs, then run the remaining cells.

In [0]:
%pip install dbl-discoverx
dbutils.library.restartPython()

In [0]:
# Create widgets for catalogs
catalogs = [row.catalog for row in spark.sql("SHOW CATALOGS").collect()]
catalogs.append("None Selected")
dbutils.widgets.multiselect("1.catalogs", "None Selected", catalogs)

dbutils.widgets.text("2.target_table", "", label="Target table (catalog.schema.table)")
dbutils.widgets.dropdown("2.write_mode", "overwrite", ["overwrite", "append"], label="Write mode")


In [0]:
catalog_list = [c for c in dbutils.widgets.get("1.catalogs").split(',') if c]

In [0]:
from pyspark.sql import Row
from discoverx import DX

dx = DX()

def human_size(size_bytes):
    for unit in ['B','KB','MB','GB','TB','PB','EB']:
        if size_bytes < 1024 or unit == 'EB':
            return f"{size_bytes:.2f} {unit}"
        size_bytes /= 1024

def table_size(tbl):
    qname = f"`{tbl.catalog}`.`{tbl.schema}`.`{tbl.table}`"
    df = spark.sql(f"DESCRIBE DETAIL {qname}")
    size = df.select('sizeInBytes').collect()[0][0]
    return {
        'table': f"{tbl.catalog}.{tbl.schema}.{tbl.table}",
        'size': size,
        'size_human': human_size(size)
    }

results = []
for cat in catalog_list:
    results.extend(dx.from_tables(f'{cat}.*.*').map(table_size))
df = spark.createDataFrame(results)

# Add total row for all tables
total_size = df.agg({'size': 'sum'}).collect()[0][0]
df = df.union(spark.createDataFrame([Row(size=total_size, size_human=human_size(total_size), table='ALL_TABLES')]))
target_table = dbutils.widgets.get("2.target_table").strip()
write_mode = dbutils.widgets.get("2.write_mode")

if target_table:
    (df
     .write
     .mode(write_mode)
     .option("overwriteSchema", "true")
     .saveAsTable(target_table)
     )

display(df)

In [None]:
from pyspark.sql import functions as F

# Aggregate table sizes to the schema level
schema_sizes = (
    df.filter(F.col("table") != "ALL_TABLES")
      .withColumn("catalog", F.split("table", \"\\.\\").getItem(0))
      .withColumn("schema", F.split("table", \"\\.\\").getItem(1))
      .groupBy("catalog", "schema")
      .agg(F.sum("size").alias("size"))
)

schema_sizes.createOrReplaceTempView("schema_sizes")

# Use ai_similarity to build a graph where every edge represents two schemas
# whose names are similar enough to belong to the same dataset. The
# connected components of this graph capture groups of any size (2 or more)
# so a schema can be linked indirectly through intermediate matches.
similar_pairs = (
    spark.sql(
        """
        SELECT
            a.catalog AS catalog_a,
            a.schema AS schema_a,
            b.catalog AS catalog_b,
            b.schema AS schema_b,
            ai_similarity(a.schema, b.schema) AS similarity
        FROM schema_sizes a
        CROSS JOIN schema_sizes b
        WHERE a.schema <= b.schema
        """
    )
    .filter(F.col("similarity") >= 0.7)
    .select("catalog_a", "schema_a", "catalog_b", "schema_b")
)

pairs = [
    (f"{row.catalog_a}.{row.schema_a}", f"{row.catalog_b}.{row.schema_b}")
    for row in similar_pairs.collect()
]

schema_rows = schema_sizes.select("catalog", "schema").collect()
nodes = [f"{row.catalog}.{row.schema}" for row in schema_rows]

parents = {node: node for node in nodes}

def find(node):
    while parents[node] != node:
        parents[node] = parents[parents[node]]
        node = parents[node]
    return node


def union(a, b):
    root_a, root_b = find(a), find(b)
    if root_a == root_b:
        return
    if root_a < root_b:
        parents[root_b] = root_a
    else:
        parents[root_a] = root_b


for a, b in pairs:
    union(a, b)

clusters = {}
for node in nodes:
    root = find(node)
    clusters.setdefault(root, []).append(node)

dataset_rows = []
for members in clusters.values():
    members.sort()
    schema_names = [m.split(".", 1)[1] for m in members]
    if len(schema_names) == 1:
        label = schema_names[0]
    else:
        label = f"{schema_names[0]} (+{len(schema_names) - 1} more)"
    for member in members:
        catalog, schema = member.split(".", 1)
        dataset_rows.append((catalog, schema, label))

dataset_assignments = spark.createDataFrame(dataset_rows, ["catalog", "schema", "dataset"])

human_size_udf = F.udf(human_size)

dataset_totals = (
    schema_sizes.join(dataset_assignments, ["catalog", "schema"], "inner")
                .groupBy("dataset")
                .agg(F.sum("size").alias("size"))
                .withColumn("size_human", human_size_udf("size"))
                .orderBy(F.desc("size"))
)

display(dataset_totals.select("dataset", "size", "size_human"))
