In [15]:
from IPython.display import FileLink

In [43]:
Parser("SELECT uuid_field, int_field5 FROM jungle WHERE int_field5 >= 95487 AND int_field5 < 96074").columns

['uuid_field', 'int_field5']

In [1]:
import pandas as pd
import re
from sql_metadata import Parser
from collections import defaultdict
from itertools import combinations
import psycopg2 as pg
from pglast import parser
from random import shuffle

CONNECTION_STRING = "dbname='project1db' user='project1user' password='project1pass' host='localhost'"

def get_table_column_map(conn):
    QUERY = "select table_name, column_name from information_schema.columns where table_schema='public';"
    
    table_columns_map = defaultdict(set)
    with conn.cursor() as cursor:
        cursor.execute(QUERY)
        db_columns = cursor.fetchall()
        for table, column in db_columns:
            table_columns_map[table].add(column)
    return table_columns_map
        


def filter_interesting_queries(queries):
    res = [re.sub(r'^statement: ', '', q)
           for q in queries if q.startswith('statement')]
    res = [q for q in res if 'pg_' not in q and not q.startswith('SHOW ALL') and not q.startswith(
        'COMMIT') and not q.startswith('SET') and not q.startswith('BEGIN')]
    res = [q for q in res if 'WHERE' in q or 'ORDER' in q or 'JOIN' in q or 'join' in q or 'where' in q or 'order' in q]
    return res


def cluster_queries(queries):
    group_counts = defaultdict(int)
    group_repr = {}
    gqueries = set()
    for query in queries:
        try:
            generalized = parser.fingerprint(query)
            group_counts[generalized] += 1
            group_repr[generalized] = query
            gqueries.add(generalized)
        except:
            pass
    return gqueries, group_counts, group_repr


def get_relevant_columns(clusters, cluster_counts, table_column_map):
    column_usage_counts = defaultdict(int)
    tables = []
    table_columns = defaultdict(set)

    for qgroup, query in clusters.items():
        try:
            pq = Parser(query)
            tables.extend(pq.tables)
            columns = pq.columns_dict
            pq_tables = pq.tables
            groups = ['where', 'order_by', 'join']
            for group in groups:
                if group in columns:
                    standardised_columns = []
                    for col in columns[group]:
                        if '.' not in col:
                            for table in pq_tables:
                                if table in table_column_map and col in table_column_map[table]:
                                    col = table + '.' + col
                                    break
                        column_usage_counts[col] += cluster_counts[qgroup]
                        tab, c = col.split('.')
                        table_columns[tab].add(col)
        except:
            pass
    tables = set(tables)
    ordered_columns = list(
        reversed(sorted(list(column_usage_counts.items()), key=lambda x: x[1])))
    return ordered_columns, table_columns


def get_combinations_list(table_columns, max_index_width):
    combs = []
    # all_column_combinations = [_ for _ in ordered_columns]
    
    all_columns_combinations = []
    for table, cols in table_columns.items():
        tcols = list(cols)
        for width in range(1, max_index_width+1):
            all_columns_combinations.extend(combinations(tcols, width))
    for i in range(1, len(all_columns_combinations)+1):
        combs.extend(combinations(all_columns_combinations, i))

    shuffle(combs)
    return combs


def get_columns_from_logs(logs_path):
    QCOL = 13
    df = pd.read_csv(logs_path, header=None)
    queries = filter_interesting_queries(df[QCOL].tolist())
    with pg.connect(CONNECTION_STRING) as conn:
        table_column_map = get_table_column_map(conn)
        print(table_column_map)
    gqueries, group_counts, group_repr = cluster_queries(queries)
    cols, table_columns = get_relevant_columns(group_repr, group_counts, table_column_map)
    return cols, table_columns, group_counts, group_repr


def generate_index_creation_queries(columns):
    query_template = 'CREATE INDEX ON {} ({})'
    queries = []
    
    for column_group in columns:
        
        table_name = column_group[0].split('.')[0]
        
        columns = [column.split('.')[1] for column in column_group]   
        
        queries.append(query_template.format(table_name, ', '.join(columns)))
        
    return queries


