In [23]:
import os
import pandas as pd
import ast
import numpy as np

In [24]:

def calculate_grade(strings_list: list, dataframe: pd.DataFrame) -> list:
    """
    Returns a list of integers with the complexity grade of the input dataframes transformations
    """
    results = []
    substrings_data = {'Transformation': [], 'Occurrences': []}  # Initialize an empty dictionary to store data

    for string in strings_list:
        total_grade = 0
        for index, row in dataframe.iterrows():
            substring = row['General_Alias']
            grade = row['Complexity_Grade']
            occurrences = string.count(substring)
            total_grade += occurrences * grade
            
            # Append to the substrings_data dictionary if occurrences are not 0
            if occurrences != 0:
                substrings_data['Transformation'].append(substring)
                substrings_data['Occurrences'].append(occurrences)
        
        results.append(total_grade)
    
    # Create a DataFrame from the substrings_data dictionary
    substrings_df = pd.DataFrame(substrings_data)
    
    return results, substrings_df
 


In [25]:
# --------------- block code to create the files ready for sankey diagram data sources -> calcviews ------------------


functions_score = pd.read_excel("data/functions_score.xlsx")[["General_Alias", "Complexity_Grade"]]
functions_score = functions_score.drop_duplicates()
DIR = "data/output-tables/lineages"
save_DIR = "report/data"   
list_files = os.listdir(DIR)



try:
    list_files.remove('lineage-merged.csv')
except:
    pass

df_labels = pd.read_csv("data/output-tables/nodes.csv", sep = ',')
#df = pd.read_csv(f"{DIR}/{list_files[0]}", sep = ',')
df_labels = df_labels.dropna(subset=['FUNCTION'])
Sources = pd.DataFrame()

# Iterate over rows of the DataFrame
for index, row in df_labels.iterrows():
    # Check if '@' is not in the 'LABEL_NODE' column of the current row
    if 'table' in row['FUNCTION']:
        # Extract only the desired columns and rename 'FILTER' to 'COUNT'
        filtered_row = row[['LABEL_NODE', 'ID', 'WHERE']].rename({'WHERE': 'COUNT'})
        
        # Append the filtered row to the empty DataFrame
        Sources = pd.concat([Sources, filtered_row.to_frame().transpose()], ignore_index=True)
Sources['COUNT'] = 0
info_calc = {}

Sources

Unnamed: 0,LABEL_NODE,ID,COUNT
0,orders,1,0
1,reviews,3,0
2,order_items,5,0
3,products,6,0
4,payments,8,0
5,customers,11,0
6,loans,13,0
7,transactions,15,0
8,accounts,17,0
9,market_data,19,0


In [26]:
#this takes into account all the xlsx in the folder which are all the tech lineages from all the calc views
for files in list_files:
    df = pd.read_csv(f"{DIR}/{files}", sep = ',')
    source_ids = set(Sources['ID'])
    
    # Iterate over rows in df['SOURCE_NODE']
    for source_id in df['SOURCE_NODE']:
        # Check if the source_id exists in source_ids and has not been counted before
        if source_id in source_ids:
            # Find the index of the ID in Sources
            idx = Sources.index[Sources['ID'] == source_id] # !!! ID
            # Increment the count for the ID by one
            Sources.loc[idx, 'COUNT'] += 1
            # Remove the ID from source_ids to ensure it's only counted once
            source_ids.remove(source_id)

    
    # List of unique nodes
    nodes = list(set(df['TARGET_NODE']) | set(df['SOURCE_NODE']))
    
    # Filter label_nodes and function_nodes based on matching IDs
    label_nodes = df_labels[df_labels['ID'].isin(nodes)][['ID', 'LABEL_NODE']].rename(columns={'LABEL_NODE': 'LABEL_NODE'})
    function_nodes = df_labels[df_labels['ID'].isin(nodes)][['ID', 'FUNCTION']].rename(columns={'FUNCTION': 'FUNCTION'})
    
    # Count occurrences of each node in TARGET_NODE and SOURCE_NODE columns
    target_nodes = df['TARGET_NODE'].value_counts().reset_index().rename(columns={'TARGET_NODE': 'ID', 'count': 'TARGET_COUNT'})
    source_nodes = df['SOURCE_NODE'].value_counts().reset_index().rename(columns={'SOURCE_NODE': 'ID', 'count': 'SOURCE_COUNT'})
    
    # Merge label_nodes and function_nodes on 'ID'
    label_function_nodes = pd.merge(label_nodes, function_nodes, on='ID', how='outer')
    
    # Merge label_function_nodes, target_nodes, and source_nodes on 'ID'
    result = pd.merge(label_function_nodes, target_nodes, left_on='ID', right_on='ID', how='outer')
    result = pd.merge(result, source_nodes, left_on='ID', right_on='ID', how='outer')
    result['TARGET_COUNT'] = result['TARGET_COUNT'].fillna(0)
    result['SOURCE_COUNT'] = result['SOURCE_COUNT'].fillna(0)
    
    info_calc[files.split('-')[1].split('.')[0]] = result
    """ this first part is to calcualtion how much a node is used as a source or a target node based on the columns which are fed into or arise from the node
    """ 
