In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import sys
sys.path.append("..")

import networkx as nx
from networkx.readwrite import json_graph

from query_representation.utils import *

from torch.utils import data
import glob
import random
import os
import json
import time
import matplotlib.pyplot as plt
import pickle

# Setup file paths / Download query data

In [None]:
import errno
def make_dir(directory):
    try:
        os.makedirs(directory)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise
            
def eval_alg(alg, eval_funcs, qreps, samples_type, result_dir="./results/"):
    '''
    '''
    np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
    print("start eval alg")
    if not isinstance(alg, list):
        ests = alg.test(qreps)
        alg_name = alg.__str__()
        exp_name = alg.get_exp_name()
    else:
        ests = alg
        alg_name = "Estimates"
        exp_name = "test"
    
    if isinstance(qreps[0], str):
        # only file paths sent
        qreps = load_qdata(qreps)
    
    
    print("before eval funcs")
    for efunc in eval_funcs:
        rdir = None
        if result_dir is not None:
            rdir = os.path.join(result_dir, exp_name)
            make_dir(rdir)

        errors = efunc.eval(qreps, ests, samples_type=samples_type,
                result_dir=rdir,
                num_processes = -1,
                alg_name = alg_name,
                use_wandb=0,
                user = "ceb",
                db_name = "imdb",
                db_host = "localhost",
                password = "password",
                port = 5432
                )

        print("{}, {}, #samples: {}, {}: mean: {}, median: {}, 99p: {}"\
                .format(samples_type, alg_name, len(errors),
                    efunc.__str__(),
                    np.round(np.mean(errors),3),
                    np.round(np.median(errors),3),
                    np.round(np.percentile(errors,99),3)))

In [None]:

#TESTDIR = os.path.join(os.path.join("..", "queries"), "imdb-unique-plans")
#RESULTDIR = os.path.join("..", "results")
#make_dir(RESULTDIR)

MAXCARD = 150001000000.0

# Query loading helper functions

In [None]:
def load_qrep(fn):
    assert ".pkl" in fn
    try:
        with open(fn, "rb") as f:
            query = pickle.load(f)
    except:
        print(fn + " failed to load...")
        exit(-1)

    query["subset_graph"] = \
            nx.OrderedDiGraph(json_graph.adjacency_graph(query["subset_graph"]))
    query["join_graph"] = json_graph.adjacency_graph(query["join_graph"])
    if "subset_graph_paths" in query:
        query["subset_graph_paths"] = \
                nx.OrderedDiGraph(json_graph.adjacency_graph(query["subset_graph_paths"]))

    return query

def load_qdata(fns):
    qreps = []
    for qfn in fns:
        qrep = load_qrep(qfn)
        # TODO: can do checks like no queries with zero cardinalities etc.
        qreps.append(qrep)
        template_name = os.path.basename(os.path.dirname(qfn))
        qrep["name"] = os.path.basename(qfn)
        qrep["template_name"] = template_name
        qrep["workload"] = ""
    return qreps

def get_query_fns(basedir, template_fraction=1.0, sel_templates=None):
    fns = []
    tmpnames = list(glob.glob(os.path.join(basedir, "*")))
    print(tmpnames)
    assert template_fraction <= 1.0
    
    if sel_templates == None:
        sel_templates = "all"
    
    for qi,qdir in enumerate(tmpnames):
        if os.path.isfile(qdir):
            print(qdir)
            continue
        template_name = os.path.basename(qdir)
        
        if "no7" in sel_templates and template_name == "7a":
            continue
            
        if "all" not in sel_templates and template_name not in sel_templates:
            continue
        
        # let's first select all the qfns we are going to load
        qfns = list(glob.glob(os.path.join(qdir, "*.pkl")))
        qfns.sort()
        num_samples = max(int(len(qfns)*template_fraction), 1)
        random.seed(1234)
        qfns = random.sample(qfns, num_samples)
        fns += qfns
    return fns

# Evaluation helper functions

# Load queries

In [None]:
QDIR = "imdb-unique-plans"
#QDIR = "job2"
#TEMPLATES = "all"
TEMPLATES = "2b"
TEST = "2a"
#TEMPLATES = "all-no7"

