In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import sys
sys.path.append("..")

from cardinality_estimation.featurizer import Featurizer
from query_representation.query import load_qrep
from cardinality_estimation.dataset import *
from torch.utils import data

import glob
import random
import os
import json
import time
import matplotlib.pyplot as plt

# Setup file paths / Download query data

In [None]:
import errno
def make_dir(directory):
    try:
        os.makedirs(directory)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

In [None]:
#TRAINDIR = os.path.join(os.path.join("..", "queries"), "imdb-unique-plans")

TRAINDIR = os.path.join(os.path.join("..", "queries"), "synth_2d_gaussian100K_lt")
TESTDIR = os.path.join(os.path.join("..", "queries"), "synth_2d_gaussian100K_gt")
RESULTDIR = os.path.join("..", "results")
make_dir(RESULTDIR)

# Query loading helper functions

In [None]:
def load_qdata(fns):
    qreps = []
    for qfn in fns:
        qrep = load_qrep(qfn)
        # TODO: can do checks like no queries with zero cardinalities etc.
        qreps.append(qrep)
        template_name = os.path.basename(os.path.dirname(qfn))
        qrep["name"] = os.path.basename(qfn)
        qrep["template_name"] = template_name
        qrep["workload"] = "imdb"
    return qreps

def get_query_fns(basedir, template_fraction=1.0, sel_templates=None):
    fns = []
    tmpnames = list(glob.glob(os.path.join(basedir, "*")))
    assert template_fraction <= 1.0
    
    for qi,qdir in enumerate(tmpnames):
        if os.path.isfile(qdir):
            continue
        template_name = os.path.basename(qdir)
        if sel_templates is not None and template_name not in sel_templates:
            continue
        
        # let's first select all the qfns we are going to load
        qfns = list(glob.glob(os.path.join(qdir, "*.pkl")))
        qfns.sort()
        num_samples = max(int(len(qfns)*template_fraction), 1)
        random.seed(1234)
        qfns = random.sample(qfns, num_samples)
        fns += qfns
    return fns

# Evaluation helper functions

In [None]:
def get_preds(alg, qreps):
    if isinstance(qreps[0], str):
        # only file paths sent
        qreps = load_qdata(qreps)
    
    ests = alg.test(qreps)
    return ests

def eval_alg(alg, eval_funcs, qreps, samples_type, result_dir="./results/"):
    '''
    '''
    np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

    alg_name = alg.__str__()
    exp_name = alg.get_exp_name()
    
    if isinstance(qreps[0], str):
        # only file paths sent
        qreps = load_qdata(qreps)
    
    ests = alg.test(qreps)

    for efunc in eval_funcs:
        rdir = None
        if result_dir is not None:
            rdir = os.path.join(result_dir, exp_name)
            make_dir(rdir)

        errors = efunc.eval(qreps, ests, samples_type=samples_type,
                result_dir=rdir,
                num_processes = -1,
                alg_name = alg_name,
                use_wandb=0)

        print("{}, {}, #samples: {}, {}: mean: {}, median: {}, 99p: {}"\
                .format(samples_type, alg, len(errors),
                    efunc.__str__(),
                    np.round(np.mean(errors),3),
                    np.round(np.median(errors),3),
                    np.round(np.percentile(errors,99),3)))

# Load queries

In [None]:
# set template_fraction <= 1.0 to test quickly w/ smaller datasets
# train_qfns = get_query_fns(TRAINDIR, template_fraction = 0.001)
# val_qfns = get_query_fns(VALDIR, template_fraction = 1.0)
# test_qfns = get_query_fns(TESTDIR, template_fraction = 1.0)

train_qfns = get_query_fns(TRAINDIR, template_fraction = 1.0)
val_qfns = []
test_qfns = get_query_fns(TESTDIR, template_fraction = 1.0)

print("Selected {} training queries, {} validation queries, {} test queries".\
      format(len(train_qfns), len(val_qfns), len(test_qfns)))

In [None]:
from evaluation.eval_fns import QError, SimplePlanCost
EVAL_FNS = []
EVAL_FNS.append(QError())
#EVAL_FNS.append(SimplePlanCost())

