In [1]:
from datasets import load_dataset, Dataset
import pandas as pd
import psycopg
from psycopg.rows import dict_row

QUERY_PATH = 'queries.json'
postgreq_params = {
    "host": "localhost",
    "dbname": "data_catalog",
    "user": "text2sql"
}
query_dataset = load_dataset("json", data_files=QUERY_PATH, split='train')
df = pd.DataFrame(query_dataset)

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
import sqlglot
from sqlglot import exp
import psycopg
# from faker import Faker
import datetime
import re

# Connection setup
# conn = psycopg.connect("dbname=testdb user=yourusername password=yourpassword")

# Faker for data generation
# fake = Faker()

def check_time_stamp(expression):
    # Regular expression to match the pattern now() [+-] INTERVAL 'n' [HOURS|DAYS]
    pattern = r"(now\(\)\s*([+-])\s*INTERVAL.*)(\d+)(.*HOURS|DAYS.*)"
    match = re.match(pattern, expression, re.IGNORECASE)
    if not match:
        return False
    return match

def replace_literal(node):
    if isinstance(node, exp.Interval):
        # Access the existing literal value and unit
        value = node.this
        unit = node.unit
        # Modify the literal value to include the unit
        new_literal_value = f"{value.this}{unit.this.lower()}"
        # Update the Interval node with the new literal
        new_interval = exp.Interval(
            this= exp.Literal(this=new_literal_value, is_string=True),
            unit=""  # Keep the unit unchanged
        )
        return new_interval
    return node

def modify_interval_value(expression, operation):
    match = check_time_stamp(expression)
    prefix = match.group(1)
    new_value = match.group(3)
    interval_type = match.group(4)
    print(prefix, new_value, interval_type, sep='')
    if operation == '>':
        new_value = int(new_value) - 1
    elif operation == '<':
        new_value = int(new_value) + 1
    # Reconstruct the modified expression
    modified_expression = f"{prefix}{new_value}{interval_type}"
    return modified_expression

def analyze_literals(query):
    # Parse the SQL query
    expressions = sqlglot.parse_one(query)
    
    # Find all table references and where conditions
    tables = expressions.find_all(sqlglot.expressions.Table)
    where_conditions = expressions.find_all(sqlglot.expressions.Where)

    needed_inserts = {}

    # Analyze tables and conditions
    for table in tables:
        table_name = table.args['this'].sql()
        if table_name not in needed_inserts:
            needed_inserts[table_name] = []
        
        for condition in where_conditions:
            condition = condition.transform(replace_literal)
            expr = condition.args['this']
            for condition in [sqlglot.expressions.EQ, sqlglot.expressions.LTE, sqlglot.expressions.GTE, sqlglot.expressions.LT, sqlglot.expressions.GT, sqlglot.expressions.Like]:
                comparisons = expr.find_all(condition)
                for comparison in comparisons:
                    column = comparison.args['this'].this.sql() # remove the alias
                    value = comparison.args['expression'].sql().strip('%')
                    
                    if condition == sqlglot.expressions.LT:
                        if check_time_stamp(value):
                            value = modify_interval_value(value, '<')
                        else:
                            value = str(int(value) - 1)
                    elif condition == sqlglot.expressions.GT:
                        if check_time_stamp(value):
                            value = modify_interval_value(value, '>')
                        else:
                            value = str(int(value) + 1)
                    needed_inserts[table_name].append((column, value))
    print(needed_inserts)
    return needed_inserts