TRAINDIR = os.path.join(os.path.join("/flash1/pari/MyCEB", "queries"), QDIR)
RTDIRS = ["/flash1/pari/MyCEB/runtime_plans/pg"]
qfns = get_query_fns(TRAINDIR, template_fraction = 1.0, sel_templates=TEMPLATES)
tqfns = get_query_fns(TRAINDIR, template_fraction = 1.0, sel_templates=TEST)

# TRAINDIR = os.path.join(os.path.join("/flash1/pari/MyCEB", "queries"), "job")
#RTDIRS = ["/flash1/pari/MyCEB/runtime_plans/JOB/"]
# qfns = get_query_fns(TRAINDIR, template_fraction = 1.0, sel_templates=None)

print(len(qfns))
qdata = load_qdata(qfns)
tqdata = load_qdata(tqfns)

In [None]:
rtdfs = []

for RTDIR in RTDIRS:
    rdirs = os.listdir(RTDIR)
    for rd in rdirs:
        rtfn = os.path.join(RTDIR, rd, "Runtimes.csv")
        if os.path.exists(rtfn):
            rtdfs.append(pd.read_csv(rtfn))
            
rtdf = pd.concat(rtdfs)
print("Num RTs: ", len(rtdf))

In [None]:
rtdf

In [None]:
from collections import defaultdict
import numpy

subplan_data = defaultdict(list)

rowkeys = set()
for qi, qrep in enumerate(qdata):
    for node in qrep["subset_graph"].nodes():
#         if len(node) == 1:
#             continue
        rowkeys.add(node)
        
rowkeys = list(rowkeys)
rowkeys.sort()
rowidxs = {rk:ri for ri,rk in enumerate(rowkeys)}

mat = np.zeros((len(rowidxs), len(qdata)))

for qi, qrep in enumerate(qdata):
    for node in qrep["subset_graph"].nodes():
        if node not in rowidxs:
            continue
        truec = qrep["subset_graph"].nodes()[node]["cardinality"]["actual"]
        if truec >= MAXCARD:
            truec = 0.0
            
        mat[rowidxs[node], qi] = truec

In [None]:
def load_plandata(qdata, rtdf):
    rowkeys = set()
    rtdata = []
    for qi, qrep in enumerate(qdata):
        if qrep["name"] not in rtdf["qname"].values:
            continue
        rtdata.append(qrep)
    
    rowkeys = set()
    for qi, qrep in enumerate(rtdata):
        for node in qrep["subset_graph"].nodes():
            rowkeys.add(node)
    rowkeys = list(rowkeys)
    rowkeys.sort()
    rowidxs = {rk:ri for ri,rk in enumerate(rowkeys)}
    
    mat = np.zeros((len(rowidxs), len(rtdata)))
    planmat = np.zeros((len(rowidxs), len(rtdata)))
    subplan_masks = []
    
    for qi, qrep in enumerate(rtdata):
        tmp = rtdf[rtdf["qname"] == qrep["name"]]
        exp = tmp["exp_analyze"].values[0]
        try:
            exp = eval(exp)
        except:
            continue
            
        G = explain_to_nx(exp)
        seen_subplans = [ndata["aliases"] for n,ndata in G.nodes(data=True)]
        subplan_masks.append(seen_subplans)
        
        for node in qrep["subset_graph"].nodes():
            if node not in rowidxs:
                continue
            truec = qrep["subset_graph"].nodes()[node]["cardinality"]["actual"]
            mat[rowidxs[node], qi] = truec
            
            if list(node) in seen_subplans:
                planmat[rowidxs[node], qi] = truec
    
    return mat, planmat, rtdata, subplan_masks, rowidxs

In [None]:
P, S, Q = np.linalg.svd(mat, full_matrices=False)
print(mat.shape)
print(S.shape)

In [None]:
S.round(2)

In [None]:
print(np.percentile(S, 90), np.percentile(P, 90), np.percentile(Q, 90))

In [None]:
print(np.percentile(S, 90) - np.max(mat))
print(np.max(mat))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

#sns.lineplot(np.log(S))
#sns.lineplot(y=S)
plt.plot(S)
plt.yscale("log")
plt.show()

In [None]:
cds = np.cumsum(S) / np.sum(S)
r90 = np.min(np.where(cds > 0.90))
r90

