In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd

from cardinality_estimation.featurizer import Featurizer
from query_representation.query import load_qrep

import glob
import random
import os
import json
import time
from collections import defaultdict

In [None]:
TRAINDIR = os.path.join(os.path.join("", "queries"), "mlsys1-train")

In [None]:

def load_qdata(fns):
    qreps = []
    for qfn in fns:
        qrep = load_qrep(qfn)
        # TODO: can do checks like no queries with zero cardinalities etc.
        qreps.append(qrep)
        template_name = os.path.basename(os.path.dirname(qfn))
        qrep["name"] = os.path.basename(qfn)
        qrep["template_name"] = template_name
    return qreps

def get_query_fns(basedir, template_fraction=1.0):
    fns = []
    tmpnames = list(glob.glob(os.path.join(basedir, "*")))
    assert template_fraction <= 1.0
    
    for qi,qdir in enumerate(tmpnames):
        if os.path.isfile(qdir):
            continue
        template_name = os.path.basename(qdir)
        # let's first select all the qfns we are going to load
        qfns = list(glob.glob(os.path.join(qdir, "*.pkl")))
        qfns.sort()
        num_samples = max(int(len(qfns)*template_fraction), 1)
        random.seed(1234)
        qfns = random.sample(qfns, num_samples)
        fns += qfns
    return fns

In [None]:
train_qfns = get_query_fns(TRAINDIR, template_fraction = 0.1)
trainqs = load_qdata(train_qfns)

In [None]:
allconstants = defaultdict(set)
constantmaxs = defaultdict(int)

In [None]:
for query in trainqs:
    for node in query["join_graph"].nodes():
        for ci, col in enumerate(query["join_graph"].nodes()[node]["pred_cols"]):
            consts = query["join_graph"].nodes()[node]["pred_vals"][ci]
            for const in consts:
                allconstants[col].add(const)
            if constantmaxs[col] < len(consts):
                constantmaxs[col] = len(consts)
#             if len(consts) > 20:
#                 print(consts)
#                 print(query["sql"])
#                 print(query["name"])

In [None]:
for k,v in allconstants.items():
    print(k, len(v), constantmaxs[k])