In [None]:
import snowflake.snowpark as sp
from sqlglot import parse_one, exp
import re


# Use the provided session instead of creating a new one
def get_snowflake_session():
    # Return the active session (provided by Snowflake environment)
    return sp.Session.builder.getOrCreate()



# Create the results table to store the processed data
def create_results_table(session):
    session.sql("""
        CREATE TABLE IF NOT EXISTS PROCESSED_SQL_DATA (
            DATABASE_NAME VARCHAR,
            SCHEMA_NAME VARCHAR,
            TABLE_NAME VARCHAR,
            QUERY_TEXT VARCHAR,
            EXPANDED_SQL VARCHAR
        )
    """).collect()

# Fetch SQL text from a Snowflake view
def fetch_query_text(session):
    result = session.sql("""
        SELECT DISTINCT table_name, query_text FROM VIEW_SNOWFLAKE_SQL_DATA
    """).collect()
    return [{'table_name': row['TABLE_NAME'], 'query_text': row['QUERY_TEXT']} for row in result]

# Extract SELECT part from DDL/DML
def extract_select_from_ddl(query_text):
    pattern = re.compile(r"\bAS\s*\(\s*", re.IGNORECASE)
    match = pattern.search(query_text)
    if not match:
        print("No 'AS (' pattern found in the query.")
        return None

    start_pos = match.end()
    end_pos = query_text.rfind(')')
    if end_pos == -1:
        print("No matching closing parenthesis found.")
        return None

    extracted_sql = query_text[start_pos:end_pos].strip()
    if not extracted_sql.lower().startswith(("select", "with")):
        print("Extracted SQL does not start with 'SELECT' or 'WITH'.")
        return None

    return extracted_sql

# Fetch table columns using Snowpark DataFrame
def get_table_columns(session, database, schema, table):
    query = f"""
        SELECT COLUMN_NAME
        FROM {database}.INFORMATION_SCHEMA.COLUMNS
        WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{table}'
        ORDER BY ORDINAL_POSITION
    """
    result = session.sql(query).collect()
    return [row['COLUMN_NAME'] for row in result]

# Generate expanded SQL
def generate_expanded_sql(original_sql, session):
    parsed = parse_one(original_sql)
    cte_columns = {}

    def get_fully_qualified_name(table_expression):
        parts = []
        if table_expression.args.get("catalog"):
            parts.append(table_expression.args["catalog"].name)
        if table_expression.args.get("db"):
            parts.append(table_expression.args["db"].name)
        if table_expression.args.get("this"):
            parts.append(table_expression.args["this"].name)
        return ".".join(parts)

    def replace_star_in_select(select_expr, current_columns, all_columns):
        new_expressions = []
        for projection in select_expr.expressions:
            if isinstance(projection, exp.Star):
                for col in current_columns:
                    new_expressions.append(exp.to_identifier(col))
            elif isinstance(projection, exp.Column) and projection.args.get('table'):
                table_name = projection.args['table'].name
                if projection.name == "*":
                    if table_name in all_columns:
                        for col in all_columns[table_name]:
                            new_expressions.append(exp.column(col, table=table_name))
                    else:
                        database, schema, table = split_table_name(table_name)
                        columns = get_table_columns(session, database, schema, table)
                        for col in columns:
                            new_expressions.append(exp.column(col, table=table_name))
                else:
                    new_expressions.append(projection)
            else:
                new_expressions.append(projection)
        select_expr.set("expressions", new_expressions)

    with_expression = parsed.args.get("with")
    if with_expression:
        for cte in with_expression.expressions:
            cte_name = cte.alias
            select_expr = cte.this
            from_expr = select_expr.args.get("from")
            from_table = from_expr.find(exp.Table)
            source_full = get_fully_qualified_name(from_table)

            if source_full in cte_columns:
                source_columns = cte_columns[source_full]
            else:
                database, schema, table = split_table_name(source_full)
                source_columns = get_table_columns(session, database, schema, table)

            replace_star_in_select(select_expr, source_columns, cte_columns)
            new_columns = [proj.sql() for proj in select_expr.expressions]
            cte_columns[cte_name] = new_columns

    final_select = parsed.find(exp.Select)
    from_expr = final_select.args.get("from")
    from_table = from_expr.find(exp.Table)
    source_full = get_fully_qualified_name(from_table)

    if source_full in cte_columns:
        final_columns = cte_columns[source_full]
    else:
        database, schema, table = split_table_name(source_full)
        final_columns = get_table_columns(session, database, schema, table)

    replace_star_in_select(final_select, final_columns, cte_columns)
    return parsed.sql(pretty=True).upper()