In [None]:
def omega_approx(beta):
    """Return an approximate omega value for given beta. Equation (5) from Gavish 2014."""
    return 0.56 * beta**3 - 0.95 * beta**2 + 1.82 * beta + 1.43

def svht(X, sigma=None, sv=None):
    """Return the optimal singular value hard threshold (SVHT) value.
    `X` is any m-by-n matrix. `sigma` is the standard deviation of the 
    noise, if known. Optionally supply the vector of singular values `sv`
    for the matrix (only necessary when `sigma` is unknown). If `sigma`
    is unknown and `sv` is not supplied, then the method automatically
    computes the singular values."""

    try:
        m,n = sorted(X.shape) # ensures m <= n
    except:
        raise ValueError('invalid input matrix')
    beta = m / n # ratio between 0 and 1
    if sigma is None: # sigma unknown
        if sv is None:
            sv = svdvals(X)
        sv = np.squeeze(sv)
        if sv.ndim != 1:
            raise ValueError('vector of singular values must be 1-dimensional')
        return np.median(sv) * omega_approx(beta)
    else: # sigma known
        return lambda_star(beta) * np.sqrt(n) * sigma

# find tau star hat when sigma is unknown
# tau = svht(D, sv=sv)

# # find tau star when sigma is known
# tau = svht(D, sigma=0.5)

In [None]:
tau = svht(mat, sv=S)
tau
rank = np.sum(S > tau)
rank

In [None]:
mat.shape

In [None]:
B = np.dot(P, np.dot(np.diag(S), Q))

print("Sum Recon: ", np.sum(B), "Sum Orig: ", np.sum(mat), "Diff: ", np.sum(B)-np.sum(mat))

print("Orig Stats, Min: ", np.min(mat), "Max: ", np.max(mat), 
      "50p: ", np.percentile(mat, 50), "90p:", np.percentile(mat, 90), 
      "99p: ", np.percentile(mat, 99), "999p: ", np.percentile(mat, 99.9))

print("Recon Stats, Min: ", np.min(B), "Max: ", np.max(B), 
      "50p: ", np.percentile(B, 50), "90p:", np.percentile(B, 90), 
      "99p: ", np.percentile(B, 99), "999p: ", np.percentile(mat, 99.9))

In [None]:
print(P.shape, Q.shape)

In [None]:
RANK  = 25
B = np.dot(P[:,0:RANK], np.dot(np.diag(S[0:RANK]), Q[0:RANK,:]))
print("Sum Recon: ", np.sum(B), "Sum Orig: ", np.sum(mat), "Diff: ", np.sum(B)-np.sum(mat))

print("Orig Stats, Min: ", np.min(mat), "Max: ", np.max(mat), 
      "50p: ", np.percentile(mat, 50), "90p:", np.percentile(mat, 90), 
      "99p: ", np.percentile(mat, 99), "999p: ", np.percentile(mat, 99.9))

print("Recon Stats, Min: ", np.min(B), "Max: ", np.max(B), 
      "50p: ", np.percentile(B, 50), "90p:", np.percentile(B, 90), 
      "99p: ", np.percentile(B, 99), "999p: ", np.percentile(mat, 99.9))

In [None]:
def matrix_to_ests(estmat, rowidxs, qdata):
    ests = []
    
    for qi, q in enumerate(qdata):
        curests = {}
        for node in q["subset_graph"].nodes():
            curests[node] = max(estmat[rowidxs[node], qi], 1.0)
        ests.append(curests)
    return ests

In [None]:
#matests = matrix_to_ests(B, rowidxs, qdata)

In [None]:
#eval_alg(matests, EVAL_FNS, qdata, "train")

In [None]:
from evaluation.eval_fns import QError, SimplePlanCost, PostgresPlanCost, AbsError

# EVAL_FNS = []
# #EVAL_FNS.append(SimplePlanCost())
# EVAL_FNS.append(QError())
# EVAL_FNS.append(PostgresPlanCost(cost_model="C"))

PG_PERRS = {}
PG_PERRS["job2all"] = 471442.0
PG_PERRS["imdb-unique-plans1a"] = 6208987
PG_PERRS["imdb-unique-plansall-no7"] = 17062409