def analyze_joins(sql):
    # Parse the SQL query into an AST
    expression = sqlglot.parse_one(sql)
    join_info = []

    def find_joins(node):
        if isinstance(node, exp.Join):
            # Extract the left and right tables
            #print(node.args)
            from_clause = expression.find(exp.From)
            if from_clause:
                left_table = from_clause.this.this
            right_table = node.this.this
            # Extract the join condition (ON clause)
            condition = node.args.get('on')
            # print(left_table, right_table, condition, sep=' ')
            # left_table = condition.args.get('this')
            # right_table = condition.args.get('expression')
            

            if condition:
                for condition_node in condition.find_all(exp.Condition):
                    # Extract the join attributes from the condition
                    left_column = condition_node.args.get('this')
                    right_column = condition_node.args.get('expression')
                    if isinstance(left_column, exp.Column) and isinstance(right_column, exp.Column):
                        join_info.append({
                            'left_table': left_table.this,
                            'left_column': left_column.this.this,
                            'right_table': right_table.this,
                            'right_column': right_column.this.this
                        }) 
        return node

    # Traverse the AST and find joins
    expression.transform(find_joins)
    return join_info


    # Insert data based on analysis
    # with conn.cursor() as cur:
    #     for table, conditions in needed_inserts.items():
    #         if not conditions:  # No specific conditions, just insert generic data
    #             cur.execute(f"INSERT INTO {table} DEFAULT VALUES;")
    #         else:
    #             for column, value in conditions:
    #                 cur.execute(f"""
    #                 INSERT INTO {table} ({column}) VALUES (%s)
    #                 ON CONFLICT DO NOTHING;
    #                 """, (value,))
    #     conn.commit()

# Example query to analyze and potentially trigger insertions
temp = analyze_literals("select t3.short_name, t3.long_name, t3.description from node_table as t1 join edge_has_table_col as t2 on t1.node_id = t2.source_node_id join node_column as t3 on t2.target_node_id = t3.node_id where t1.short_name = 'bill_of_materials' and t3.type_id = 3")
temp

{'node_table': [('short_name', "'bill_of_materials'"), ('type_id', '3')], 'edge_has_table_col': [], 'node_column': []}


{'node_table': [('short_name', "'bill_of_materials'"), ('type_id', '3')],
 'edge_has_table_col': [],
 'node_column': []}

In [29]:
import json
import random
import time
from datetime import datetime, timedelta
import string
import nltk
from nltk.corpus import words


def get_taxonomy_info(taxonomy_json_path, child_node, path=None):
    if taxonomy_json_path is not None:
        taxonomy_json = {}
        for file_path in taxonomy_json_path:
            with open(file_path, 'r') as file:
                data = json.load(file)
                taxonomy_json.update(data)
    if path is None:
        path = []
    for key, values in taxonomy_json.items():
        if child_node in values and key not in path:
            path.append(key)
            get_taxonomy_info(taxonomy_json_path, key, path)            
    return path

def fetch_table_schema(postgreq_conn):
        table_schema = {}
        query = """
            SELECT table_name, column_name
            FROM information_schema.columns
            WHERE table_schema = 'public'
            ORDER BY table_name, ordinal_position;
        """
        with postgreq_conn.cursor() as cur:
            cur.execute(query)
            current_table = None
            columns = []
            for row in cur.fetchall():
                table_name, column_name = row
                if table_name != current_table:
                    if current_table is not None:
                        table_schema[current_table] = (current_table, columns)
                        columns = []
                    current_table = table_name
                columns.append(column_name)
            if current_table is not None:  
                table_schema[current_table] = (current_table, columns)
        return table_schema

def insert_data(postgreq_conn, table_name, columns, data):
    with postgreq_conn.cursor() as cur:
        placeholders = ', '.join(['%s'] * len(columns))
        column_headers = ', '.join(columns)           
        cur.execute(f"INSERT INTO {table_name} ({column_headers}) VALUES ({placeholders})", data)
    postgreq_conn.commit()

def fetch_id(postgreq_conn):
    try:
        # Create a cursor object
        cursor = postgreq_conn.cursor() 
        # Query to fetch the largest node_id from the node table
        cursor.execute("SELECT MAX(node_id) FROM node")
        largest_node_id = cursor.fetchone()[0]
        
        # Query to fetch the largest edge_id from the edge table
        cursor.execute("SELECT MAX(edge_id) FROM edge")
        largest_edge_id = cursor.fetchone()[0]
        
        # Close the cursor and connection
        cursor.close()
        return largest_node_id, largest_edge_id

    except psycopg.Error as e:
        print(f"Error fetching largest IDs: {e}")
        return None, None

