# Generate - Blueprints

## Prerequisites

In [None]:
!pip install -q sqlglot==27.20.0

In [None]:
import re
import typing as t

from datetime import datetime, timezone
from functools import reduce
from pyspark.sql import functions as F, Row
from pyspark.sql.dataframe import DataFrame
from sqlglot import exp, parse_one
from tqdm.auto import tqdm

In [None]:
MANIFEST_PATH = "metadata.manifest"

RAW_SCHEMA = "das__raw"
RAW_PREFIX = "raw"

HOOK_SCHEMA = "dab__hook"
HOOK_PREFIX = "frame"

GRAPH_SCHEMA = "dar__graph"
GRAPH_PREFIX = "graph"

USS_SCHEMA = "dar__uss"
USS_PREFIX = "uss"

## Helper Functions

In [None]:
def construct_qualified_name(
    schema: str,
    prefix: str,
    source: str,
    name: str
) -> str:

    return f"{schema}.{prefix}__{source}__{name}"

In [None]:
def load_manifest(table_path: str) -> DataFrame:
    df = spark.read.table(table_path).cache()
    _ = df.count()
    return df

In [None]:
def extract_active_select(mlv_name: str) -> exp.Expression | None:
    try:
        statement =  spark.sql(f"SHOW CREATE MATERIALIZED LAKE VIEW {mlv_name};").collect()[0][0]
    except:
        return None

    match = re.search(r"AS\s*\((.*)\)\s*$", statement, flags=re.DOTALL)

    if not match:
        return None

    select = match.group(1).strip()
    result = parse_one(select, dialect="spark")

    return result

In [None]:
def manage_mlv(
    select_statement: exp.Expression,
    MLV_Identifier: str
) -> None:
    active_select = extract_active_select(MLV_Identifier)

    # We need to parse the selects in order to compare correctly
    convert_to_spark_sql = lambda x: x.sql(
        dialect="spark",
        identify=True,
        pretty=True,
    )
    spark_sql = convert_to_spark_sql(select_statement)
    active_spark_sql = convert_to_spark_sql(active_select) if active_select else None
    is_unchanged = spark_sql == active_spark_sql

    # If the select is unchanged, we refresh the MLV
    if is_unchanged:
        result = spark.sql(f"REFRESH MATERIALIZED LAKE VIEW {MLV_Identifier };")
        
        result = result.unionByName(
            spark.createDataFrame(
                [Row(metric_name="change_type", metric_value="No Changes")]
            )
        )

        return result
    

    # If there is an active select, we need to backup the view before creation
    change_type = "Create new MLV"
    if active_select:
        utc_ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
        _ = spark.sql(f"ALTER MATERIALIZED LAKE VIEW {MLV_Identifier} RENAME TO {MLV_Identifier}__{utc_ts};")
        change_type = "Recreate MLV"

    result = spark.sql(f"CREATE MATERIALIZED LAKE VIEW {MLV_Identifier} AS ({spark_sql});")
    result = result.unionByName(
            spark.createDataFrame(
                [Row(metric_name="change_type", metric_value=change_type)]
            )
    )
    
    return result

In [None]:
def union_all(queries):
    return reduce(lambda a, b: a.union(b, distinct=False), queries)

In [None]:
def generate_blueprint_sql(
    manifest: DataFrame,
    sql_function: t.Callable,
    from_schema: str,
    from_prefix: str,
    to_schema: str,
    to_prefix: str,
    single_table: str | None = None,
) -> None:

    _ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {to_schema};")

    statements = []
    frames_to_generate = manifest.collect()
    for frame in tqdm(
        frames_to_generate,
        desc="Generating SQL Statements"
    ):

        frame_name = frame["name"]
        frame_source = frame["source"]
        frame_hooks = frame["hooks"]

        from_table = construct_qualified_name(
            from_schema, from_prefix, frame_source, frame_name
        )
        to_table = construct_qualified_name(
            to_schema, to_prefix, frame_source, frame_name
        )

        expression = sql_function(frame_hooks, from_table)

        if expression is not None:
            statement = {
                "table": to_table,
                "expression": expression
            }

            statements.append(statement)

    results = []
    
    if single_table:
        for _ in tqdm([None], desc=f"Creating MLV"):
            union_expr = union_all([s["expression"] for s in statements])
            result = manage_mlv(select_statement=union_expr, MLV_Identifier=single_table)
            results.append(result)

    else:
        for s in tqdm(statements, desc="Creating MLVs"):
            result = manage_mlv(select_statement=s["expression"], MLV_Identifier=s["table"])
            results.append(result)

    return combine_results(results)

In [None]:
def combine_results(dfs: list) -> DataFrame:
    tagged = [d.withColumn("id", F.lit(i)) for i, d in enumerate(dfs)]
    union = reduce(DataFrame.unionByName, tagged)
    pivot = union.groupby("id").pivot("metric_name").agg(F.first("metric_value", ignorenulls=True))
    final = pivot.drop("id")

    return final

## Blueprints

### HOOK