PG_QERRS = {}
PG_QERRS["job2all"] = 4974.446
PG_QERRS["imdb-unique-plans1a"] = 190.0
PG_QERRS["imdb-unique-plansall-no7"] = 89941

TRUE_PERRS = {}
TRUE_PERRS["job2all"] = 159146.0
TRUE_PERRS["imdb-unique-plans1a"] = 1200973
TRUE_PERRS["imdb-unique-plansall-no7"] = 8022977

PG_ABSERRS = {}
PG_ABSERRS["job2all"] = 6785997.991
PG_ABSERRS["imdb-unique-plans1a"] = 1282985
PG_ABSERRS["imdb-unique-plansall-no7"] = 41657676


def get_rank_effects(mat, rowidxs, qdata):
    
    mean_qerrs = []
    abs_errs = []
    perrs = []
    
    ranks = []
    
    for rank in range(0, 100, 2):
        
        if rank == 0:
            rank = 1
        B = np.dot(P[:,0:rank], np.dot(np.diag(S[0:rank]), Q[0:rank,:]))
        matests = matrix_to_ests(B, rowidxs, qdata)
        
        ranks.append(rank)
        
        qerr = QError()
        qerrors = qerr.eval(qdata, matests, samples_type="train",
            result_dir=None,
            num_processes = -1,
            alg_name = "SVD",
            use_wandb=0)
     
        mean_qerrs.append(np.mean(qerrors))
        
        abserr = AbsError()
        abs_errs.append(np.mean(abserr.eval(qdata, matests, samples_type="train",
            result_dir=None,
            num_processes = -1,
            alg_name = "SVD",
            use_wandb=0)))

        ppc = PostgresPlanCost(cost_model="C")
        errors = ppc.eval(qdata, matests, samples_type="train",
            result_dir=None,
            num_processes = -1,
            alg_name = "SVD",
            use_wandb=0,
            user = "ceb",
            db_name = "imdb",
            db_host = "localhost",
            password = "password",
            port = 5432
        )
        perrs.append(np.mean(errors))
        
        print(ranks)
        print(mean_qerrs)
    
    fig,axs = plt.subplots(nrows=1,ncols=3,figsize=(30,6))
    ax = axs[0]
    sns.lineplot(x=ranks, y = mean_qerrs, ax = ax)
    
    pgqerr = PG_QERRS[QDIR+TEMPLATES]
    ax.hlines(y=pgqerr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)
    
    ax.set_ylabel("QError", fontsize=16)
    ax.set_xlabel("Rank", fontsize=16)
    ax.tick_params(axis='both', which='both', labelsize=16)
    
    ax.set_yscale("log")
    
    ax = axs[1]
    sns.lineplot(x=ranks, y = abs_errs, ax = ax)
    
    ax.set_ylabel("Absolute Errors", fontsize=16)
    ax.set_xlabel("Rank", fontsize=16)
    ax.tick_params(axis='both', which='both', labelsize=16)
    
    pgaerr = PG_ABSERRS[QDIR+TEMPLATES]
    ax.hlines(y=pgaerr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)
    
    ax.set_yscale("log")
    
    ax = axs[2]
    sns.lineplot(x=ranks, y = perrs, ax = ax)
    
    ax.set_yscale("log")
    ax.set_ylabel("Plan Costs", fontsize=16)
    ax.set_xlabel("Rank", fontsize=16)
    ax.tick_params(axis='both', which='both', labelsize=16)
    
    pgperr = PG_PERRS[QDIR+TEMPLATES]
    ax.hlines(y=pgperr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)
    
    trueperr = TRUE_PERRS[QDIR+TEMPLATES]
    ax.hlines(y=trueperr, xmin=1, xmax=ranks[-1], colors='g', linestyles='-', lw=4)
    
    FN_TMP = "SVD-Recon-Errors-{}-{}.pdf"
    
    FN = FN_TMP.format(QDIR, TEMPLATES)
    
    fig.suptitle("{}-{}".format(QDIR, TEMPLATES), fontsize=20)
    print(FN)
    plt.savefig(FN, bbox_inches="tight")
    plt.show()
    
    return ranks, mean_qerrs, abs_errs, perrs

In [None]:
#ranks, qerrs, abserrs, perrs = get_rank_effects(mat, rowidxs, qdata)