In [None]:
def init_featurizer(featurization_type):
    # Load database specific data, e.g., information about columns, tables etc.
    dbdata_fn = os.path.join(TRAINDIR, "dbdata.json")
    featurizer = Featurizer("user", "pwd", "synth", None, None)
    
    
    with open(dbdata_fn, "r") as f:
        dbdata = json.load(f)
        
    featurizer.update_using_saved_stats(dbdata)
    
    featurizer.setup(ynormalization="log",
        feat_separate_alias = 0,
        onehot_dropout = onehot_dropout,
        feat_mcvs = 0,
        heuristic_features = 1,
        featurization_type=featurization_type,
        table_features=1,
        flow_features = 0,
        join_features= "onehot",
        set_column_feature= "onehot",
        max_discrete_featurizing_buckets=10,
        max_like_featurizing_buckets=10,
        embedding_fn = "none",
        embedding_pooling = None,
        implied_pred_features = 0,
        feat_onlyseen_preds = 1)
    featurizer.update_ystats(trainqs)
    
    featurizer.update_max_sets(trainqs)
    featurizer.update_workload_stats(trainqs)
    featurizer.init_feature_mapping()
    #featurizer.update_ystats(trainqs)
   

    # if feat_onlyseen_preds:
    # just do it always
    featurizer.update_seen_preds(trainqs)
    
    return featurizer

In [None]:
# going to start training the models
trainqs = load_qdata(train_qfns)

In [None]:
testqs = load_qdata(test_qfns)

In [None]:
max_epochs = 100
lr=0.001
training_opt = "none"
opt_lr = 0.1
swa_start = 5
mask_unseen_subplans = 0
subplan_level_outputs=0
normalize_flow_loss = 1
heuristic_unseen_preds = 0
cost_model = "C"
use_wandb = 0
eval_fns = "qerr,plancost"
load_padded_mscn_feats = 1
mb_size = 1024
weight_decay = 0.0
load_query_together = 0
result_dir = "./results"

onehot_dropout=0
onehot_mask_truep=0.8
onehot_reg=0
onehot_reg_decay=0.0
eval_epoch = 20000
optimizer_name="adamw"
clip_gradient=20.0
loss_func_name = "mse"
hidden_layer_size = 8
num_hidden_layers = 2

In [None]:
#from cardinality_estimation.mscn import MSCN as MSCN2
from cardinality_estimation.mscn import MSCNCaptum as MSCN2

featurizer = init_featurizer("set")

mscn = MSCN2(max_epochs = max_epochs, lr=lr,
                training_opt = training_opt,
                test_random_bitmap = 0,
                inp_dropout = 0.0,
                hl_dropout = 0.0,
                comb_dropout = 0.0,
                max_num_tables = -1,
                opt_lr = opt_lr,
                swa_start = swa_start,
                mask_unseen_subplans = mask_unseen_subplans,
                subplan_level_outputs=subplan_level_outputs,
                normalize_flow_loss = normalize_flow_loss,
                heuristic_unseen_preds = heuristic_unseen_preds,
                cost_model = cost_model,
                use_wandb = use_wandb,
                eval_fns = eval_fns,
                load_padded_mscn_feats = load_padded_mscn_feats,
                mb_size = mb_size,
                weight_decay = weight_decay,
                load_query_together = load_query_together,
                result_dir = result_dir,
                onehot_dropout=onehot_dropout,
                onehot_mask_truep=onehot_mask_truep,
                onehot_reg=onehot_reg,
                onehot_reg_decay=onehot_reg_decay,
                # num_hidden_layers=num_hidden_layers,
                eval_epoch = eval_epoch,
                optimizer_name=optimizer_name,
                clip_gradient=clip_gradient,
                loss_func_name = loss_func_name,
                hidden_layer_size = hidden_layer_size,
                other_hid_units = hidden_layer_size,
                num_hidden_layers = 2,
                early_stopping = False,
                random_bitmap_idx = False,
                reg_loss = False,
                )

In [None]:
mscn.train(trainqs, valqs=None, testqs=None,
    featurizer=featurizer, result_dir=RESULTDIR)