def create_hypothetical_indexes(index_queries, conn):
    hypo_template = "SELECT * FROM hypopg_create_index('{}');"
    with conn.cursor() as cur:
        for index_creation_query in index_queries:
            cur.execute(hypo_template.format(index_creation_query))
            res = cur.fetchall()


def get_scaled_loss(group_counts, costs):
    cost = 0.
    for cluster, count in group_counts.items():
        cost += count*costs[cluster]
    return cost


def remove_hypo_indexes(conn):
    reset_indexes_q = 'SELECT * FROM hypopg_reset();'
    with conn.cursor() as cur:
        cur.execute(reset_indexes_q)
        res = cur.fetchall()


def enable_hypopg(conn):
    with conn.cursor() as cur:
        cur.execute('CREATE EXTENSION IF NOT EXISTS hypopg;')


def get_query_costs(query_clusters, conn):
    costs = dict()
    with conn.cursor() as cur:
        for cluster, query in query_clusters.items():
            cur.execute('EXPLAIN (FORMAT JSON) '+query)
            explain_res = cur.fetchall()
            costs[cluster] = explain_res[0][0][0]['Plan']['Total Cost']
    return costs


def find_best_index(log_file_path):
    cols, table_columns, group_counts, group_repr = get_columns_from_logs(
        log_file_path)
    cmbns = get_combinations_list(cols, table_columns)
    best_config = []
    with pg.connect("dbname='project1db' user='project1user' password='project1pass' host='localhost'") as conn:
        baseline_costs = get_query_costs(group_repr, conn)
        baseline_cost = get_scaled_loss(group_counts, baseline_costs)
        baseline_cost
        min_cost = baseline_cost
        best_config = []
        enable_hypopg(conn)
        for cmb in cmbns:
            index_q = generate_index_creation_queries(cmb)
            create_hypothetical_indexes(index_q, conn)
            costs = get_query_costs(group_repr, conn)
            cost = get_scaled_loss(group_counts, costs)
            remove_hypo_indexes(conn)
            print(cmb, cost)
            if cost < min_cost or (cost == min_cost and len(cmb) < len(best_config)):
                min_cost = cost
                best_config = cmb

    index_creation_queries = generate_index_creation_queries(best_config)
    with open('actions.sql', 'w') as f:
        for query in index_creation_queries:
            
            f.write("{};\n".format(query))
    with open('config.json', 'w') as f:
        f.write('{"VACUUM": false}')

In [3]:
# log_file_path = 'epinions.csv'
log_file_path = 'indexjungle.csv'

cols, table_columns, group_counts, group_repr = get_columns_from_logs(
        log_file_path)

  df = pd.read_csv(logs_path, header=None)


defaultdict(<class 'set'>, {'jungle': {'varchar_field8', 'timestamp_field3', 'int_field9', 'timestamp_field8', 'timestamp_field1', 'int_field0', 'int_field1', 'timestamp_field0', 'varchar_field5', 'int_field5', 'varchar_field2', 'varchar_field4', 'float_field7', 'float_field8', 'int_field4', 'int_field7', 'float_field2', 'timestamp_field7', 'varchar_field0', 'timestamp_field5', 'float_field1', 'varchar_field7', 'int_field6', 'int_field2', 'timestamp_field4', 'int_field3', 'varchar_field1', 'varchar_field6', 'float_field9', 'float_field3', 'varchar_field9', 'float_field0', 'timestamp_field6', 'timestamp_field9', 'float_field4', 'timestamp_field2', 'uuid_field', 'int_field8', 'varchar_field3', 'float_field5', 'float_field6'}, 'hypopg_list_indexes': {'am_name', 'table_name', 'index_name', 'indexrelid', 'schema_name'}})


In [5]:
print(len(table_columns['jungle']))
table_columns

14


defaultdict(set,
            {'jungle': {'jungle.float_field0',
              'jungle.float_field1',
              'jungle.float_field2',
              'jungle.float_field3',
              'jungle.float_field4',
              'jungle.float_field5',
              'jungle.float_field6',
              'jungle.float_field7',
              'jungle.float_field8',
              'jungle.float_field9',
              'jungle.int_field1',
              'jungle.int_field5',
              'jungle.int_field7',
              'jungle.uuid_field'}})