In [None]:
#ranks0, qerrs0, abserrs0, perrs0 = ranks, qerrs, abserrs, perrs

In [None]:
# mean_qerrs = qerrs0 + qerrs
# abs_errs = abserrs0 + abserrs
# perrs = perrs0 + perrs
# ranks = ranks0 + ranks

# fig,axs = plt.subplots(nrows=1,ncols=3,figsize=(30,6))
# ax = axs[0]
# sns.lineplot(x=ranks, y = mean_qerrs, ax = ax)

# pgqerr = PG_QERRS[QDIR+TEMPLATES]
# ax.hlines(y=pgqerr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)

# ax.set_ylabel("QError", fontsize=16)
# ax.set_xlabel("Rank", fontsize=16)
# ax.tick_params(axis='both', which='both', labelsize=16)

# ax.set_yscale("log")

# ax = axs[1]
# sns.lineplot(x=ranks, y = abs_errs, ax = ax)

# ax.set_ylabel("Absolute Errors", fontsize=16)
# ax.set_xlabel("Rank", fontsize=16)
# ax.tick_params(axis='both', which='both', labelsize=16)

# pgaerr = PG_ABSERRS[QDIR+TEMPLATES]
# ax.hlines(y=pgaerr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)

# ax.set_yscale("log")

# ax = axs[2]
# sns.lineplot(x=ranks, y = perrs, ax = ax)

# ax.set_yscale("log")
# ax.set_ylabel("Plan Costs", fontsize=16)
# ax.set_xlabel("Rank", fontsize=16)
# ax.tick_params(axis='both', which='both', labelsize=16)

# pgperr = PG_PERRS[QDIR+TEMPLATES]
# ax.hlines(y=pgperr, xmin=1, xmax=ranks[-1], colors='r', linestyles='-', lw=4)

# trueperr = TRUE_PERRS[QDIR+TEMPLATES]
# ax.hlines(y=trueperr, xmin=1, xmax=ranks[-1], colors='g', linestyles='-', lw=4)

# FN_TMP = "SVD-Recon-Errors-{}-{}.pdf"

# FN = FN_TMP.format(QDIR, TEMPLATES)

# fig.suptitle("{}-{}".format(QDIR, TEMPLATES), fontsize=20)
# print(FN)
# plt.savefig(FN, bbox_inches="tight")
# plt.show()

In [None]:
def qerr(mat, newmat):
    # find number of zeros in mat
    tmp = mat[mat == 0]
 
    mat = np.maximum(mat, 1)
    newmat = np.maximum(newmat, 1)
    #print(mat.shape)
    qerrs = np.maximum (mat / newmat, newmat / mat)
    print("QError --> Mean: {}, 50p: {}. 90p: {}, 99p: {}".format(
          np.mean(qerrs), np.percentile(qerrs,50), np.percentile(qerrs,90),
          np.percentile(qerrs,99))
         )
    return np.mean(qerrs)

def qerr_known(mat, newmat):
    # find number of zeros in mat
    tmp = copy.deepcopy(mat)
    tmp[tmp != 0.0] = -1.0
    #tmp[tmp >= 0.0] = 0.0
    num_nonzeros = np.abs(np.sum(tmp))
 
    mat = np.maximum(mat, 1)
    newmat = np.maximum(newmat, 1)
    #print(mat.shape)
    qerrs = np.maximum (mat / newmat, newmat / mat)
    
    return np.sum(qerrs) / num_nonzeros

# Creating a new matrix with some values as zeros

In [None]:
import copy

def create_incomplete_matrix(mat, mfrac):
    newmat = np.zeros(mat.shape)

    for col in range(mat.shape[1]):
        #print(col)
        curcol = copy.deepcopy(mat[:,col])
        #print(curcol.shape)
        indices = np.random.choice(np.arange(curcol.size), replace=False,
                               size=int(curcol.size * mfrac))
        curcol[indices] = 0.0
        newmat[:,col] = curcol
    
    return newmat

In [None]:
# newmat = create_incomplete_matrix(mat, mfrac=0.1)
# print(np.sum(mat), np.sum(newmat))
# print("MSE: ", ((mat - newmat)**2).mean(axis=None))
# print("QError: ", qerr(mat, newmat))