In [None]:
from evaluation.eval_fns import QError, SimplePlanCost
EVAL_FNS = []
EVAL_FNS.append(QError())
#EVAL_FNS.append(SimplePlanCost())

#evaluate model
eval_alg(mscn, EVAL_FNS, trainqs, "train")

#preds = mscn.test(testqs)
eval_alg(mscn, EVAL_FNS, testqs, "test")

In [None]:
#eval_alg(mscn, EVAL_FNS, testqs, "test")

In [None]:
models_path = "results/" + mscn.get_exp_name() + "/"
print(models_path)

In [None]:
# imports from captum library
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation, Lime

In [None]:
#featurizer = init_featurizer("set")
ds = QueryDataset(trainqs[0:10], featurizer,
        True,
        load_padded_mscn_feats=True)
loader = data.DataLoader(ds,
        batch_size=1, shuffle=False,
        collate_fn=mscn_collate_fn_together,
        )

testds = QueryDataset(testqs[0:3], featurizer,
        True,
        load_padded_mscn_feats=True)

testloader = data.DataLoader(testds,
        batch_size=1, shuffle=False,
        collate_fn=mscn_collate_fn_together,
        )

In [None]:
iloader = iter(testloader)
#iloader = iter(loader)

In [None]:
xbatch,y,info = next(iloader)
print(torch.sum(xbatch["pred"][:,:,20]))
print(torch.sum(xbatch["pred"][:,:,21]))

In [None]:
xbatch

In [None]:
import sqlparse
#sql = trainqs[0]["sql"]
sql = testqs[0]["sql"]
print(sqlparse.format(sql, reindent=True, keyword_case='upper'))

In [None]:
print(featurizer.featurizer_type_idxs["col_onehot"])
print(featurizer.columns_onehot_idx)
# mii.info
# n.name
# n.name_pcode_cf
# n.surname_pcode
#[0  1  6  8  9 11 17 18 20 21 34 37 38 39 40 41 42 43 44 45]

In [None]:
curx = xbatch["pred"].detach().numpy()
print(curx.shape)
xsum = curx.sum(axis=0).sum(axis=0)

print(xsum.shape)
print(xsum)

In [None]:
def get_attr_vecs(curx, curattrs):
    #idxs = 0
    xsum = curx.sum(axis=0).sum(axis=0)
    zero_idxs = xsum == 0
    curattrs = curattrs[:,:,~zero_idxs]
    idxs = np.array(range(len(xsum)))[~zero_idxs]
    
    #print(curx)
    #print(zero_idxs)
    #print(idxs)
    
    curx = curx[:,:,~zero_idxs]
    
    assert curx.shape == curattrs.shape
    
    # TODO: avg based on non-zero elements
    
    #print("attr sum: ", np.sum(curattrs))
    curattrs = np.abs(curattrs)
    #print("attr sum after abs: ", np.sum(curattrs))
    
    attr_sum = curattrs.sum(axis=0).sum(axis=0)
    
    assert attr_sum.shape[0] == curx.shape[-1]
    
    # TODO: do we need this?
    #attr_sum = attr_sum / np.linalg.norm(attr_sum, ord=1)
    
    # TODO: do this or no?
    curx_nonz = curx != 0
    
    xnonzero_sums = curx_nonz.sum(axis=0).sum(axis=0)
    
    #TODO?
    attr_sum = attr_sum / xnonzero_sums
    
    return idxs, attr_sum