def generate_random_date(start_year=1970, end_year=datetime.now().year):
    # Generate a random date between start_year and end_year
    start_date = datetime(start_year, 1, 1)
    end_date = datetime(end_year, 12, 31)
    delta = end_date - start_date
    random_days = random.randint(0, delta.days)
    random_date = start_date + timedelta(days=random_days)
    return random_date.strftime("%c")

def insert_literal(postgresq_conn, taxonomy_json_path, missed_data):
    nltk.download('words')
    # Iterate over the data to insert
    data_catalog = fetch_table_schema(postgresq_conn)
    for table, columns in data_catalog.values():
        if table not in missed_data.keys():
            continue
        if not missed_data[table]:
            continue
        # Build Data
        data = {column: None for column in columns}
        for key, value in data.items():
            if "date" in key.lower():
                data[key] = generate_random_date()
            elif "num" in key.lower() or "length" in key.lower() or "size" in key.lower():
                data[key] = random.randint(1, 1000)
            else:
                word_list = words.words()
                length = random.randint(1, 5)
                random_words = [random.choice(word_list) for _ in range(length)]
                data[key] = ' '.join(random_words)
                if 'short_name' in key.lower():
                    data[key] = '_'.join(random_words).lower()

        # load the type_id
        node_type_schema = json.load(open("../schema/node_type_records.json"))
        edge_type_schema = json.load(open("../schema/edge_type_records.json"))
        # fetch the largest node_id or edge_id
        largest_node_id, largest_edge_id = fetch_id(postgresq_conn)
        if 'node' in table.lower():
            data['node_id'] = largest_node_id + 1
            data['type_id'] = node_type_schema[table]['type_id']
        elif 'edge' in table.lower():
            data['edge_id'] = largest_edge_id + 1
            data['type_id'] = edge_type_schema[table]['type_id']
        underscore_position = table.find('_')
        # Fetch the string after the first underscore
        if underscore_position != -1:
            data['type_name'] = table[underscore_position + 1:]
        else:
            data['type_name'] = table  # In case there is no underscore in the string
        # update the value
        for column, value in missed_data[table]:
            data[column] = value
        value_list = [data.get(column) for column in columns]
        print(table, columns, value_list, sep='\n')
        #insert_data(postgresq_conn, table, columns, value_list)
        print(f"Data inserted into table {table}.")
        # Handle taxonomy information
        if taxonomy_json_path is not None:
            for parent in get_taxonomy_info(taxonomy_json_path, table):
                parent_columns = data_catalog[parent][1]
                parent_value_list = [data.get(column) for column in parent_columns]
                #insert_data(parent, parent_columns, parent_value_list)
                print(parent, parent_columns, parent_value_list, sep='\n')
                print(f"Data inserted into table {parent}.")

postgreq_params = {
            "host": "localhost",
            "dbname": "data_catalog",
            "user": "text2sql"
        }
postgreq_conn = psycopg.connect(**postgreq_params)
insert_literal(postgreq_conn, ['../schema/node_type_taxonomy.json','../schema/edge_type_taxonomy.json'], temp)

[nltk_data] Downloading package words to /u/z/z/zzheng/nltk_data...
[nltk_data]   Package words is already up-to-date!


node_table
['node_id', 'type_id', 'type_name', 'short_name', 'long_name', 'description', 'creation_date', 'modified_date', 'num_cols', 'num_rows']
[999, '3', 'table', "'bill_of_materials'", 'boxty increment preabstract', 'Irvingite thermometrograph', 'Tue Nov 20 00:00:00 1984', 'Sat Nov  9 00:00:00 2013', 511, 742]
Data inserted into table node_table.
node
['node_id', 'type_id', 'type_name', 'short_name', 'long_name', 'description', 'creation_date', 'modified_date']
[999, '3', 'table', "'bill_of_materials'", 'boxty increment preabstract', 'Irvingite thermometrograph', 'Tue Nov 20 00:00:00 1984', 'Sat Nov  9 00:00:00 2013']
Data inserted into table node.


