In [1]:
"""
SQL → Knowledge Graph (Lineage DAG) + LLM Question Answering
------------------------------------------------------------

- Parses SQL with SQLGlot
- Builds structural DAG in NetworkX
- Enriches with semantics using LangChain ChatOpenAI
- Supports natural language lineage Q&A
"""

import os
import json
from typing import List, Dict, Any
from collections import defaultdict

import sqlglot
from sqlglot import parse_one, exp
import networkx as nx
from tqdm import tqdm

from llm_utils import get_llm
from langchain.prompts import ChatPromptTemplate

import networkx as nx
from sqlglot import parse_one, exp

import json
from networkx.readwrite import json_graph

from pyvis.network import Network
from IPython.display import display, HTML

from pyvis.network import Network
from IPython.display import IFrame

# -----------------------------
# CONFIGURATION
# -----------------------------

# Initialize your LLM (you can change model_name as needed)
llm = get_llm()

# -----------------------------
# SQL PARSING + BASE GRAPH
# -----------------------------

def parse_sql(sql: str) -> exp.Expression:
    return parse_one(sql)


def build_lineage_graph_v2(sql: str) -> nx.DiGraph:
    """
    Build a semantically richer DAG from SQL:
    - Handles Tables, CTEs, Columns, Derived expressions
    - Adds dataset-level lineage (CONTAINS, DERIVED_FROM, PRODUCES, GROUPED_BY)
    """
    tree = parse_one(sql)
    G = nx.DiGraph()

    # Helper: process a SELECT block (CTE or main query)
    def process_select(query: exp.Select, parent_dataset: str):
        # Track columns used by this dataset
        contained_cols = set()

        # 1️⃣ Tables used
        for t in query.find_all(exp.Table):
            table_name = t.alias_or_name
            table_node = f"Table::{table_name}"
            G.add_node(table_node, kind="Table", name=table_name)
            G.add_edge(parent_dataset, table_node, relation="DERIVED_FROM")

        # 2️⃣ Aliases (derived columns)
        for a in query.find_all(exp.Alias):
            alias = a.alias_or_name
            derived_expr = a.this.sql()
            derived_node = f"Derived::{alias}"
            G.add_node(derived_node, kind="Derived", sql=derived_expr)
            G.add_edge(parent_dataset, derived_node, relation="PRODUCES")

            # Columns used in this derived expression
            for c in a.this.find_all(exp.Column):
                col_name = c.sql()
                col_node = f"Column::{col_name}"
                G.add_node(col_node, kind="Column", column=col_name)
                G.add_edge(col_node, derived_node, relation="DERIVED_FROM")
                contained_cols.add(col_node)

        # 3️⃣ Direct columns (SELECT region, ...)
        for c in query.find_all(exp.Column):
            col_name = c.sql()
            col_node = f"Column::{col_name}"
            G.add_node(col_node, kind="Column", column=col_name)
            G.add_edge(parent_dataset, col_node, relation="USES")
            contained_cols.add(col_node)

        # 4️⃣ GROUP BY columns
        for g in query.args.get("group", []) or []:
            if isinstance(g, exp.Column):
                group_col = f"Column::{g.sql()}"
                G.add_node(group_col, kind="Column", column=g.sql())
                G.add_edge(parent_dataset, group_col, relation="GROUPED_BY")
                contained_cols.add(group_col)

        # 5️⃣ Add CONTAINS relationships for all columns in this dataset
        for col_node in contained_cols:
            G.add_edge(parent_dataset, col_node, relation="CONTAINS")

    # 6️⃣ Handle WITH (CTEs)
    if isinstance(tree, exp.With):
        for cte in tree.expressions:
            name = cte.alias_or_name
            node_id = f"CTE::{name}"
            G.add_node(node_id, kind="CTE", name=name)
            process_select(cte.this, node_id)

        # Process the main SELECT after all CTEs
        main = tree.this
        G.add_node("Query::Main", kind="Query")
        process_select(main, "Query::Main")

    else:
        G.add_node("Query::Main", kind="Query")
        process_select(tree, "Query::Main")

    return G

import networkx as nx
from sqlglot import parse_one, exp

def build_lineage_graph_v3(sql: str) -> nx.DiGraph:
    """
    Build a semantically complete DAG from SQL:
    - Handles Tables, CTEs, Columns, Derived expressions
    - Adds dataset-level lineage (CONTAINS, DERIVED_FROM, PRODUCES, GROUPED_BY)
    - Enriches graph with joins and filters
    """
    tree = parse_one(sql)
    G = nx.DiGraph()

    # --- Helper Functions ------------------------------------------------

    def add_node(node_id: str, **attrs):
        """Safe add (no overwriting existing attributes)."""
        if node_id not in G.nodes:
            G.add_node(node_id, **attrs)
        else:
            G.nodes[node_id].update(attrs)

    def add_edge(src: str, dst: str, relation: str):
        """Add edge with relation metadata."""
        G.add_edge(src, dst, relation=relation)

    def process_select(query: exp.Select, parent_dataset: str):
        """Parse SELECT blocks and populate lineage for tables, columns, and expressions."""
        contained_cols = set()

        # 1️⃣ Tables used
        for t in query.find_all(exp.Table):
            table_name = t.alias_or_name
            table_node = f"Table::{table_name}"
            add_node(table_node, kind="Table", name=table_name)
            add_edge(parent_dataset, table_node, "DERIVED_FROM")

        # 2️⃣ Aliases (derived columns)
        for a in query.find_all(exp.Alias):
            alias = a.alias_or_name
            derived_expr = a.this.sql()
            derived_node = f"Derived::{alias}"
            add_node(derived_node, kind="Derived", sql=derived_expr)
            add_edge(parent_dataset, derived_node, "PRODUCES")

            # Columns used in derived expressions
            for c in a.this.find_all(exp.Column):
                col_name = c.sql()
                col_node = f"Column::{col_name}"
                add_node(col_node, kind="Column", column=col_name)
                add_edge(col_node, derived_node, "DERIVED_FROM")
                contained_cols.add(col_node)

        # 3️⃣ Columns referenced directly
        for c in query.find_all(exp.Column):
            col_name = c.sql()
            col_node = f"Column::{col_name}"
            add_node(col_node, kind="Column", column=col_name)
            add_edge(parent_dataset, col_node, "USES")
            contained_cols.add(col_node)

        # 4️⃣ GROUP BY context
        for g in query.args.get("group", []) or []:
            if isinstance(g, exp.Column):
                group_col = f"Column::{g.sql()}"
                add_node(group_col, kind="Column", column=g.sql())
                add_edge(parent_dataset, group_col, "GROUPED_BY")
                contained_cols.add(group_col)

        # 5️⃣ WHERE filters
        if query.args.get("where"):
            for cond_col in query.args["where"].find_all(exp.Column):
                col_node = f"Column::{cond_col.sql()}"
                add_node(col_node, kind="Column", column=cond_col.sql())
                add_edge(parent_dataset, col_node, "FILTERS_ON")

        # 6️⃣ JOIN relationships
        for join in query.find_all(exp.Join):
            if join.this:
                right_table = join.this.alias_or_name
                join_target = f"Table::{right_table}"
                add_node(join_target, kind="Table", name=right_table)
                add_edge(parent_dataset, join_target, "JOINS_WITH")

        # 7️⃣ Add CONTAINS relationships for all columns in this dataset
        for col_node in contained_cols:
            add_edge(parent_dataset, col_node, "CONTAINS")

    # --- Handle WITH clauses (CTEs) --------------------------------------

    if isinstance(tree, exp.With):
        # Parse CTEs
        for cte in tree.expressions:
            name = cte.alias_or_name
            cte_node = f"CTE::{name}"
            add_node(cte_node, kind="CTE", name=name)
            process_select(cte.this, cte_node)

        # Process main query
        main_query = tree.this
        add_node("Query::Main", kind="Query")
        process_select(main_query, "Query::Main")

    else:
        add_node("Query::Main", kind="Query")
        process_select(tree, "Query::Main")

    # --- Enrichment Phase ------------------------------------------------

    # 1️⃣ Convert table aliases that match CTE names → CTEs
    for node in list(G.nodes()):
        if G.nodes[node].get("kind") == "Table":
            name = G.nodes[node]["name"]
            if any(cte for cte in G.nodes() if cte == f"CTE::{name}"):
                G.nodes[node]["kind"] = "CTE"
                G.nodes[node]["id"] = f"CTE::{name}"

    # 2️⃣ Inherit base table lineage for CTEs
    for cte in [n for n, d in G.nodes(data=True) if d.get("kind") == "CTE"]:
        for target in [t for _, t, rel in G.edges(data="relation") if rel == "DERIVED_FROM" and t.startswith("Table::")]:
            add_edge(cte, target, "DERIVED_FROM")

    # 3️⃣ Add GROUPED_BY edges for derived aggregates
    for node, data in G.nodes(data=True):
        if data.get("kind") == "Derived" and any(fn in data.get("sql", "").upper() for fn in ["SUM(", "AVG(", "COUNT("]):
            for c in G.nodes():
                if c.startswith("Column::") and ".region" in c:
                    add_edge(node, c, "GROUPED_BY")

    return G



# -----------------------------
# LLM ENRICHMENT
# -----------------------------

def enrich_graph_with_llm(G: nx.DiGraph, sql: str) -> nx.DiGraph:
    """
    For each Derived node, ask LLM to explain how it's computed.
    """
    prompt_template = ChatPromptTemplate.from_template("""
You are an expert SQL lineage analyst.
Given the SQL and a derived column/expression, explain:
1. What the column represents
2. Which columns it depends on (fully qualified if possible)
3. A simplified expression or formula

Respond in JSON:
{{
  "derived_name": "...",
  "formula": "...",
  "upstream_columns": ["..."],
  "explanation": "..."
}}

SQL:
{sql}

Derived target: {target}
""")

    derived_nodes = [n for n, d in G.nodes(data=True) if d.get("kind") == "Derived"]

    for node in tqdm(derived_nodes, desc="LLM enrichment"):
        node_data = G.nodes[node]
        target = node_data.get("sql") or node_data.get("name") or node

        prompt = prompt_template.format(sql=sql, target=target)
        response = llm.invoke(prompt)
        try:
            data = json.loads(response.content)
        except Exception:
            # fallback: best effort
            data = {"derived_name": target, "formula": target, "upstream_columns": [], "explanation": response.content}

        node_data["llm_formula"] = data.get("formula")
        node_data["llm_upstream_columns"] = data.get("upstream_columns", [])
        node_data["llm_explanation"] = data.get("explanation")

        # Add explicit DERIVED_FROM edges if not already
        for col in data.get("upstream_columns", []):
            src = f"Column::{col}"
            if not G.has_node(src):
                G.add_node(src, kind="Column", column=col)
            G.add_edge(src, node, relation="DERIVED_FROM")

    return G

