In [None]:
import wandb
from collections import defaultdict
import pandas as pd
import numpy as np

import random

import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import HTML
from IPython.display import display_html

%load_ext autoreload
%autoreload 2

In [None]:
COLORS = {}
COLORS["MSCN"] = "#4260f5"
COLORS["MSCN (Ours)"] = "Green"

In [None]:
api = wandb.Api()

In [None]:
# runs = api.runs("pari/skdwda-pari",
#     {"$or": [
#        {"config.algs": "dqo"},
#        #{"config.wandb_tags":"run_all,final1"},
#        #{"config.wandb_tags":"run_all,final2"},
#        {"config.wandb_tags":"run_all,final2-fixedBestVal"}
#     ]
#     })

runs = api.runs("pari/MyCEB",
    {"$and": [
       {"config.algs": "mscn"},
       {"config.train_test_split_kind":"custom"},
       #{"tags":"v17"},
       {"tags":"new-custom-runs"},
       {"config.max_discrete_featurizing_buckets":10},
       {"config.query_dir":"queries/imdb"}
    ]
    })
print(f"Found {len(runs)} runs")

In [None]:
# TRAIN_TMPS="2b"
# SEP=1
# BINS=30
# EMB_FN="none"
# DECAY=0

KEYNAMES = ["train_tmps", "test_tmps", "max_discrete_featurizing_buckets", "embedding_fn", 
            "weight_decay", "feat_onlyseen_preds", "loss_func_name", "feat_separate_alias"]
#VALS = ["1a", "2a,2b", 10, "none", 0, 1, "mse", 0]
VALS = ["2b", "2a", 10, "none", 0, 1, "mse", 0]

if VALS[0] == "2b":
    PG_BASELINE_VAL = 8.849
    PG_BASELINE_COST = 7.599
elif VALS[0] == "1a":
    PG_BASELINE_VAL = 4.433
    PG_BASELINE_COST = 8.004
else:
    assert False
    
SAVE_DIR = "Users/pari/Desktop/"

In [None]:
import time
start = time.time()
summary_list, config_list, name_list = [], [], []
#TAGS = ["baselines", "baseline", "best_model_0.2", "best_model"]
dfs = []

for run in runs:
    skip = False
#    for ci, curkey in enumerate(KEYNAMES):
#         if not curkey in run.config:
#             skip = True
#             break
#         if not run.config[curkey] == VALS[ci]:
#             skip = True
#             break

    if skip:
        continue
    if run.State != "finished":
        continue
    
    data = defaultdict(list)
    data["Tags"].append(run.Tags)
    data["name"].append(run.name)
    
    
    for k,v in run.config.items():
        if not k.startswith("_") or not k.contains("/"):
            data[k].append(v)

    for k,v in run.summary._json_dict.items():
        if k.startswith("_"):
            continue
        if "/" in k:
            continue
        data[k].append(v)
        
    dfs.append(pd.DataFrame(data))

print("took: ", time.time()-start)

In [None]:
runs[0]

In [None]:
df = pd.concat(dfs)

In [None]:
print(len(df))
df.head(5)

In [None]:
def get_row_featurization(row):
    if row["table_features"] == 1 \
        and row["set_column_feature"] in ["1", 1, "onehot"] \
        and row["join_features"] in ["1", 1, "onehot"] \
        and row["onehot_dropout"] == 0:
        return "MSCN"
    elif row["onehot_dropout"] == 2:
        return "MSCN (Ours)" 
    else:
        return "unknown"

In [None]:
print(len(set(df["name"])))
print(len(df))
print(df.keys())

In [None]:
df["Featurization"] = df.apply(lambda x: get_row_featurization(x), axis=1)

In [None]:
df.groupby(["Featurization"]).count()

In [None]:
# print(set(df["set_column_feature"]))
# print(set(df["join_features"]))
pdf = df[df.Featurization != "unknown"]
pdf = pdf[~pdf["Final-Relative-TotalPPCost-test"].isna()]

In [None]:
pdf.head(5)

In [None]:
import matplotlib.patches as mpatches

plt.style.use("seaborn-white")
fig, axs = plt.subplots(figsize=(14,14), nrows=1, ncols=2)
ymin = 0
ymax = np.max([np.max(pdf["Final-Relative-TotalPPCost-val"].values), 
              np.max(pdf["Final-Relative-TotalPPCost-test"].values)])

