In [None]:
!pip install sqlglot==27.20.0

In [None]:
from functools import reduce
from itertools import combinations
from math import comb
from pyspark.sql import functions as F
from sqlglot import exp
from tqdm.auto import tqdm
from typing import Dict, Iterable, List, Tuple, Any

In [None]:
MANIFEST_PATH = "metadata.manifest"
FROM_SCHEMA = "hook"
FROM_PREFIX = "frame"
TO_SCHEMA = "graph"

In [None]:
def load_manifest(table_path: str):
    manifest = (
        spark.read.table(table_path)
        .collect()
    )

    return manifest

In [None]:
def is_unique_key(table, key):
    df = (
        spark.read.table(table)
        .select(F.col(key))
        .agg(
            F.count(F.lit(1)).alias("row_count"),
            F.countDistinct(F.col(key)).alias("n_unique")
        )
        .collect()[0]
    )

    result = df["row_count"] == df["n_unique"]

    return result

In [None]:
def construct_qualified_name(
    schema: str,
    prefix: str,
    source: str,
    name: str
) -> str:

    return f"{schema}.{prefix}__{source}__{name}"

In [None]:
def core_fields(h: dict) -> tuple:
    """Extract the fields that define hook identity."""
    return (h["name"], h["concept"], h["keyset"])

In [None]:
def build_edges_for_frame(table: str, hooks: list[dict]) -> list[dict]:
    """Produce unordered, deduped edges for a frame."""
    n = len(hooks)

    if n < 2:
        return []
        
    edges: list[dict] = []
    seen_pairs = set()
    total_pairs = comb(n, 2)
    hook_combinations = combinations(hooks, 2)

    for a, b in hook_combinations:
        ka, kb = core_fields(a), core_fields(b)
        pair_key = tuple(sorted((ka, kb)))
        if pair_key in seen_pairs:
            continue
        seen_pairs.add(pair_key)

        u, v = (a, b) if ka <= kb else (b, a)
        u_name, u_concept, u_keyset = core_fields(u)
        v_name, v_concept, v_keyset = core_fields(v)

        edges.append({
            "from_frame": table,
            "from_concept": u_concept,
            "from_keyset": u_keyset,
            "from_hook_name": u_name,
            "to_concept": v_concept,
            "to_keyset": v_keyset,
            "to_hook_name": v_name,
        })
    return edges

In [None]:
def build_nodes_for_frame(table: str, hooks: list[dict]) -> list[dict]:
    """Produce nodes for a frame."""
    nodes: list[dict] = []
    for h in hooks:
        name, concept, keyset = core_fields(h)
        if is_unique_key(table, name):
            nodes.append({
                "frame": table,
                "concept": concept,
                "keyset": keyset,
                "hook_name": name,
            })
    return nodes

In [None]:
def scan_frames(manifest_path: str) -> dict:
    frames = load_manifest(manifest_path)

    nodes = []
    edges = []

    for frame in tqdm(frames, desc="Scanning frames"):
        table = construct_qualified_name(
            FROM_SCHEMA, FROM_PREFIX, frame["source"], frame["name"]
        )

        hooks = frame["hooks"]

        edges.extend(build_edges_for_frame(table, hooks))
        nodes.extend(build_nodes_for_frame(table, hooks))
    
    return {"nodes": nodes, "edges": edges}

nodes_and_edges = scan_frames(MANIFEST_PATH)

In [None]:
def union_all(queries):
    return reduce(lambda a, b: a.union(b, distinct=False), queries)

In [None]:
def generate_node_sql(nodes: list[dict]) -> exp.Expression:
    queries = []

    for node in nodes:
        frame = node["frame"]
        concept = node["concept"]
        keyset = node["keyset"]
        hook_name = node["hook_name"]

        query = (
            exp.select(
                exp.Literal.string(frame).as_("frame"),
                exp.Literal.string(concept).as_("concept"),
                exp.Literal.string(keyset).as_("keyset"),
                exp.Literal.string(hook_name).as_("hook_name"),
                exp.column(hook_name).as_("hook_value")
            )
            .from_(frame)
        )

        queries.append(query)

        final_query = union_all(queries)

    return final_query

In [None]:
def generate_edge_sql(edges: list[dict]) -> exp.Expression:
    queries = []

    for edge in edges:
        from_frame = edge["from_frame"]
        from_concept = edge["from_concept"]
        from_keyset = edge["from_keyset"]
        from_hook_name = edge["from_hook_name"]
        to_concept = edge["to_concept"]
        to_keyset = edge["to_keyset"]
        to_hook_name = edge["to_hook_name"]

        query = (
            exp.select(
                exp.Literal.string(from_frame).as_("from_frame"),
                exp.Literal.string(from_concept).as_("from_concept"),
                exp.Literal.string(from_keyset).as_("from_keyset"),
                exp.Literal.string(from_hook_name).as_("from_hook_name"),
                exp.column(from_hook_name).as_("from_hook_value"),

                exp.Literal.string(to_concept).as_("to_concept"),
                exp.Literal.string(to_keyset).as_("to_keyset"),
                exp.Literal.string(to_hook_name).as_("to_hook_name"),
                exp.column(to_hook_name).as_("to_hook_value")
            )
            .from_(from_frame)
            .where(
                exp.and_(
                    exp.Is(this=exp.column(from_hook_name), expression=exp.Null()).not_(),
                    exp.Is(this=exp.column(to_hook_name), expression=exp.Null()).not_(),
                )
            )
        )

        queries.append(query)

        final_query = union_all(queries)

    return final_query

In [None]:
_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {TO_SCHEMA};")

In [None]:
node_query = generate_node_sql(nodes_and_edges["nodes"])
edge_query = generate_edge_sql(nodes_and_edges["edges"])

In [None]:
mlvs = {
    "nodes": {"query": node_query, "partition_by": ["frame"]},
    "edges": {"query": edge_query, "partition_by": ["from_frame"]},
}

for mlv, value in mlvs.items():
    query = value["query"]
    partition_by = ", ".join(value["partition_by"])

    spark_sql = query.sql(dialect="spark", identify=True, pretty=True)
    result = spark.sql(f"CREATE MATERIALIZED LAKE VIEW IF NOT EXISTS {TO_SCHEMA}.{mlv} PARTITIONED BY ({partition_by}) AS ({spark_sql});")

    display(result)