In [None]:
pip install sqlglot




Procsessing AST directly

In [87]:
# Define the classes to mimic the AST structure
class Identifier:
    def __init__(self, this, quoted=False):
        self.this = this
        self.quoted = quoted

class Table:
    def __init__(self, this, db=None, catalog=None):
        self.this = this
        self.db = db
        self.catalog = catalog

class Column:
    def __init__(self, this, table=None):
        self.this = this
        self.table = table

class Alias:
    def __init__(self, this, alias):
        self.this = this
        self.alias = alias

class Cast:
    def __init__(self, this, to):
        self.this = this
        self.to = to

class Paren:
    def __init__(self, this):
        self.this = this

class Div:
    def __init__(self, this, expression, typed=False, safe=False):
        self.this = this
        self.expression = expression
        self.typed = typed
        self.safe = safe

class DataType:
    def __init__(self, this, expressions, nested=False):
        self.this = this
        self.expressions = expressions
        self.nested = nested

class DataTypeParam:
    def __init__(self, this, is_string=False):
        self.this = this
        self.is_string = is_string

class Literal:
    def __init__(self, this, is_string=False):
        self.this = this
        self.is_string = is_string

class Coalesce:
    def __init__(self, this, expressions):
        self.this = this
        self.expressions = expressions

class EQ:
    def __init__(self, this, expression):
        self.this = this
        self.expression = expression

class Boolean:
    def __init__(self, this):
        self.this = this

class CTE:
    def __init__(self, this, alias):
        self.this = this
        self.alias = alias

class From:
    def __init__(self, this):
        self.this = this

class Star:
    pass

class Select:
    def __init__(self, expressions, from_=None, with_=None):
        self.expressions = expressions
        self.from_ = from_
        self.with_ = with_

class TableAlias:
    def __init__(self, this):
        self.this = this

class With:
    def __init__(self, expressions):
        self.expressions = expressions


# The input AST is recreated here based on your input:
ast = Select(
    expressions=[Star()],
    from_=From(
        this=Table(
            this=Identifier(this="renamed", quoted=False)
        )
    ),
    with_=With(
        expressions=[
            CTE(
                this=Select(
                    expressions=[Star()],
                    from_=From(
                        this=Table(
                            this=Identifier(this="raw_products", quoted=False),
                            db=Identifier(this="raw", quoted=False),
                            catalog=Identifier(this="jaffle_shop", quoted=False)
                        )
                    )
                ),
                alias=TableAlias(
                    this=Identifier(this="source", quoted=False)
                )
            ),
            CTE(
                this=Select(
                    expressions=[
                        Alias(
                            this=Column(
                                this=Identifier(this="sku", quoted=False)
                            ),
                            alias=Identifier(this="product_id", quoted=False)
                        ),
                        Alias(
                            this=Column(
                                this=Identifier(this="name", quoted=False)
                            ),
                            alias=Identifier(this="product_name", quoted=False)
                        ),
                        Alias(
                            this=Column(
                                this=Identifier(this="type", quoted=False)
                            ),
                            alias=Identifier(this="product_type", quoted=False)
                        ),
                        Alias(
                            this=Column(
                                this=Identifier(this="description", quoted=False)
                            ),
                            alias=Identifier(this="product_description", quoted=False)
                        ),
                        Alias(
                            this=Cast(
                                this=Paren(
                                    this=Div(
                                        this=Column(
                                            this=Identifier(this="price", quoted=False)
                                        ),
                                        expression=Literal(this=100, is_string=False)
                                    )
                                ),
                                to=DataType(
                                    this="DECIMAL",
                                    expressions=[
                                        DataTypeParam(this=Literal(this=16, is_string=False)),
                                        DataTypeParam(this=Literal(this=2, is_string=False))
                                    ]
                                )
                            ),
                            alias=Identifier(this="product_price", quoted=False)
                        ),
                        Alias(
                            this=Coalesce(
                                this=EQ(
                                    this=Column(
                                        this=Identifier(this="type", quoted=False)
                                    ),
                                    expression=Literal(this="jaffle", is_string=True)
                                ),
                                expressions=[Boolean(this=False)]
                            ),
                            alias=Identifier(this="is_food_item", quoted=False)
                        ),
                        Alias(
                            this=Coalesce(
                                this=EQ(
                                    this=Column(
                                        this=Identifier(this="type", quoted=False)
                                    ),
                                    expression=Literal(this="beverage", is_string=True)
                                ),
                                expressions=[Boolean(this=False)]
                            ),
                            alias=Identifier(this="is_drink_item", quoted=False)
                        )
                    ],
                    from_=From(
                        this=Table(
                            this=Identifier(this="source", quoted=False)
                        )
                    )
                ),
                alias=TableAlias(
                    this=Identifier(this="renamed", quoted=False)
                )
            )
        ]
    )
)


# Define the classes to mimic the AST structure (same as before)
# ... [Classes definitions remain unchanged for brevity] ...