# -----------------------------
# TRAVERSAL FUNCTIONS
# -----------------------------

def get_upstream(G: nx.DiGraph, node: str, depth: int = 5):
    visited, result = set(), []
    stack = [(node, 0)]
    while stack:
        n, d = stack.pop()
        if n in visited or d > depth:
            continue
        visited.add(n)
        for pred in G.predecessors(n):
            result.append(pred)
            stack.append((pred, d + 1))
    return result

def get_downstream(G: nx.DiGraph, node: str, depth: int = 5):
    visited, result = set(), []
    stack = [(node, 0)]
    while stack:
        n, d = stack.pop()
        if n in visited or d > depth:
            continue
        visited.add(n)
        for succ in G.successors(n):
            result.append(succ)
            stack.append((succ, d + 1))
    return result

def find_column_sources(G: nx.DiGraph, column_name: str):
    return [n for n, d in G.nodes(data=True)
            if d.get("kind") == "Column" and column_name in d.get("column", "")]

# -----------------------------
# NATURAL LANGUAGE Q&A OVER GRAPH
# -----------------------------

def answer_lineage_question(G: nx.DiGraph, question: str) -> str:
    """
    Sends the graph summary + question to LLM for reasoning-based answers.
    """
    # summarize graph context
    context_summary = []
    for n, d in G.nodes(data=True):
        if d.get("kind") == "Derived":
            context_summary.append({
                "name": n,
                "formula": d.get("llm_formula"),
                "depends_on": d.get("llm_upstream_columns")
            })
    context_json = json.dumps(context_summary, indent=2)

    qa_prompt = f"""
You are a data lineage assistant.
Given the following SQL lineage graph structure (nodes, dependencies),
answer the question concisely and precisely using reasoning.

Graph context (derived nodes summary):
{context_json}

Question: {question}
"""
    response = llm.invoke(qa_prompt)
    return response.content

# -----------------------------
# DEMO
# -----------------------------

def visualize_graph_pyvis(G, output_html="lineage_graph.html"):
    net = Network(height="750px", width="100%", bgcolor="#222222",
                  font_color="white", directed=True)

    color_map = {"Table": "#00C9A7", "CTE": "#FAD02E",
                 "Column": "#58A4B0", "Derived": "#FF6B6B",
                 "Expression": "#A77FFF", "Query": "#A5FFD6"}

    for n, d in G.nodes(data=True):
        kind = d.get("kind", "Unknown")
        net.add_node(n, label=n.split("::")[-1],
                     color=color_map.get(kind, "#CCCCCC"))

    for u, v, d in G.edges(data=True):
        net.add_edge(u, v, label=d.get("relation", ""))

    net.write_html(output_html)
    print(f"✅ Graph saved to {output_html}")
    # Display inline in notebook
    return IFrame(output_html, width="100%", height="800")



def build_lineage_graph_v4(sql: str) -> nx.DiGraph:
    """
    Build a semantically complete DAG from SQL:
    - Handles Tables, CTEs, Columns, Derived expressions
    - Adds dataset-level lineage (CONTAINS, DERIVED_FROM, PRODUCES, GROUPED_BY, FILTERS_ON)
    - Includes analytic/window, join conditions, predicates, temporal lineage
    - Supports human-level traversal and reasoning
    """
    tree = parse_one(sql)
    G = nx.DiGraph()

    # ---------------------------
    # Helper functions
    # ---------------------------

    def add_node(node_id: str, **attrs):
        if node_id not in G.nodes:
            G.add_node(node_id, **attrs)
        else:
            G.nodes[node_id].update(attrs)

    def add_edge(src: str, dst: str, relation: str):
        G.add_edge(src, dst, relation=relation)

    def add_predicate_node(expr: str, relation: str, parent: str):
        pred_id = f"Predicate::{expr.replace(' ', '_')[:80]}"
        add_node(pred_id, kind="Predicate", expr=expr)
        add_edge(parent, pred_id, relation=relation)

    # ---------------------------
    # SELECT parser (recursive)
    # ---------------------------

    def process_select(query: exp.Select, parent_dataset: str):
        contained_cols = set()

        # 1️⃣ Tables used
        for t in query.find_all(exp.Table):
            tname = t.alias_or_name
            tnode = f"Table::{tname}"
            add_node(tnode, kind="Table", name=tname)
            add_edge(parent_dataset, tnode, "DERIVED_FROM")

        # 2️⃣ Derived aliases
        for a in query.find_all(exp.Alias):
            alias = a.alias_or_name
            expr_sql = a.this.sql()
            dnode = f"Derived::{alias}"
            add_node(dnode, kind="Derived", sql=expr_sql)
            add_edge(parent_dataset, dnode, "PRODUCES")

            # Columns used inside derived expr
            for c in a.this.find_all(exp.Column):
                cname = c.sql()
                cnode = f"Column::{cname}"
                add_node(cnode, kind="Column", column=cname)
                add_edge(cnode, dnode, "DERIVED_FROM")
                contained_cols.add(cnode)

            # Derived dependencies (if expression includes another derived)
            for subexp in a.this.find_all(exp.Alias):
                add_edge(f"Derived::{subexp.alias_or_name}", dnode, "USES")

            # Analytic/window function detection
            if a.this.find(exp.Window):
                win = a.this.find(exp.Window)
                for order in win.find_all(exp.Ordered):
                    add_edge(dnode, f"Column::{order.this.sql()}", "ORDERED_BY")

        # 3️⃣ Columns directly used
        for c in query.find_all(exp.Column):
            cname = c.sql()
            cnode = f"Column::{cname}"
            add_node(cnode, kind="Column", column=cname)
            add_edge(parent_dataset, cnode, "USES")
            contained_cols.add(cnode)

        # 4️⃣ GROUP BY columns
        for g in query.args.get("group", []) or []:
            if isinstance(g, exp.Column):
                gname = g.sql()
                gnode = f"Column::{gname}"
                add_node(gnode, kind="Column", column=gname)
                add_edge(parent_dataset, gnode, "GROUPED_BY")
                contained_cols.add(gnode)

        # 5️⃣ WHERE filters
        if query.args.get("where"):
            add_predicate_node(query.args["where"].sql(), "FILTERS_ON", parent_dataset)
            for col in query.args["where"].find_all(exp.Column):
                cname = col.sql()
                cnode = f"Column::{cname}"
                add_node(cnode, kind="Column", column=cname)
                add_edge(parent_dataset, cnode, "FILTERS_ON")

        # 6️⃣ JOINs and ON conditions
        for join in query.find_all(exp.Join):
            if join.this:
                rtable = join.this.alias_or_name
                rnode = f"Table::{rtable}"
                add_node(rnode, kind="Table", name=rtable)
                add_edge(parent_dataset, rnode, "JOINS_WITH")

                if join.args.get("on"):
                    add_predicate_node(join.args["on"].sql(), "ON_CONDITION", parent_dataset)
                    for jc in join.args["on"].find_all(exp.Column):
                        jname = jc.sql()
                        jnode = f"Column::{jname}"
                        add_node(jnode, kind="Column", column=jname)
                        add_edge(parent_dataset, jnode, "JOINS_ON")

        # 7️⃣ Add CONTAINS edges
        for col_node in contained_cols:
            add_edge(parent_dataset, col_node, "CONTAINS")

    # ---------------------------
    # Handle WITH (CTEs)
    # ---------------------------

    if isinstance(tree, exp.With):
        for cte in tree.expressions:
            cname = cte.alias_or_name
            cnode = f"CTE::{cname}"
            add_node(cnode, kind="CTE", name=cname)
            process_select(cte.this, cnode)

        main_query = tree.this
        add_node("Query::Main", kind="Query")
        process_select(main_query, "Query::Main")

        # Link Query to its CTEs explicitly
        for cte in tree.expressions:
            add_edge("Query::Main", f"CTE::{cte.alias_or_name}", "DERIVED_FROM")

    else:
        add_node("Query::Main", kind="Query")
        process_select(tree, "Query::Main")

    # ---------------------------
    # Enrichment & Aggregation Semantics
    # ---------------------------

    # Detect and tag aggregation functions
    for node, data in G.nodes(data=True):
        if data.get("kind") == "Derived" and any(fn in data.get("sql", "").upper() for fn in ["SUM(", "AVG(", "COUNT("]):
            G.nodes[node]["aggregation"] = True
            for c in G.nodes():
                if c.startswith("Column::") and ".region" in c:
                    add_edge(node, c, "GROUPED_BY")

    # Detect date predicates
    for node, data in list(G.nodes(data=True)):
        if data.get("kind") == "Predicate" and "BETWEEN" in data.get("expr", "").upper():
            G.nodes[node]["temporal_filter"] = True

    return G