In [None]:
# #from fancyimpute import NuclearNormMinimization
# from fancyimpute import KNN, NuclearNormMinimization, SoftImpute, BiScaler
# #new_mat
# newmat[newmat == 0] = np.nan

# solver = KNN(
#     min_value=1.0
#     )

# # X_incomplete has missing data which is represented with NaN values
# mat_filled = solver.fit_transform(newmat)

In [None]:
# qerr(mat, mat_filled)

In [None]:
import copy

def zero_percentage(newmat):
    tmp = copy.deepcopy(newmat)
    tmp[np.isnan(tmp)] = -1.0
    tmp[tmp == 0.0] = -1.0
    tmp[tmp != -1.0] = 0
    zeros = abs(np.sum(tmp))
    total = tmp.shape[0]*tmp.shape[1]
    return zeros / total

In [None]:
full_mat, plan_mat, qdata, subplan_masks, rowidxs = load_plandata(qdata, rtdf)

print(full_mat.shape, plan_mat.shape)
print(zero_percentage(plan_mat))

In [None]:
zero_idxs = ~np.all(plan_mat == 0, axis=1)
plan_mat = plan_mat[zero_idxs]
full_mat = full_mat[zero_idxs]

fmask = np.array(full_mat != 0, dtype=np.float32)
pmask = np.array(plan_mat == 0, dtype=np.float32)

print(full_mat.shape, plan_mat.shape)

In [None]:
qerr(full_mat, plan_mat)

In [None]:
plan_mat += 1

In [None]:
plan_mat = np.log(plan_mat)

In [None]:
np.min(plan_mat)

In [None]:
#from fancyimpute import NuclearNormMinimization
from fancyimpute import KNN, NuclearNormMinimization, SoftImpute, BiScaler
#new_mat
tmp = copy.deepcopy(plan_mat)
tmp[tmp == 0] = np.nan

solver = KNN(
    min_value=1.0
    )

# X_incomplete has missing data which is represented with NaN values
plan_filled = solver.fit_transform(tmp)

In [None]:
plan_filled

In [None]:
plan_filled = np.exp(plan_filled)

In [None]:
qerr(full_mat, plan_filled)
plan_filledm = plan_filled*fmask
qerr(full_mat, plan_filledm)
plan_filled2 = (plan_filled*fmask)*pmask
full_mat2 = full_mat*pmask
qerr(full_mat2, plan_filled2)
print("QError Unknown: ", qerr_known(full_mat2, plan_filled2))

In [None]:
# tmp = copy.deepcopy(plan_mat)
# tmp[tmp == 0] = np.nan

# solver = NuclearNormMinimization(
#     min_value=1.0
#     )

# # X_incomplete has missing data which is represented with NaN values
# plan_filled = solver.fit_transform(tmp)

# Training MSCN model

In [None]:
#from cardinality_estimation.mscn import MSCN as MSCN2
from cardinality_estimation.mscn import MSCN as MSCN

from cardinality_estimation.featurizer import Featurizer
#from query_representation.query import load_qrep
from cardinality_estimation.dataset import *

max_epochs = 500
lr=0.0001
training_opt = "none"
opt_lr = 0.1
swa_start = 5
mask_unseen_subplans = 0
subplan_level_outputs=0
normalize_flow_loss = 1
heuristic_unseen_preds = 0
heuristic_features = 1
cost_model = "C"
use_wandb = 0
eval_fns = "qerr,plancost"
load_padded_mscn_feats = 1
mb_size = 1024
weight_decay = 0.0
load_query_together = 0
result_dir = "./results"
joinbitmap = True
samplebitmap=False
bitmapdir = os.path.join("../queries/allbitmaps_new/", os.path.basename(TRAINDIR))

print(TRAINDIR)
print("BitmapDir: ", bitmapdir)

onehot_dropout=2
onehot_mask_truep=0.8
onehot_reg=0
onehot_reg_decay=0.1
eval_epoch = 200
optimizer_name="adamw"
clip_gradient=20.0
loss_func_name = "mse"
hidden_layer_size = 128
num_hidden_layers = 2