# Function to extract source columns and transformation from any expression node
def extract_source_columns_and_transformation(node):
    if isinstance(node, Column):
        source_column = node.this.this  # Column name
        return [source_column], source_column  # No transformation applied
    elif isinstance(node, Literal):
        return [], f"'{node.this}'" if node.is_string else str(node.this)
    elif isinstance(node, Boolean):
        return [], str(node.this)
    elif isinstance(node, Paren):
        # Process the expression inside the parentheses
        source_columns, transformation = extract_source_columns_and_transformation(node.this)
        return source_columns, f"({transformation})"
    elif isinstance(node, Div):
        # Handle division
        left_columns, left_transformation = extract_source_columns_and_transformation(node.this)
        right_columns, right_transformation = extract_source_columns_and_transformation(node.expression)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} / {right_transformation}"
        return source_columns, transformation
    elif isinstance(node, Cast):
        # Handle casting
        source_columns, inner_transformation = extract_source_columns_and_transformation(node.this)
        data_type = node.to.this  # Data type, e.g., DECIMAL
        # Get parameters if any
        params = ', '.join([str(param.this.this) for param in node.to.expressions])
        transformation = f"{inner_transformation}::{data_type}({params})"
        return source_columns, transformation
    elif isinstance(node, Coalesce):
        # Handle Coalesce
        coalesce_args = []
        source_columns = []
        for arg in [node.this] + node.expressions:
            arg_columns, arg_transformation = extract_source_columns_and_transformation(arg)
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
        transformation = f"COALESCE({', '.join(coalesce_args)})"
        return source_columns, transformation
    elif isinstance(node, EQ):
        # Handle equality
        left_columns, left_transformation = extract_source_columns_and_transformation(node.this)
        right_columns, right_transformation = extract_source_columns_and_transformation(node.expression)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} = {right_transformation}"
        return source_columns, transformation
    else:
        # For other nodes, attempt to process their children
        source_columns = []
        transformations = []
        for attr in dir(node):
            if not attr.startswith('__') and not callable(getattr(node, attr)):
                child = getattr(node, attr)
                if isinstance(child, (Column, Literal, Paren, Div, Cast, Coalesce, EQ, Boolean)):
                    child_columns, child_transformation = extract_source_columns_and_transformation(child)
                    source_columns.extend(child_columns)
                    transformations.append(child_transformation)
                elif isinstance(child, list):
                    for item in child:
                        if isinstance(item, (Column, Literal, Paren, Div, Cast, Coalesce, EQ, Boolean)):
                            child_columns, child_transformation = extract_source_columns_and_transformation(item)
                            source_columns.extend(child_columns)
                            transformations.append(child_transformation)
        transformation = ' '.join(transformations).strip()
        return source_columns, transformation

# Function to extract final columns from the AST
def extract_final_columns(ast_node):
    final_columns = []

    if isinstance(ast_node, Alias):
        alias_name = ast_node.alias.this

        source_columns, transformation = extract_source_columns_and_transformation(ast_node.this)
        source_columns = list(set(source_columns))  # Remove duplicates

        # Assuming all source columns are from 'raw_products' unless specified otherwise
        source_table = "raw_products" if all(col in ["sku", "name", "type", "description", "price"] for col in source_columns) else "unknown"

        final_columns.append({
            "Final Column": alias_name,
            "Source Table": source_table,
            "Source Column": ', '.join(source_columns),
            "Transformation": transformation
        })

    # Traverse child nodes
    for attr in dir(ast_node):
        if not attr.startswith('__') and not callable(getattr(ast_node, attr)):
            child = getattr(ast_node, attr)
            if isinstance(child, list):
                for item in child:
                    final_columns.extend(extract_final_columns(item))
            elif isinstance(child, (Alias, Select, From, With, CTE)):
                final_columns.extend(extract_final_columns(child))

    return final_columns

# Use the function on the AST to extract final columns and their transformations
final_columns = extract_final_columns(ast)

# Display the final columns, their source tables/columns, and transformations
for col in final_columns:
    print(f"Final Column: {col['Final Column']}")
    print(f"Source Table: {col['Source Table']}")
    print(f"Source Column: {col['Source Column']}")
    print(f"Transformation: {col['Transformation']}\n")




Final Column: product_id
Source Table: raw_products
Source Column: sku
Transformation: sku

Final Column: product_name
Source Table: raw_products
Source Column: name
Transformation: name

Final Column: product_type
Source Table: raw_products
Source Column: type
Transformation: type

Final Column: product_description
Source Table: raw_products
Source Column: description
Transformation: description

Final Column: product_price
Source Table: raw_products
Source Column: price
Transformation: (price / 100)::DECIMAL(16, 2)

Final Column: is_food_item
Source Table: raw_products
Source Column: type
Transformation: COALESCE(type = 'jaffle', False)

Final Column: is_drink_item
Source Table: raw_products
Source Column: type
Transformation: COALESCE(type = 'beverage', False)



Generate AST from SQL

In [18]:
import sqlglot

def extract_source_columns_and_transformation(node):
    if isinstance(node, sqlglot.expressions.Column):
        source_column = node.name  # Column name
        return [source_column], source_column
    elif isinstance(node, sqlglot.expressions.Identifier):
        return [node.name], node.name
    elif isinstance(node, sqlglot.expressions.Literal):
        return [], f"'{node.this}'" if node.is_string else str(node.this)
    elif isinstance(node, sqlglot.expressions.Boolean):
        return [], str(node.this).lower()
    elif isinstance(node, sqlglot.expressions.Paren):
        source_columns, transformation = extract_source_columns_and_transformation(node.this)
        return source_columns, f"({transformation})"
    elif isinstance(node, sqlglot.expressions.Div):
        left_columns, left_transformation = extract_source_columns_and_transformation(node.left)
        right_columns, right_transformation = extract_source_columns_and_transformation(node.right)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} / {right_transformation}"
        return source_columns, transformation
    elif isinstance(node, sqlglot.expressions.Cast):
        source_columns, inner_transformation = extract_source_columns_and_transformation(node.this)
        data_type = node.to.sql()
        transformation = f"{inner_transformation}::{data_type}"
        return source_columns, transformation
    elif isinstance(node, sqlglot.expressions.Coalesce):
        coalesce_args = []
        source_columns = []
        if node.this:
            arg_columns, arg_transformation = extract_source_columns_and_transformation(node.this)
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
        for arg in node.expressions:
            arg_columns, arg_transformation = extract_source_columns_and_transformation(arg)
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
        transformation = f"COALESCE({', '.join(coalesce_args)})"
        return source_columns, transformation
    elif isinstance(node, sqlglot.expressions.EQ):
        left_columns, left_transformation = extract_source_columns_and_transformation(node.this)
        right_columns, right_transformation = extract_source_columns_and_transformation(node.expression)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} = {right_transformation}"
        return source_columns, transformation
    else:
        source_columns = []
        transformations = []
        for arg_key in node.arg_types:
            child = node.args.get(arg_key)
            if isinstance(child, list):
                for item in child:
                    if isinstance(item, sqlglot.expressions.Expression):
                        child_columns, child_transformation = extract_source_columns_and_transformation(item)
                        source_columns.extend(child_columns)
                        transformations.append(child_transformation)
            elif isinstance(child, sqlglot.expressions.Expression):
                child_columns, child_transformation = extract_source_columns_and_transformation(child)
                source_columns.extend(child_columns)
                transformations.append(child_transformation)
        transformation = ' '.join(transformations).strip()
        return source_columns, transformation