calc_views = list(info_calc.keys())

filtered_data = []
for key, df in info_calc.items():
    """ part to get the sources which source feed data into the calc view. This could also be other caluculation views"""
    filtered_df = df[df['FUNCTION'] == 'table']
    if not filtered_df.empty:
        label_nodes = filtered_df['LABEL_NODE'].tolist()
        filtered_data.append({'CALC_VIEW': key, 'SOURCE': label_nodes})

filtered_data

[{'CALC_VIEW': 'CUSTOMER_BANK_DETAILS',
  'SOURCE': ['payments',
   'customers',
   'loans',
   'transactions',
   'accounts',
   'branches',
   'account_types']},
 {'CALC_VIEW': 'CUSTOMER_ORDER',
  'SOURCE': ['orders',
   'reviews',
   'order_items',
   'products',
   'payments',
   'customers',
   'categories']},
 {'CALC_VIEW': 'CUSTOMER_SUBSCRIPTION_DETAILS',
  'SOURCE': ['payments',
   'customers',
   'subscription_reviews',
   'subscriptions',
   'subscription_plans']},
 {'CALC_VIEW': 'INVESTOR_OVERVIEW',
  'SOURCE': ['market_data',
   'trades',
   'dividends',
   'stocks',
   'investors',
   'portfolios']},
 {'CALC_VIEW': 'VENDOR_PERFORMANCE_ANALYSIS',
  'SOURCE': ['products',
   'deliveries',
   'product_reviews',
   'delivery_items',
   'vendors',
   'categories',
   'd']}]

In [27]:

"""this part is to make the csv files which are used for the sankey where sources are coupled to the calc views"""
# Create a new dataframe from the filtered data
result_df = pd.DataFrame(filtered_data)
result_df = result_df.explode('SOURCE').reset_index(drop=True)
Nodes_source = list(np.unique(result_df['CALC_VIEW']))
Nodes_source.extend(np.unique(result_df['SOURCE']))
Nodes_source = pd.DataFrame(Nodes_source, columns=['Name'])
result_df['CALC_ID'],result_df['SOURCE_ID'],result_df['LINK_VALUE'],result_df['COLOR'] = 0,0,1,'aliceblue'

for i in range(len(result_df)):
    for j in range(len(Nodes_source)):
        if result_df.at[i, 'CALC_VIEW'] == Nodes_source.at[j, 'Name']:
            result_df.at[i, 'CALC_ID'] = j
        elif result_df.at[i, 'SOURCE'] == Nodes_source.at[j, 'Name']:
            result_df.at[i, 'SOURCE_ID'] = j

            
result_df.to_csv("data/output-tables/analysis/lineage_calc_source.csv", index = False)
Nodes_source.to_csv("data/output-tables/analysis/nodes_calc_source.csv", index = True)

In [28]:

""" these table names correspond with the table names in the report creation file"""
trans_data = pd.DataFrame(columns=['Transformation', 'Occurrences'])