def init_featurizer(featurization_type, trainqs):
    # Load database specific data, e.g., information about columns, tables etc.
    dbdata_fn = os.path.join(TRAINDIR, "dbdata.json")
    featurizer = Featurizer(None, None, None, None, None)
    
    with open(dbdata_fn, "r") as f:
        dbdata = json.load(f)
        
    featurizer.update_using_saved_stats(dbdata)
    
    featurizer.setup(ynormalization="log",
        feat_separate_alias = 0,
        onehot_dropout = onehot_dropout,
        feat_mcvs = 0,
        heuristic_features = heuristic_features,
        featurization_type=featurization_type,
        table_features=1,
        flow_features = 0,
        join_features= "onehot",
        set_column_feature= "onehot",
        max_discrete_featurizing_buckets=10,
        max_like_featurizing_buckets=10,
        embedding_fn = "none",
        embedding_pooling = None,
        implied_pred_features = 0,
        feat_onlyseen_preds = 1,
        bitmap_dir = bitmapdir,
        join_bitmap = joinbitmap,
        sample_bitmap = samplebitmap,
                    )
    featurizer.update_ystats(trainqs)
    
    featurizer.update_max_sets(trainqs)
    featurizer.update_workload_stats(trainqs)
    featurizer.init_feature_mapping()
   

    # if feat_onlyseen_preds:
    # just do it always
    featurizer.update_seen_preds(trainqs)
    
    return featurizer

featurizer = init_featurizer("set", qdata + tqdata)

mscn = MSCN(max_epochs = max_epochs, lr=lr,
                training_opt = training_opt,
                inp_dropout = 0.0,
                hl_dropout = 0.0,
                comb_dropout = 0.0,
                max_num_tables = -1,
                opt_lr = opt_lr,
                swa_start = swa_start,
                mask_unseen_subplans = mask_unseen_subplans,
                subplan_level_outputs=subplan_level_outputs,
                normalize_flow_loss = normalize_flow_loss,
                heuristic_unseen_preds = heuristic_unseen_preds,
                cost_model = cost_model,
                use_wandb = use_wandb,
                eval_fns = eval_fns,
                load_padded_mscn_feats = load_padded_mscn_feats,
                mb_size = mb_size,
                weight_decay = weight_decay,
                load_query_together = load_query_together,
                result_dir = result_dir,
                onehot_dropout=onehot_dropout,
                onehot_mask_truep=onehot_mask_truep,
                onehot_reg=onehot_reg,
                onehot_reg_decay=onehot_reg_decay,
                # num_hidden_layers=num_hidden_layers,
                save_mscn_feats = False,
                eval_epoch = eval_epoch,
                optimizer_name=optimizer_name,
                clip_gradient=clip_gradient,
                loss_func_name = loss_func_name,
                hidden_layer_size = hidden_layer_size,
                other_hid_units = hidden_layer_size,
                num_hidden_layers = 2,
                early_stopping = False,
                random_bitmap_idx = False,
                reg_loss = False,
                )

In [None]:
len(qdata)

In [None]:
mscn.train(qdata, valqs=[], testqs=[],
    featurizer=featurizer, result_dir="results",
          subplan_mask=subplan_masks
        )

In [None]:
from evaluation.eval_fns import QError, SimplePlanCost, PostgresPlanCost
EVAL_FNS = []

#EVAL_FNS.append(SimplePlanCost())
EVAL_FNS.append(QError())
EVAL_FNS.append(PostgresPlanCost(cost_model="C"))

In [None]:
eval_alg(mscn, EVAL_FNS, qdata, "train")

In [None]:
eval_alg(mscn, EVAL_FNS, tqdata, "test")

In [None]:
ests = mscn.test(qdata)

In [None]:
len(ests)

In [None]:
estmat = np.ones((len(rowidxs), len(qdata)))

for ei, est in enumerate(ests):
    for k,v in est.items():
        estmat[rowidxs[k], ei] = v

In [None]:
## matrices are only defined over subset of rows that are seen in data
## fmask: only selects where true matrix has non-zeros
## pmask: 1 only where plan_mat has zeros; i.e., unknown ones.

estmat2 = estmat[zero_idxs]
estmat2 = estmat2*fmask
qerr(full_mat, estmat2)
estmat2 = (estmat2*fmask)*pmask
full_mat2 = full_mat*pmask

qerr(full_mat2, estmat2)

print(qerr_known(full_mat2, estmat2))

In [None]:
np.sum(pmask)

In [None]:
full_mat.shape