In [6]:
MAX_WIDTH = 1
cmbns = get_combinations_list(table_columns, MAX_WIDTH)

In [7]:
len(cmbns)

16383

In [15]:
cmbns[:10]

[(('jungle.int_field5',),
  ('jungle.float_field6',),
  ('jungle.float_field4',),
  ('jungle.float_field9',),
  ('jungle.float_field5',),
  ('jungle.int_field1',)),
 (('jungle.float_field8',),
  ('jungle.float_field6',),
  ('jungle.float_field4',),
  ('jungle.float_field7',),
  ('jungle.uuid_field',),
  ('jungle.float_field1',),
  ('jungle.float_field5',)),
 (('jungle.float_field8',),
  ('jungle.int_field5',),
  ('jungle.float_field4',),
  ('jungle.float_field9',),
  ('jungle.uuid_field',),
  ('jungle.float_field0',)),
 (('jungle.float_field8',),
  ('jungle.int_field5',),
  ('jungle.float_field4',),
  ('jungle.float_field7',),
  ('jungle.float_field9',),
  ('jungle.float_field1',),
  ('jungle.float_field3',),
  ('jungle.float_field5',),
  ('jungle.float_field0',)),
 (('jungle.float_field8',),
  ('jungle.int_field5',),
  ('jungle.float_field2',),
  ('jungle.uuid_field',),
  ('jungle.float_field1',),
  ('jungle.float_field3',),
  ('jungle.float_field5',),
  ('jungle.float_field0',)),
 ((

In [8]:
generate_index_creation_queries(cmbns[100])

['CREATE INDEX ON jungle (float_field8)',
 'CREATE INDEX ON jungle (int_field5)',
 'CREATE INDEX ON jungle (float_field6)',
 'CREATE INDEX ON jungle (float_field7)',
 'CREATE INDEX ON jungle (uuid_field)',
 'CREATE INDEX ON jungle (float_field1)',
 'CREATE INDEX ON jungle (float_field3)',
 'CREATE INDEX ON jungle (int_field1)']

In [9]:
best_config = []
with open('config.json', 'w') as f:
    f.write('{"VACUUM": false}')
with pg.connect(CONNECTION_STRING) as conn:
    baseline_costs = get_query_costs(group_repr, conn)
    baseline_cost = get_scaled_loss(group_counts, baseline_costs)
    min_cost = baseline_cost
    best_config = []
    enable_hypopg(conn)
    for i, cmb in enumerate(cmbns):
        index_q = generate_index_creation_queries(cmb)
        create_hypothetical_indexes(index_q, conn)
        costs = get_query_costs(group_repr, conn)
        cost = get_scaled_loss(group_counts, costs)
        remove_hypo_indexes(conn)
#         print(cmb, cost)
        if cost < min_cost or (cost == min_cost and len(cmb) < len(best_config)):
            min_cost = cost
            best_config = cmb
        if i%1000 == 0:
            with open('actions.sql', 'w') as f:
                for query in index_q:
                    f.write("{};\n".format(query))

index_creation_queries = generate_index_creation_queries(best_config)
with open('actions.sql', 'w') as f:
    for query in index_creation_queries:

        f.write("{};\n".format(query))

# print('#')

In [10]:
index_creation_queries

['CREATE INDEX ON jungle (int_field5)',
 'CREATE INDEX ON jungle (int_field7)',
 'CREATE INDEX ON jungle (uuid_field)',
 'CREATE INDEX ON jungle (int_field1)']

In [11]:
# width 3
min_cost

940959.0200000004

In [37]:
# width 1
min_cost

484797.90999999957

In [30]:
index_creation_queries

['CREATE INDEX ON item (i_id)',
 'CREATE INDEX ON review (i_id, u_id)',
 'CREATE INDEX ON review (u_id, creation_date)',
 'CREATE INDEX ON useracct (u_id)',
 'CREATE INDEX ON trust (source_u_id)',
 'CREATE INDEX ON trust (target_u_id, source_u_id)']

In [12]:
# width 2
min_cost, baseline_cost

(940959.0200000004, 15786295.0)