def split_table_name(full_name):
    parts = full_name.split('.')
    if len(parts) == 3:
        return parts
    elif len(parts) == 2:
        return os.getenv('database'), parts[0], parts[1]
    elif len(parts) == 1:
        return os.getenv('database'), os.getenv('schema'), parts[0]
    raise ValueError(f"Invalid table name: {full_name}")

def insert_processed_data(session, database, schema, table, query, expanded_sql):
    """
    Insert processed data into the PROCESSED_SQL_DATA table safely by treating 
    expanded SQL as a string using parameter binding.
    """
    # Prepare the insertion query with placeholders
    insert_query = """
        INSERT INTO PROCESSED_SQL_DATA 
        (DATABASE_NAME, SCHEMA_NAME, TABLE_NAME, QUERY_TEXT, EXPANDED_SQL)
        VALUES (?, ?, ?, ?, ?)
    """
    
    # Execute the query with parameterized binding
    session.sql(
        insert_query, 
        (database, schema, table, query, expanded_sql)
    ).collect()


def main():

    
    session = get_snowflake_session()
    result = session.sql("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA(), CURRENT_ROLE()").collect()
    print(result)
    create_results_table(session)
    

    sql_data = fetch_query_text(session)
    session.sql("TRUNCATE TABLE PROCESSED_SQL_DATA").collect()

    for record in sql_data:
        database, schema, table = split_table_name(record['table_name'])
        query_text = record['query_text']
        extracted_sql = extract_select_from_ddl(query_text)

        if extracted_sql:
            expanded_sql = generate_expanded_sql(extracted_sql, session)
            insert_processed_data(session, database, schema, table, query_text, expanded_sql)

if __name__ == "__main__":
    main()


In [None]:
from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import call_builtin
import json
import re
import logging

# Configure logging for better debugging and visibility
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Get the active Snowpark session (in a Snowflake notebook, the session is usually already available)
session = Session.builder.getOrCreate()

# Define the table names (replace with your actual table names if different)
TABLE_SCHEMA_REF = 'JAFFLE_LINEAGE.LINEAGE_DATA.PROCESSED_SQL_DATA'
COLUMN_LINEAGE_CORTEX = 'JAFFLE_LINEAGE.LINEAGE_DATA.COLUMN_LINEAGE_CORTEX_NEW'

# Fetch distinct combinations of DATABASE, SCHEMA, TABLE_NAME, COLUMN_NAME, EXPANDED_SQL
df = session.sql(f"""
    SELECT DISTINCT DATABASE_NAME, SCHEMA_NAME, TABLE_NAME, EXPANDED_SQL
    FROM {TABLE_SCHEMA_REF}
    WHERE EXPANDED_SQL IS NOT NULL
""")

rows = df.collect()

