In [1]:
import pandas as pd
import json
import re
import math
import numpy as np
from collections import defaultdict
import psycopg2
import time
import os
import sys
from tqdm import tqdm

In [2]:
dataset = 'tpcds_sf1'
tmp_data_dir = 'data/tpcds_sf1'


# column_type_file = os.path.join(os.path.dirname(__file__), f'../zsce/cross_db_benchmark/datasets/{dataset}/column_type.json')
column_type_file = os.path.join(f'/home/wuy/DB/memory_prediction/zsce/cross_db_benchmark/datasets/{dataset}/column_type.json')
with open(column_type_file, 'r') as f:
    column_type = json.load(f)

schema = {}
for table, columns in column_type.items():
    if table == 'dbgen_version':
        continue
    for column, type_ in columns.items():
        if table not in schema:
            schema[table] = []
        schema[table].append(column)

# Define table aliases as their original names for tpcds by iterating tpcds_schema, meanwhile get col2idx
t2alias = {}
col2idx = {}
for table, columns in schema.items():
    for column in columns:
        t2alias[table] = table
        col2idx[table + '.' + column] = len(col2idx)

alias2t = {v: k for k, v in t2alias.items()}

In [3]:
# Database connection parameters
DB_PARAMS = {
    'database': dataset,
    'user': "wuy",
    'host': "127.0.0.1",
    'password': "wuy",
    'port': "5432"
}
conn = psycopg2.connect(**DB_PARAMS)
conn.set_session(autocommit=True)
cur = conn.cursor()

In [4]:
def to_vals(data_list):
    for dat in data_list:
        val = dat[0]
        if val is not None: break
    try:
        float(val)
        return np.array(data_list, dtype=float).squeeze()
    except:
#         print(val)
        res = []
        for dat in data_list:
            try:
                mi = dat[0].timestamp()
            except:
                mi = 0
            res.append(mi)
        return np.array(res)

## Histogram

In [5]:
hist_file = pd.DataFrame(columns=['table','column','bins','table_column'])
# load hist_file if exists
hist_file_path = os.path.join(tmp_data_dir, 'hist_file.csv')
if os.path.exists(hist_file_path):
    hist_file = pd.read_csv(hist_file_path)
else:
    for table,columns in schema.items():
        for column in tqdm(columns, desc=table):
            cmd = 'select {} from {} as {}'.format(column, table,t2alias[table])
            cur.execute(cmd)
            col = cur.fetchall()
            col_array = to_vals(col)
            
            hists = np.nanpercentile(col_array, range(0,101,2), axis=0)
            freqs, _ = np.histogram(col_array, bins=hists)

            freq_bytes = freqs.astype('float32').tobytes()
            freq_hex = freq_bytes.hex()
            
            res_dict = {
                'table':table,
                'column':column,
                'table_column': '.'.join((table, column)),
                'bins':hists,
                'freq': freq_hex
            }
            hist_file = pd.concat([hist_file, pd.DataFrame([res_dict])], ignore_index=True)
        hist_file.to_csv(hist_file_path, index=False)