def extract_final_columns(node, current_table_alias=None, table_columns=None):
    if table_columns is None:
        table_columns = {}
    final_columns = []

    for arg_key in node.arg_types:
        child = node.args.get(arg_key)
        if isinstance(child, list):
            for item in child:
                if isinstance(item, sqlglot.expressions.Expression):
                    final_columns.extend(extract_final_columns(item, current_table_alias, table_columns))
        elif isinstance(child, sqlglot.expressions.Expression):
            final_columns.extend(extract_final_columns(child, current_table_alias, table_columns))

    if isinstance(node, sqlglot.expressions.CTE):
        cte_alias = node.alias_or_name
        cte_columns = extract_final_columns(node.this, cte_alias, table_columns)
        table_columns[cte_alias] = [col['Final Column'] for col in cte_columns if col['Final Column'] != '*']
        final_columns.extend(cte_columns)
    elif isinstance(node, sqlglot.expressions.Select):
        from_ = node.args.get('from')
        if from_ and isinstance(from_.this, sqlglot.expressions.Table):
            current_table_alias = from_.this.alias_or_name or from_.this.name
        elif from_ and isinstance(from_.this, sqlglot.expressions.Subquery):
            current_table_alias = from_.this.alias_or_name
        for exp in node.expressions:
            if isinstance(exp, sqlglot.expressions.Alias):
                alias_name = exp.alias_or_name
                source_columns, transformation = extract_source_columns_and_transformation(exp.this)
                source_columns = list(set(source_columns))
                source_table = "raw_products" if all(
                    col in ["sku", "name", "type", "description", "price"]
                    for col in source_columns
                ) else "unknown"

                final_columns.append({
                    "Final Column": alias_name,
                    "Source Table": source_table,
                    "Source Column": ', '.join(source_columns),
                    "Transformation": transformation
                })
            elif isinstance(exp, sqlglot.expressions.Star):
                if current_table_alias and current_table_alias in table_columns:
                    for col_name in table_columns[current_table_alias]:
                        final_columns.append({
                            "Final Column": col_name,
                            "Source Table": current_table_alias,
                            "Source Column": col_name,
                            "Transformation": col_name
                        })
                else:
                    final_columns.append({
                        "Final Column": '*',
                        "Source Table": 'unknown',
                        "Source Column": '*',
                        "Transformation": 'Select all columns'
                    })
    return final_columns

def main():
    query = input("Enter your SQL query:\n")
    try:
        parsed_ast = sqlglot.parse_one(query)
        # Uncomment the next line to see the AST representation for debugging
        print(repr(parsed_ast))
        print("SQL parsed successfully.\n")
    except Exception as e:
        print(f"Error parsing SQL: {e}")
        return

    table_columns = {}
    final_columns = extract_final_columns(parsed_ast, table_columns=table_columns)

    # if final_columns:
    #     for col in final_columns:
    #         print(f"Final Column: {col.get('Final Column', '')}")
    #         print(f"Source Table: {col.get('Source Table', '')}")
    #         print(f"Source Column: {col.get('Source Column', '')}")
    #         print(f"Transformation: {col.get('Transformation', '')}\n")
    # else:
    #     print("No columns extracted. Please check the SQL query and ensure it is correct.")

if __name__ == "__main__":
    main()