table11 = pd.DataFrame(columns=['Calculation view','Number of nodes', 'Number of transformations', 'Number of filters'])

# --------------- block code to create technical lineages for a calculation view ------------------

table212 = pd.DataFrame(columns=['Calculation view','Node', 'Transformation', 'Complexity Score'])
table2122 = pd.DataFrame(columns=['Calculation view', 'Transformation count', 'Summation complexity Score'])
table211 = pd.DataFrame(columns=['Calculation view','Node', 'Transformation', 'Complexity Score'])

for files in list_files:
    df = pd.read_csv(f"{DIR}/{files}", sep = ',')
    #df = pd.read_csv(f"{DIR}/{list_files[8]}", sep = ',')
    nodes = list(set(df['TARGET_NODE']) | set(df['SOURCE_NODE']))
    
    # Filter label_nodes and function_nodes based on matching IDs
    label_nodes = df_labels[df_labels['ID'].isin(nodes)][['ID', 'LABEL_NODE', 'FUNCTION', 'ON','WHERE']].rename(columns={'LABEL_NODE': 'LABEL_NODE', 'FUNCTION': 'FUNCTION', 'ON': 'ON', 'WHERE': 'WHERE'})
    label_nodes = label_nodes.reset_index(drop=True)
    
    Data = df[['SOURCE_NODE','SOURCE_FIELD','TARGET_NODE','TARGET_FIELD','TRANSFORMATION']].copy()
    sub_join = list(label_nodes['ON'])
    
    sub_join = [ast.literal_eval(item) if isinstance(item, str) else item for item in sub_join]
    # Iterate through the list and add 'LABEL_NODE' to dictionaries
    for i in range(len(sub_join)):
        if isinstance(sub_join[i], dict):
            print(i)
            # Extract the corresponding 'LABEL_NODE' from label_nodes DataFrame based on its index
            label_node = label_nodes.iloc[i]['LABEL_NODE']
            sub_join[i]['LABEL_NODE'] = label_node
    sub_join = [item for item in sub_join if isinstance(item, dict)]
    list_filters = []

    print(label_nodes['WHERE'])
    for filter_value, label_node in zip(label_nodes['WHERE'], label_nodes['LABEL_NODE']):
        if isinstance(filter_value, str):
            list_filters.append({'filter': filter_value, 'LABEL_NODE': label_node, 'Field' : filter_value})#.split('"')[1]})
    
    
    Data["ON"] = np.nan
    Data["WHERE"] = np.nan
    
    for i in range(len(Data)):
        for j in range(len(label_nodes)):
            if Data.at[i, 'SOURCE_NODE'] == label_nodes.at[j, 'ID']:
                Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']
            if Data.at[i, 'TARGET_NODE'] == label_nodes.at[j, 'ID']:
                Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
        for k in range(len(sub_join)):
            if 'LABEL_NODE' in sub_join[k] and sub_join[k]['LABEL_NODE'] == Data.at[i, 'TARGET_NODE'] and sub_join[k].get('JoinVariable') == Data.at[i, 'TARGET_FIELD']:
                updated_dict = sub_join[k].copy()  # Make a copy of the dictionary
                updated_dict.pop('LABEL_NODE', None)  # Remove 'LABEL_NODE' key
                Data.at[i, 'ON'] = str(updated_dict)
        for l in range(len(list_filters)):
            if list_filters[l]['LABEL_NODE'] == Data.at[i, 'TARGET_NODE'] and list_filters[l].get('Field') == Data.at[i, 'TARGET_FIELD']: 
                Data.at[i, 'WHERE'] = list_filters[l].get('filter')
    
    strings_list = list(Data["TRANSFORMATION"])
    strings_list = [str(x) for x in strings_list]
    grades_list, substrings_df = calculate_grade(strings_list, functions_score)
    substrings_df = substrings_df.groupby('Transformation', as_index=False).sum().reset_index(drop=True)
    Data["Complexity_Score"] = grades_list
    #Data = Data.dropna(subset=['JOIN_ARGU', 'FILTER', 'TRANSFORMATION'], how='all')
    trans_data = pd.concat([trans_data, substrings_df])


    #------------------- block for aggregation ---------------------------
    
    unique_trans = Data['TRANSFORMATION'].dropna().nunique()
    unique_filter = Data['WHERE'].dropna().nunique()
    function_counts = label_nodes["FUNCTION"].value_counts()
    Data_final_tech = Data.dropna(subset=['ON', 'WHERE', 'TRANSFORMATION'], how='all')
    
    
    Data_final = Data.dropna(subset=['TRANSFORMATION'], how='all')
    
    # Drop duplicate rows based on selected columns
    filtered_df = Data_final.drop_duplicates(subset=['SOURCE_NODE', 'TRANSFORMATION']).reset_index(drop=True)
    for index, row in filtered_df.iterrows():
        if str(row["SOURCE_FIELD"]) in str(row["SOURCE_NODE"]):
            filtered_df.drop(index=index, inplace=True)
    filtered_df = filtered_df.reset_index(drop=True)   
    filtered_df = filtered_df[['SOURCE_NODE', 'TRANSFORMATION', 'Complexity_Score']]
    filtered_df = filtered_df.rename(columns={'SOURCE_NODE': "Node", 'TRANSFORMATION' : 'Transformation', 'Complexity_Score' : 'Complexity Score'})
    filtered_df['Calculation view'] = files.split('-')[1].split('.')[0]
    filtered_df = filtered_df.reindex(columns=['Calculation view', 'Node', 'Transformation', 'Complexity Score'])
    
    table212 = pd.concat([table212, filtered_df])
    filtered_df = filtered_df.sort_values(by='Complexity Score', ascending=False).head(1)
    table211 = pd.concat([table211, filtered_df])
    
    temp = {'Calculation view': files.split('-')[1].split('.')[0],'Number of nodes' : sum(function_counts), 'Number of transformations' : unique_trans, 'Number of filters' : unique_filter}
    table11.loc[len(table11)] = temp
    calc_scores =  Data[Data['TRANSFORMATION'].notna()].drop_duplicates(subset='TRANSFORMATION').reset_index(drop=True)
    temp = {'Calculation view' : files.split('-')[1].split('.')[0], 'Transformation count' : unique_trans, 'Summation complexity Score' : sum(calc_scores["Complexity_Score"])}
    table2122.loc[len(table2122)] = temp
    if files == 'lineage-Q_AccountsPayable.csv':
        substrings_df.to_csv(f"{save_DIR}/substrings_df.csv",index = False)
        Data_final_tech.to_csv(f"{save_DIR}/Data_final_tech .csv",index = False)
        Account_payable_tech_lineage = Data
