In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd

from cardinality_estimation.featurizer import Featurizer
from query_representation.query import load_qrep

import glob
import random
import os
import json
import time
from collections import defaultdict

from IPython.display import HTML
from IPython.display import display_html

In [None]:
#TRAINDIR = os.path.join(os.path.join("", "queries"), "imdb")
#TRAINDIR = os.path.join(os.path.join("", "queries"), "tpcds")
TRAINDIR = os.path.join(os.path.join("", "queries"), "job")
print(TRAINDIR)

In [None]:

def load_qdata(fns):
    qreps = []
    for qfn in fns:
        qrep = load_qrep(qfn)
        # TODO: can do checks like no queries with zero cardinalities etc.
        qreps.append(qrep)
        template_name = os.path.basename(os.path.dirname(qfn))
        qrep["name"] = os.path.basename(qfn)
        qrep["template_name"] = template_name
    return qreps

def get_query_fns(basedir, template_fraction=1.0):
    fns = []
    tmpnames = list(glob.glob(os.path.join(basedir, "*")))
    assert template_fraction <= 1.0
    print(tmpnames)
    for qi,qdir in enumerate(tmpnames):
        if os.path.isfile(qdir):
            continue
        print(qdir)
        template_name = os.path.basename(qdir)
        # let's first select all the qfns we are going to load
        qfns = list(glob.glob(os.path.join(qdir, "*.pkl")))
        qfns.sort()
        num_samples = max(int(len(qfns)*template_fraction), 1)
        random.seed(1234)
        qfns = random.sample(qfns, num_samples)
        fns += qfns
    return fns

In [None]:
train_qfns = get_query_fns(TRAINDIR, template_fraction = 1.0)
trainqs = load_qdata(train_qfns)

In [None]:
print(len(trainqs))

In [None]:
allconstants = defaultdict(set)
constantmaxs = defaultdict(int)

In [None]:
for query in trainqs:
    for node in query["join_graph"].nodes():
        if not "pred_cols" in query["join_graph"].nodes()[node]:
            continue
        for ci, col in enumerate(query["join_graph"].nodes()[node]["pred_cols"]):
            consts = query["join_graph"].nodes()[node]["pred_vals"][ci]
            #print(consts)
            if isinstance(consts, dict):
                consts = consts["literal"]
            elif not isinstance(consts, list):
                consts = [consts]
            
            for const in consts:
                if isinstance(const, dict):
                    const = const["literal"]
                allconstants[col].add(const)
            if constantmaxs[col] < len(consts):
                constantmaxs[col] = len(consts)
#             if len(consts) > 20:
#                 print(consts)
#                 print(query["sql"])
#                 print(query["name"])

In [None]:
for k,v in allconstants.items():
    print(k, len(v), constantmaxs[k])

In [None]:
data = defaultdict(list)
for query in trainqs:
    jg = query["join_graph"]
    sg = query["subset_graph"]
    
    for node in jg.nodes():
        #data["num_unique_cols"].append(len)
        #print(query["sql"])
        #print(query[""])
        #print(query["join_graph"].nodes()[node])
        if not "pred_cols" in query["join_graph"].nodes()[node]:
            continue
        if len(jg.nodes()[node]["pred_cols"]) == 0:
            continue
        #data["num_unique_cols"].append(jg.nodes()[node]["pred_cols"])
        
        alias_key = tuple([node])
        cards = sg.nodes()[alias_key]["cardinality"]
        #print(cards)
        sel = cards["actual"] / cards["total"]
        sel = min(sel, 1.00)
        curcard = cards["actual"]
        
        seencols = []
        seenops = []
        consts = []
            
#         data["selectivity"].append(sel)
#         data["cardinality"].append(curcard)
        for ci, col in enumerate(query["join_graph"].nodes()[node]["pred_cols"]):
            
            op = jg.nodes()[node]["pred_types"][ci]           
            if op not in seenops:
                seenops.append(op)
            if col not in seencols:
                seencols.append(seencols)
            
            if isinstance(jg.nodes()[node]["pred_vals"][ci], int):
                consts.append(jg.nodes()[node]["pred_vals"][ci])
            else:
                consts += jg.nodes()[node]["pred_vals"][ci]
        
        data["input"].append(jg.nodes()[node]["real_name"])     
        if "like" in seenops:
            data["like_ops"].append(1)
        else:
            data["like_ops"].append(0)

        if "lt" in seenops:
            data["cont_ops"].append(1)
        else:
            data["cont_ops"].append(0)

        if "in" in seenops:
            data["in_ops"].append(1)
        else:
            data["in_ops"].append(0)

        if "in" in seenops or "eq" in seenops:
            data["discrete_ops"].append(1)
            data["num_discrete_consts"].append(len(consts))
        else:
            data["discrete_ops"].append(0)
            data["num_discrete_consts"].append(0)
        
        data["num_ops"].append(len(jg.nodes()[node]["pred_types"]))
        data["num_cols_all"].append(len(jg.nodes()[node]["pred_cols"]))
        data["num_unique_ops"].append(len(seenops))
        data["unique_filter_cols"].append(len(seencols))
        data["equal_dates"].append(0.0)

In [None]:
df = pd.DataFrame(data)

In [None]:
print(df.keys())

In [None]:
HTML(df[["num_ops", "num_unique_ops", "unique_filter_cols",
        "num_discrete_consts"]].\
     describe(percentiles=[0.9,0.99]).reset_index().to_html(index=False))

In [None]:
# HTML(df[["like_ops", "discrete_ops", "cont_ops", "in_ops", "equal_dates"]].\
#      describe(percentiles=[0.9,0.99]).reset_index().to_html(index=False))
HTML(df[["like_ops", "discrete_ops", "cont_ops", "in_ops"]].\
     describe(percentiles=[0.9,0.99]).reset_index().to_html(index=False))

In [None]:
HTML(df[["selectivity", "cardinality"]].\
     describe(percentiles=[0.9,0.99]).reset_index().round(3).to_html(index=False))

In [None]:
#df.groupby("input").count()

In [None]:
df.head(5)