In [45]:
import sqlglot
import sqlglot.expressions as exp

def find_same_value_attributes(joins):
    # Create a dictionary to map each attribute to its group
    attr_to_group = {}
    groups = []

    def get_group(attr):
        # If the attribute already has a group, return it
        if attr in attr_to_group:
            return attr_to_group[attr]
        # Otherwise, create a new group for this attribute
        new_group = [attr]
        groups.append(new_group)
        attr_to_group[attr] = new_group
        return new_group

    for join in joins:
        left_attr = (join['left_table'], join['left_column'])
        right_attr = (join['right_table'], join['right_column'])

        left_group = get_group(left_attr)
        right_group = get_group(right_attr)

        # If the left and right attributes are not already in the same group, merge the groups
        if left_group is not right_group:
            left_group.extend(right_group)
            for attr in right_group:
                attr_to_group[attr] = left_group
            groups.remove(right_group)

    return groups

def extract_joins(sql):
    # Parse the SQL query into an AST
    expression = sqlglot.parse_one(sql)
    join_info = []
    alias_map = {}

    # Extract table aliases
    for table in expression.find_all(exp.Table):
        table_name = table.name
        alias = table.alias_or_name
        alias_map[alias] = table_name

    def find_joins(node):
        if isinstance(node, exp.Join):
            # Extract the right table
            right_table_alias = node.this.alias_or_name
            right_table = alias_map.get(right_table_alias)

            # Extract the join condition (ON clause)
            condition = node.args.get('on')

            if condition:
                for condition_node in condition.find_all(exp.Condition):
                    # Extract the join attributes from the condition
                    left_column = condition_node.args.get('this')
                    right_column = condition_node.args.get('expression')
                    
                    if isinstance(left_column, exp.Column) and isinstance(right_column, exp.Column):
                        left_table_alias = left_column.args.get('table').this
                        left_table = alias_map.get(left_table_alias)
                        join_info.append({
                            'left_table': left_table,
                            'left_column': left_column.this.this,
                            'right_table': right_table,
                            'right_column': right_column.this.this
                        })
        return node

    # Traverse the AST and find joins
    expression.transform(find_joins)
    return find_same_value_attributes(join_info)

# Example usage
sql_query = """
select t3.short_name, t3.long_name, t3.description from node_table as t1 join edge_has_table_col as t2 on t1.node_id = t2.source_node_id join node_column as t3 on t2.source_node_id = t3.node_id and t1.hello = t2.hello where t1.short_name = 'bill_of_materials' and t3.short_name = 's_date'
"""

join_details = extract_joins(sql_query)
join_details

[[('node_table', 'node_id'),
  ('edge_has_table_col', 'source_node_id'),
  ('node_column', 'node_id')],
 [('node_table', 'hello'), ('node_column', 'hello')]]

