In [1]:
from sqlglot import parse_one, exp
from sqlglot.dialects.tsql import TSQL
import pypyodbc as odbc
odbc.lowercase = False
import configparser
import copy
from collections import defaultdict
from collections import OrderedDict 
import pandas as pd
import pypyodbc as odbc
import configparser
import os
import json
import re
from modules.sql_parser.parse_lineages import *
from modules.sql_parser.parse_nodes import *

odbc.lowercase = False

In [2]:
# load data

queries = [] # queries content list
files = [] # file names list
for file in os.listdir("data/preprocessed-queries/gpt-sql-queries-extracted3"):
    files.append(file.split('.')[0])
    with open(f'data/preprocessed-queries/gpt-sql-queries-extracted3/{file}', 'r') as file:
        data = json.load(file)
    queries.append(data)

# load nodes dataset
nodes = pd.read_csv('data/output-tables/nodes.csv')
nodes.head()

Unnamed: 0.1,Unnamed: 0,NAME_NODE,LABEL_NODE,WHERE,FUNCTION,ON,COLOR,ID
0,0,subquery1_5,json_data0@subquery1_5,orders.status = 'Completed',subquery,,grey,0
1,1,orders,orders,,table,,grey,1
2,2,subquery1_4,json_data0@subquery1_4,,subquery,,grey,2
3,3,reviews,reviews,,table,,grey,3
4,4,subquery1_3,json_data0@subquery1_3,,subquery,['order_items.product_id = products.product_id...,grey,4


In [3]:
# extract lineages

lineages_dfs = []
trees = []

for i, file in enumerate(os.listdir("data/preprocessed-queries/gpt-sql-queries-extracted3")):
    filename = file.split('.')[0]
    print(filename)
    files.append(file.split('.')[0])
    with open(f'data/preprocessed-queries/gpt-sql-queries-extracted3/{file}', 'r') as file:
        data = json.load(file)

    queries.append(data)

    # reverse subqueries dict to start from deepest level
    query_subqueries = dict(reversed(list(queries[i]['subquery_dictionary'].items())))
    query_subqueries['main_query'] = queries[i]['modified_SQL_query']
    query_subqueries

    lineages = [] # list of dictionaries with the nodes

    for name_query in query_subqueries:

        query = query_subqueries[name_query]

        if query.startswith("("):
            query = query.strip("()")
        else:
            pass

        ast = parse_query(query) # get parsed tree

        if 'subquery' in name_query: # if the query is a subquery then the name is the dict key, else the name is the target table
            target_node = name_query
            target_columns =[]

        else:
            target_columns =[]

            try: # try with create table statement
                target_node = list(ast.find_all(exp.Create))[0].this.this.this
            except IndexError: # else try with insert into table statement
                target_node = list(ast.find_all(exp.Insert))[0].this.this

                insert_obj = list(ast.find_all(exp.Insert))[0]
                target_columns = list(insert_obj.find_all(exp.Column))
                target_columns = [[i] for i in target_columns]
             

        space_table = find_table_w_spaces(ast) # list with tables with spaces (sqlglot cant parse them)

        space_table = list(set(space_table)) # a list of tuples with table names paired (space removed original - original ) Eg. (OrderDetails, Order Details)

        alias_table = get_tables(ast) # parse table name + table alias

        tree = replace_aliases(query) # transform query by removing table aliases

        if target_columns == []:
            select_statement, target_columns = extract_target_columns(tree) # extract target columns
        else:
            select_statement, x = extract_target_columns(tree) # extract target columns


        replaced_trees = [x.transform(transformer_functions) for x in select_statement] # replace columns aliases
        trees.append(replaced_trees)

        # add possible transformation to columns
        transformations = extract_transformation(replaced_trees)
        target_columns = list(zip(target_columns, transformations)) 
        
        lineages = extract_source_target_transformation(target_columns, lineages, space_table, target_node) # append lineages of node to list


    lineages = pd.DataFrame(lineages)

    lineages = lineages.explode('SOURCE_COLUMNS').reset_index()

    lineages['FILE_NAME'] = filename
    lineages['ROW_ID'] = 0
    lineages['LINK_VALUE'] =1

    lineages['SOURCE_NODE'] = lineages['SOURCE_COLUMNS'].apply(lambda x:".".join(x.split('.')[0:-1]))
    lineages['TARGET_NODE'] = lineages['TARGET_COLUMN'].apply(lambda x:".".join(x.split('.')[0:-1]))

    lineages['SOURCE_FIELD'] = lineages['SOURCE_COLUMNS'].apply(lambda x:x.split('.')[-1])
    lineages['TARGET_FIELD'] = lineages['TARGET_COLUMN'].apply(lambda x:x.split('.')[-1])

    lineages['SOURCE_NODE'] = [f'{filename}@{i}' if 'subquery' in i else i for i in lineages['SOURCE_NODE'] ]
    lineages['TARGET_NODE'] = [f'{filename}@{i}' if 'subquery' in i else i for i in lineages['TARGET_NODE']]


    lineages['COLOR'] =  ["aliceblue" if i == "" else "orangered" for i in lineages['TRANSFORMATION']]

    # merge source id
    lineages = pd.merge(lineages, nodes[['ID', 'LABEL_NODE']], left_on='SOURCE_NODE', right_on = 'LABEL_NODE', how='left')
    lineages['SOURCE_NODE'] = lineages['ID']
    lineages.drop(columns=['ID', 'LABEL_NODE'], inplace=True)
    # merge target id
    lineages = pd.merge(lineages, nodes[['ID', 'LABEL_NODE']], left_on='TARGET_NODE', right_on = 'LABEL_NODE', how='left')
    lineages['TARGET_NODE'] = lineages['ID']
    lineages.drop(columns=['ID', 'LABEL_NODE'], inplace=True)

    lineages = lineages.drop_duplicates(subset =['SOURCE_COLUMNS', 'TARGET_COLUMN', 'TRANSFORMATION']).reset_index(drop=True)


    lineages.to_csv(f"data/output-tables/lineages/lineage-{target_node}.csv")
    lineages_dfs.append(lineages)

json_data0
json_data1
json_data2
json_data3
json_data4


In [4]:
lineages_dfs[0]

Unnamed: 0,index,SOURCE_COLUMNS,TARGET_COLUMN,TRANSFORMATION,FILE_NAME,ROW_ID,LINK_VALUE,SOURCE_NODE,TARGET_NODE,SOURCE_FIELD,TARGET_FIELD,COLOR
0,0,orders.customer_id,subquery1_5.customer_id,,json_data0,0,1,1.0,0,customer_id,customer_id,aliceblue
1,1,orders.total_amount,subquery1_5.total_spent,SUM(total_amount),json_data0,0,1,1.0,0,total_amount,total_spent,orangered
2,2,reviews.product_id,subquery1_4.product_id,,json_data0,0,1,3.0,2,product_id,product_id,aliceblue
3,3,reviews.rating,subquery1_4.avg_rating,AVG(rating),json_data0,0,1,3.0,2,rating,avg_rating,orangered
4,4,order_items.order_id,subquery1_3.order_id,,json_data0,0,1,5.0,4,order_id,order_id,aliceblue
5,5,products.product_name,subquery1_3.product_name,,json_data0,0,1,6.0,4,product_name,product_name,aliceblue
6,6,categories.category_name,subquery1_3.category_name,,json_data0,0,1,,4,category_name,category_name,aliceblue
7,7,order_items.quantity,subquery1_3.quantity,,json_data0,0,1,5.0,4,quantity,quantity,aliceblue
8,8,order_items.unit_price,subquery1_3.unit_price,,json_data0,0,1,5.0,4,unit_price,unit_price,aliceblue
9,9,order_items.quantity,subquery1_3.total_item_amount,(quantity * unit_price),json_data0,0,1,5.0,4,quantity,total_item_amount,orangered


In [11]:
lineages_dfs.__len__()

5

In [12]:
lineages_dfs[0]['TRANSFORMATION'][30] 

# target node in case of insert into should also include the new columns
# when two columns are multiplied that does not count as transformation, fix that

'CAST(avg_rating AS NUMERIC(3, 2))'

In [13]:
# parse limit 1
# fix source columns and target node when insert into