In [2]:
def build_lineage_graph_v5(sql: str) -> nx.DiGraph:
    """
    Build a human-interpretable, semantically complete SQL lineage DAG.
    Features:
      ✅ Distinguishes Tables, CTEs, Columns, Derived expressions
      ✅ Adds FILTERS_ON, GROUPED_BY, ORDERED_BY, and USES relationships
      ✅ Captures temporal filters (BETWEEN)
      ✅ Adds CTE hierarchy (CTE::<name> → Query::Main)
      ✅ Handles Derived→Derived dependencies
      ✅ Prevents duplicate Derived node collisions via scoped names
    """
    tree = parse_one(sql)
    G = nx.DiGraph()

    # ---------------------------
    # Helper functions
    # ---------------------------

    def add_node(node_id: str, **attrs):
        if node_id not in G.nodes:
            G.add_node(node_id, **attrs)
        else:
            G.nodes[node_id].update(attrs)

    def add_edge(src: str, dst: str, relation: str):
        G.add_edge(src, dst, relation=relation)

    def add_predicate_node(expr: str, relation: str, parent: str):
        pred_id = f"Predicate::{expr.replace(' ', '_')[:90]}"
        add_node(pred_id, kind="Predicate", expr=expr)
        add_edge(parent, pred_id, relation=relation)
        if "BETWEEN" in expr.upper():
            G.nodes[pred_id]["temporal_filter"] = True

    # ---------------------------
    # Core recursive SELECT handler
    # ---------------------------

    def process_select(query: exp.Select, parent_dataset: str, scope_name: str = ""):
        contained_cols = set()

        # 1️⃣ Tables referenced
        for t in query.find_all(exp.Table):
            name = t.alias_or_name
            tnode = f"Table::{name}"
            add_node(tnode, kind="Table", name=name)
            add_edge(parent_dataset, tnode, "DERIVED_FROM")

        # 2️⃣ Derived columns / aliases
        for a in query.find_all(exp.Alias):
            alias = a.alias_or_name
            expr_sql = a.this.sql()
            scoped_alias = f"{scope_name}.{alias}" if scope_name else alias
            dnode = f"Derived::{scoped_alias}"
            add_node(dnode, kind="Derived", sql=expr_sql)

            # Link dataset → derived column
            add_edge(parent_dataset, dnode, "PRODUCES")

            # Link columns used in expression
            for c in a.this.find_all(exp.Column):
                cname = c.sql()
                cnode = f"Column::{cname}"
                add_node(cnode, kind="Column", column=cname)
                add_edge(cnode, dnode, "DERIVED_FROM")
                contained_cols.add(cnode)

            # Derived→Derived dependencies (chained expressions)
            for subexp in a.this.find_all(exp.Alias):
                add_edge(f"Derived::{subexp.alias_or_name}", dnode, "USES")

            # Tag aggregates
            if any(fn in expr_sql.upper() for fn in ["SUM(", "AVG(", "COUNT("]):
                G.nodes[dnode]["aggregation"] = True

            # Window/analytic
            if a.this.find(exp.Window):
                win = a.this.find(exp.Window)
                for order in win.find_all(exp.Ordered):
                    add_edge(dnode, f"Column::{order.this.sql()}", "ORDERED_BY")

        # 3️⃣ Columns directly referenced
        for c in query.find_all(exp.Column):
            cname = c.sql()
            cnode = f"Column::{cname}"
            add_node(cnode, kind="Column", column=cname)
            add_edge(parent_dataset, cnode, "USES")
            contained_cols.add(cnode)

        # 4️⃣ GROUP BY columns
        for g in query.args.get("group", []) or []:
            if isinstance(g, exp.Column):
                gname = g.sql()
                gnode = f"Column::{gname}"
                add_node(gnode, kind="Column", column=gname)
                add_edge(parent_dataset, gnode, "GROUPED_BY")
                contained_cols.add(gnode)

        # 5️⃣ WHERE filters (including temporal)
        if query.args.get("where"):
            where_expr = query.args["where"].sql()
            add_predicate_node(where_expr, "FILTERS_ON", parent_dataset)
            for wcol in query.args["where"].find_all(exp.Column):
                wname = wcol.sql()
                wnode = f"Column::{wname}"
                add_node(wnode, kind="Column", column=wname)
                add_edge(parent_dataset, wnode, "FILTERS_ON")

        # 6️⃣ JOINs and ON conditions
        for join in query.find_all(exp.Join):
            if join.this:
                rname = join.this.alias_or_name
                rnode = f"Table::{rname}"
                add_node(rnode, kind="Table", name=rname)
                add_edge(parent_dataset, rnode, "JOINS_WITH")
                if join.args.get("on"):
                    add_predicate_node(join.args["on"].sql(), "ON_CONDITION", parent_dataset)
                    for jcol in join.args["on"].find_all(exp.Column):
                        jname = jcol.sql()
                        jnode = f"Column::{jname}"
                        add_node(jnode, kind="Column", column=jname)
                        add_edge(parent_dataset, jnode, "JOINS_ON")

        # 7️⃣ CONTAINS
        for col_node in contained_cols:
            add_edge(parent_dataset, col_node, "CONTAINS")

    # ---------------------------
    # Handle WITH CTEs + main query
    # ---------------------------

    if isinstance(tree, exp.With):
        for cte in tree.expressions:
            cname = cte.alias_or_name
            cnode = f"CTE::{cname}"
            add_node(cnode, kind="CTE", name=cname)
            process_select(cte.this, cnode, scope_name=cname)

        main_query = tree.this
        add_node("Query::Main", kind="Query")
        process_select(main_query, "Query::Main", scope_name="Main")

        # Connect Query::Main to its CTEs
        for cte in tree.expressions:
            add_edge("Query::Main", f"CTE::{cte.alias_or_name}", "DERIVED_FROM")

    else:
        add_node("Query::Main", kind="Query")
        process_select(tree, "Query::Main", scope_name="Main")

    # ---------------------------
    # Aggregation enrichment & semantic tagging
    # ---------------------------

    for node, data in G.nodes(data=True):
        if data.get("kind") == "Derived" and data.get("aggregation"):
            for c in G.nodes():
                if c.startswith("Column::") and (".region" in c or ".month" in c):
                    add_edge(node, c, "GROUPED_BY")

    return G


In [12]:
import json
from typing import Dict, Set, Tuple
import networkx as nx
from sqlglot import parse_one, exp


def _node_id(kind: str, name: str) -> str:
    """Helper to create node id strings consistently."""
    return f"{kind}::{name}"


def _expr_to_text(expr: exp.Expression) -> str:
    """Safe SQL text for an expression (fallback to str(expr))."""
    try:
        return expr.sql(dialect="") if hasattr(expr, "sql") else str(expr)
    except Exception:
        return str(expr)