Enter your SQL query:
with  source as (      select * from jaffle_shop.raw.raw_products  ),  renamed as (      select          sku as product_id,          name as product_name,         type as product_type,         description as product_description,          (price / 100)::numeric(16, 2) as product_price,          coalesce(type = 'jaffle', false) as is_food_item,          coalesce(type = 'beverage', false) as is_drink_item      from source  )  select * from renamed
Select(
  expressions=[
    Star()],
  from=From(
    this=Table(
      this=Identifier(this=renamed, quoted=False))),
  with=With(
    expressions=[
      CTE(
        this=Select(
          expressions=[
            Star()],
          from=From(
            this=Table(
              this=Identifier(this=raw_products, quoted=False),
              db=Identifier(this=raw, quoted=False),
              catalog=Identifier(this=jaffle_shop, quoted=False)))),
        alias=TableAlias(
          this=Identifier(this=source, quoted

## Query Type 1

In [6]:
import sqlglot
from sqlglot import expressions as exp

def get_operator_symbol(node):
    operator_map = {
        exp.EQ: '=', exp.NEQ: '!=', exp.GT: '>', exp.GTE: '>=',
        exp.LT: '<', exp.LTE: '<=', exp.Add: '+', exp.Sub: '-',
        exp.Mul: '*', exp.Div: '/',
    }
    return operator_map.get(type(node), 'UNKNOWN_OPERATOR')

def trace_column_lineage(column, table, cte_definitions, visited=None):
    if visited is None:
        visited = set()

    if table in visited:
        return [], table, '' # Added '' to return 3 values

    visited.add(table)

    if table not in cte_definitions:
        return [column], table, column

    cte_info = cte_definitions[table]
    if column in cte_info:
        source_columns = cte_info[column]['source_columns']
        source_table = cte_info[column]['source_table']
        transformation = cte_info[column]['transformation']
    elif '*' in cte_info:
        source_columns = [column]
        source_table = cte_info['*']['source_table']
        transformation = column
    else:
        return [column], table, column # This already returns 3 values

    final_columns = []
    final_tables = set()
    final_transformations = []

    for src_col in source_columns:
        traced_columns, traced_table, traced_transformation = trace_column_lineage(src_col, source_table, cte_definitions, visited)
        final_columns.extend(traced_columns)
        final_tables.add(traced_table)
        final_transformations.append(traced_transformation)

    return final_columns, ', '.join(final_tables), transformation # This already returns 3 values

def extract_source_columns_and_transformation(node, table_alias_map, cte_definitions):
    if isinstance(node, exp.Column):
        source_column = node.name
        table_name = node.table if node.table else "unknown"
        real_table_name = table_alias_map.get(table_name, table_name)

        source_columns, source_table, transformation = trace_column_lineage(source_column, real_table_name, cte_definitions)
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Literal):
        return [], str(node.this), "constant"

    elif isinstance(node, exp.Paren):
        source_columns, transformation, table_name = extract_source_columns_and_transformation(node.this, table_alias_map, cte_definitions)
        return source_columns, f"({transformation})", table_name

    elif isinstance(node, (exp.EQ, exp.NEQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.Add, exp.Sub, exp.Mul, exp.Div)):
        left_columns, left_transformation, left_table = extract_source_columns_and_transformation(node.left, table_alias_map, cte_definitions)
        right_columns, right_transformation, right_table = extract_source_columns_and_transformation(node.right, table_alias_map, cte_definitions)
        operator = get_operator_symbol(node)
        return left_columns + right_columns, f"{left_transformation} {operator} {right_transformation}", f"{left_table}, {right_table}"

    elif isinstance(node, (exp.Count, exp.Sum, exp.Min, exp.Max)):
        source_columns, inner_transformation, source_table = extract_source_columns_and_transformation(node.this, table_alias_map, cte_definitions)
        func_name = node.__class__.__name__.upper()
        transformation = f"{func_name}({inner_transformation})"
        if node.args.get("distinct"):
            transformation = f"{func_name}(DISTINCT {inner_transformation})"
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Case):
        case_conditions = []
        source_columns = []
        source_tables = set()
        for when_clause in node.args.get("ifs", []):
            cond_columns, cond_transformation, cond_table = extract_source_columns_and_transformation(when_clause.this, table_alias_map, cte_definitions)
            true_columns, true_transformation, true_table = extract_source_columns_and_transformation(when_clause.args['true'], table_alias_map, cte_definitions)
            source_columns.extend(cond_columns + true_columns)
            source_tables.update([cond_table, true_table])
            case_conditions.append(f"WHEN {cond_transformation} THEN {true_transformation}")

        else_expr = node.args.get('default')
        if else_expr:
            else_columns, else_transformation, else_table = extract_source_columns_and_transformation(else_expr, table_alias_map, cte_definitions)
            source_columns.extend(else_columns)
            source_tables.add(else_table)
            default_case = f"ELSE {else_transformation}"
        else:
            default_case = ""

        transformation = f"CASE {' '.join(case_conditions)} {default_case} END"
        return source_columns, transformation, ", ".join(source_tables)

    else:
        return [], str(node), "unknown"

def process_cte(cte, table_alias_map, cte_definitions):
    cte_name = cte.alias
    cte_columns = {}

    if isinstance(cte.this, exp.Select):
        from_clause = cte.this.args.get("from")
        if from_clause and isinstance(from_clause.this, exp.Table):
            source_table = from_clause.this.name
            if from_clause.this.db:
                source_table = f"{from_clause.this.db}.{source_table}"
            if from_clause.this.catalog:
                source_table = f"{from_clause.this.catalog}.{source_table}"
            table_alias_map[cte_name] = source_table

        for select_expr in cte.this.expressions:
            if isinstance(select_expr, exp.Star):
                cte_columns['*'] = {
                    'source_columns': ['*'],
                    'transformation': '*',
                    'source_table': source_table
                }
            elif isinstance(select_expr, exp.Alias):
                alias_name = select_expr.alias
                source_columns, transformation, source_table = extract_source_columns_and_transformation(select_expr.this, table_alias_map, cte_definitions)
                cte_columns[alias_name] = {
                    'source_columns': source_columns,
                    'transformation': transformation,
                    'source_table': source_table
                }
            elif isinstance(select_expr, exp.Column):
                column_name = select_expr.name
                source_columns, transformation, source_table = extract_source_columns_and_transformation(select_expr, table_alias_map, cte_definitions)
                cte_columns[column_name] = {
                    'source_columns': source_columns,
                    'transformation': transformation,
                    'source_table': source_table
                }

    cte_definitions[cte_name] = cte_columns

def process_with_and_select(ast):
    table_alias_map = {}
    cte_definitions = {}
    final_columns = []

    # Process CTEs
    with_clause = ast.args.get("with")
    if with_clause:
        for cte in with_clause.expressions:
            process_cte(cte, table_alias_map, cte_definitions)

    # Process main SELECT
    if isinstance(ast, exp.Select):
        from_clause = ast.args.get("from")
        if from_clause and isinstance(from_clause.this, exp.Table):
            main_table = from_clause.this.name
            table_alias_map[main_table] = main_table

        for select_expr in ast.expressions:
            if isinstance(select_expr, exp.Star):
                if main_table in cte_definitions:
                    for col, info in cte_definitions[main_table].items():
                        source_columns, source_table, transformation = trace_column_lineage(col, main_table, cte_definitions)
                        final_columns.append({
                            "Final Column": col,
                            "Source Table": source_table,
                            "Source Columns": source_columns,
                            "Transformation": transformation
                        })
                else:
                    final_columns.append({
                        "Final Column": "*",
                        "Source Table": main_table,
                        "Source Columns": ["*"],
                        "Transformation": "Select all columns"
                    })
            elif isinstance(select_expr, exp.Alias):
                alias_name = select_expr.alias
                source_columns, transformation, source_table = extract_source_columns_and_transformation(select_expr.this, table_alias_map, cte_definitions)
                final_columns.append({
                    "Final Column": alias_name,
                    "Source Table": source_table,
                    "Source Columns": source_columns,
                    "Transformation": transformation
                })
            elif isinstance(select_expr, exp.Column):
                column_name = select_expr.name
                source_columns, transformation, source_table = extract_source_columns_and_transformation(select_expr, table_alias_map, cte_definitions)
                final_columns.append({
                    "Final Column": column_name,
                    "Source Table": source_table,
                    "Source Columns": source_columns,
                    "Transformation": transformation
                })

    return final_columns

def main():
    # Your SQL query here
    sql_query = """
    with
    customers as (
        select * from jaffle_shop.dbt_jmarwaha.stg_customers
    ),
    orders as (
        select * from jaffle_shop.dbt_jmarwaha.orders
    ),
    customer_orders_summary as (
        select
            orders.customer_id,
            count(distinct orders.order_id) as count_lifetime_orders,
            count(distinct orders.order_id) > 1 as is_repeat_buyer,
            min(orders.ordered_at) as first_ordered_at,
            max(orders.ordered_at) as last_ordered_at,
            sum(orders.subtotal) as lifetime_spend_pretax,
            sum(orders.tax_paid) as lifetime_tax_paid,
            sum(orders.order_total) as lifetime_spend
        from orders
        group by 1
    ),
    joined as (
        select
            customers.*,
            customer_orders_summary.count_lifetime_orders,
            customer_orders_summary.first_ordered_at,
            customer_orders_summary.last_ordered_at,
            customer_orders_summary.lifetime_spend_pretax,
            customer_orders_summary.lifetime_tax_paid,
            customer_orders_summary.lifetime_spend,
            case
                when customer_orders_summary.is_repeat_buyer then 'returning'
                else 'new'
            end as customer_type
        from customers
        left join customer_orders_summary
            on customers.customer_id = customer_orders_summary.customer_id
    )
    select * from joined
    """

    try:
        ast = sqlglot.parse_one(sql_query)
        print("SQL parsed successfully.\n")
    except Exception as e:
        print(f"Error parsing SQL: {e}")
        return

    final_columns = process_with_and_select(ast)
    #print(repr(ast))

    if final_columns:
        print("Final columns and their source information:")
        for col in final_columns:
            print(f"Final Column: {col.get('Final Column', 'Unknown')}")
            print(f"Source Table: {col.get('Source Table', 'Unknown')}")
            print(f"Source Columns: {', '.join(col.get('Source Columns', ['Unknown']))}")
            print(f"Transformation: {col.get('Transformation', 'Unknown')}\n")
    else:
        print("No final columns found or SQL does not contain valid select expressions.")

if __name__ == "__main__":
    main()

SQL parsed successfully.

Final columns and their source information:
Final Column: *
Source Table: jaffle_shop.dbt_jmarwaha.stg_customers
Source Columns: *
Transformation: *

Final Column: count_lifetime_orders
Source Table: jaffle_shop.dbt_jmarwaha.orders
Source Columns: count_lifetime_orders
Transformation: count_lifetime_orders

Final Column: first_ordered_at
Source Table: jaffle_shop.dbt_jmarwaha.orders
Source Columns: first_ordered_at
Transformation: first_ordered_at

Final Column: last_ordered_at
Source Table: jaffle_shop.dbt_jmarwaha.orders
Source Columns: last_ordered_at
Transformation: last_ordered_at

Final Column: lifetime_spend_pretax
Source Table: jaffle_shop.dbt_jmarwaha.orders
Source Columns: lifetime_spend_pretax
Transformation: lifetime_spend_pretax

Final Column: lifetime_tax_paid
Source Table: jaffle_shop.dbt_jmarwaha.orders
Source Columns: lifetime_tax_paid
Transformation: lifetime_tax_paid

Final Column: lifetime_spend
Source Table: jaffle_shop.dbt_jmarwaha.orders

## Query Type 2

In [122]:
import sqlglot
from sqlglot import expressions as exp

def extract_source_columns_and_transformation(node, table_name="unknown", cte_columns=None):
    if isinstance(node, exp.Column):
        source_column = node.name  # Column name
        table_name = node.table if node.table else table_name
        # If it's referencing a CTE, we need to get the actual base table
        if cte_columns and table_name in cte_columns:
            table_name = cte_columns[table_name][0]["Source Table"]
        return [source_column], source_column, table_name
    elif isinstance(node, exp.Identifier):
        return [node.name], node.name, table_name
    elif isinstance(node, exp.Literal):
        return [], f"'{node.this}'" if node.is_string else str(node.this), table_name
    elif isinstance(node, exp.Boolean):
        return [], str(node.this).lower(), table_name
    elif isinstance(node, exp.Paren):
        source_columns, transformation, table_name = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
        return source_columns, f"({transformation})", table_name
    elif isinstance(node, exp.Div):
        left_columns, left_transformation, left_table = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
        right_columns, right_transformation, right_table = extract_source_columns_and_transformation(node.expression, table_name, cte_columns)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} / {right_transformation}"
        return source_columns, transformation, left_table if left_table != "unknown" else right_table
    elif isinstance(node, exp.Cast):
        source_columns, inner_transformation, source_table = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
        cast_type = node.to.sql()
        transformation = f"{inner_transformation}::{cast_type}"
        return source_columns, transformation, source_table
    elif isinstance(node, exp.Coalesce):
        coalesce_args = []
        source_columns = []
        if node.this:
            arg_columns, arg_transformation, arg_table = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
            table_name = arg_table
        for arg in node.expressions:
            arg_columns, arg_transformation, arg_table = extract_source_columns_and_transformation(arg, table_name, cte_columns)
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
            if table_name == "unknown":
                table_name = arg_table
        transformation = f"COALESCE({', '.join(coalesce_args)})"
        return source_columns, transformation, table_name
    elif isinstance(node, exp.EQ):
        left_columns, left_transformation, left_table = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
        right_columns, right_transformation, right_table = extract_source_columns_and_transformation(node.expression, table_name, cte_columns)
        source_columns = left_columns + right_columns
        transformation = f"{left_transformation} = {right_transformation}"
        return source_columns, transformation, left_table if left_table != "unknown" else right_table
    elif isinstance(node, exp.Join):
        # Handle JOIN scenarios by recursively extracting the source columns from both sides of the join
        left_columns, left_transformation, left_table = extract_source_columns_and_transformation(node.this, table_name, cte_columns)
        right_columns, right_transformation, right_table = extract_source_columns_and_transformation(node.expression, table_name, cte_columns)
        source_columns = left_columns + right_columns
        transformation = f"JOIN {left_transformation} ON {right_transformation}"
        return source_columns, transformation, left_table if left_table != "unknown" else right_table
    else:
        return [], "", table_name

def extract_final_columns(ast, cte_columns, table_name=None):
    final_columns = []

    # Process the WITH clause if present to extract columns from CTEs
    if isinstance(ast, exp.With):
        for cte in ast.expressions:
            if isinstance(cte, exp.CTE) and isinstance(cte.this, exp.Select):
                cte_name = cte.alias_or_name
                cte_columns[cte_name] = []

                # Look for the real table in the CTE's FROM clause
                from_clause = cte.this.args.get("from")
                real_table_name = None
                if from_clause and isinstance(from_clause.this, exp.Table):
                    real_table_name = from_clause.this.name  # Dynamically get the actual base table name

                for select_exp in cte.this.expressions:
                    if isinstance(select_exp, exp.Alias):
                        alias_name = select_exp.alias_or_name
                        source_columns, transformation, source_table = extract_source_columns_and_transformation(select_exp.this, real_table_name, cte_columns)
                        source_columns = list(set(source_columns))
                        cte_columns[cte_name].append({
                            "Final Column": alias_name,
                            "Source Table": source_table if source_table else real_table_name,
                            "Source Columns": ', '.join(source_columns),
                            "Transformation": transformation
                        })
                    elif isinstance(select_exp, exp.Star):
                        # Handle 'select *' by indicating all columns are selected
                        cte_columns[cte_name].append({
                            "Final Column": "* (all columns)",
                            "Source Table": real_table_name,
                            "Source Columns": "*",
                            "Transformation": "Select all columns"
                        })

    # Process the main SELECT statement to determine final columns
    if isinstance(ast, exp.Select):
        from_clause = ast.args.get("from")
        if from_clause and isinstance(from_clause.this, exp.Table):
            table_name = from_clause.this.name

            # If the table in the main SELECT query matches a CTE, get its columns
            if table_name in cte_columns:
                final_columns.extend(cte_columns[table_name])
            else:
                # If not a CTE, assume it's a regular table
                for select_exp in ast.expressions:
                    if isinstance(select_exp, exp.Alias):
                        alias_name = select_exp.alias_or_name
                        source_columns, transformation, source_table = extract_source_columns_and_transformation(select_exp.this, table_name, cte_columns)
                        source_columns = list(set(source_columns))
                        final_columns.append({
                            "Final Column": alias_name,
                            "Source Table": source_table if source_table else table_name,
                            "Source Columns": ', '.join(source_columns),
                            "Transformation": transformation
                        })
                    elif isinstance(select_exp, exp.Star):
                        final_columns.append({
                            "Final Column": "* (all columns)",
                            "Source Table": table_name,
                            "Source Columns": "*",
                            "Transformation": "Select all columns"
                        })

    return final_columns


def process_with_and_select(ast):
    final_columns = []
    cte_columns = {}

    # If there's a WITH clause, process it to extract the CTEs
    with_clause = ast.args.get("with")
    if with_clause:
        extract_final_columns(with_clause, cte_columns)

    # Process the final SELECT statement
    final_columns.extend(extract_final_columns(ast, cte_columns))

    # Handle case when * is selected but we need to expand it from a CTE
    if any(col["Final Column"] == "* (all columns)" for col in final_columns):
        final_columns = [col for col in final_columns if col["Final Column"] != "* (all columns)"]
        # Expand the '*' using the CTE columns if available
        from_clause = ast.args.get("from")
        if from_clause and isinstance(from_clause.this, exp.Table):
            table_name = from_clause.this.name
            if table_name in cte_columns:
                final_columns.extend(cte_columns[table_name])

    return final_columns


def main():
    # Input SQL query
    sql_query = """
    with
    source as (
        select * from jaffle_shop.raw.raw_products
    ),
    renamed as (
        select
            sku as product_id,
            name as product_name,
            type as product_type,
            description as product_description,
            (price / 100)::numeric(16, 2) as product_price,
            coalesce(type = 'jaffle', false) as is_food_item,
            coalesce(type = 'beverage', false) as is_drink_item
        from source
    )
    select * from renamed
    """

    # Parse the SQL query to get the AST
    try:
        ast = sqlglot.parse_one(sql_query)
        print("SQL parsed successfully.\n")
    except Exception as e:
        print(f"Error parsing SQL: {e}")
        return

    # Process the AST to extract the final columns
    final_columns = process_with_and_select(ast)

    # Output the final columns with their source table and transformations
    if final_columns:
        print("Final columns and their source information:")
        for col in final_columns:
            print(f"Final Column: {col.get('Final Column', '')}")
            print(f"Source Table: {col.get('Source Table', '')}")
            print(f"Source Columns: {col.get('Source Columns', '')}")
            print(f"Transformation: {col.get('Transformation', '')}\n")
    else:
        print("No final columns found or SQL does not contain valid select expressions.")

if __name__ == "__main__":
    main()


SQL parsed successfully.

Final columns and their source information:
Final Column: product_id
Source Table: raw_products
Source Columns: sku
Transformation: sku

Final Column: product_name
Source Table: raw_products
Source Columns: name
Transformation: name

Final Column: product_type
Source Table: raw_products
Source Columns: type
Transformation: type

Final Column: product_description
Source Table: raw_products
Source Columns: description
Transformation: description

Final Column: product_price
Source Table: raw_products
Source Columns: price
Transformation: (price / 100)::DECIMAL(16, 2)

Final Column: is_food_item
Source Table: raw_products
Source Columns: type
Transformation: COALESCE(type = 'jaffle', false)

Final Column: is_drink_item
Source Table: raw_products
Source Columns: type
Transformation: COALESCE(type = 'beverage', false)



## Final Code for both queries

In [23]:
import sqlglot
from sqlglot import expressions as exp

def get_operator_symbol(node):
    operator_map = {
        exp.EQ: '=', exp.NEQ: '!=', exp.GT: '>', exp.GTE: '>=',
        exp.LT: '<', exp.LTE: '<=', exp.Add: '+', exp.Sub: '-',
        exp.Mul: '*', exp.Div: '/',
    }
    return operator_map.get(type(node), 'UNKNOWN_OPERATOR')

def trace_column_lineage(column, table, cte_definitions, visited=None):
    if visited is None:
        visited = set()

    if table in visited:
        return [], table, ''  # Keep table as a string

    visited.add(table)

    if table not in cte_definitions:
        return [column], table, column  # Keep table as a string

    cte_info = cte_definitions[table]
    if column in cte_info:
        source_columns = cte_info[column]['source_columns']
        source_table = cte_info[column]['source_table']
        transformation = cte_info[column]['transformation']
    elif '*' in cte_info:
        source_columns = [column]
        source_table = cte_info['*']['source_table']
        transformation = column
    else:
        return [column], table, column  # Keep table as a string

    final_columns = []
    final_tables = set()
    final_transformations = []

    for src_col in source_columns:
        traced_columns, traced_table, traced_transformation = trace_column_lineage(src_col, source_table, cte_definitions, visited)
        final_columns.extend(traced_columns)
        final_tables.add(traced_table)
        final_transformations.append(traced_transformation)

    return final_columns, ', '.join(final_tables), transformation  # Keep table as a string

def extract_source_columns_and_transformation(node, table_alias_map, cte_definitions, current_table=None):
    if isinstance(node, exp.Column):
        source_column = node.name
        table_name = node.table if node.table else current_table
        if table_name is None:
            table_name = "unknown"
        real_table_name = table_alias_map.get(table_name, table_name)

        source_columns, source_table, transformation = trace_column_lineage(source_column, real_table_name, cte_definitions)
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Identifier):
        return [node.name], node.name, current_table or "unknown"

    elif isinstance(node, exp.Boolean):
        return [], str(node.this).lower(), "constant"

    elif isinstance(node, exp.Literal):
        return [], str(node.this), "constant"

    elif isinstance(node, exp.Paren):
        source_columns, transformation, source_table = extract_source_columns_and_transformation(
            node.this, table_alias_map, cte_definitions, current_table=current_table
        )
        return source_columns, f"({transformation})", source_table

    elif isinstance(node, (exp.EQ, exp.NEQ, exp.GT, exp.GTE, exp.LT, exp.LTE, exp.Add, exp.Sub, exp.Mul, exp.Div)):
        left_columns, left_transformation, left_table = extract_source_columns_and_transformation(
            node.left, table_alias_map, cte_definitions, current_table=current_table
        )
        right_columns, right_transformation, right_table = extract_source_columns_and_transformation(
            node.right, table_alias_map, cte_definitions, current_table=current_table
        )
        operator = get_operator_symbol(node)
        source_tables = set(filter(lambda x: x not in ["unknown", "constant"], [left_table, right_table]))
        source_table = ', '.join(source_tables) if source_tables else current_table or "unknown"
        return left_columns + right_columns, f"{left_transformation} {operator} {right_transformation}", source_table

    elif isinstance(node, (exp.Count, exp.Sum, exp.Min, exp.Max)):
        source_columns, inner_transformation, source_table = extract_source_columns_and_transformation(
            node.this, table_alias_map, cte_definitions, current_table=current_table
        )
        func_name = node.__class__.__name__.upper()
        transformation = f"{func_name}({inner_transformation})"
        if node.args.get("distinct"):
            transformation = f"{func_name}(DISTINCT {inner_transformation})"
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Case):
        case_conditions = []
        source_columns = []
        source_tables = set()
        for when_clause in node.args.get("ifs", []):
            cond_columns, cond_transformation, cond_table = extract_source_columns_and_transformation(
                when_clause.this, table_alias_map, cte_definitions, current_table=current_table
            )
            true_columns, true_transformation, true_table = extract_source_columns_and_transformation(
                when_clause.args['true'], table_alias_map, cte_definitions, current_table=current_table
            )
            source_columns.extend(cond_columns + true_columns)
            source_tables.update([cond_table, true_table])
            case_conditions.append(f"WHEN {cond_transformation} THEN {true_transformation}")

        else_expr = node.args.get('default')
        if else_expr:
            else_columns, else_transformation, else_table = extract_source_columns_and_transformation(
                else_expr, table_alias_map, cte_definitions, current_table=current_table
            )
            source_columns.extend(else_columns)
            source_tables.add(else_table)
            default_case = f"ELSE {else_transformation}"
        else:
            default_case = ""

        transformation = f"CASE {' '.join(case_conditions)} {default_case} END"
        source_tables.discard("unknown")
        source_tables.discard("constant")
        source_table = ', '.join(source_tables) if source_tables else current_table or "unknown"
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Cast):
        source_columns, inner_transformation, source_table = extract_source_columns_and_transformation(
            node.this, table_alias_map, cte_definitions, current_table=current_table
        )
        cast_type = node.args.get('to')
        if cast_type:
            cast_type_str = cast_type.sql()
        else:
            cast_type_str = "UNKNOWN_TYPE"
        transformation = f"{inner_transformation}::{cast_type_str}"
        return source_columns, transformation, source_table

    elif isinstance(node, exp.Coalesce):
        coalesce_args = []
        source_columns = []
        source_tables = set()

        # Process node.this if it exists
        if node.this:
            arg_columns, arg_transformation, arg_table = extract_source_columns_and_transformation(
                node.this, table_alias_map, cte_definitions, current_table=current_table
            )
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
            source_tables.add(arg_table)

        # Process node.expressions
        for arg in node.expressions:
            arg_columns, arg_transformation, arg_table = extract_source_columns_and_transformation(
                arg, table_alias_map, cte_definitions, current_table=current_table
            )
            source_columns.extend(arg_columns)
            coalesce_args.append(arg_transformation)
            source_tables.add(arg_table)

        transformation = f"COALESCE({', '.join(coalesce_args)})"
        source_tables.discard("unknown")
        source_tables.discard("constant")
        source_table = ', '.join(source_tables) if source_tables else current_table or "unknown"
        return source_columns, transformation, source_table

    else:
        return [], str(node), current_table or "unknown"