#print(ymin, ymax)

ax = axs[0]
sns.barplot(data=pdf, hue="Featurization", y="Final-Relative-TotalPPCost-val", 
                x="Featurization", dodge=False, ax=ax,
            palette=COLORS)
ax.legend(title="", loc="center left", bbox_to_anchor=(0.60, 1.10), ncol=2, frameon=False, 
                fontsize=20)
ax.set_ylabel("Relative PostgreSQL Cost", fontsize=20)
ax.set_title("Queries from training template", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=20)

ax.set_xlabel("")
ax.axhline(y=PG_BASELINE_VAL, color="red")
#ax.set_ylim([ymin, ymax])

ax.legend().remove()

legend = ax.legend(loc='upper left', bbox_to_anchor=(0.6,1.1), ncol=3,
                  prop={'size': 16})
handles, labels = ax.get_legend_handles_labels()
red_patch = mpatches.Patch(color='red', label='PostgreSQL')
handles.append(red_patch)
labels.append("PostgreSQL")
legend._legend_box = None
legend._init_legend_box(handles, labels)
legend._set_loc(legend._loc)
legend.set_title("Estimator")

#ax.set_ylim([ymin,ymax])

ax = axs[1]
sns.barplot(data=pdf, hue="Featurization", y="Final-Relative-TotalPPCost-test", 
                x="Featurization", dodge=False, ax=ax, 
            palette=COLORS)
ax.set_title("Queries from new template", fontsize=20)
ax.legend().remove()
ax.set_ylabel("")
ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_ylabel("Relative PostgreSQL Cost", fontsize=20)
ax.set_xlabel("")
ax.axhline(y=PG_BASELINE_COST, color="red")

#ax.set_ylim([ymin,ymax])

#FN = SAVE_DIR + "/" + VALS[0] + ".png"
FN = "./" + VALS[0] + "-PPC.pdf"
print(FN)

plt.savefig(FN)

In [None]:
plt.style.use("seaborn-white")
fig, axs = plt.subplots(figsize=(14,10), nrows=1, ncols=2)
# ymin = 0
# ymax = np.max([np.max(pdf["Final-QError-val-mean"].values), 
#               np.max(pdf["Final-QError-test-mean"].values)])

ax = axs[0]
sns.barplot(data=pdf, hue="Featurization", y="Final-QError-val-mean", 
                x="Featurization", dodge=False, ax=ax,
           palette=COLORS)
ax.legend(title="", loc="center left", bbox_to_anchor=(0.60, 1.10), ncol=2, frameon=False, 
                fontsize=20)
ax.set_ylabel("Q-Error", fontsize=20)
ax.set_title("Queries from training template", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_xlabel("")
#ax.axhline(y=PG_BASELINE_VAL, color="red")
#ax.set_ylim([ymin,ymax])

ax = axs[1]
sns.barplot(data=pdf, hue="Featurization", y="Final-QError-test-mean", 
                x="Featurization", dodge=False, ax=ax,
           palette=COLORS)
ax.set_title("Queries from new template", fontsize=20)
ax.legend().remove()
ax.set_ylabel("")
ax.tick_params(axis='both', which='major', labelsize=20)
#ax.set_ylabel("Relative PostgreSQL Cost", fontsize=20)
ax.set_xlabel("")
#ax.axhline(y=PG_BASELINE_COST, color="red")
#ax.set_ylim([ymin,ymax])

#FN = SAVE_DIR + "/" + VALS[0] + ".png"
FN = "./" + VALS[0] + "Q-Error.pdf"
plt.savefig(FN)

In [None]:
# sns.barplot(data=pdf, hue="Featurization", y="Final-Relative-TotalPPCost-val", 
#                 x="Featurization", dodge=False)

In [None]:
pdf.keys()

In [None]:
sns.barplot(data=pdf, hue="Featurization", y="Final-QError-test-mean", 
                x="Featurization", dodge=False)

In [None]:
sns.barplot(data=pdf, hue="Featurization", y="Final-Relative-TotalPPCost-test", 
                x="name", dodge=False)

In [None]:
print(set(df["feat_onlyseen_preds"]))
print(set(df["feat_separate_alias"]))