for row in rows:
    input_record = row.as_dict()
    sql_query = input_record['EXPANDED_SQL']
    
    composite_key = {
        'DATABASE_NAME': input_record['DATABASE_NAME'],
        'SCHEMA_NAME': input_record['SCHEMA_NAME'],
        'TABLE_NAME': input_record['TABLE_NAME']
    }
    
    # Check if the record already exists in the target table
    check_query = f"""
        SELECT EXPANDED_SQL
        FROM {COLUMN_LINEAGE_CORTEX}
        WHERE DATABASE_NAME = '{composite_key['DATABASE_NAME']}'
        AND SCHEMA_NAME = '{composite_key['SCHEMA_NAME']}'
        AND TABLE_NAME = '{composite_key['TABLE_NAME']}'
    """
    
    existing_record_df = session.sql(check_query)
    existing_records = existing_record_df.collect()
    
    if existing_records:
        existing_sql = existing_records[0]['EXPANDED_SQL']
        if existing_sql == sql_query:
            logging.info(f"SQL unchanged for {composite_key}. Skipping record.")
            print(f"SQL unchanged for {composite_key}. Skipping record.")
            continue  # Skip processing if the SQL is the same
        else:
            logging.info(f"SQL changed for {composite_key}. Deleting old records.")
            print(f"SQL changed for {composite_key}. Deleting old records.")
            # SQL has changed, so delete existing records with this composite key
            delete_query = f"""
                DELETE FROM {COLUMN_LINEAGE_CORTEX}
                WHERE DATABASE_NAME = '{composite_key['DATABASE_NAME']}'
                AND SCHEMA_NAME = '{composite_key['SCHEMA_NAME']}'
                AND TABLE_NAME = '{composite_key['TABLE_NAME']}'
            """
            try:
                session.sql(delete_query).collect()
            except Exception as e:
                logging.error(f"Error deleting existing records: {e}")
                continue  # Skip to next record
    else:
        logging.info(f"New record for {composite_key}. Processing.")
        print(f"New record for {composite_key}. Processing.")
    
    # Now process the SQL query using Cortex LLM to generate lineage information
    # Construct the prompt for the Cortex LLM
    prompt = (
        "You are an expert in SQL lineage analysis."
        "Given the following SQL query, identify the source tables and columns for each final column in the SELECT statement." 
        "If multiple columns are used to derive the final column, please mention all columns by seperating those out using commas, Similarly of these multiple columns are coming from multiple tables, mention those as well by seperating those out using commas"
        "Additionally, provide simple reasoning in business-friendly language explaining the transformation for each column. A non technical person should be able to understand"
        "Provide the results as a JSON array of objects, where each object has the keys: FINAL_COLUMN, SOURCE_TABLE, SOURCE_DATABASE, SOURCE_SCHEMA, SOURCE_COLUMNS, and REASONING."
        "SOURCE_TABLE should be returned as a table name only, it shouldnt contain scheme or database name"
        "SQL Query: " + sql_query
    )

    # Escape single quotes in the prompt
    prompt = prompt.replace("'", "''")

    # Call the Cortex LLM using the SNOWFLAKE.CORTEX.COMPLETE function
    try:
        lineage_response_df = session.sql(f"""
            SELECT SNOWFLAKE.CORTEX.COMPLETE(
                'mistral-large2',  -- You can change the model if needed
                '{prompt}'
            ) AS LINEAGE_RESPONSE
        """)
        lineage_response_row = lineage_response_df.collect()[0]
        lineage_response = lineage_response_row['LINEAGE_RESPONSE']
    except Exception as e:
        logging.error(f"Error calling Cortex LLM: {e}")
        print(f"Error calling Cortex LLM: {e}")
        continue  # Skip to the next SQL query

    # Try to extract the JSON data from the response
    try:
        # Use regex to extract JSON array from the response
        json_match = re.search(r'\[.*\]', lineage_response, re.DOTALL)
        if json_match:
            json_data = json_match.group(0)
            parsed_records = json.loads(json_data)
        else:
            # Try to parse the entire response if it is valid JSON
            parsed_records = json.loads(lineage_response)
    except json.JSONDecodeError as e:
        logging.error(f"Error parsing JSON response: {e}")
        logging.error(f"Response: {lineage_response}")
        print(f"Error parsing JSON response: {e}")
        continue  # Skip to the next SQL query

    # Insert the parsed records into the COLUMN_LINEAGE_CORTEX table
    for record in parsed_records:
        # Prepare the data for insertion
        insert_data = {
            'DATABASE_NAME': composite_key['DATABASE_NAME'],
            'SCHEMA_NAME': composite_key['SCHEMA_NAME'],
            'TABLE_NAME': composite_key['TABLE_NAME'],
            'REFERENCE': None,  # If REFERENCE is needed, you can fetch it from source if available
            'EXPANDED_SQL': sql_query,
            'FINAL_COLUMN': record.get('FINAL_COLUMN', 'Unknown'),
            'SOURCE_TABLE': record.get('SOURCE_TABLE', 'Unknown'),
            'SOURCE_DATABASE': record.get('SOURCE_DATABASE', 'Unknown'),
            'SOURCE_SCHEMA': record.get('SOURCE_SCHEMA', 'Unknown'),
            'SOURCE_COLUMNS': ', '.join(record.get('SOURCE_COLUMNS', [])) if isinstance(record.get('SOURCE_COLUMNS'), list) else record.get('SOURCE_COLUMNS', 'Unknown'),
            'REASONING': record.get('REASONING', 'Unknown')
        }

        # Convert the data to a DataFrame
        insert_df = session.create_dataframe([insert_data])

        # Write the DataFrame to the COLUMN_LINEAGE_CORTEX table
        try:
            insert_df.write.mode('append').save_as_table(COLUMN_LINEAGE_CORTEX)
            logging.info(f"Inserted/Updated record for FINAL_COLUMN: {insert_data['FINAL_COLUMN']}")
        except Exception as e:
            logging.error(f"Error inserting record into {COLUMN_LINEAGE_CORTEX}: {e}")
            logging.error(f"Record data: {json.dumps(insert_data)}")
            continue  # Skip to the next record

logging.info("Processing completed.")


In [None]:
import json
import pandas as pd
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.snowpark.exceptions import SnowparkSQLException