table212 = table212.sort_values(by='Complexity Score', ascending=False).head(5).reset_index(drop=True)  
table2122 = table2122.sort_values(by='Summation complexity Score', ascending=False).head(5).reset_index(drop=True)  

trans_data = trans_data.groupby('Transformation', as_index=False).sum().reset_index(drop=True)   
table112 = Sources.sort_values(by='COUNT',ascending=False) 
table112 = table112.drop(columns=['ID'], axis=1).head(5).reset_index(drop=True)
#print(sorted_sources)
calc_names = []
table22 = pd.DataFrame(columns=['Calculation view','Input calculation view'])
table31 = pd.DataFrame(columns=['Calculation view','Data source','Columns used','Columns in source','Percentage columns used'])
columns_tables = pd.read_excel("data/Columns_sources.xlsx").dropna().reset_index(drop=True)
for files in list_files:
    calc_names.append(files.split('-')[1].split('.')[0])
for i in info_calc.keys():
    for j in range(len(info_calc[i])):
        if info_calc[i]['LABEL_NODE'][j] in calc_names:
            temp = {'Calculation view': i,'Input calculation view' : info_calc[i]['LABEL_NODE'][j]}
            table22.loc[len(table22)] = temp
        for k in range(len(columns_tables)):  
            if info_calc[i]['LABEL_NODE'][j] == columns_tables['LABEL_NODE'][k] and info_calc[i]['FUNCTION'][j] == "DataSources":
                temp = {'Calculation view': i,'Data source' : info_calc[i]['LABEL_NODE'][j], 'Columns used' : info_calc[i]['SOURCE_COUNT'][j], 'Columns in source' :  columns_tables['COUNT'][k], "Percentage columns used" : info_calc[i]['SOURCE_COUNT'][j]/columns_tables['COUNT'][k]}
                table31.loc[len(table31)] = temp
    