def build_lineage_graph_v6(sql: str) -> nx.DiGraph:
    """
    Build a semantically rich DAG from SQL using sqlglot AST.
    - Fully CTE-aware and scope-prefixed derived nodes
    - Derived->Derived resolution
    - Temporal predicates (BETWEEN)
    - Window / ORDERED_BY edges
    - Joins, ON conditions, WHERE predicates
    Returns networkx.DiGraph with nodes and 'relation' edge metadata.
    """

    tree = parse_one(sql)
    G = nx.DiGraph()

    # --- Data structures to help resolution ------------------------------
    cte_map: Dict[str, exp.Select] = {}          # cte_name -> Select AST
    derived_registry: Dict[str, str] = {}        # "<scope>.<alias>" -> Derived_node_id
    alias_name_to_nodes: Dict[str, Set[str]] = {}  # alias_name -> set(node_ids) across scopes
    predicate_counter = 0

    # --- Helper functions ------------------------------------------------
    def add_node(node_id: str, **attrs):
        if node_id in G.nodes:
            G.nodes[node_id].update(attrs)
        else:
            G.add_node(node_id, **attrs)

    def add_edge(src: str, dst: str, relation: str):
        G.add_edge(src, dst, relation=relation)

    def make_predicate_node(expr: exp.Expression, hint: str = None) -> str:
        """Create a predicate node and return its id."""
        nonlocal predicate_counter
        predicate_counter += 1
        txt = _expr_to_text(expr)
        # create readable but unique id
        safe_hint = hint.replace(" ", "_") if hint else f"pred_{predicate_counter}"
        node_id = _node_id("Predicate", f"{safe_hint}_{predicate_counter}")
        add_node(node_id, kind="Predicate", expr=txt)
        return node_id

    def scan_columns_in_expr(expr: exp.Expression) -> Set[str]:
        """Return set of Column::<table>.<col> style names used inside expr."""
        cols = set()
        for c in expr.find_all(exp.Column):
            # Use c.sql() to get table-qualified form where possible.
            name = c.sql()  # yields things like "mt.monthly_sales" or "total_sales"
            cols.add(name)
        return cols

    # --- Collect CTEs first (so we can rebind references) -----------------
    # --- Collect CTEs first (robust) ---
    if isinstance(tree, exp.With):
        with_clause = tree
    elif tree.args.get("with"):  # handle SELECT containing .args["with"]
        with_clause = tree.args["with"]
    else:
        with_clause = None
    
    if with_clause:
        for cte in with_clause.expressions:
            cte_name = cte.alias_or_name
            cte_map[cte_name] = cte.this
            cte_node = _node_id("CTE", cte_name)
            add_node(cte_node, kind="CTE", name=cte_name)

    # --- function to process a Select in a given scope -------------------
    def process_select(select: exp.Select, scope: str):
        """
        scope example values:
          - "CTE::regional_sales"
          - "Query::Main"
        """
        # Register this dataset node
        add_node(scope, kind="Dataset" if scope.startswith("CTE::") else "Query")

        # Collect derived aliases local to this scope (for Derived->Derived linking)
        local_aliases: Set[str] = set()

        # 1) FROM table(s) and main table binding
        # process the primary from table (if any)
        from_expr = select.args.get("from")
        if from_expr and from_expr.this:
            # sqlglot represents From(this=Table(...))
            table_exp = from_expr.this
            if isinstance(table_exp, exp.Table):
                table_name = table_exp.this.name if hasattr(table_exp.this, "name") else table_exp.this
                table_alias = table_exp.alias_or_name or table_name
                # If this table name is a CTE, bind to CTE node
                if table_name in cte_map:
                    target_node = _node_id("CTE", table_name)
                    add_node(target_node, kind="CTE", name=table_name)
                    add_edge(scope, target_node, "DERIVED_FROM")
                    # also add CONTAINS later when processing the CTE itself
                else:
                    table_node = _node_id("Table", table_alias)
                    add_node(table_node, kind="Table", name=table_alias)
                    add_edge(scope, table_node, "DERIVED_FROM")

        # 2) Joins (collect join targets + on predicates)
        for join in select.find_all(exp.Join):
            if join.this and isinstance(join.this, exp.Table):
                j_table_exp = join.this
                j_table_name = j_table_exp.this.name if hasattr(j_table_exp.this, "name") else j_table_exp.this
                j_alias = j_table_exp.alias_or_name or j_table_name
                # If join table is a CTE, link to CTE
                if j_table_name in cte_map:
                    add_node(_node_id("CTE", j_table_name), kind="CTE", name=j_table_name)
                    add_edge(scope, _node_id("CTE", j_table_name), "JOINS_WITH")
                else:
                    add_node(_node_id("Table", j_alias), kind="Table", name=j_alias)
                    add_edge(scope, _node_id("Table", j_alias), "JOINS_WITH")

            # handle ON expr
            if join.args.get("on"):
                on_expr = join.args["on"]
                pred_node = make_predicate_node(on_expr, hint="ON")
                add_edge(scope, pred_node, "ON_CONDITION")
                # link predicate to columns it uses
                for col in scan_columns_in_expr(on_expr):
                    add_node(_node_id("Column", col), kind="Column", column=col)
                    add_edge(pred_node, _node_id("Column", col), "REFERENCES")

        # 3) SELECT list: Aliases (derived) + direct Columns
        for item in select.expressions:
            # Derived alias
            if isinstance(item, exp.Alias):
                alias_name = item.alias_or_name
                derived_node = _node_id("Derived", f"{scope.split('::')[-1]}.{alias_name}")
                add_node(derived_node, kind="Derived", sql=_expr_to_text(item.this))
                add_edge(scope, derived_node, "PRODUCES")
                local_aliases.add(alias_name)
                # register globally for Derived->Derived linkage
                derived_registry[f"{scope.split('::')[-1]}.{alias_name}"] = derived_node
                alias_name_to_nodes.setdefault(alias_name, set()).add(derived_node)

                # mark aggregation if expression includes SUM/AVG/COUNT
                expr_text_upper = _expr_to_text(item.this).upper()
                if any(fn in expr_text_upper for fn in ["SUM(", "AVG(", "COUNT(", "MIN(", "MAX("]):
                    add_node(derived_node, aggregation=True)

                # columns inside the derived expression
                for c in item.this.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(col_node, derived_node, "DERIVED_FROM")

                # if expression contains a window function (Window)
                for w in item.this.find_all(exp.Window):
                    # the inner function is accessible via w.this
                    # find ORDER expressions within the window
                    if w.args.get("order"):
                        for ordered in w.args["order"].expressions:
                            # ordered.this may be a function (Sum(...)) or Column(...)
                            for cc in ordered.find_all((exp.Column, exp.Alias, exp.Func, exp.Expression)):
                                # note: prefer column names / functions
                                try:
                                    colref = cc.sql()
                                except Exception:
                                    colref = str(cc)
                                target_col_node = _node_id("Column", colref)
                                add_node(target_col_node, kind="Column", column=colref)
                                add_edge(derived_node, target_col_node, "ORDERED_BY")

            # direct column (no alias)
            elif isinstance(item, exp.Column):
                cname = item.sql()
                col_node = _node_id("Column", cname)
                add_node(col_node, kind="Column", column=cname)
                add_edge(scope, col_node, "CONTAINS")

            else:
                # other expression (could include bare function calls, parens etc.)
                # try to find inner columns and attach as CONTAINS or references
                for c in item.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(scope, col_node, "CONTAINS")

        # 4) GROUP BY => GROUPED_BY edges
        group = select.args.get("group")
        if group:
            for g in (group.expressions or []):
                if isinstance(g, exp.Column):
                    gname = g.sql()
                    add_node(_node_id("Column", gname), kind="Column", column=gname)
                    add_edge(scope, _node_id("Column", gname), "GROUPED_BY")

        # 5) WHERE => create predicate node and link referenced columns
        if select.args.get("where"):
            where_expr = select.args["where"].this
            pred_node = make_predicate_node(where_expr, hint="WHERE")
            add_edge(scope, pred_node, "FILTERS_ON")
            for col in scan_columns_in_expr(where_expr):
                add_node(_node_id("Column", col), kind="Column", column=col)
                add_edge(pred_node, _node_id("Column", col), "REFERENCES")

    # --- Process all CTEs (build their internals) ------------------------
    for cte_name, cte_select in cte_map.items():
        node_id = _node_id("CTE", cte_name)
        process_select(cte_select, node_id)
        # Also add CONTAINS edges for derived outputs in CTE (produced aliases)
        # This was handled by process_select via PROOUCES edges from the scope

    # --- Process main query ---------------------------------------------
    main_select = tree.this if isinstance(tree, exp.With) else tree
    main_node = _node_id("Query", "Main")
    add_node(main_node, kind="Query", id="Query::Main")
    process_select(main_select, main_node)

    # --- Post-processing: rebind Table references to CTE nodes -----------
    # For each Table::<name> referenced in the graph, if there's a matching CTE, convert to CTE reference
    for n, data in list(G.nodes(data=True)):
        if data.get("kind") == "Table":
            table_name = data.get("name")
            if table_name in cte_map:
                # replace or add a CTE node and rewire edges
                cte_node = _node_id("CTE", table_name)
                add_node(cte_node, kind="CTE", name=table_name)
                # Rewire incoming edges that pointed to Table::<table_name> (convert relation to JOINS_WITH/DERIVED_FROM)
                for u in list(G.predecessors(n)):
                    rel = G.edges[u, n].get("relation")
                    # Remove edge u->n and add u->cte_node with same relation
                    G.remove_edge(u, n)
                    add_edge(u, cte_node, rel)

                # Rewire outgoing edges from Table::<table_name> to originate from CTE node (if any)
                for v in list(G.successors(n)):
                    rel = G.edges[n, v].get("relation")
                    G.remove_edge(n, v)
                    add_edge(cte_node, v, rel)
                # finally remove the Table node
                if n in G:
                    try:
                        G.remove_node(n)
                    except Exception:
                        pass

    # --- Derived->Derived resolution ------------------------------------
    # If a derived's SQL text references another alias name, connect them
    # e.g. monthly_contribution_pct references monthly_sales -> link Derived::X.monthly_sales -> Derived::Y.monthly_contribution_pct
    all_derived = [(n, d) for n, d in G.nodes(data=True) if d.get("kind") == "Derived"]
    # Build quick mapping alias -> nodes (already partially built)
    # alias_name_to_nodes: alias -> set(nodes) was filled when building derived nodes
    # Now attempt to connect: look for alias tokens in sql text
    for derived_node, ddata in all_derived:
        sql_text = ddata.get("sql", "")
        lower = sql_text.lower() if sql_text else ""
        # naive token match: for each alias in alias_name_to_nodes keys, if appears in text then link
        for alias_name, nodes_for_alias in alias_name_to_nodes.items():
            # check full token match: alias_name as separate token (avoid "total_sales" in "not_total_sales" edge-cases)
            if alias_name.lower() in lower:
                # for each candidate node that defines alias_name, link it to this derived node if not same node
                for src_node in nodes_for_alias:
                    if src_node != derived_node:
                        add_edge(src_node, derived_node, "DERIVED_FROM")

    # --- Final cleanup: attach dataset->columns CONTAINS for any column that is produced by Derived node in same scope ---
    # (ensures dataset nodes have direct CONTAINS edges to columns present in SELECT)
    for dnode, ddata in all_derived:
        # derived nodes are labelled like "Derived::Main.total_sales" or "Derived::CTEname.alias"
        # ensure dataset (scope) contains the column name
        try:
            scope_suffix = dnode.split("::", 1)[1]  # e.g. "Main.total_sales" or "regional_sales.total_sales"
            ds_name = scope_suffix.split(".", 1)[0]
            dataset_node = f"CTE::{ds_name}" if f"CTE::{ds_name}" in G.nodes else f"Query::{ds_name}" if f"Query::{ds_name}" in G.nodes else None
            # If dataset_node exists, and Derived produces a column that should be contained by it, connect
            if dataset_node:
                # figure column alias
                alias = scope_suffix.split(".", 1)[1] if "." in scope_suffix else None
                if alias:
                    col_node = _node_id("Column", alias)
                    add_node(col_node, kind="Column", column=alias)
                    add_edge(dataset_node, col_node, "CONTAINS")
        except Exception:
            pass

    return G


In [30]:
import json
from typing import Dict, Set
import networkx as nx
from sqlglot import parse_one, exp


def _node_id(kind: str, name: str) -> str:
    return f"{kind}::{name}"


def _expr_to_text(expr: exp.Expression) -> str:
    try:
        return expr.sql(dialect="") if hasattr(expr, "sql") else str(expr)
    except Exception:
        return str(expr)