def process_cte(cte, table_alias_map, cte_definitions):
    cte_name = cte.alias
    cte_columns = {}

    if isinstance(cte.this, exp.Select):
        from_clause = cte.this.args.get("from")
        source_table = "unknown"
        if from_clause:
            if isinstance(from_clause.this, exp.Table):
                source_table = from_clause.this.name
                if from_clause.this.db:
                    source_table = f"{from_clause.this.db}.{source_table}"
                if from_clause.this.catalog:
                    source_table = f"{from_clause.this.catalog}.{source_table}"
            elif isinstance(from_clause.this, exp.Subquery):
                # Handle subqueries in FROM clause
                subquery_alias = from_clause.alias_or_name
                process_cte(from_clause, table_alias_map, cte_definitions)
                source_table = subquery_alias
            table_alias_map[cte_name] = source_table

        for select_expr in cte.this.expressions:
            if isinstance(select_expr, exp.Star):
                cte_columns['*'] = {
                    'source_columns': ['*'],
                    'transformation': '*',
                    'source_table': source_table
                }
            elif isinstance(select_expr, exp.Alias):
                alias_name = select_expr.alias
                source_columns, transformation, source_table_expr = extract_source_columns_and_transformation(
                    select_expr.this, table_alias_map, cte_definitions, current_table=source_table
                )
                cte_columns[alias_name] = {
                    'source_columns': source_columns,
                    'transformation': transformation,
                    'source_table': source_table_expr
                }
            elif isinstance(select_expr, exp.Column):
                column_name = select_expr.name
                source_columns, transformation, source_table_expr = extract_source_columns_and_transformation(
                    select_expr, table_alias_map, cte_definitions, current_table=source_table
                )
                cte_columns[column_name] = {
                    'source_columns': source_columns,
                    'transformation': transformation,
                    'source_table': source_table_expr
                }

    cte_definitions[cte_name] = cte_columns