In [None]:
def generate_hook_expression(hook_dict: dict) -> exp.Expression:
    name = hook_dict["name"]
    keyset = hook_dict["keyset"]
    business_key_field = hook_dict["business_key_field"]

    key_lit = exp.Literal.string(f"{keyset}|")
    column = parse_one(business_key_field)

    cast = exp.Cast(this=column, to=exp.DataType.build("STRING"))
    trim = exp.Trim(this=cast)
    val = exp.Concat(
        expressions=[key_lit, trim]
    )

    condition = exp.Is(this=column, expression=exp.Null()).not_()
    expr = exp.Case().when(condition, val).as_(name)

    return expr.sql()

In [None]:
def generate_hook_sql(
    frame_hooks: list,
    from_table: str
) -> exp.Expression:
    
    hook_expressions = [generate_hook_expression(hook) for hook in frame_hooks]
    expression = exp.select(*hook_expressions, exp.Star()).from_(from_table)

    return expression

### Graph

In [None]:
def is_unique_key(table, key):
    df = (
        spark.read.table(table)
        .select(F.col(key))
        .agg(
            F.count(F.lit(1)).alias("row_count"),
            F.countDistinct(F.col(key)).alias("n_unique")
        )
        .collect()[0]
    )

    result = df["row_count"] == df["n_unique"]

    return result

#### Nodes

In [None]:
def scan_for_nodes(
    table: str,
    hooks: list[dict]
) -> list[dict]:

    nodes: list[dict] = []
    for hook in hooks:
        hook_name = hook["name"]
        concept = hook["concept"]
        keyset = hook["keyset"]

        if is_unique_key(table, hook_name):
            nodes.append({
                "frame": table,
                "concept": concept,
                "keyset": keyset,
                "hook_name": hook_name
            })

    return nodes

In [None]:
def generate_graph_node_sql(
    frame_hooks: list,
    from_table: str,
) -> exp.Expression | None:

    nodes = scan_for_nodes(from_table, frame_hooks)

    if not nodes:
        return None

    queries = []
    for node in nodes:
        frame = node["frame"]
        concept = node["concept"]
        keyset = node["keyset"]
        hook_name = node["hook_name"]

        query = (
            exp.select(
                exp.Literal.string(frame).as_("frame"),
                exp.Literal.string(concept).as_("concept"),
                exp.Literal.string(keyset).as_("keyset"),
                exp.Literal.string(hook_name).as_("hook_name"),
                exp.column(hook_name).as_("hook_value"),
                exp.func("to_json", exp.Struct(expressions=[exp.Star()])).as_("attributes")
            )
            .from_(frame)
        )

        queries.append(query)

    union_query = union_all(queries)
    
    return union_query


#### Edges

In [None]:
def generate_graph_edge_sql(
    frame_hooks: list,
    from_table: str
) -> exp.Expression:
    n = len(hooks)

    if n < 2:
        return []
        
    edges: list[dict] = []
    seen_pairs = set()
    total_pairs = comb(n, 2)
    hook_combinations = combinations(hooks, 2)

    for a, b in tqdm(hook_combinations, total=total_pairs, desc="Scanning for edges...", leave=False, position=2):
        ka, kb = core_fields(a), core_fields(b)
        pair_key = tuple(sorted((ka, kb)))
        if pair_key in seen_pairs:
            continue
        seen_pairs.add(pair_key)

        u, v = (a, b) if ka <= kb else (b, a)
        u_name, u_concept, u_keyset = core_fields(u)
        v_name, v_concept, v_keyset = core_fields(v)

        edges.append({
            "from_frame": table,
            "from_concept": u_concept,
            "from_keyset": u_keyset,
            "from_hook_name": u_name,
            "to_concept": v_concept,
            "to_keyset": v_keyset,
            "to_hook_name": v_name,
        })
        
    return edges

### Unified Star Schema

#### Puppini Bridge

In [None]:
def generate_uss_bridge_sql(
    frame_hooks: list,
    from_table: str
) -> exp.Expression:
    pass

#### Peripherals

In [None]:
def generate_uss_peripheral_sql(
    frame_hooks: list,
    from_table: str
) -> exp.Expression:
    pass

## Orchestrate

In [None]:
manifest = load_manifest(MANIFEST_PATH)

### HOOK

In [None]:
frames_to_generate = manifest.filter(F.col("generate") == True)

hook_results = generate_blueprint_sql(
    manifest=frames_to_generate,
    sql_function=generate_hook_sql,
    from_schema=RAW_SCHEMA,
    from_prefix=RAW_PREFIX,
    to_schema=HOOK_SCHEMA,
    to_prefix=HOOK_PREFIX,
)

display(hook_results)

### Graph

In [None]:
graph_node_results = generate_blueprint_sql(
    manifest=manifest,
    sql_function=generate_graph_node_sql,
    from_schema=HOOK_SCHEMA,
    from_prefix=HOOK_PREFIX,
    to_schema=GRAPH_SCHEMA,
    to_prefix=GRAPH_PREFIX,
    single_table=f"{GRAPH_SCHEMA}.nodes"
)

display(graph_node_results)

### Unified Star Schema