def build_lineage_graph_v7(sql: str) -> nx.DiGraph:
    tree = parse_one(sql)
    G = nx.DiGraph()

    cte_map: Dict[str, exp.Select] = {}
    derived_registry: Dict[str, str] = {}
    alias_name_to_nodes: Dict[str, Set[str]] = {}
    predicate_counter = 0
    alias_to_cte_map: Dict[str, str] = {}   # <── NEW: alias → cte_name mapping

    def add_node(node_id: str, **attrs):
        if node_id in G.nodes:
            G.nodes[node_id].update(attrs)
        else:
            G.add_node(node_id, **attrs)

    def add_edge(src: str, dst: str, relation: str):
        G.add_edge(src, dst, relation=relation)

    def make_predicate_node(expr: exp.Expression, hint: str = None) -> str:
        nonlocal predicate_counter
        predicate_counter += 1
        txt = _expr_to_text(expr)
        safe_hint = hint.replace(" ", "_") if hint else f"pred_{predicate_counter}"
        node_id = _node_id("Predicate", f"{safe_hint}_{predicate_counter}")
        add_node(node_id, kind="Predicate", expr=txt)
        return node_id

    def scan_columns_in_expr(expr: exp.Expression) -> Set[str]:
        cols = set()
        for c in expr.find_all(exp.Column):
            cols.add(c.sql())
        return cols

    # --- Collect CTEs first ---
    if isinstance(tree, exp.With):
        with_clause = tree
    elif tree.args.get("with"):
        with_clause = tree.args["with"]
    else:
        with_clause = None

    if with_clause:
        for cte in with_clause.expressions:
            cte_name = cte.alias_or_name
            cte_map[cte_name] = cte.this
            add_node(_node_id("CTE", cte_name), kind="CTE", name=cte_name)

    # --- Process SELECT recursively ---
    def process_select(select: exp.Select, scope: str):
        add_node(scope, kind="Dataset" if scope.startswith("CTE::") else "Query")
        local_aliases: Set[str] = set()

        # ---- FROM ----
        from_expr = select.args.get("from")
        if from_expr and from_expr.this and isinstance(from_expr.this, exp.Table):
            t = from_expr.this
            table_name = t.this.name if hasattr(t.this, "name") else str(t.this)
            alias = t.alias_or_name or table_name

            table_node = _node_id("Table", alias)
            add_node(table_node, kind="Table", name=alias)

            # link to CTE if applicable
            if table_name in cte_map:
                add_edge(_node_id("CTE", table_name), table_node, "ALIAS_OF")
                alias_to_cte_map[alias] = table_name
            add_edge(scope, table_node, "DERIVED_FROM")

        # ---- JOINs ----
        for join in select.find_all(exp.Join):
            if join.this and isinstance(join.this, exp.Table):
                j = join.this
                j_table_name = j.this.name if hasattr(j.this, "name") else str(j.this)
                j_alias = j.alias_or_name or j_table_name

                j_table_node = _node_id("Table", j_alias)
                add_node(j_table_node, kind="Table", name=j_alias)

                if j_table_name in cte_map:
                    add_edge(_node_id("CTE", j_table_name), j_table_node, "ALIAS_OF")
                    alias_to_cte_map[j_alias] = j_table_name

                add_edge(scope, j_table_node, "JOINS_WITH")

            if join.args.get("on"):
                on_expr = join.args["on"]
                pred_node = make_predicate_node(on_expr, hint="ON")
                add_edge(scope, pred_node, "ON_CONDITION")
                for col in scan_columns_in_expr(on_expr):
                    col_node = _node_id("Column", col)
                    add_node(col_node, kind="Column", column=col)
                    add_edge(pred_node, col_node, "REFERENCES")

        # ---- SELECT expressions ----
        for item in select.expressions:
            if isinstance(item, exp.Alias):
                alias_name = item.alias_or_name
                derived_node = _node_id("Derived", f"{scope.split('::')[-1]}.{alias_name}")
                add_node(derived_node, kind="Derived", sql=_expr_to_text(item.this))
                add_edge(scope, derived_node, "PRODUCES")
                local_aliases.add(alias_name)
                derived_registry[f"{scope.split('::')[-1]}.{alias_name}"] = derived_node
                alias_name_to_nodes.setdefault(alias_name, set()).add(derived_node)

                expr_text_upper = _expr_to_text(item.this).upper()
                if any(fn in expr_text_upper for fn in ["SUM(", "AVG(", "COUNT(", "MIN(", "MAX("]):
                    G.nodes[derived_node]["aggregation"] = True

                for c in item.this.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(col_node, derived_node, "DERIVED_FROM")

            elif isinstance(item, exp.Column):
                cname = item.sql()
                col_node = _node_id("Column", cname)
                add_node(col_node, kind="Column", column=cname)
                add_edge(scope, col_node, "CONTAINS")

            else:
                for c in item.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(scope, col_node, "CONTAINS")

        # ---- GROUP BY ----
        group = select.args.get("group")
        if group:
            for g in (group.expressions or []):
                if isinstance(g, exp.Column):
                    gname = g.sql()
                    add_node(_node_id("Column", gname), kind="Column", column=gname)
                    add_edge(scope, _node_id("Column", gname), "GROUPED_BY")

        # ---- WHERE ----
        if select.args.get("where"):
            where_expr = select.args["where"].this
            pred_node = make_predicate_node(where_expr, hint="WHERE")
            add_edge(scope, pred_node, "FILTERS_ON")
            for col in scan_columns_in_expr(where_expr):
                add_node(_node_id("Column", col), kind="Column", column=col)
                add_edge(pred_node, _node_id("Column", col), "REFERENCES")

    # --- Build all CTEs ---
    for cte_name, cte_select in cte_map.items():
        process_select(cte_select, _node_id("CTE", cte_name))

    # --- Build main query ---
    main_select = tree.this if isinstance(tree, exp.With) else tree
    main_scope = _node_id("Query", "Main")
    add_node(main_scope, kind="Query", name="Main")
    process_select(main_select, main_scope)

    # --- Postprocess: connect derived columns to alias columns ---
    for alias, cte_name in alias_to_cte_map.items():
        derived_nodes = [n for n in G.nodes if n.startswith(f"Derived::{cte_name}.")]
        for d in derived_nodes:
            col = d.split(".")[-1]
            alias_col = _node_id("Column", f"{alias}.{col}")
            if alias_col in G.nodes:
                add_edge(d, alias_col, "PROPAGATES_TO")

    return G


In [37]:
import json
import re
from typing import Dict, Set, Tuple
import networkx as nx
from sqlglot import parse_one, exp


def _node_id(kind: str, name: str) -> str:
    return f"{kind}::{name}"


def _expr_to_text(expr: exp.Expression) -> str:
    try:
        return expr.sql(dialect="") if hasattr(expr, "sql") else str(expr)
    except Exception:
        return str(expr)