def get_mscn_attrs(xbatch, attrs, featurizer, normalize=True):
    '''
    returns a vector with x-axis names and attribution values;
    '''
    batchsize = xbatch["table"].shape[0]
    assert batchsize == attrs[0].shape[0]
    tabidxs, tabattrs = get_attr_vecs(xbatch["table"].detach().numpy(), attrs[0].detach().numpy())
    predidxs, predattrs = get_attr_vecs(xbatch["pred"].detach().numpy(), attrs[1].detach().numpy())
    print(predidxs)
    joinidxs, joinattrs = get_attr_vecs(xbatch["join"].detach().numpy(), attrs[2].detach().numpy())
    
    # TODO: need to do sample_bitmaps
    tablabels = []
    for curtabidx in tabidxs:
        for tab,tidx in featurizer.table_featurizer.items():
            if tidx == curtabidx:
                tablabels.append(tab)
                break
    joinlabels = []
    for curjidx in joinidxs:
        for join,jidx in featurizer.join_featurizer.items():
            found = False
            if curjidx == jidx:
                joinlabels.append(join)
                found = True
                break
        if not found:
            joinlabels.append("unknown")
    # TODO: join-stats
    
    predlabels = []
    colstart,collen = featurizer.featurizer_type_idxs["col_onehot"]
    # TODO: if stats used
    #colstatsstart,colstatsend = self.featurizer_type_idxs["col_stats"]
    cmp_start,cmplen = featurizer.featurizer_type_idxs["cmp_op"]
    cstart,clen = featurizer.featurizer_type_idxs["constant_continuous"]
    lstart,llen = featurizer.featurizer_type_idxs["constant_like"]
    dstart,dlen = featurizer.featurizer_type_idxs["constant_discrete"]
    hstart,hlen = featurizer.featurizer_type_idxs["heuristic_ests"]
    
    #print(hstart, hlen)
    for pi in predidxs:
        if pi >= colstart and pi < colstart+collen:
            found = False
            for col,colidx in featurizer.columns_onehot_idx.items():
                if colidx == pi-colstart:
                #if colidx == pi:
                    print(col)
                    predlabels.append(col)
                    found = True
                    break     
            if not found:
                print(pi)
                predlabels.append("col-unknown")
        elif pi >= cmp_start and pi < cmp_start+cmplen:
            predlabels.append("cmp")
        elif pi == cstart:
            #predlabels.append("<")
            predlabels.append("range-filter1")
        elif pi == cstart+1:
            predlabels.append("range-filter2")
        elif pi >= lstart and pi < lstart+llen:
            #predlabels.append("Like-Hash-" + str(pi))
            predlabels.append("Like-Hashes")
        elif pi >= dstart and pi < dstart+dlen:
            #predlabels.append("Constant-Hash-" + str(pi))
            predlabels.append("Constant-Hashes")
        elif pi == hstart:
            predlabels.append("PostgreSQL Est (Table)")
        elif pi == hstart+1:
            predlabels.append("PostgreSQL Est (Subplan)")
    
    assert len(predidxs) == len(predlabels)
#     print(len(predidxs), len(predlabels))
#     print(predidxs)
#     print(predlabels)
    attrs = np.concatenate([tabattrs, joinattrs, predattrs])
    xlabels = tablabels + joinlabels + predlabels
    
    if normalize:
        attrs = attrs / np.linalg.norm(attrs, ord=1)
    
    return xlabels,attrs

In [None]:
from cardinality_estimation.nets import *

def get_model_attrs(mscn_model, weights, xbatch):
    n_out = 1
    sfeats = mscn_model.sample_mlp1.in_features
    pfeats = mscn_model.predicate_mlp1.in_features
    jfeats = mscn_model.join_mlp1.in_features

    net = SetConvNoFlow(sfeats,
        pfeats, 
        jfeats,
        hidden_layer_size,
        n_out=n_out,
        dropouts=[0.0, 0.0, 0.0])
    net.load_state_dict(weights)
    
    model = net
    model.eval()
    ig = IntegratedGradients(model)
    #ig_nt = NoiseTunnel(ig)
    #dl = DeepLift(model)
    #gs = GradientShap(model)
    #fa = FeatureAblation(model)
    #limea = Lime(model)
    ig_attr_test = ig.attribute(tuple([xbatch["table"], 
                            xbatch["pred"],
                            xbatch["join"], 
                            xbatch["tmask"], 
                            xbatch["pmask"], 
                            xbatch["jmask"]]
                            ), n_steps=50)
    #print("ig done")
    
    xlabels, igattrs = get_mscn_attrs(xbatch, ig_attr_test, featurizer, normalize=False)

    return xlabels, igattrs

In [None]:
weights = mscn.net.state_dict()
xlabels,attrs = get_model_attrs(mscn.net, weights, xbatch)

In [None]:
attrs

In [None]:
mscn.net

In [None]:
import seaborn as sns