def process_with_and_select(ast):
    table_alias_map = {}
    cte_definitions = {}
    final_columns = []

    # Process CTEs
    with_clause = ast.args.get("with")
    if with_clause:
        for cte in with_clause.expressions:
            process_cte(cte, table_alias_map, cte_definitions)

    # Process main SELECT
    if isinstance(ast, exp.Select):
        from_clause = ast.args.get("from")
        if from_clause and isinstance(from_clause.this, exp.Table):
            main_table = from_clause.this.name
            table_alias_map[main_table] = main_table

        elif from_clause and isinstance(from_clause.this, exp.Subquery):
            # Handle subquery in FROM clause
            subquery_alias = from_clause.alias_or_name
            process_cte(from_clause, table_alias_map, cte_definitions)
            main_table = subquery_alias

        else:
            main_table = "unknown"

        current_table = main_table

        for select_expr in ast.expressions:
            if isinstance(select_expr, exp.Star):
                if main_table in cte_definitions:
                    for col, info in cte_definitions[main_table].items():
                        source_columns, source_table, transformation = trace_column_lineage(col, main_table, cte_definitions)
                        final_columns.append({
                            "Final Column": col,
                            "Source Table": source_table,
                            "Source Columns": source_columns,
                            "Transformation": transformation
                        })
                else:
                    final_columns.append({
                        "Final Column": "*",
                        "Source Table": main_table,
                        "Source Columns": ["*"],
                        "Transformation": "Select all columns"
                    })
            elif isinstance(select_expr, exp.Alias):
                alias_name = select_expr.alias
                source_columns, transformation, source_table_expr = extract_source_columns_and_transformation(
                    select_expr.this, table_alias_map, cte_definitions, current_table=current_table
                )
                final_columns.append({
                    "Final Column": alias_name,
                    "Source Table": source_table_expr,
                    "Source Columns": source_columns,
                    "Transformation": transformation
                })
            elif isinstance(select_expr, exp.Column):
                column_name = select_expr.name
                source_columns, transformation, source_table_expr = extract_source_columns_and_transformation(
                    select_expr, table_alias_map, cte_definitions, current_table=current_table
                )
                final_columns.append({
                    "Final Column": column_name,
                    "Source Table": source_table_expr,
                    "Source Columns": source_columns,
                    "Transformation": transformation
                })

    return final_columns

