In [1]:
import os
import subprocess
import textwrap
import time

from absl import app
from absl import flags
from absl import logging
from mako import template
import numpy as np
import pandas as pd
import psycopg2 as pg
from pyspark.sql import SparkSession

from neurocard_lib import common,datasets,join_utils,utils
from neurocard_lib.factorized_sampler import FactorizedSamplerIterDataset

In [2]:
def MakeTablesKey(table_names):
    sorted_tables = sorted(table_names)
    return '-'.join(sorted_tables)

In [11]:
def MakeQueries(num_queries, tables_in_templates, table_names,join_keys, rng):
    MIN_filters=1
    MAX_filters=8
    """Sample a tuple from actual join result then place filters."""
    # spark.catalog.clearCache()
     # TODO: this assumes single equiv class.
    join_items = list(join_keys.items())
    lhs = join_items[0][1]
    join_clauses_list = []
    for rhs in join_items[1:]:
        rhs = rhs[1]
        join_clauses_list.append('{} = {}'.format(lhs, rhs))
        lhs = rhs
    join_clauses = '\n AND '.join(join_clauses_list)

    # Take only the content columns.
    content_cols = []
    categoricals = []
    numericals = []
    for table_name in table_names:
        # categorical_cols = datasets.JoinOrderBenchmark.CATEGORICAL_COLUMNS[table_name]
        # for c in categorical_cols:
        #     disambiguated_name = common.JoinTableAndColumnNames(table_name, c,sep='.')
        #     content_cols.append(disambiguated_name)
        #     categoricals.append(disambiguated_name)

        range_cols = datasets.JoinOrderBenchmark.RANGE_COLUMNS[table_name]
        for c in range_cols:
            disambiguated_name = common.JoinTableAndColumnNames(table_name,c,sep='.')
            content_cols.append(disambiguated_name)
            numericals.append(disambiguated_name)
    
    # Build a concat table reprsesenting the join result schema.
    join_keys_list = [join_keys[n] for n in table_names]
    print(join_keys_list)
    join_spec = join_utils.get_join_spec({
        "join_tables": table_names,
        "join_keys": dict(
            zip(table_names, [[k.split(".")[1]] for k in join_keys_list])),
        "join_root": "users",
        "join_how": "inner",
    })
    ds = FactorizedSamplerIterDataset(tables_in_templates,
                                      join_spec,
                                      sample_batch_size=num_queries,
                                      disambiguate_column_names=False,
                                      add_full_join_indicators=False,
                                      add_full_join_fanouts=False)
    concat_table = common.ConcatTables(tables_in_templates,
                                       join_keys_list,
                                       sample_from_join_dataset=ds)

    template_for_execution = template.Template(
        textwrap.dedent("""
        SELECT COUNT(*)
        FROM ${', '.join(table_names)}
        WHERE ${join_clauses}
        AND ${filter_clauses};
    """).strip())

    true_inner_join_card = ds.sampler.join_card
    # true_full_join_card = JOB_LIGHT_OUTER_CARDINALITY
    print('True inner join card', true_inner_join_card)

    ncols = len(content_cols)
    queries = []
    filter_strings = []
    sql_queries = []  # To get true cardinalities.

    while len(queries) < num_queries:
        sampled_df = ds.sampler.run()[content_cols]

        for r in sampled_df.iterrows():
            tup = r[1]
            num_filters = rng.randint(MIN_filters, max(ncols // 2, MAX_filters))

            # Positions where the values are non-null.
            non_null_indices = np.argwhere(~pd.isnull(tup).values).reshape(-1,)
            if len(non_null_indices) < num_filters:
                continue
            print('{} filters out of {} content cols'.format(
                num_filters, ncols))

            # Place {'<=', '>=', '='} on numericals and '=' on categoricals.
            idxs = rng.choice(non_null_indices, replace=False, size=num_filters)
            vals = tup[idxs].values
            cols = np.take(content_cols, idxs)
            ops = rng.choice(['<=', '>=', '='], size=num_filters)
            sensible_to_do_range = [c in numericals for c in cols]
            ops = np.where(sensible_to_do_range, ops, '=')

            print('cols', cols, 'ops', ops, 'vals', vals)

            queries.append((cols, ops, vals))
            filter_strings.append(','.join(
                [','.join((c, o, str(v))) for c, o, v in zip(cols, ops, vals)]))

            # Quote string literals & leave other literals alone.
            filter_clauses = '\n AND '.join([
                '{} {} {}'.format(col, op, val)
                if concat_table[col].data.dtype in [np.int64, np.float64] else
                '{} {} \'{}\''.format(col, op, val)
                for col, op, val in zip(cols, ops, vals)
            ])

            sql = template_for_execution.render(table_names=table_names,
                                                join_clauses=join_clauses,
                                                filter_clauses=filter_clauses)
            sql_queries.append(sql)

            if len(queries) >= num_queries:
                break

    true_cards = []
    for i, sql_query in enumerate(sql_queries):
        print(sql_query)
        # DropBufferCache()
        # spark.catalog.clearCache()

        # print('  Query',
        #       i,
        #       'out of',
        #       len(sql_queries),
        #       '[{}]'.format(filter_strings[i]),
        #       end='')

        # t1 = time.time()

        # true_card = ExecuteSql(spark, sql_query)[0][0]

        # cursor.execute(sql_query)
        # result = cursor.fetchall()
        # true_card = result[0][0]

        # dur = time.time() - t1

        # true_cards.append(true_card)
        # print(
        #     '...done: {} (inner join sel {}; full sel {}; inner join {}); dur {:.1f}s'
        #     .format(true_card, true_card / true_inner_join_card,
        #             true_card / true_full_join_card, true_inner_join_card, dur))

        # if i > 0 and i % 1 == 0:
        #     spark = StartSpark(spark)

    # df = pd.DataFrame({
    #     'tables': [','.join(table_names)] * len(true_cards),
    #     'join_conds': [
    #         ','.join(map(lambda s: s.replace(' ', ''), join_clauses_list))
    #     ] * len(true_cards),
    #     'filters': filter_strings,
    #     'true_cards': true_cards,
    # })
    # df.to_csv(, sep='#', mode='a', index=False, header=False)
    # print('Template done.')
    return queries, true_cards

In [7]:
import dataset_util
def main():
    dataset_dir='data/STATS_NEW/'
    queries_path='workload/stats_query_chosen.csv'
    output_csv = 'workload/stats_gen.csv'
    NUM_QUERIES=100

    cursor = None
    tables = dataset_util.LoadSTATS(data_dir=dataset_dir,use_cols=None)
    print("Load tables.")
    queries = dataset_util.STATSToQuery(queries_path, use_alias_keys=False)
    print("Load queries.")
    # print(queries)
    tables_to_join_keys = {}
    for query in queries:
        key = MakeTablesKey(query[0])
        if key not in tables_to_join_keys:
            join_dict = query[1]
            # Disambiguate: title->id changed to title->title.id.
            for table_name in join_dict.keys():
                # TODO: only support a single join key
                join_key = next(iter(join_dict[table_name]))
                join_dict[table_name] = common.JoinTableAndColumnNames(
                    table_name, join_key, sep='.')
            tables_to_join_keys[key] = join_dict

    num_templates = len(tables_to_join_keys)
    num_queries_per_template = NUM_QUERIES // num_templates
    print(num_templates,'join templates.')

    rng = np.random.RandomState(1234)
    queries = []  # [(cols, ops, vals)]

    # Disambiguate to not prune away stuff during join sampling.
    for table_name, table in tables.items():
        for col in table.columns:
            col.name = common.JoinTableAndColumnNames(table.name,
                                                      col.name,
                                                      sep='.')
        table.data.columns = [col.name for col in table.columns]

    # Generate queries.
    last_run_queries = file_len(output_csv) if os.path.exists(output_csv) else 0
    next_template_idx = last_run_queries // num_queries_per_template
    print('next_template_idx', next_template_idx)
    print(tables_to_join_keys.items())

    # spark = StartSpark()
    for i, (tables_to_join, join_keys) in enumerate(tables_to_join_keys.items()):

        if i < next_template_idx:
            print('Skipping template:', tables_to_join)
            continue
        print('Template:', tables_to_join)

        if i == num_templates - 1:
            num_queries_per_template += NUM_QUERIES % num_templates

        # Generate num_queries_per_template.
        table_names = tables_to_join.split('-')

        tables_in_templates = [tables[n] for n in table_names]

        queries.extend(MakeQueries(num_queries_per_template, tables_in_templates, table_names, join_keys, rng))


In [3]:
import csv
def choose_queries(queries_path,use_alias_keys=True):
    sql_file = open(queries_path)
    queries = sql_file.readlines()
    sql_file.close()
 

    chosensql=[]
    for line in queries:
        # data_raw=line.split("#")
        table_dict=line.split("#")[0].split(',')
        if len(table_dict)<=3:
            chosensql.append(line)
    
    f=open(queries_path[:-4]+"_chosen.csv","w")
    f.writelines(chosensql)
    f.close()
    # with open(queries_path) as f:
    #     data=
    #     data_raw = list(list(rec) for rec in csv.reader(f, delimiter='#'))
    #     for row in data_raw:
    #         reader = csv.reader(row)  # comma-separated
    #         table_dict =utils._get_table_dict(next(reader))
    #         join_dict =utils._get_join_dict(next(reader), table_dict, use_alias_keys)
    #         # print(len(join_dict.keys()))
    #         if len(join_dict.keys())<=3:
    #             print(row)

# choose_queries('workload/stats_sub_query.csv')

In [12]:
main()

Loaded parsed Table from data/STATS_NEW/badges.table
Loaded parsed Table from data/STATS_NEW/votes.table
Loaded parsed Table from data/STATS_NEW/postHistory.table
Loaded parsed Table from data/STATS_NEW/posts.table
Loaded parsed Table from data/STATS_NEW/users.table
Loaded parsed Table from data/STATS_NEW/comments.table
Loaded parsed Table from data/STATS_NEW/postLinks.table
Loaded parsed Table from data/STATS_NEW/tags.table
Load tables.
Load queries.
17 join templates.
next_template_idx 0
dict_items([('badges-users', defaultdict(<class 'set'>, {'badges': 'badges.UserId', 'users': 'users.Id'})), ('badges-comments', defaultdict(<class 'set'>, {'comments': 'comments.UserId', 'badges': 'badges.UserId'})), ('comments-postHistory', defaultdict(<class 'set'>, {'comments': 'comments.UserId', 'postHistory': 'postHistory.UserId'})), ('comments-votes', defaultdict(<class 'set'>, {'comments': 'comments.UserId', 'votes': 'votes.UserId'})), ('badges-posts', defaultdict(<class 'set'>, {'badges': 'ba

AssertionError: ['badges', 'users']