def plot_attrs(xlabels, attrs, ax=None):
    if ax is None:
        plt.figure(figsize=(20, 20))
        ax = plt.axes()

    #plt.yticks(fontsize=20)
    sns.barplot(x=attrs, y=xlabels, color='#4260f5', orient="horizontal", ax=ax, ci="sd")

In [None]:
# import pickle

# if onehot_dropout:
#     with open("dropout-attrs.pkl", "wb") as f:
#         pickle.dump([xlabels, attrs], f)
# else:
#     with open("default-attrs.pkl", "wb") as f:
#         pickle.dump([xlabels, attrs], f)

In [None]:
sql

In [None]:
plot_attrs(xlabels, attrs)
plt.yticks(fontsize=20)
#plt.show()
plt.savefig("attribution-dropout.png")

In [None]:
featurizer.featurizer_type_idxs

In [None]:
print(xbatch["table"].sum(axis=[0,1]))
print(xbatch["pred"].sum(axis=[0,1]))
print(xbatch["join"].sum(axis=[0,1]))

print(xbatch["table"].shape)
print(xbatch["pred"].shape)
print(xbatch["join"].shape)

In [None]:
def get_attr_vecs_single(curx, curattrs):
    #idxs = 0
    #print(curx.shape)
    #print(curattrs.shape)
    xsum = curx.sum(axis=0).sum(axis=0)
    #print(xsum.shape)
    zero_idxs = xsum == 0
    curattrs = curattrs[:,:,~zero_idxs]
    idxs = np.array(range(len(xsum)))[~zero_idxs]
    
    curx = curx[:,:,~zero_idxs]
    
    assert curx.shape == curattrs.shape
    
    # TODO: abs values or also accept negative correlations?
    curattrs = np.abs(curattrs)
    #print(curattrs)
    
    attr_sum = curattrs.sum(axis=0).sum(axis=0)
    
    assert attr_sum.shape[0] == curx.shape[-1]
    
    # TODO: do we need this?
    #attr_sum = attr_sum / np.linalg.norm(attr_sum, ord=1)
    
    # TODO: do this or no?
    curx_nonz = curx != 0
    
    xnonzero_sums = curx_nonz.sum(axis=0).sum(axis=0)
    
#     print(attr_sum)
#     print(xnonzero_sums)
    # Do this because different features have different number of copies in the same set, 
    # e.g, subplan features are in every vector
    attr_sum = attr_sum / xnonzero_sums
    
    return idxs, attr_sum