def main():
    # Prompt the user to enter the SQL query
    sql_query = input("Please enter your SQL query:\n")

    print("\nProcessing SQL Query:\n")
    try:
        ast = sqlglot.parse_one(sql_query)
        print("SQL parsed successfully.\n")
    except Exception as e:
        print(f"Error parsing SQL: {e}")
        return

    final_columns = process_with_and_select(ast)

    if final_columns:
        print("Final columns and their source information:")
        for col in final_columns:
            print(f"Final Column: {col.get('Final Column', 'Unknown')}")
            print(f"Source Table: {col.get('Source Table', 'Unknown')}")
            print(f"Source Columns: {', '.join(col.get('Source Columns', ['Unknown']))}")
            print(f"Transformation: {col.get('Transformation', 'Unknown')}\n")
    else:
        print("No final columns found or SQL does not contain valid select expressions.")

if __name__ == "__main__":
    main()


Please enter your SQL query:
with  customers as (      select * from jaffle_shop.dbt_jmarwaha.stg_customers  ),  orders as (      select * from jaffle_shop.dbt_jmarwaha.orders  ),  customer_orders_summary as (      select         orders.customer_id,          count(distinct orders.order_id) as count_lifetime_orders,         count(distinct orders.order_id) > 1 as is_repeat_buyer,         min(orders.ordered_at) as first_ordered_at,         max(orders.ordered_at) as last_ordered_at,         sum(orders.subtotal) as lifetime_spend_pretax,         sum(orders.tax_paid) as lifetime_tax_paid,         sum(orders.order_total) as lifetime_spend      from orders      group by 1  ),  joined as (      select         customers.*,          customer_orders_summary.count_lifetime_orders,         customer_orders_summary.first_ordered_at,         customer_orders_summary.last_ordered_at,         customer_orders_summary.lifetime_spend_pretax,         customer_orders_summary.lifetime_tax_paid,         customer