table113 = trans_data.sort_values(by='Occurrences', ascending=False).head(5).reset_index(drop=True)



table11.to_csv(f"{save_DIR}/table11.csv",index = False)
table112.to_csv(f"{save_DIR}/table112.csv",index = False)
table113.to_csv(f"{save_DIR}/table113.csv",index = False)
table2122.to_csv(f"{save_DIR}/table2122.csv",index = False)
table211.to_csv(f"{save_DIR}/table211.csv",index = False)
table212.to_csv(f"{save_DIR}/table212.csv",index = False)
table22.to_csv(f"{save_DIR}/table22.csv",index = False)
table31.to_csv(f"{save_DIR}/table31.csv",index = False)

0                                                  NaN
1                                                  NaN
2         COMPARE(loans.end_date, CURRENT_TIMESTAMP())
3                                                  NaN
4    COMPARE(transactions.transaction_date, DATETIM...
5                                                  NaN
6    COMPARE(customers.join_date, DATETIME_ADDITION...
7                                                  NaN
8                                                  NaN
9                                                  NaN
Name: WHERE, dtype: object
0                        EQ(orders.status, 'Completed')
1                                                   NaN
2                                                   NaN
3                                                   NaN
4                                                   NaN
5                                                   NaN
6                                                   NaN
7                              

  Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']


0                                                  NaN
1                                                  NaN
2                                                  NaN
3                                                  NaN
4    COMPARE(payments.payment_date, DATETIME_ADDITI...
5                                                  NaN
6    COMPARE(customers.signup_date, DATETIME_ADDITI...
7                                                  NaN
8                                                  NaN
Name: WHERE, dtype: object
0                                                   NaN
1                                                   NaN
2                                                   NaN
3                                                   NaN
4     COMPARE(market_data.market_date, DATETIME_ADDI...
5     COMPARE(dividends.dividend_date, DATETIME_ADDI...
6                                                   NaN
7     COMPARE(trades.trade_date, DATETIME_ADDITION('...
8                             

  Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']


0                                                   NaN
1                                                   NaN
2                                                   NaN
3                                                   NaN
4                                                   NaN
5                                                   NaN
6                                                   NaN
7     COMPARE(deliveries.delivery_date, DATETIME_ADD...
8     COMPARE(vendors.contract_start_date, DATETIME_...
9                                                   NaN
10                                                  NaN
11                                                  NaN
Name: WHERE, dtype: object


  Data.at[i, 'TARGET_NODE'] = label_nodes.at[j, 'LABEL_NODE']
  Data.at[i, 'SOURCE_NODE'] = label_nodes.at[j, 'LABEL_NODE']


In [31]:
from sqlglot import parse_one, exp
from sqlglot.dialects.ma import MA
from sqlglot.dialects.tsql import TSQL
import sqlglot
import copy
import os
import json
import re


def open_query(dir:str) -> list:
    """
    Open sql queries from one text file
    """ 
    with open(dir, 'r') as file: 
        file = file.read().strip().split(';')
        sql_queries = [re.sub(r'\s+', ' ', query.strip().replace('\n', ' ').replace('\t', ' ')) for query in file if query.strip()]
    return sql_queries

def transformer_functions(node):
    """
    Replaces column objects within the functions with simple column names
    """
    if isinstance(node, exp.Column):
        return parse_one(node.name)
    return node

def extract_subqueries(ast: sqlglot.expressions) -> dict:
    """
    Extract all subqueries from a query and saves them in a dictionary with structured format
    """
    count = 0
    selects = list(ast.find_all(exp.Select))  

    selects = [select for select in selects] # problem with tsql conversion: column aliases get dropped in case of no transformation
    nested_subqueries = {}

    for i, subquery_i in enumerate(selects):
        for j, subquery_j in enumerate(selects):
    

            if str(subquery_j) in str(subquery_i) and j != i:                      
                count +=1
                print(subquery_i.find(subquery_j))#.replace(sqlglot.exp.Table(this="subquery_x"))

                subquery_i = subquery_i.replace(subquery_j, f"subquery_{count}")
                nested_subqueries[f"subquery_{i}"] = subquery_i
                nested_subqueries[f"subquery_{j}"] = subquery_j

            #elif i!=j and len(subquery_j) < len(subquery_i)-30:
             #   pass
                #print(subquery_j)
                #print(subquery_i) 
                #print()
                    

    #main_query = nested_subqueries['subquery_0']

    subqueries = {k: v for k, v in nested_subqueries.items() if k != 'subquery_0'}
    print(count)

    return subqueries


def replace_nested_subqueries_in_subqueries(subqueries: dict) -> dict:
    """
    Replace nested subqueries in subqueries with the key
    """
        
    for i, query_i in subqueries.items():
        for j, query_j in subqueries.items():
      

            if j in query_i:
                subqueries[i] = subqueries[i].replace(j, query_j)
    return subqueries


def replace_subqueries_in_mainquery(ast: sqlglot.expressions, subqueries_global:dict) -> str:
    """
    Replace nested subqueries in main_query with the key
    """
    subqueries = subqueries_global.copy()
    try:
        main_query = list(repr((ast.find_all(exp.Create))))[0]
    except:
        print('error')
        #main_query = str(list(ast.find_all(exp.Insert))[0])

    subqueries = replace_nested_subqueries_in_subqueries(subqueries)

    for key_i, value_i in subqueries.items():
            
        main_query = main_query.replace(f"({str(value_i)})", str(key_i))
        main_query = main_query.replace(value_i, str(key_i))

    return main_query


#def replace_subqueries_in_mainquery(main_query: str, subqueries_global:dict) -> str:
#    """
#    Replace nested subqueries in main_query with the key
#    """
#    subqueries = subqueries_global.copy()
#
#    subqueries = replace_nested_subqueries_in_subqueries(subqueries)
#
#    for key_i, value_i in subqueries.items():
#            
#        main_query = main_query.replace(f"({str(value_i)})", str(key_i))
#        main_query = main_query.replace(value_i, str(key_i))
#
#    return main_query


def save_preprocessed_query(preprocessed_query, idx):
    filename = f'data/preprocessed-queries/json_data{idx}.json'
    with open(filename, 'w') as json_file:
        json.dump(preprocessed_query, json_file, indent=4)


def preprocess_queries(dir:str) -> dict:
    """
    Orchestrates the preprocessing and extraction of the SQL queries
    """
    preprocessed_queries = []
    sql_queries = open_query(dir)
    for i, query in enumerate(sql_queries):
        ast = parse_one(query, read="tsql")
        subqueries = extract_subqueries(ast)
        main_query = replace_subqueries_in_mainquery(ast, subqueries)
        #main_query = replace_subqueries_in_mainquery(query, subqueries)

        preprocessed_query = {'modified_SQL_query': main_query, 'subquery_dictionary': subqueries}
        save_preprocessed_query(preprocessed_query, i)
        preprocessed_queries.append(preprocessed_query)

    return preprocessed_queries



preprocessed_queries = preprocess_queries('data/queries-txts/TEST.txt') # 'data/queries-txts/WorldWideImporters 1.txt'

preprocessed_queries


0


TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union

In [45]:
from sqlglot import parse_one, exp
from sqlglot.dialects.ma import MA
from sqlglot.dialects.tsql import TSQL
import sqlglot
import copy
import os
import json
import re

query = """CREATE VIEW INVESTOR_OVERVIEW AS
SELECT
    i.investor_id,
    i.first_name,
    i.last_name,
    i.email,
    p.portfolio_id,
    p.portfolio_name,
    recent_trades.trade_id,
    recent_trades.trade_date,
    recent_trades.ticker,
    recent_trades.company_name,
    recent_trades.trade_type,
    recent_trades.quantity,
    recent_trades.price_per_share,
    recent_trades.total_trade_value,
    CAST(dividends_received.total_dividends AS DECIMAL(10,2)) AS total_dividends,
    CAST(average_performance.avg_performance AS DECIMAL(10,2)) AS average_performance,
    CAST(total_investment.total_value AS DECIMAL(10,2)) AS total_investment_value
FROM
    investors i
JOIN portfolios p ON i.investor_id = p.investor_id
JOIN (
    SELECT
        t.trade_id,
        t.portfolio_id,
        s.ticker,
        s.company_name,
        t.trade_date,
        t.trade_type,
        t.quantity,
        t.price_per_share,
        (t.quantity * t.price_per_share) AS total_trade_value
    FROM
        trades t
    JOIN stocks s ON t.stock_id = s.stock_id
    WHERE
        t.trade_date >= DATEADD(month, -1, GETDATE())
) recent_trades ON p.portfolio_id = recent_trades.portfolio_id
JOIN (
    SELECT
        t.portfolio_id,
        SUM(d.dividend_amount * t.quantity) AS total_dividends
    FROM
        trades t
    JOIN dividends d ON t.stock_id = d.stock_id
    WHERE
        d.dividend_date >= DATEADD(year, -1, GETDATE())
    GROUP BY
        t.portfolio_id
) dividends_received ON p.portfolio_id = dividends_received.portfolio_id
JOIN (
    SELECT
        t.portfolio_id,
        AVG(md.closing_price) AS avg_performance
    FROM
        trades t
    JOIN market_data md ON t.stock_id = md.stock_id
    WHERE
        md.market_date >= DATEADD(year, -1, GETDATE())
    GROUP BY
        t.portfolio_id
) average_performance ON p.portfolio_id = average_performance.portfolio_id
JOIN (
    SELECT
        t.portfolio_id,
        SUM(t.quantity * md.closing_price) AS total_value
    FROM
        trades t
    JOIN market_data md ON t.stock_id = md.stock_id
    WHERE
        md.market_date = (SELECT MAX(market_date) FROM market_data)
    GROUP BY
        t.portfolio_id
) total_investment ON p.portfolio_id = total_investment.portfolio_id
WHERE
    i.join_date <= DATEADD(year, -1, GETDATE())
ORDER BY
    i.investor_id,
    p.portfolio_id;
"""


import sqlglot

def extract_subqueries(query: str) -> dict:
    """
    Extract all subqueries from a query and saves them in a dictionary with structured format
    """
    ast = parse_one(query, dialect = 'tsql')
    count = 0
    selects = list(ast.find_all(exp.Select))   
    selects = [select for select in selects] # problem with tsql conversion: column aliases get dropped in case of no transformation
    subqueries = {}
    
    for i, select in enumerate(selects):
        print(type(select))
        subqueries[f'subquery_{i}'] = select


    subqueries = {k: v for k, v in subqueries.items() if k != 'subquery_0'}


    return subqueries


subqueries = extract_subqueries(query)

type(subqueries['subquery_1'])

<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>


sqlglot.expressions.Select

In [46]:
import sqlglot

def extract_subqueries(query: str) -> dict:
    """
    Extract all subqueries from a query and saves them in a dictionary with structured format
    """
    ast = parse_one(query, dialect = 'tsql')
    count = 0
    selects = list(ast.find_all(exp.Select))   
    selects = [select for select in selects] # problem with tsql conversion: column aliases get dropped in case of no transformation
    subqueries = {}
    
    for i, select in enumerate(selects):
        print(type(select))
        subqueries[f'subquery_{i}'] = select


    subqueries = {k: v for k, v in subqueries.items() if k != 'subquery_0'}


    return subqueries


subqueries = extract_subqueries(query)

type(subqueries['subquery_1'])

<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>


sqlglot.expressions.Select

In [48]:
# EXTRACT SUBQUERIES FROM MAIN QUERY WITHOUT STRINGS


def replace_subquery_with_table(node):

    if type(node) == sqlglot.expressions.Select:
        for name, subquery in subqueries.items():    
            if node.sql() == subquery.sql():  # Check if the node is a subquery
                return sqlglot.exp.Table(this=name)  # Replace with a table node
    return node


# Parse the SQL query into an abstract syntax tree (AST)
parsed = sqlglot.parse_one(query, dialect = 'tsql')

# Apply the transformation to the AST
transformed = parsed.transform(replace_subquery_with_table)

# Generate the modified SQL query
#new_query = transformed.sql()

print("Original Query:")
print(query)
print("\nModified Query:")
print(transformed)


Original Query:
CREATE VIEW INVESTOR_OVERVIEW AS
SELECT
    i.investor_id,
    i.first_name,
    i.last_name,
    i.email,
    p.portfolio_id,
    p.portfolio_name,
    recent_trades.trade_id,
    recent_trades.trade_date,
    recent_trades.ticker,
    recent_trades.company_name,
    recent_trades.trade_type,
    recent_trades.quantity,
    recent_trades.price_per_share,
    recent_trades.total_trade_value,
    CAST(dividends_received.total_dividends AS DECIMAL(10,2)) AS total_dividends,
    CAST(average_performance.avg_performance AS DECIMAL(10,2)) AS average_performance,
    CAST(total_investment.total_value AS DECIMAL(10,2)) AS total_investment_value
FROM
    investors i
JOIN portfolios p ON i.investor_id = p.investor_id
JOIN (
    SELECT
        t.trade_id,
        t.portfolio_id,
        s.ticker,
        s.company_name,
        t.trade_date,
        t.trade_type,
        t.quantity,
        t.price_per_share,
        (t.quantity * t.price_per_share) AS total_trade_value
    FROM


In [43]:
# EXTRACT NESTED SUBQUERIES FROM SUBQUERIES WITHOUT STRINGS

def extract_subqueries(ast: sqlglot.expressions.Select) -> dict:
    """
    Extract all subqueries from a query and saves them in a dictionary with structured format
    """
    count = 0
    selects = list(ast.find_all(exp.Select))   
    selects = [select for select in selects] # problem with tsql conversion: column aliases get dropped in case of no transformation
    subqueries = {}
    
    for i, select in enumerate(selects):
        print(type(select))
        subqueries[f'subquery_{i}'] = select

    subqueries = {k: v for k, v in subqueries.items() if k != 'subquery_0'}

    return subqueries


def replace_subquery_with_table(node):

    if type(node) == sqlglot.expressions.Select:
        for name, subquery, name_old, subquery_old in zip(subqueries.items(), old_subqueries.items()):    
            if node.sql() == subquery.sql():  # Check if the node is a subquery
                return sqlglot.exp.Table(this=name_old)  # Replace with a table node
    return node

def replace_subquery_with_table_subquery(node):

    global old_subqueries
    global subqueries 


    old_subqueries = subqueries.copy()
    subqueries = extract_subqueries(node)
    print(subqueries)


    for select in node.find_all(exp.Select):
        transformed = node.transform(replace_subquery_with_table)
        return transformed

    return None


for name, subquery in subqueries.items():

    # Apply the transformation to the AST
    transformed = replace_subquery_with_table_subquery(subquery)

    # Generate the modified SQL query
    #new_query = transformed.sql()

    print("Original Query:")
    print(subquery)
    print("Modified Query:")
    print(transformed)
    print()


AttributeError: 'list' object has no attribute 'items'

<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
<class 'sqlglot.expressions.Select'>
SELECT * FROM (SELECT id, name FROM users) AS subquery
*
FROM (SELECT id, name FROM users) AS subquery
(SELECT id, name FROM users) AS subquery
SELECT id, name FROM users
subquery
subquery
Original Query:

SELECT *
FROM (
    SELECT id, name
    FROM users
) AS subquery


Modified Query:
SELECT * FROM (replaced_table) AS subquery