def build_lineage_graph_v8(sql: str) -> nx.DiGraph:
    """
    v8:
      - fully CTE-aware (keeps alias -> CTE mapping)
      - creates Table nodes for aliases and ALIAS_OF edges to CTE
      - builds Derived, Predicate, Column nodes as before
      - POST-PROCESSING:
          * rewires Table::<alias> -> CTE::<cte_name> (ALIAS_OF already added during parsing)
          * creates PROPAGATES_TO edges from Derived outputs -> Column::<alias>
          * generic usage propagation: add USED_IN edges (and more specific USED_IN_<TYPE>)
            by token-matching column names against Derived/Predicate/Query SQL text
    """
    tree = parse_one(sql)
    G = nx.DiGraph()

    cte_map: Dict[str, exp.Select] = {}
    derived_registry: Dict[str, str] = {}
    alias_name_to_nodes: Dict[str, Set[str]] = {}
    predicate_counter = 0
    alias_to_cte_map: Dict[str, str] = {}   # alias -> cte_name mapping

    def add_node(node_id: str, **attrs):
        if node_id in G.nodes:
            G.nodes[node_id].update(attrs)
        else:
            G.add_node(node_id, **attrs)

    def add_edge(src: str, dst: str, relation: str):
        G.add_edge(src, dst, relation=relation)

    def make_predicate_node(expr: exp.Expression, hint: str = None) -> str:
        nonlocal predicate_counter
        predicate_counter += 1
        txt = _expr_to_text(expr)
        safe_hint = (hint or "pred").replace(" ", "_")
        node_id = _node_id("Predicate", f"{safe_hint}_{predicate_counter}")
        add_node(node_id, kind="Predicate", expr=txt)
        return node_id

    def scan_columns_in_expr(expr: exp.Expression) -> Set[str]:
        cols = set()
        for c in expr.find_all(exp.Column):
            cols.add(c.sql())
        return cols

    # --- Collect CTEs first ---
    if isinstance(tree, exp.With):
        with_clause = tree
    elif tree.args.get("with"):
        with_clause = tree.args["with"]
    else:
        with_clause = None

    if with_clause:
        for cte in with_clause.expressions:
            cte_name = cte.alias_or_name
            cte_map[cte_name] = cte.this
            add_node(_node_id("CTE", cte_name), kind="Dataset", name=cte_name)

    # --- Process SELECT recursively ---
    def process_select(select: exp.Select, scope: str):
        """
        scope is node id like "CTE::regional_sales" or "Query::Main"
        """
        add_node(scope, kind="Dataset" if scope.startswith("CTE::") else "Query")

        # ---- FROM ----
        from_expr = select.args.get("from")
        if from_expr and from_expr.this and isinstance(from_expr.this, exp.Table):
            t = from_expr.this
            # table name might be Identifier or other
            table_name = t.this.name if hasattr(t.this, "name") else str(t.this)
            alias = t.alias_or_name or table_name

            table_node = _node_id("Table", alias)
            add_node(table_node, kind="Table", name=alias)

            # If the table name corresponds to a CTE, record ALIAS_OF
            if table_name in cte_map:
                add_edge(_node_id("CTE", table_name), table_node, "ALIAS_OF")
                alias_to_cte_map[alias] = table_name

            add_edge(scope, table_node, "DERIVED_FROM")

        # ---- JOINs ----
        for join in select.find_all(exp.Join):
            if join.this and isinstance(join.this, exp.Table):
                j = join.this
                j_table_name = j.this.name if hasattr(j.this, "name") else str(j.this)
                j_alias = j.alias_or_name or j_table_name

                j_table_node = _node_id("Table", j_alias)
                add_node(j_table_node, kind="Table", name=j_alias)

                if j_table_name in cte_map:
                    add_edge(_node_id("CTE", j_table_name), j_table_node, "ALIAS_OF")
                    alias_to_cte_map[j_alias] = j_table_name

                add_edge(scope, j_table_node, "JOINS_WITH")

            if join.args.get("on"):
                on_expr = join.args["on"]
                pred_node = make_predicate_node(on_expr, hint="ON")
                add_edge(scope, pred_node, "ON_CONDITION")
                for col in scan_columns_in_expr(on_expr):
                    col_node = _node_id("Column", col)
                    add_node(col_node, kind="Column", column=col)
                    add_edge(pred_node, col_node, "REFERENCES")

        # ---- SELECT expressions ----
        for item in select.expressions:
            if isinstance(item, exp.Alias):
                alias_name = item.alias_or_name
                derived_node = _node_id("Derived", f"{scope.split('::')[-1]}.{alias_name}")
                add_node(derived_node, kind="Derived", sql=_expr_to_text(item.this))
                add_edge(scope, derived_node, "PRODUCES")
                derived_registry[f"{scope.split('::')[-1]}.{alias_name}"] = derived_node
                alias_name_to_nodes.setdefault(alias_name, set()).add(derived_node)

                expr_text_upper = _expr_to_text(item.this).upper()
                if any(fn in expr_text_upper for fn in ["SUM(", "AVG(", "COUNT(", "MIN(", "MAX("]):
                    G.nodes[derived_node]["aggregation"] = True

                # columns inside derived expression
                for c in item.this.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(col_node, derived_node, "DERIVED_FROM")

            elif isinstance(item, exp.Column):
                cname = item.sql()
                col_node = _node_id("Column", cname)
                add_node(col_node, kind="Column", column=cname)
                add_edge(scope, col_node, "CONTAINS")

            else:
                # complex expression without alias: still scan columns
                for c in item.find_all(exp.Column):
                    cname = c.sql()
                    col_node = _node_id("Column", cname)
                    add_node(col_node, kind="Column", column=cname)
                    add_edge(scope, col_node, "CONTAINS")

        # ---- GROUP BY ----
        group = select.args.get("group")
        if group:
            for g in (group.expressions or []):
                if isinstance(g, exp.Column):
                    gname = g.sql()
                    add_node(_node_id("Column", gname), kind="Column", column=gname)
                    add_edge(scope, _node_id("Column", gname), "GROUPED_BY")

        # ---- WHERE ----
        if select.args.get("where"):
            where_expr = select.args["where"].this
            pred_node = make_predicate_node(where_expr, hint="WHERE")
            add_edge(scope, pred_node, "FILTERS_ON")
            for col in scan_columns_in_expr(where_expr):
                add_node(_node_id("Column", col), kind="Column", column=col)
                add_edge(pred_node, _node_id("Column", col), "REFERENCES")

    # --- Build all CTEs ---
    if cte_map:
        for cte_name, cte_select in cte_map.items():
            process_select(cte_select, _node_id("CTE", cte_name))

    # --- Build main query ---
    main_select = tree.this if isinstance(tree, exp.With) else tree
    main_scope = _node_id("Query", "Main")
    add_node(main_scope, kind="Query", name="Main")
    process_select(main_select, main_scope)

    # --- Postprocess: connect derived columns to alias columns (PROPAGATES_TO) ---
    for alias, cte_name in alias_to_cte_map.items():
        # derived nodes defined inside the CTE: e.g. "regional_sales.total_sales"
        derived_nodes = [n for n in G.nodes if n.startswith(f"Derived::{cte_name}.")]
        for d in derived_nodes:
            col = d.split(".")[-1]
            # the alias column in the main query will be like "rs.total_sales" if alias=rs
            alias_col = _node_id("Column", f"{alias}.{col}")
            # if alias column exists, connect derived -> alias_col as PROPAGATES_TO
            if alias_col in G.nodes:
                add_edge(d, alias_col, "PROPAGATES_TO")
            else:
                # create alias column and link
                add_node(alias_col, kind="Column", column=f"{alias}.{col}")
                add_edge(d, alias_col, "PROPAGATES_TO")

    # --- Ensure dataset CONTAINS for each derived produced alias (dataset-level columns) ---
    for derived_node, ddata in [(n, d) for n, d in G.nodes(data=True) if d.get("kind") == "Derived"]:
        # derived node ids look like Derived::Scope.alias
        try:
            suffix = derived_node.split("::", 1)[1]
            ds, alias = suffix.split(".", 1)
            dataset_node = f"CTE::{ds}" if f"CTE::{ds}" in G.nodes else f"Query::{ds}" if f"Query::{ds}" in G.nodes else None
            if dataset_node:
                col_node = _node_id("Column", alias)
                add_node(col_node, kind="Column", column=alias)
                add_edge(dataset_node, col_node, "CONTAINS")
        except Exception:
            pass

    # --- Rewire Table nodes that are aliases of CTEs (optional) ---
    # Add DERIVED_FROM edges from Query/Main to the CTE node if alias used
    for alias, cte_name in alias_to_cte_map.items():
        table_node = _node_id("Table", alias)
        cte_node = _node_id("CTE", cte_name)
        # ensure cte_node exists
        add_node(cte_node, kind="Dataset", name=cte_name)
        if table_node in G.nodes:
            # keep ALIAS_OF edges (already added), but also link Query/Main to CTE directly
            # find any dataset that DERIVED_FROM Table::<alias> and also add DERIVED_FROM -> CTE::<cte_name>
            for u, v, d in list(G.edges(data=True)):
                if v == table_node and d.get("relation") == "DERIVED_FROM":
                    if not G.has_edge(u, cte_node):
                        add_edge(u, cte_node, "DERIVED_FROM")

    # -------------------------------------------------------------------------
    # --- GENERIC USAGE PROPAGATION (CORE v8 FEATURE) -------------------------
    # For every consumer node (Derived, Predicate, Dataset/Query) with SQL or expr text,
    # find Column nodes and create USED_IN edges from Column -> consumer.
    #
    # We match tokens with word-boundaries for both qualified (a.b) and unqualified (b).
    # Also we create a more specific relation when consumer is a Predicate or Derived.
    # -------------------------------------------------------------------------
    # Build column lookup: node_id -> (qualified, unqualified)
    col_nodes = {}
    for n, d in G.nodes(data=True):
        if d.get("kind") == "Column":
            col_text = d.get("column", "")  # e.g., "mt.monthly_sales" or "amount"
            qual = col_text.lower()
            unqual = qual.split(".")[-1]
            col_nodes[n] = (qual, unqual)

    # Helper for token match using regex word boundaries
    def token_appears_in(token: str, text: str) -> bool:
        if not token or not text:
            return False
        # escape token for regex; allow dot in token by escaping dot
        pattern = r"\b" + re.escape(token) + r"\b"
        return re.search(pattern, text, flags=re.IGNORECASE) is not None

    # Iterate consumers
    for t_node, t_data in list(G.nodes(data=True)):
        kind = t_data.get("kind")
        # gather textual content to search (sql for Derived, expr for Predicate, name for Dataset/Query)
        if kind == "Derived":
            sql_text = (t_data.get("sql") or "").lower()
            rel_type = "USED_IN_DERIVATION"
        elif kind == "Predicate":
            sql_text = (t_data.get("expr") or "").lower()
            rel_type = "USED_IN_PREDICATE"
        elif kind in {"Dataset", "Query"}:
            # include contained columns names and produced derived sqls as context
            sql_text = ""
            # also try to include node 'name' if present
            sql_text = (t_data.get("sql") or t_data.get("expr") or t_data.get("name") or "").lower()
            rel_type = "USED_IN_DATASET"
        else:
            continue

        # If we have no textual SQL, also try to infer context by collecting PRODUCTION/CONTAINS edges
        # but primary approach is textual matching
        if not sql_text:
            # Gather text by concatenating contained derived/expr strings if available
            parts = []
            # derived nodes produced by this dataset
            for u, v, ed in G.edges(data=True):
                # edges: dataset -> derived with relation PRODUCES
                if u == t_node and ed.get("relation") == "PRODUCES":
                    dv = v
                    parts.append((G.nodes[v].get("sql") or ""))
            sql_text = " ".join(parts).lower()

        if not sql_text:
            continue

        # For each column node, test match against sql_text using qualified and unqualified token checks
        for col_node_id, (qualified, unqualified) in col_nodes.items():
            if token_appears_in(qualified, sql_text) or token_appears_in(unqualified, sql_text):
                # avoid creating duplicate edges
                if not G.has_edge(col_node_id, t_node):
                    add_edge(col_node_id, t_node, relation=rel_type)
                # also add a generic USED_IN marker for easier traversal
                if not G.has_edge(col_node_id, t_node):
                    add_edge(col_node_id, t_node, relation="USED_IN")

    # Final: collapse duplicated relations (if any edges were added with same src/dst but different labels,
    # networkx will keep last one — acceptable for now). If you need multi-relations, switch to MultiDiGraph.

    return G


In [13]:
sql = """

WITH regional_sales AS (
    SELECT region,
           SUM(amount) AS total_sales,
           COUNT(DISTINCT customer_id) AS unique_customers
    FROM sales
    WHERE sale_date BETWEEN '2023-01-01' AND '2023-12-31'
    GROUP BY region
),
top_customers AS (
    SELECT c.customer_id,
           c.customer_name,
           SUM(s.amount) AS total_spent,
           RANK() OVER (ORDER BY SUM(s.amount) DESC) AS rank
    FROM customers c
    JOIN sales s ON c.customer_id = s.customer_id
    WHERE s.sale_date BETWEEN '2023-01-01' AND '2023-12-31'
    GROUP BY c.customer_id, c.customer_name
),
category_performance AS (
    SELECT p.category,
           SUM(s.amount) AS total_sales,
           AVG(s.amount) AS avg_order_value,
           COUNT(DISTINCT s.order_id) AS total_orders
    FROM sales s
    JOIN products p ON s.product_id = p.product_id
    GROUP BY p.category
),
high_value_orders AS (
    SELECT order_id,
           customer_id,
           amount,
           CASE
               WHEN amount > 10000 THEN 'VIP'
               WHEN amount BETWEEN 5000 AND 10000 THEN 'Premium'
               ELSE 'Regular'
           END AS order_tier
    FROM sales
),
monthly_trends AS (
    SELECT DATE_TRUNC('month', sale_date) AS month,
           region,
           SUM(amount) AS monthly_sales,
           AVG(amount) AS avg_sale_value,
           COUNT(order_id) AS order_count
    FROM sales
    GROUP BY DATE_TRUNC('month', sale_date), region
)
SELECT
    rs.region,
    rs.total_sales,
    rs.unique_customers,
    tc.customer_name AS top_customer,
    tc.total_spent AS top_customer_spent,
    cp.category,
    cp.total_sales AS category_sales,
    hv.order_tier,
    mt.month,
    mt.monthly_sales,
    mt.avg_sale_value,
    (mt.monthly_sales / NULLIF(rs.total_sales, 0)) * 100 AS monthly_contribution_pct
FROM regional_sales rs
JOIN top_customers tc ON rs.region = (
    SELECT region
    FROM customers c
    JOIN sales s ON c.customer_id = s.customer_id
    WHERE c.customer_id = tc.customer_id
    LIMIT 1
)
JOIN category_performance cp ON 1=1
LEFT JOIN high_value_orders hv ON hv.customer_id = tc.customer_id
JOIN monthly_trends mt ON mt.region = rs.region
WHERE mt.monthly_sales > 100000
  AND cp.total_sales > 500000
ORDER BY rs.total_sales DESC, mt.month DESC;

"""