customer_address: 100%|██████████| 13/13 [00:01<00:00, 12.07it/s]
customer_demographics: 100%|██████████| 9/9 [00:19<00:00,  2.16s/it]
date_dim: 100%|██████████| 28/28 [00:02<00:00, 10.79it/s]
warehouse: 100%|██████████| 14/14 [00:00<00:00, 466.47it/s]
ship_mode: 100%|██████████| 6/6 [00:00<00:00, 481.93it/s]
time_dim: 100%|██████████| 10/10 [00:01<00:00,  8.59it/s]
reason: 100%|██████████| 3/3 [00:00<00:00, 301.13it/s]
income_band: 100%|██████████| 3/3 [00:00<00:00, 204.28it/s]
item: 100%|██████████| 22/22 [00:00<00:00, 25.80it/s]
store: 100%|██████████| 29/29 [00:00<00:00, 477.63it/s]
call_center: 100%|██████████| 31/31 [00:00<00:00, 463.94it/s]
customer: 100%|██████████| 18/18 [00:02<00:00,  8.09it/s]
web_site: 100%|██████████| 26/26 [00:00<00:00, 470.67it/s]
store_returns: 100%|██████████| 20/20 [00:11<00:00,  1.70it/s]
household_demographics: 100%|██████████| 5/5 [00:00<00:00, 82.43it/s]
web_page: 100%|██████████| 14/14 [00:00<00:00, 433.40it/s]
promotion: 100%|██████████| 19/19 [

In [6]:
cur.close()
conn.close()

## Sample
### Steps (There may be other easier methods)
1. generate 1000 sample points for each table
2. duplicate database schema from full db
    > pg_dump imdb -s -O > imdb_schema.sql
3. create small base by in psql
    > create database imdb_sample
4. create schema using imdb_schema.sql
5. load the sample data using pandas and sqlalchemy
6. query the small base to get sample bitmaps for each predicate

Step 1

In [11]:
DB_PARAMS = {
    'database': 'tpcds_sample',
    'user': "wuy",
    'host': "127.0.0.1",
    'password': "wuy",
    'port': "5432"
}
conm = psycopg2.connect(**DB_PARAMS)
conm.set_session(autocommit=True)
cur = conm.cursor()

In [8]:
## sampling extension
try:
    cmd = 'CREATE EXTENSION tsm_system_rows'
    cur.execute(cmd)
except Exception as e:
    print(e)
    pass

could not open extension control file "/usr/local/pgsql/share/extension/tsm_system_rows.control": No such file or directory



In [13]:
import pickle
# load sample_data from file if exists
sample_data_file = os.path.join(tmp_data_dir, "sample_data.pkl")
if os.path.exists(sample_data_file):
    with open(sample_data_file, "rb") as f:
        sample_data = pickle.load(f)
else:
    tables = list(schema.keys())
    sample_data = {}
    for table in tables:
        cur.execute("Select * FROM {} LIMIT 0".format(table))
        colnames = [desc[0] for desc in cur.description]

        ts = pd.DataFrame(columns = colnames)

        for num in tqdm(range(1000), desc=table):
            # cmd = 'SELECT * FROM {} TABLESAMPLE SYSTEM_ROWS(1)'.format(table)
            cmd = 'SELECT * FROM {} ORDER BY RANDOM() LIMIT 1'.format(table)
            cur.execute(cmd)
            samples = cur.fetchall()
            for i,row in enumerate(samples):
                ts.loc[num]=row
        
        sample_data[table] = ts
    with open(sample_data_file, "wb") as f:
        pickle.dump(sample_data, f)

customer_address:   0%|          | 0/1000 [00:00<?, ?it/s]

customer_address: 100%|██████████| 1000/1000 [00:00<00:00, 3421.65it/s]
customer_demographics: 100%|██████████| 1000/1000 [00:00<00:00, 3779.10it/s]
date_dim: 100%|██████████| 1000/1000 [00:00<00:00, 3059.00it/s]
warehouse: 100%|██████████| 1000/1000 [00:00<00:00, 3617.66it/s]
ship_mode: 100%|██████████| 1000/1000 [00:00<00:00, 4678.24it/s]
time_dim: 100%|██████████| 1000/1000 [00:00<00:00, 3827.14it/s]
reason: 100%|██████████| 1000/1000 [00:00<00:00, 4330.21it/s]
income_band: 100%|██████████| 1000/1000 [00:00<00:00, 4831.79it/s]
item: 100%|██████████| 1000/1000 [00:00<00:00, 3625.11it/s]
store: 100%|██████████| 1000/1000 [00:00<00:00, 2983.35it/s]
call_center: 100%|██████████| 1000/1000 [00:00<00:00, 2699.83it/s]
customer: 100%|██████████| 1000/1000 [00:00<00:00, 3373.58it/s]
web_site: 100%|██████████| 1000/1000 [00:00<00:00, 2956.99it/s]
store_returns: 100%|██████████| 1000/1000 [00:00<00:00, 3237.13it/s]
household_demographics: 100%|██████████| 1000/1000 [00:00<00:00, 4115.54it/s]
w

In [14]:
sample_data.keys()

dict_keys(['customer_address', 'customer_demographics', 'date_dim', 'warehouse', 'ship_mode', 'time_dim', 'reason', 'income_band', 'item', 'store', 'call_center', 'customer', 'web_site', 'store_returns', 'household_demographics', 'web_page', 'promotion', 'catalog_page', 'inventory', 'catalog_returns', 'web_returns', 'web_sales', 'catalog_sales', 'store_sales'])

Step 5 (Do step 2-4 outside first)

In [16]:
from sqlalchemy import create_engine
engine = create_engine('postgresql://wuy:wuy@localhost:5432/tpcds_sample')

In [17]:
for k,v in tqdm(sample_data.items()):
    try:
        v['sid'] = list(range(1000))
        cmd = 'alter table {} add column sid integer'.format(k)
        cur.execute(cmd)
        v.to_sql(k,engine,if_exists='append',index=False)
    except Exception as e:
        print(e)

  4%|▍         | 1/24 [00:00<00:04,  5.69it/s]

(psycopg2.errors.NotNullViolation) null value in column "ca_address_sk" of relation "customer_address" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO customer_address (ca_address_sk, ca_address_id, ca_street_number, ca_street_name, ca_street_type, ca_suite_number, ca_city, ca_county, ca_state, ca_zip, ca_country, ca_gmt_offset, ca_location_type, sid) VALUES (%(ca_address_sk__0)s, %(ca_ ... 311330 characters truncated ... ca_zip__999)s, %(ca_country__999)s, %(ca_gmt_offset__999)s, %(ca_location_type__999)s, %(sid__999)s)]
[parameters: {'ca_zip__0': None, 'sid__0': 0, 'ca_address_id__0': None, 'ca_gmt_offset__0': None, 'ca_suite_number__0': None, 'ca_street_name__0': None, 'ca_location_type__0': None, 'ca_street_type__0': None, 'ca_address_sk__0': None, 'ca_city__0': None, 'ca_state__0': None, 'ca_street_number__0': None, 'ca_country__0': None, 'ca_county__0': None, 'ca_zip__

 21%|██        | 5/24 [00:00<00:02,  8.79it/s]

(psycopg2.errors.NotNullViolation) null value in column "d_date_sk" of relation "date_dim" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO date_dim (d_date_sk, d_date_id, d_date, d_month_seq, d_week_seq, d_quarter_seq, d_year, d_dow, d_moy, d_dom, d_qoy, d_fy_year, d_fy_quarter_seq, d_fy_week_seq, d_day_name, d_quarter_name, d_holiday, d_weekend, d_following_holiday, d_first ... 623851 characters truncated ... _999)s, %(d_current_month__999)s, %(d_current_quarter__999)s, %(d_current_year__999)s, %(sid__999)s)]
[parameters: {'d_dom__0': None, 'sid__0': 0, 'd_month_seq__0': None, 'd_same_day_lq__0': None, 'd_date_sk__0': None, 'd_holiday__0': None, 'd_fy_week_seq__0': None, 'd_current_day__0': None, 'd_day_name__0': None, 'd_current_week__0': None, 'd_fy_quarter_seq__0': None, 'd_qoy__0':

 33%|███▎      | 8/24 [00:00<00:01, 13.70it/s]

(psycopg2.errors.NotNullViolation) null value in column "r_reason_sk" of relation "reason" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, 0).

[SQL: INSERT INTO reason (r_reason_sk, r_reason_id, r_reason_desc, sid) VALUES (%(r_reason_sk__0)s, %(r_reason_id__0)s, %(r_reason_desc__0)s, %(sid__0)s), (%(r_reason_sk__1)s, %(r_reason_id__1)s, %(r_reason_desc__1)s, %(sid__1)s), (%(r_reason_sk__2)s, %(r_ ... 83281 characters truncated ... s, %(sid__998)s), (%(r_reason_sk__999)s, %(r_reason_id__999)s, %(r_reason_desc__999)s, %(sid__999)s)]
[parameters: {'r_reason_id__0': None, 'r_reason_sk__0': None, 'r_reason_desc__0': None, 'sid__0': 0, 'r_reason_id__1': None, 'r_reason_sk__1': None, 'r_reason_desc__1': None, 'sid__1': 1, 'r_reason_id__2': None, 'r_reason_sk__2': None, 'r_reason_desc__2': None, 'sid__2': 2, 'r_reason_id__3': None, 'r_reason_sk__3': None, 'r_reason_desc__3': None, 'sid__3': 3, 'r_reason_id__4': None, 'r_reason_sk__4': None, 'r_reason_desc__4': No

 42%|████▏     | 10/24 [00:00<00:01, 11.27it/s]

(psycopg2.errors.NotNullViolation) null value in column "s_store_sk" of relation "store" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO store (s_store_sk, s_store_id, s_rec_start_date, s_rec_end_date, s_closed_date_sk, s_store_name, s_number_employees, s_floor_space, s_hours, s_manager, s_market_id, s_geography_class, s_market_desc, s_market_manager, s_division_id, s_divi ... 685791 characters truncated ...  %(s_zip__999)s, %(s_country__999)s, %(s_gmt_offset__999)s, %(s_tax_precentage__999)s, %(sid__999)s)]
[parameters: {'s_county__0': None, 'sid__0': 0, 's_closed_date_sk__0': None, 's_street_type__0': None, 's_company_id__0': None, 's_country__0': None, 's_manager__0': None, 's_gmt_offset__0': None, 's_geography_class__0': None, 's_store_id__0': None, 's_market_id__0': None, 's_

 50%|█████     | 12/24 [00:01<00:01, 10.34it/s]

(psycopg2.errors.NotNullViolation) null value in column "c_customer_sk" of relation "customer" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO customer (c_customer_sk, c_customer_id, c_current_cdemo_sk, c_current_hdemo_sk, c_current_addr_sk, c_first_shipto_date_sk, c_first_sales_date_sk, c_salutation, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_day, c_birth_month,  ... 483900 characters truncated ... ry__999)s, %(c_login__999)s, %(c_email_address__999)s, %(c_last_review_date_sk__999)s, %(sid__999)s)]
[parameters: {'sid__0': 0, 'c_customer_id__0': None, 'c_birth_year__0': None, 'c_first_name__0': None, 'c_first_sales_date_sk__0': None, 'c_salutation__0': None, 'c_login__0': None, 'c_email_address__0': None, 'c_customer_sk__0': None, 'c_birth_country__0': None, 'c_birth_month__0': None, 'c_birth_day__0': None, 'c_last_name__0': None,

 67%|██████▋   | 16/24 [00:01<00:00, 11.73it/s]

(psycopg2.errors.NotNullViolation) null value in column "sr_item_sk" of relation "store_returns" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO store_returns (sr_returned_date_sk, sr_return_time_sk, sr_item_sk, sr_customer_sk, sr_cdemo_sk, sr_hdemo_sk, sr_addr_sk, sr_store_sk, sr_reason_sk, sr_ticket_number, sr_return_quantity, sr_return_amt, sr_return_tax, sr_return_amt_inc_tax, ... 514698 characters truncated ... h__999)s, %(sr_reversed_charge__999)s, %(sr_store_credit__999)s, %(sr_net_loss__999)s, %(sid__999)s)]
[parameters: {'sr_returned_date_sk__0': None, 'sid__0': 0, 'sr_return_tax__0': None, 'sr_return_quantity__0': None, 'sr_fee__0': None, 'sr_return_ship_cost__0': None, 'sr_ticket_number__0': None, 'sr_return_amt_inc_tax__0': None, 'sr_store_credit__0': None, 'sr_reversed_charge__0': None, 'sr_store_sk__0': None, 'sr_cdemo_

 75%|███████▌  | 18/24 [00:01<00:00, 12.26it/s]

(psycopg2.errors.NotNullViolation) null value in column "cp_catalog_page_sk" of relation "catalog_page" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO catalog_page (cp_catalog_page_sk, cp_catalog_page_id, cp_start_date_sk, cp_end_date_sk, cp_department, cp_catalog_number, cp_catalog_page_number, cp_description, cp_type, sid) VALUES (%(cp_catalog_page_sk__0)s, %(cp_catalog_page_id__0)s,  ... 252743 characters truncated ... er__999)s, %(cp_catalog_page_number__999)s, %(cp_description__999)s, %(cp_type__999)s, %(sid__999)s)]
[parameters: {'cp_start_date_sk__0': None, 'sid__0': 0, 'cp_catalog_page_number__0': None, 'cp_department__0': None, 'cp_type__0': None, 'cp_end_date_sk__0': None, 'cp_catalog_page_id__0': None, 'cp_catalog_number__0': None, 'cp_description__0': None, 'cp_catalog_page_sk__0': None, 'cp_start_date_sk__1': None, 'sid__1': 1, 'cp_catalog_page_number__1': None, 'cp_department__1': No

 83%|████████▎ | 20/24 [00:01<00:00,  9.68it/s]

(psycopg2.errors.NotNullViolation) null value in column "cr_item_sk" of relation "catalog_returns" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO catalog_returns (cr_returned_date_sk, cr_returned_time_sk, cr_item_sk, cr_refunded_customer_sk, cr_refunded_cdemo_sk, cr_refunded_hdemo_sk, cr_refunded_addr_sk, cr_returning_customer_sk, cr_returning_cdemo_sk, cr_returning_hdemo_sk, cr_re ... 771124 characters truncated ... h__999)s, %(cr_reversed_charge__999)s, %(cr_store_credit__999)s, %(cr_net_loss__999)s, %(sid__999)s)]
[parameters: {'sid__0': 0, 'cr_refunded_cdemo_sk__0': None, 'cr_fee__0': None, 'cr_warehouse_sk__0': None, 'cr_reversed_charge__0': None, 'cr_return_ship_cost__0': None, 'cr_call_center_sk__0': None, 'cr_returned_date_sk__0': None, 'cr_returning_addr_sk__0': None, 'cr_return_amt

 92%|█████████▏| 22/24 [00:02<00:00,  9.10it/s]

(psycopg2.errors.NotNullViolation) null value in column "ws_item_sk" of relation "web_sales" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO web_sales (ws_sold_date_sk, ws_sold_time_sk, ws_ship_date_sk, ws_item_sk, ws_bill_customer_sk, ws_bill_cdemo_sk, ws_bill_hdemo_sk, ws_bill_addr_sk, ws_ship_customer_sk, ws_ship_cdemo_sk, ws_ship_hdemo_sk, ws_ship_addr_sk, ws_web_page_sk,  ... 852906 characters truncated ... s_net_paid_inc_ship__933)s, %(ws_net_paid_inc_ship_tax__933)s, %(ws_net_profit__933)s, %(sid__933)s)]
[parameters: {'sid__0': 0, 'ws_warehouse_sk__0': None, 'ws_wholesale_cost__0': None, 'ws_order_number__0': None, 'ws_promo_sk__0': None, 'ws_ext_ship_cost__0': None, 'ws_bill_customer_sk__0': None, 'ws_ext_discount_amt__0': None, 'ws_ext_sales

100%|██████████| 24/24 [00:02<00:00,  9.72it/s]

(psycopg2.errors.NotNullViolation) null value in column "ss_item_sk" of relation "store_sales" violates not-null constraint
DETAIL:  Failing row contains (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, 0).

[SQL: INSERT INTO store_sales (ss_sold_date_sk, ss_sold_time_sk, ss_item_sk, ss_customer_sk, ss_cdemo_sk, ss_hdemo_sk, ss_addr_sk, ss_store_sk, ss_promo_sk, ss_ticket_number, ss_quantity, ss_wholesale_cost, ss_list_price, ss_sales_price, ss_ext_discount_am ... 586411 characters truncated ... mt__999)s, %(ss_net_paid__999)s, %(ss_net_paid_inc_tax__999)s, %(ss_net_profit__999)s, %(sid__999)s)]
[parameters: {'sid__0': 0, 'ss_wholesale_cost__0': None, 'ss_ticket_number__0': None, 'ss_ext_tax__0': None, 'ss_ext_discount_amt__0': None, 'ss_promo_sk__0': None, 'ss_store_sk__0': None, 'ss_hdemo_sk__0': None, 'ss_quantity__0': None, 'ss_addr_sk__0': None, 'ss_coupon_amt__0': None, 'ss_ext_list_price__0




Step 6

In [18]:
import sqlparse
from sqlparse.sql import Comparison, Where
from sqlparse.tokens import Keyword, DML
import re

def is_number(value):
    """
    Check if the given value is a number.
    
    Args:
        value (str): The value to check.
        
    Returns:
        bool: True if value is a number, False otherwise.
    """
    try:
        float(value)
        return True
    except ValueError:
        return False

def extract_numeric_predicates(sql_query):
    """
    Extract predicates from the SQL WHERE clause where the right side of the operator is a number.
    
    Args:
        sql_query (str): The SQL query string.
        
    Returns:
        list: A list of numeric predicates as strings.
    """
    # Parse the SQL query
    parsed = sqlparse.parse(sql_query)
    if not parsed:
        return []
    
    stmt = parsed[0]
    numeric_predicates = []

    def extract_from_tokens(tokens):
        """
        Recursively traverse tokens to find numeric comparisons.
        
        Args:
            tokens (list): List of sqlparse tokens.
        """
        for token in tokens:
            if isinstance(token, Comparison):
                # Extract the comparison string
                comparison = str(token).strip()
                
                # Regex to split the comparison into left, operator, and right
                match = re.match(r'(.+?)(=|<>|<=|>=|<|>)(.+)', comparison)
                if match:
                    left, operator, right = match.groups()
                    left = left.strip()
                    operator = operator.strip()
                    right = right.strip()
                    
                    # Remove surrounding quotes from strings
                    if right.startswith("'") and right.endswith("'"):
                        continue  # It's a string predicate; skip
                    if right.startswith('"') and right.endswith('"'):
                        continue  # It's a string predicate; skip
                    
                    # Check if the right side is a number
                    if is_number(right):
                        numeric_predicates.append(comparison)
            elif token.is_group:
                # Recursively handle sub-tokens
                extract_from_tokens(token.tokens)

    # Iterate through the tokens to find the WHERE clause
    for token in stmt.tokens:
        if isinstance(token, Where):
            extract_from_tokens(token.tokens)
            break  # Assuming only one WHERE clause

    return numeric_predicates


In [None]:
# table_samples = []
# for i,row in query_file.iterrows():
#     table_sample = {}
#     preds = row['predicate'].split(',')
#     for i in range(0,len(preds),3):
#         left, op, right = preds[i:i+3]
#         alias,col = left.split('.')
#         table = alias2t[alias]
#         pred_string = ''.join((col,op,right))
#         q = 'select sid from {} where {}'.format(table, pred_string)
#         cur.execute(q)
#         sps = np.zeros(1000).astype('uint8')
#         sids = cur.fetchall()
#         sids = np.array(sids).squeeze()
#         if sids.size>1:
#             sps[sids] = 1
#         if table in table_sample:
#             table_sample[table] = table_sample[table] & sps
#         else:
#             table_sample[table] = sps
#     table_samples.append(table_sample)

In [20]:
data_dir = '/home/wuy/DB/pg_mem_data'

# load table_sample from file if exists
table_sample_file = os.path.join(tmp_data_dir, 'table_samples.pkl')
if os.path.exists(table_sample_file):
    with open(table_sample_file, 'rb') as f:
        table_samples = pickle.load(f)
    print('Loaded table_samples from file.')
else:
    with open(os.path.join(data_dir, dataset, 'train_plans.json')) as f:
        plans = json.load(f)

    table_pattern = r'\"([a-zA-Z_]+)\"\.'
    column_pattern = r'\.\"([a-zA-Z_]+)\"'

    table_samples = []
    for plan in tqdm(plans):
        table_sample = {}
        predicates = extract_numeric_predicates(plan['sql'])
        # print(plan['sql'])
        for predicate in predicates:
            try:
                table_name = re.search(table_pattern, predicate).group(1)
                column_name = re.search(column_pattern, predicate).group(1)
                if column_type[table_name][column_name] == 'char':
                    continue
                q = 'select sid from {} where {}'.format(table_name, predicate)
                cur.execute(q)
                sps = np.zeros(1000).astype('uint8')
                sids = cur.fetchall()
                sids = np.array(sids).squeeze()
                if sids.size>1:
                    sps[sids] = 1
                if table_name in table_sample:
                    table_sample[table_name] = table_sample[table_name] & sps
                else:
                    table_sample[table_name] = sps
            except Exception as e:
                print(f"Error: {e}")
        # if len(table_sample) > 0:
        table_samples.append(table_sample)

    import pickle
    # Save table_samples to file
    with open(table_sample_file, 'wb') as f:
        pickle.dump(table_samples, f)


100%|██████████| 40000/40000 [02:56<00:00, 226.18it/s]


In [21]:
table_samples

[{},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},
 {},


In [72]:
query_file = pd.read_csv('data/imdb/workloads/synthetic.csv',sep='#',header=None)
query_file.columns = ['table','join','predicate','card']

In [73]:
query_file.head()

Unnamed: 0,table,join,predicate,card
0,cast_info ci,,"ci.person_id,=,172968",838
1,"title t,movie_info mi",t.id=mi.movie_id,"t.kind_id,<,3,t.production_year,=,2008,mi.info...",297013
2,"title t,cast_info ci",t.id=ci.movie_id,"ci.person_id,<,3194645",31427248
3,"title t,cast_info ci,movie_info mi","t.id=ci.movie_id,t.id=mi.movie_id","ci.person_id,=,1742124,ci.role_id,>,2,mi.info_...",12
4,"title t,cast_info ci,movie_info_idx mi_idx","t.id=ci.movie_id,t.id=mi_idx.movie_id","t.kind_id,=,7,t.production_year,>,0,ci.role_id...",733244


In [76]:
conm = psycopg2.connect(database="imdb", user="wuy", host="127.0.0.1",password="wuy", port="5432")
conm.set_session(autocommit=True)
cur = conm.cursor()

In [77]:
table_samples = []
for i,row in query_file.iterrows():
    table_sample = {}
    preds = row['predicate'].split(',')
    for i in range(0,len(preds),3):
        left, op, right = preds[i:i+3]
        alias,col = left.split('.')
        table = alias2t[alias]
        pred_string = ''.join((col,op,right))
        q = 'select sid from {} where {}'.format(table, pred_string)
        cur.execute(q)
        sps = np.zeros(1000).astype('uint8')
        sids = cur.fetchall()
        sids = np.array(sids).squeeze()
        if sids.size>1:
            sps[sids] = 1
        if table in table_sample:
            table_sample[table] = table_sample[table] & sps
        else:
            table_sample[table] = sps
    table_samples.append(table_sample)

KeyError: 'ci'

In [85]:
# table_samples