# Get the active Snowpark session (in a Snowflake notebook, the session is usually already available)
session = Session.builder.getOrCreate()

# Function to read data from Snowflake using Snowpark
def read_lineage_from_snowflake():
    df_lineage = session.table("COLUMN_LINEAGE_CORTEX_NEW").to_pandas()
    
    # Convert relevant columns to uppercase for case-insensitive matching
    for col_name in ['TABLE_NAME', 'FINAL_COLUMN', 'SOURCE_TABLE', 'SOURCE_COLUMNS', 
                     'DATABASE_NAME', 'SCHEMA_NAME', 'TRANSFORMATION']:
        df_lineage[col_name] = df_lineage[col_name].str.upper().str.strip()

    return df_lineage

# Function to write lineage records to Snowflake
def write_lineage_to_snowflake(lineage_records):
    df_records = pd.DataFrame(lineage_records)
    df_records = df_records.where(pd.notnull(df_records), None)  # Replace NaN with None

    table_name = 'COLUMN_LINEAGE_MERGED'

    # Truncate existing table if exists
    try:
        session.sql(f"TRUNCATE TABLE {table_name}").collect()
    except SnowparkSQLException:
        print(f"Table {table_name} does not exist. Creating it...")

    # Create the table if it doesn't exist
    session.sql(f"""
    CREATE TABLE IF NOT EXISTS {table_name} (
        SOURCE_TYPE VARCHAR(255),
        SOURCE_NAME VARCHAR(255),
        SOURCE_TABLE VARCHAR(255),
        SOURCE_DATABASE VARCHAR(255),
        TARGET_TYPE VARCHAR(255),
        TARGET_NAME VARCHAR(255),
        TARGET_TABLE VARCHAR(255),
        TARGET_DATABASE VARCHAR(255),
        TRANSFORMATION VARCHAR(16777216),
        WORKBOOK VARCHAR(255),
        DASHBOARD VARCHAR(255),
        DATASOURCE VARCHAR(255),
        SHEET VARCHAR(255)
    )
    """).collect()

    # Write records to Snowflake using Snowpark
    session.write_pandas(df_records, table_name, auto_create_table=False)

    print(f"Successfully wrote {len(lineage_records)} rows to {table_name}.")

# Function to process fields recursively and generate lineage records
def process_field(field, context, df_lineage, lineage_records, visited_nodes):
    field_name = field.get('name', '').strip()
    formula = field.get('formula', '').strip()
    context['TARGET_NAME'] = field_name
    context['TRANSFORMATION'] = formula
    context['TARGET_TYPE'] = 'Tableau Field'
    context['TARGET_TABLE'] = context.get('SHEET', '')
    context['TARGET_DATABASE'] = ''  # Not applicable for Tableau fields

    # Process upstreamFields recursively
    for upstream_field in field.get('upstreamFields', []):
        upstream_field_name = upstream_field.get('name', '').strip()
        lineage_record = {
            'SOURCE_TYPE': 'Tableau Field',
            'SOURCE_NAME': upstream_field_name,
            'SOURCE_TABLE': context.get('SHEET', ''),
            'SOURCE_DATABASE': '',
            'TARGET_TYPE': context['TARGET_TYPE'],
            'TARGET_NAME': context['TARGET_NAME'],
            'TARGET_TABLE': context['TARGET_TABLE'],
            'TARGET_DATABASE': '',
            'TRANSFORMATION': context['TRANSFORMATION'],
            'WORKBOOK': context.get('WORKBOOK', ''),
            'DASHBOARD': context.get('DASHBOARD', ''),
            'DATASOURCE': context.get('DATASOURCE', ''),
            'SHEET': context.get('SHEET', '')
        }
        lineage_records.append(lineage_record)
        # Recursive call for upstream fields
        process_field(upstream_field, context.copy(), df_lineage, lineage_records, visited_nodes)

    # Process upstreamColumns
    for upstream_column in field.get('upstreamColumns', []):
        column_name = upstream_column.get('name', '').strip().upper()
        # Process upstreamTables
        for upstream_table in upstream_column.get('upstreamTables', []):
            table_name = upstream_table.get('name', '').strip().upper()
            # Get DATABASE_NAME from upstreamDatabases if available
            databases = upstream_column.get('upstreamDatabases', [])
            database_name = databases[0].get('name', '').strip().upper() if databases else ''
            # Create a lineage record from database column to Tableau field
            lineage_record = {
                'SOURCE_TYPE': 'Database Column',
                'SOURCE_NAME': column_name,
                'SOURCE_TABLE': table_name,
                'SOURCE_DATABASE': database_name,
                'TARGET_TYPE': context['TARGET_TYPE'],
                'TARGET_NAME': context['TARGET_NAME'],
                'TARGET_TABLE': context['TARGET_TABLE'],
                'TARGET_DATABASE': '',
                'TRANSFORMATION': context['TRANSFORMATION'],
                'WORKBOOK': context.get('WORKBOOK', ''),
                'DASHBOARD': context.get('DASHBOARD', ''),
                'DATASOURCE': context.get('DATASOURCE', ''),
                'SHEET': context.get('SHEET', '')
            }
            lineage_records.append(lineage_record)
            # Process database lineage
            process_database_lineage(database_name, table_name, column_name, df_lineage, lineage_records, visited_nodes)