print("Building base lineage graph...")
G_1 = build_lineage_graph_v3(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 53 nodes and 87 edges


In [11]:
print("Building base lineage graph...")
G_1 = build_lineage_graph_v4(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 61 nodes and 95 edges


In [4]:
print("Building base lineage graph...")
G_1 = build_lineage_graph_v5(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 61 nodes and 113 edges


In [14]:
print("Building base lineage graph...")
G_1 = build_lineage_graph_v6(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 84 nodes and 126 edges


In [31]:
print("Building base lineage graph...")
G_1 = build_lineage_graph_v7(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 71 nodes and 111 edges


In [38]:
print("Building base lineage graph...")
G_1 = build_lineage_graph_v8(sql)
print(f"Graph built with {len(G_1.nodes())} nodes and {len(G_1.edges())} edges")

Building base lineage graph...
Graph built with 90 nodes and 202 edges


In [39]:
data = json_graph.node_link_data(G_1)

# Print or save it
json_str = json.dumps(data, indent=2)
print(json_str)

{
  "directed": true,
  "multigraph": false,
  "graph": {},
  "nodes": [
    {
      "kind": "Dataset",
      "name": "regional_sales",
      "id": "CTE::regional_sales"
    },
    {
      "kind": "Dataset",
      "name": "top_customers",
      "id": "CTE::top_customers"
    },
    {
      "kind": "Dataset",
      "name": "category_performance",
      "id": "CTE::category_performance"
    },
    {
      "kind": "Dataset",
      "name": "high_value_orders",
      "id": "CTE::high_value_orders"
    },
    {
      "kind": "Dataset",
      "name": "monthly_trends",
      "id": "CTE::monthly_trends"
    },
    {
      "kind": "Table",
      "name": "sales",
      "id": "Table::sales"
    },
    {
      "kind": "Column",
      "column": "region",
      "id": "Column::region"
    },
    {
      "kind": "Derived",
      "sql": "SUM(amount)",
      "aggregation": true,
      "id": "Derived::regional_sales.total_sales"
    },
    {
      "kind": "Column",
      "column": "amount",
      "id": "C

In [40]:
visualize_graph_pyvis(G_1)

✅ Graph saved to lineage_graph.html


In [16]:
def find_target_nodes(G, query, llm=False):
    query_norm = query.lower().replace(" ", "_")
    candidates = []
    
    for n, d in G.nodes(data=True):
        text = (n + " " + str(d)).lower()
        if query_norm in text or query.lower() in text:
            candidates.append(n)
    
    if not candidates and llm:
        # Ask the LLM to identify possible nodes
        context = [n for n in G.nodes() if "::" in n]
        prompt = f"""
        You are a lineage assistant. 
        From the following list of nodes, which best matches the user query "{query}"?
        Nodes: {context[:100]}  # (truncate for efficiency)
        Respond with a JSON array of node IDs that most closely match.
        """
        try:
            resp = llm.invoke(prompt)
            candidates = json.loads(resp.content)
        except Exception:
            candidates = []
    
    return candidates


def infer_traversal_direction(question, llm=False):
    q = question.lower()
    if any(k in q for k in ["where", "source", "come from", "derived from"]):
        return "upstream"
    if any(k in q for k in ["depend on", "used in", "downstream", "impact"]):
        return "downstream"
    if llm:
        resp = llm.invoke(f"Does this question ask about upstream or downstream lineage? {question}")
        return "upstream" if "upstream" in resp.content.lower() else "downstream"
    return "context"

def infer_traversal_direction(question, llm=False):
    q = question.lower()
    if any(k in q for k in ["where", "source", "come from", "derived from"]):
        return "upstream"
    if any(k in q for k in ["depend on", "used in", "downstream", "impact"]):
        return "downstream"
    if llm:
        resp = llm.invoke(f"Does this question ask about upstream or downstream lineage? {question}")
        return "upstream" if "upstream" in resp.content.lower() else "downstream"
    return "context"


def traverse_lineage(G, target_node, direction="upstream", depth=3):
    if direction == "upstream":
        return nx.ancestors(G, target_node)
    elif direction == "downstream":
        return nx.descendants(G, target_node)
    else:
        # Include neighbors (context)
        return set(G.predecessors(target_node)) | set(G.successors(target_node))

def summarize_traversal(G, nodes_subset, question, llm=False):
    sub_edges = [
        (u, v, d.get("relation"))
        for u, v, d in G.edges(data=True)
        if u in nodes_subset and v in nodes_subset
    ]
    
    if not llm:
        summary = "\n".join([f"{u} --{rel}--> {v}" for u, v, rel in sub_edges])
        return f"Traversal summary:\n{summary}"
    
    context = {
        "nodes": {n: G.nodes[n] for n in nodes_subset},
        "edges": sub_edges
    }
    prompt = f"""
    You are a data lineage analyst.
    Given the following lineage subgraph and the question "{question}",
    explain the answer in concise, natural language.

    Subgraph JSON:
    {json.dumps(context, indent=2)}
    """
    resp = llm.invoke(prompt)
    return resp.content


def answer_lineage_question(G, question, llm=False):
    targets = find_target_nodes(G, question, llm=llm)
    if not targets:
        return "No relevant node found."
    
    direction = infer_traversal_direction(question, llm=llm)
    
    # For now, just take the first matched node
    target = targets[0]
    
    nodes_subset = traverse_lineage(G, target, direction=direction, depth=3)
    nodes_subset.add(target)
    
    return summarize_traversal(G, nodes_subset, question, llm=llm)


llm = get_llm()

In [44]:
from IPython.display import display, Markdown
import json

def summarize_traversal(G, nodes_subset, question, llm=False, direction=None):
    sub_edges = [
        (u, v, d.get("relation"))
        for u, v, d in G.edges(data=True)
        if u in nodes_subset and v in nodes_subset
    ]
    
    # --- Build structured context ---
    context = {
        "nodes": {n: G.nodes[n] for n in nodes_subset},
        "edges": [
            {"from": u, "to": v, "relation": rel} for u, v, rel in sub_edges
        ],
    }

    # --- If no LLM: structured textual summary ---
    if not llm:
        summary = "\n".join([f"{u} --{rel}--> {v}" for u, v, rel in sub_edges])
        answer = f"### Traversal Summary\n\n**Direction:** {direction or 'context'}\n\n{summary}"
        display(Markdown(answer))
        return answer

    # --- Unified LLM prompt for both upstream and downstream ---
    prompt = f"""
    You are a lineage analyst.
    Summarize the findings for this question in a **structured format** with these sections:
    
    1. **Short Answer** – concise 1–2 sentences.
    2. **Detailed Explanation** – a clear explanation of how data or logic flows (or how impact propagates), referencing key nodes and relationships.
    3. **What Actually Changes** – list real changes or effects (values, classifications, aggregations, etc.).
    4. **Recommended Next Steps** – what should be validated or tested downstream or upstream.
    
    Question: "{question}"
    Traversal Direction: {direction or 'context'}
    Subgraph (JSON):
    {json.dumps(context, indent=2)}
    """
    
    try:
        resp = llm.invoke(prompt)
        answer = resp.content.strip()
    except Exception as e:
        answer = f"LLM failed: {e}"

    # --- Render as formatted Markdown in Jupyter ---
    display(Markdown(answer))
    return answer


In [19]:
question = "Where did you get the total sales from?"
answer_lineage_question(G_1, question, llm=llm)

'The total_sales value is computed in the regional_sales CTE as SUM(amount). The amount column used in that aggregation comes from the high_value_orders CTE, and the main query reads regional_sales to get total_sales.'

In [34]:
question = "Where did you get the total sales from?"
answer_lineage_question(G_1, question, llm=llm)

'The total_sales value is an aggregation (SUM(amount)) produced in the regional_sales CTE. The SUM is computed over the amount column, and that amount column comes from the high_value_orders CTE.'

In [41]:
question = "Where did you get the total sales from?"
answer_lineage_question(G_1, question, llm=llm)

'The total_sales value is computed as SUM(amount) in the regional_sales CTE. It aggregates the amount column (s.amount), which comes from the high_value_orders CTE; the regional_sales dataset itself is produced by the Main query.'

In [20]:
question = "Where did you get the category sales from?"
answer_lineage_question(G_1, question, llm=llm)

'category_sales in the Main query comes from the category_performance CTE (aliased cp): Main.category_sales = cp.total_sales. That cp.total_sales is computed by the category_performance CTE as SUM(s.amount) — i.e., it’s an aggregation over the s.amount column. (The query also brings in regional_sales.total_sales — SUM(amount) — and uses cp.total_sales in the WHERE filter cp.total_sales > 500000.)'

In [35]:
question = "Where did you get the category sales from?"
answer_lineage_question(G_1, question, llm=llm)

'The Main.category_sales value comes from cp.total_sales — the total_sales column produced by the category_performance CTE. That total_sales is computed as SUM(s.amount) (an aggregation over s.amount), and the Main query also uses cp.total_sales in its WHERE clause (cp.total_sales > 500,000).'

In [43]:
question = "Where did you get the category sales from?"
answer_lineage_question(G_1, question, llm=llm)

'The Main.category_sales value comes from cp.total_sales — the total_sales column produced by the category_performance CTE. That CTE computes total_sales as SUM(s.amount) (i.e., it aggregates the underlying order amount values, coming from the source amount/high_value_orders data).'

In [45]:
question = "what if i change the order tier variable? whats going to be the impact?"
answer_lineage_question(G_1, question, llm=llm)

**Short Answer**  
Changing the order_tier logic in Derived::high_value_orders.order_tier will directly change the values stored in hv.order_tier (Column::hv.order_tier); any downstream consumers that use hv.order_tier (reports, aggregations, joins, filters) can therefore change. The provided subgraph shows a direct PROPAGATES_TO link from the derived expression to the hv column, so the change is immediately visible wherever that column is used.

**Detailed Explanation**  
- Current logic (Derived::high_value_orders.order_tier):  
  CASE WHEN amount > 10000 THEN 'VIP' WHEN amount BETWEEN 5000 AND 10000 THEN 'Premium' ELSE 'Regular' END  
  This expression classifies each record by its amount into three categorical labels.  
- Relationship in subgraph: Derived::high_value_orders.order_tier PROPAGATES_TO Column::hv.order_tier. That means the computed labels from the CASE expression are what populate hv.order_tier.  
- How impact propagates: any modification to the CASE (thresholds, inclusive/exclusive boundaries, label names, null handling, or logic order) will change the computed label for affected rows. Those changed labels will appear in hv.order_tier, and any downstream logic that reads hv.order_tier (grouping, filtering, joins, metric calculations, SCD keys, partitions) will reflect the new values. Because the provided subgraph ends at hv.order_tier, you should expand lineage to enumerate all consumers, but conceptually the propagation is immediate and deterministic: changed expression -> changed column values -> changed consumer outputs.

**What Actually Changes**  
- Values: hv.order_tier values for records whose amount falls into ranges affected by your change (e.g., altered thresholds move some records from 'Premium' to 'VIP' or vice versa).  
- Record classification: per-row tier assignments will be different for boundary and moved ranges.  
- Aggregations and counts: totals by tier (counts, sums, averages) will change.  
- Filters and cohorts: any cohort segmentation or queries filtering on tier will return different record sets.  
- Derived metrics and dashboards: KPIs computed by tier (conversion rates, revenue per tier) will update.  
- Joins/keys: if other datasets join on specific tier labels, join results may change (or fail if labels renamed).  
- Potential metadata: enum/classification definitions, documentation, tests that assert exact tier values may become invalid.  
- Edge cases: null or borderline amounts (5000 and 10000 are currently included in 'Premium' due to BETWEEN) may change if boundaries are altered.

**Recommended Next Steps**  
1. Expand lineage: identify all downstream consumers of Column::hv.order_tier (ETL jobs, models, dashboards, scheduled reports, alerts, data APIs).  
2. Regression tests: run before/after comparisons on a representative dataset:  
   - Row-level diff for tier column (SELECT id, amount, old_tier, new_tier WHERE old_tier != new_tier).  
   - Aggregation diffs: counts and sums grouped by tier; KPI snapshots.  
   - Boundary tests for amount = 5000 and 10000 and for NULL amount.  
3. Validate semantics: confirm desired business rules (inclusive/exclusive boundaries, null handling, exact labels) with stakeholders.  
4. Update consumers: where labels are used as keys or in code, update joins, filters, hard-coded strings, and documentation. Communicate changes to downstream owners.  
5. Deploy strategy: consider phased rollout (backfill vs. in-place replace), or create a transitional column (e.g., order_tier_v2) and run both in parallel to measure impact.  
6. Monitoring: after change, monitor tier counts, dashboards and alerts for unexpected shifts; set up automated anomaly detection on tier distributions.

"**Short Answer**  \nChanging the order_tier logic in Derived::high_value_orders.order_tier will directly change the values stored in hv.order_tier (Column::hv.order_tier); any downstream consumers that use hv.order_tier (reports, aggregations, joins, filters) can therefore change. The provided subgraph shows a direct PROPAGATES_TO link from the derived expression to the hv column, so the change is immediately visible wherever that column is used.\n\n**Detailed Explanation**  \n- Current logic (Derived::high_value_orders.order_tier):  \n  CASE WHEN amount > 10000 THEN 'VIP' WHEN amount BETWEEN 5000 AND 10000 THEN 'Premium' ELSE 'Regular' END  \n  This expression classifies each record by its amount into three categorical labels.  \n- Relationship in subgraph: Derived::high_value_orders.order_tier PROPAGATES_TO Column::hv.order_tier. That means the computed labels from the CASE expression are what populate hv.order_tier.  \n- How impact propagates: any modification to the CASE (threshol

In [46]:
question = "what if i change the total spent? whats going to be the impact?"
answer_lineage_question(G_1, question, llm=llm)

1. Short Answer
- Changing the source column total_spent will directly change the values of Derived::Main.top_customer_spent (it’s defined as tc.total_spent). Any downstream entity that reads that derived column or relies on its values (rankings, aggregates, reports) may also be affected.

2. Detailed Explanation
- Relationship: Column::total_spent --[USED_IN_DERIVATION]--> Derived::Main.top_customer_spent. The derived column is a direct passthrough (sql: tc.total_spent), so the derived value is exactly the source value at materialization time.
- Propagation: when total_spent values are updated at the source, the derived column will reflect those updates the next time Derived::Main is recomputed/materialized. If Derived::Main is used by other datasets, BI dashboards, alerts, or business logic, those consumers will see changed values and any computations that use top_customer_spent (sums, averages, top-N, thresholds) may produce different results.
- Key nodes and flow: source column (total_spent) -> immediate impacted node (Derived::Main.top_customer_spent) -> any downstream consumers of Derived::Main (not shown in this subgraph) will inherit the changed values.

3. What Actually Changes
- Values: the cell values of Derived::Main.top_customer_spent will change to match the new total_spent values after recomputation.
- Aggregations/derivations: any aggregates or derived metrics computed from top_customer_spent (totals, averages, top-N selections, segments) will potentially change.
- Reports/visualizations/alerts: dashboards or alerts that use Derived::Main.top_customer_spent may show different numbers or trigger different alert states.
- No schema-level change: column name/type/lineage metadata remain the same unless you change schema; only the data content changes.

4. Recommended Next Steps
- Identify downstream consumers: list dashboards, reports, tables, or processes that read Derived::Main.top_customer_spent and prioritize by business impact.
- Run tests and comparisons: recompute Derived::Main in a test environment and run diff checks (row-level and aggregate-level) between old and new top_customer_spent to quantify impact.
- Validate business logic: ensure any thresholds, rankings, or segment rules that use top_customer_spent still behave correctly under new values.
- Check materialization cadence: confirm when Derived::Main will be refreshed and whether a backfill/rebuild is required.
- Monitor and revert plan: if this is a planned change, schedule monitoring after deployment and have rollback or corrective queries ready if critical dashboards break.
- Communicate: notify owners of high-impact downstream consumers so they can validate or adjust reports/alerts.

'1. Short Answer\n- Changing the source column total_spent will directly change the values of Derived::Main.top_customer_spent (it’s defined as tc.total_spent). Any downstream entity that reads that derived column or relies on its values (rankings, aggregates, reports) may also be affected.\n\n2. Detailed Explanation\n- Relationship: Column::total_spent --[USED_IN_DERIVATION]--> Derived::Main.top_customer_spent. The derived column is a direct passthrough (sql: tc.total_spent), so the derived value is exactly the source value at materialization time.\n- Propagation: when total_spent values are updated at the source, the derived column will reflect those updates the next time Derived::Main is recomputed/materialized. If Derived::Main is used by other datasets, BI dashboards, alerts, or business logic, those consumers will see changed values and any computations that use top_customer_spent (sums, averages, top-N, thresholds) may produce different results.\n- Key nodes and flow: source col

In [47]:
question = "where did you get the total spent from?"
answer_lineage_question(G_1, question, llm=llm)

**Short Answer** – The total_spent value is computed in the top_customers CTE as the aggregation SUM(s.amount); that s.amount comes from the amount column contained in the high_value_orders CTE.

**Detailed Explanation**  
- Key nodes: CTE::high_value_orders (contains Column::amount), Column::s.amount (the aliased column used in the aggregation), and CTE::top_customers which PRODUCES Derived::top_customers.total_spent.  
- Flow: high_value_orders CONTAINS the base column amount. In the top_customers query an alias s.amount is referenced (Column::s.amount), and the derived metric total_spent is defined as SUM(s.amount) (Derived::top_customers.total_spent). The graph records that the base Column::amount is USED_IN_DERIVATION and that Column::s.amount is DERIVED_FROM into the total_spent derivation.  
- Effectively, top_customers.total_spent = SUM(over rows of s.amount), where s.amount is sourced from the amount column in high_value_orders.

**What Actually Changes**  
- A new derived field/metric is produced: top_customers.total_spent.  
- Its value is an aggregation (SUM) computed from the amount column values (s.amount).  
- Any change to rows, filters, or values in high_value_orders.amount will change total_spent.  
- Classification: total_spent is an aggregated numeric measure (not a raw column).  
- Downstream consumers see an aggregated value rather than row-level amounts.

**Recommended Next Steps**  
- Inspect the top_customers SQL to confirm the alias s and the exact source dataset reference (ensure s points to high_value_orders).  
- Verify any GROUP BY or join logic in top_customers to confirm aggregation granularity (per customer, overall, etc.).  
- Validate upstream high_value_orders: check filters, date ranges, and that amount is numeric and in expected currency/units.  
- Run test queries: compare SUM(amount) from high_value_orders to top_customers.total_spent for matching scopes to detect double-counting or missing rows.  
- Add tests for nulls/zero values and distinct vs non-distinct summation if relevant.  
- Review downstream consumers to ensure they expect an aggregated metric and use appropriate aggregation keys/time windows.

'**Short Answer** – The total_spent value is computed in the top_customers CTE as the aggregation SUM(s.amount); that s.amount comes from the amount column contained in the high_value_orders CTE.\n\n**Detailed Explanation**  \n- Key nodes: CTE::high_value_orders (contains Column::amount), Column::s.amount (the aliased column used in the aggregation), and CTE::top_customers which PRODUCES Derived::top_customers.total_spent.  \n- Flow: high_value_orders CONTAINS the base column amount. In the top_customers query an alias s.amount is referenced (Column::s.amount), and the derived metric total_spent is defined as SUM(s.amount) (Derived::top_customers.total_spent). The graph records that the base Column::amount is USED_IN_DERIVATION and that Column::s.amount is DERIVED_FROM into the total_spent derivation.  \n- Effectively, top_customers.total_spent = SUM(over rows of s.amount), where s.amount is sourced from the amount column in high_value_orders.\n\n**What Actually Changes**  \n- A new de

In [None]:
question = "how is total spent calculated?"
answer_lineage_question(G_1, question, llm=llm)