def get_mscn_attrs_single(xbatch, xi, attrs, featurizer, normalize=True):
    '''
    returns a vector with x-axis names and attribution values;
    '''
    #batchsize = xbatch["table"].shape[0]
    #assert batchsize == attrs[0].shape[0]
    tabidxs, tabattrs = get_attr_vecs_single(xbatch["table"][xi:xi+1].detach().numpy(), 
                                      attrs[0].detach().numpy())
    predidxs, predattrs = get_attr_vecs_single(xbatch["pred"][xi:xi+1].detach().numpy(), 
                                        attrs[1].detach().numpy())
    #print(predidxs)
    joinidxs, joinattrs = get_attr_vecs_single(xbatch["join"][xi:xi+1].detach().numpy(), 
                                        attrs[2].detach().numpy())
    
    # TODO: need to do sample_bitmaps
    tablabels = []
    for curtabidx in tabidxs:
        for tab,tidx in featurizer.table_featurizer.items():
            if tidx == curtabidx:
                tablabels.append(tab)
                break
    
    joinlabels = []
    for curjidx in joinidxs:
        for join,jidx in featurizer.join_featurizer.items():
            found = False
            if curjidx == jidx:
                joinlabels.append(join)
                found = True
                break
        if not found:
            joinlabels.append("unknown")
    # TODO: join-stats
    
    predlabels = []
    colstart,collen = featurizer.featurizer_type_idxs["col_onehot"]
    # TODO: if stats used
    #colstatsstart,colstatsend = self.featurizer_type_idxs["col_stats"]
    cmp_start,cmplen = featurizer.featurizer_type_idxs["cmp_op"]
    cstart,clen = featurizer.featurizer_type_idxs["constant_continuous"]
    lstart,llen = featurizer.featurizer_type_idxs["constant_like"]
    dstart,dlen = featurizer.featurizer_type_idxs["constant_discrete"]
    hstart,hlen = featurizer.featurizer_type_idxs["heuristic_ests"]
    
    #print(hstart, hlen)
    for pi in predidxs:
        if pi >= colstart and pi < colstart+collen:
            found = False
            for col,colidx in featurizer.columns_onehot_idx.items():
                if colidx == pi-colstart:
                #if colidx == pi:
                    #print(col)
                    predlabels.append(col)
                    found = True
                    break     
            if not found:
                #print(pi)
                predlabels.append("col-unknown")
                
        elif pi >= cmp_start and pi < cmp_start+cmplen:
            predlabels.append("cmp")
        elif pi == cstart:
            #predlabels.append("<")
            predlabels.append("range-filter")
        elif pi == cstart+1:
            predlabels.append("range-filter")
        elif pi >= lstart and pi < lstart+llen:
            #predlabels.append("Like-Hash-" + str(pi))
            predlabels.append("Like-Hashes")
        elif pi >= dstart and pi < dstart+dlen:
            #predlabels.append("Constant-Hash-" + str(pi))
            predlabels.append("Constant-Hashes")
        elif pi == hstart:
            predlabels.append("PostgreSQL Est (Table)")
        elif pi == hstart+1:
            predlabels.append("PostgreSQL Est (Subplan)")
    
    assert len(predidxs) == len(predlabels)

    attrs = np.concatenate([tabattrs, joinattrs, predattrs])
    xlabels = tablabels + joinlabels + predlabels
    
    if normalize:
        #attrs = attrs / np.linalg.norm(attrs, ord=2)
        attrs = attrs / np.sum(attrs)
        #print(np.sum(attrs))
    
    return xlabels,attrs

In [None]:
n_out = 1
sfeats = mscn.net.sample_mlp1.in_features
pfeats = mscn.net.predicate_mlp1.in_features
jfeats = mscn.net.join_mlp1.in_features
    
net = SetConvNoFlow(sfeats,
    pfeats, jfeats,
    hidden_layer_size,
    n_out=1,
    dropouts=[0.0, 0.0, 0.0])
net.load_state_dict(weights)

In [None]:
qrep = testqs[0]
subsetg = qrep["subset_graph"]
node_list = list(subsetg.nodes())
node_list.sort()

sfeats = mscn.net.sample_mlp1.in_features
pfeats = mscn.net.predicate_mlp1.in_features
jfeats = mscn.net.join_mlp1.in_features


model = net
model.eval()
ig = IntegratedGradients(model)

assert xbatch["table"].shape[0] == len(node_list)

allsqls = []
allxlabels = []
alligattrs = []

for xi in range(xbatch["table"].shape[0]):
    subjg = qrep["join_graph"].subgraph(node_list[xi])
    subsql = nx_graph_to_query(subjg)
    #print(subsql)
    
    ig_attr_test = ig.attribute(tuple([xbatch["table"][xi:xi+1], xbatch["pred"][xi:xi+1],
                            xbatch["join"][xi:xi+1], 
                            xbatch["tmask"][xi:xi+1], xbatch["pmask"][xi:xi+1], 
                                   xbatch["jmask"][xi:xi+1]]), n_steps=50)

    #print("ig done")
    xlabels, igattrs = get_mscn_attrs_single(xbatch, xi, 
                                ig_attr_test, featurizer, 
                                             normalize=True)
    #print("Xlabels: ", xlabels)
    #print(igattrs)
    #break
    allsqls.append(subsql)
    allxlabels.append(xlabels)
    alligattrs.append(igattrs)

In [None]:
## saving attribute importance file; can be visualized further in the captum jupyter notebook
import pickle
if onehot_dropout:
    with open("subplan-dropout-attrs.pkl", "wb") as f:
        pickle.dump([allsqls, allxlabels, alligattrs], f, protocol=3)
else:
    print("saving default")
    with open("subplan-default-attrs.pkl", "wb") as f:
        pickle.dump([allsqls, allxlabels, alligattrs], f, protocol=3)