# Function to process database lineage recursively
def process_database_lineage(database_name, table_name, column_name, df_lineage, lineage_records, visited_nodes):
    node_id = f"{database_name}.{table_name}.{column_name}"
    if node_id in visited_nodes:
        # Avoid infinite recursion due to cycles
        return
    visited_nodes.add(node_id)

    matching_rows = df_lineage[
        (df_lineage['DATABASE_NAME'] == database_name) &
        (df_lineage['TABLE_NAME'] == table_name) &
        (df_lineage['FINAL_COLUMN'] == column_name)
    ]
    for _, row in matching_rows.iterrows():
        source_tables = row['SOURCE_TABLE'].split(',') if row['SOURCE_TABLE'] else []
        source_columns = row['SOURCE_COLUMNS'].split(',') if row['SOURCE_COLUMNS'] else []
        transformations = row.get('TRANSFORMATION', '')
        for src_table, src_column in zip(source_tables, source_columns):
            src_table = src_table.strip().upper()
            src_column = src_column.strip().upper()
            lineage_record = {
                'SOURCE_TYPE': 'Database Column',
                'SOURCE_NAME': src_column,
                'SOURCE_TABLE': src_table,
                'SOURCE_DATABASE': database_name,
                'TARGET_TYPE': 'Database Column',
                'TARGET_NAME': column_name,
                'TARGET_TABLE': table_name,
                'TARGET_DATABASE': database_name,
                'TRANSFORMATION': transformations,
                'WORKBOOK': '',
                'DASHBOARD': '',
                'DATASOURCE': '',
                'SHEET': ''
            }
            lineage_records.append(lineage_record)
            # Recursive call for further database lineage
            process_database_lineage(database_name, src_table, src_column, df_lineage, lineage_records, visited_nodes)

# Main function to process the entire Tableau lineage
def process_tableau_lineage(tableau_data, df_lineage):
    lineage_records = []
    visited_nodes = set()

    for workbook in tableau_data.get('workbooks', []):
        context = {'WORKBOOK': workbook.get('name', '').strip()}
        for dashboard in workbook.get('dashboards', []):
            context['DASHBOARD'] = dashboard.get('name', '').strip()
            for datasource in dashboard.get('upstreamDatasources', []):
                context['DATASOURCE'] = datasource.get('name', '').strip()
                for sheet in datasource.get('sheets', []):
                    context['SHEET'] = sheet.get('name', '').strip()
                    for field in sheet.get('upstreamFields', []):
                        process_field(field, context.copy(), df_lineage, lineage_records, visited_nodes)

    return lineage_records

# Function to process all database lineage and generate lineage records
def process_all_database_lineage(df_lineage):
    lineage_records = []
    visited_nodes = set()

    # Get unique database, table, column combinations
    unique_columns = df_lineage[['DATABASE_NAME', 'TABLE_NAME', 'FINAL_COLUMN']].drop_duplicates()

    for _, row in unique_columns.iterrows():
        database_name = row['DATABASE_NAME']
        table_name = row['TABLE_NAME']
        column_name = row['FINAL_COLUMN']
        process_database_lineage(database_name, table_name, column_name, df_lineage, lineage_records, visited_nodes)

    return lineage_records


# Main execution function
def main():
    # Load Tableau lineage
    with open('tableau_lineage.json', 'r') as f:
        tableau_data = json.load(f)

    # Read database lineage from Snowflake
    df_lineage = read_lineage_from_snowflake()

    # Process the Tableau lineage data
    tableau_lineage_records = process_tableau_lineage(tableau_data, df_lineage)

    # Process all database lineage data
    database_lineage_records = process_all_database_lineage(df_lineage)

    # Combine the two sets of lineage records
    lineage_records = tableau_lineage_records + database_lineage_records

    # Write the lineage records to Snowflake
    write_lineage_to_snowflake(lineage_records)

if __name__ == "__main__":
    main()