In [None]:
def insert_join(postgresq_conn, taxonomy_json_path, join_info):
    nltk.download('words')
    # Iterate over the data to insert
    data_catalog = fetch_table_schema(postgresq_conn)
    for table, columns in data_catalog.values():
        if table not in join.keys():
            continue
        if not missed_data[table]:
            continue
        # Build Data
        data = {column: None for column in columns}
        for key, value in data.items():
            if "date" in key.lower():
                data[key] = generate_random_date()
            elif "num" in key.lower() or "length" in key.lower() or "size" in key.lower():
                data[key] = random.randint(1, 1000)
            else:
                word_list = words.words()
                length = random.randint(1, 5)
                random_words = [random.choice(word_list) for _ in range(length)]
                data[key] = ' '.join(random_words)
                if 'short_name' in key.lower():
                    data[key] = '_'.join(random_words).lower()
        # load the type_id
        node_type_schema = json.load(open("../schema/node_type_records.json"))
        edge_type_schema = json.load(open("../schema/edge_type_records.json"))
        # fetch the largest node_id or edge_id
        largest_node_id, largest_edge_id = fetch_id(postgresq_conn)
        if 'node' in table.lower():
            data['node_id'] = largest_node_id + 1
            data['type_id'] = node_type_schema[table]['type_id']
        elif 'edge' in table.lower():
            data['edge_id'] = largest_edge_id + 1
            data['type_id'] = edge_type_schema[table]['type_id']
        underscore_position = table.find('_')
        # Fetch the string after the first underscore
        if underscore_position != -1:
            data['type_name'] = table[underscore_position + 1:]
        else:
            data['type_name'] = table  # In case there is no underscore in the string
        # update the value
        for column, value in missed_data[table]:
            data[column] = value
        value_list = [data.get(column) for column in columns]
        print(table, columns, value_list, sep='\n')
        #insert_data(postgresq_conn, table, columns, value_list)
        print(f"Data inserted into table {table}.")
        # Handle taxonomy information
        if taxonomy_json_path is not None:
            for parent in get_taxonomy_info(taxonomy_json_path, table):
                parent_columns = data_catalog[parent][1]
                parent_value_list = [data.get(column) for column in parent_columns]
                #insert_data(parent, parent_columns, parent_value_list)
                print(parent, parent_columns, parent_value_list, sep='\n')
                print(f"Data inserted into table {parent}.")

insert_join(postgreq_conn, ['../schema/node_type_taxonomy.json','../schema/edge_type_taxonomy.json'], join_details)

In [33]:
for index, row in df.iterrows():
    with psycopg.connect(**postgreq_params) as conn:
        with conn.cursor(row_factory=dict_row) as cursor:
            # Process the gold standard query
            try:
                cursor.execute(row['goldSqlQuery'])
                gold_results = cursor.fetchall()
                if gold_results:
                    print("gold query is not empty")
                else:
                    print(row['goldSqlQuery'])
                    print("gold query is empty")
                    needed_inserts = analyze_literals(row['goldSqlQuery'])
                    print(needed_inserts)
            except psycopg.Error as e:
                print("Gold query can not be runned.\nError:", e)

select t3.short_name from node_business_term as t1 join edge_assoc_term_col as t2 on t1.node_id = t2.source_node_id join node_column as t3 on t2.target_node_id = t3.node_id where t1.creation_date >= now() - interval '10 days'
gold query is empty
{'node_business_term': [('creation_date', "NOW() - INTERVAL '10days'")], 'edge_assoc_term_col': [], 'node_column': []}
{'node_business_term': [('creation_date', "NOW() - INTERVAL '10days'")], 'edge_assoc_term_col': [], 'node_column': []}
select t3.short_name from node_business_term as t1 join edge_assoc_term_col as t2 on t1.node_id = t2.source_node_id join node_column as t3 on t2.target_node_id = t3.node_id where t1.creation_date >= now() - interval '4 days'
gold query is empty
{'node_business_term': [('creation_date', "NOW() - INTERVAL '4days'")], 'edge_assoc_term_col': [], 'node_column': []}
{'node_business_term': [('creation_date', "NOW() - INTERVAL '4days'")], 'edge_assoc_term_col': [], 'node_column': []}
select t3.short_name from node_busi

gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
gold query is not empty
select short_name from node_file where extension = 'csv' order by creation_date asc
gold query is empty
{'node_file': [('extension', "'csv'")]}
{'node_file': [('extension', "'csv'")]}
select short_name from node_file where extension = 'csv' order by creation_date desc
gold query is empty
{'node_file': [('extension', "'csv'")]}
{'node_file': [('extension', "'csv'")]}
select distinct t1.short_name from node_column as t1 join edge_assoc_term_col as t2 on t1.node_id = t2.target_node_id
gold query is empty
{'node_column': [], 'edge_assoc_term_col': []}
{'node_column': [], 'edge_assoc_term_col': []}
select distinct t1.short_name from node_column as t1 join edge_assoc_term_col as t2 on 