In [1]:
import pandas as pd
import wandb
from tqdm.notebook import tqdm
import pickle
from os.path import exists
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import math
from matplotlib.ticker import MaxNLocator

from data.get_uci import all_datasets
from analysis.util import fetch, init_md22_dict, get_uci_info

In [2]:
filters = {
    "group": "md22"
}
raw = fetch("soft-gp-2", filters)

100%|██████████| 45/45 [00:16<00:00,  2.70it/s]


In [6]:
md22_info = [
    ("Ac-Ala3-NHMe", 85109, 42 * 3),
    ("DHA", 69753, 56 * 3),
    ("stachyose", 27272, 87 * 3),
    ("AT-AT", 20001, 118 * 3),
    ("AT-AT-CG-CG", 10153, 118 * 3)
]


In [7]:
runs = {}
md22_dict = {}
for exp in raw:
    model = exp.config["model.name"]
    dataset = exp.config["dataset.name"]
    num_inducing = exp.config["model.num_inducing"]
    dtype = exp.config["model.dtype"]
    seed = exp.config["training.seed"]
    train_frac = float(exp.config["dataset.train_frac"])
    if model == "svi-gp" and not exp.config["model.learn_noise"]:
        continue 
    if model == "sv-gp" and not exp.config["model.learn_noise"]:
        continue 
    md22_dict[(dataset, seed, num_inducing, train_frac, model)] = exp.history
    runs[(dataset, seed, num_inducing, train_frac, model)] = exp.run.id

print(md22_dict.keys())

dict_keys([('DHA', 92357, 512, 0.9, 'sv-gp'), ('DHA', 92357, 512, 0.9, 'svi-gp'), ('DHA', 92357, 512, 0.9, 'soft-gp'), ('DHA', 8830, 512, 0.9, 'sv-gp'), ('DHA', 8830, 512, 0.9, 'svi-gp'), ('DHA', 8830, 512, 0.9, 'soft-gp'), ('DHA', 6535, 512, 0.9, 'sv-gp'), ('DHA', 6535, 512, 0.9, 'svi-gp'), ('DHA', 6535, 512, 0.9, 'soft-gp'), ('stachyose', 92357, 512, 0.9, 'sv-gp'), ('stachyose', 92357, 512, 0.9, 'svi-gp'), ('stachyose', 92357, 512, 0.9, 'soft-gp'), ('stachyose', 8830, 512, 0.9, 'sv-gp'), ('stachyose', 8830, 512, 0.9, 'svi-gp'), ('stachyose', 8830, 512, 0.9, 'soft-gp'), ('stachyose', 6535, 512, 0.9, 'sv-gp'), ('stachyose', 6535, 512, 0.9, 'svi-gp'), ('stachyose', 6535, 512, 0.9, 'soft-gp'), ('AT-AT', 92357, 512, 0.9, 'sv-gp'), ('AT-AT', 92357, 512, 0.9, 'svi-gp'), ('AT-AT', 92357, 512, 0.9, 'soft-gp'), ('AT-AT', 8830, 512, 0.9, 'sv-gp'), ('AT-AT', 8830, 512, 0.9, 'svi-gp'), ('AT-AT', 8830, 512, 0.9, 'soft-gp'), ('AT-AT', 6535, 512, 0.9, 'sv-gp'), ('AT-AT', 6535, 512, 0.9, 'svi-gp'), (

In [9]:
seeds = [6535, 8830, 92357]
num_inducings = [512, 1024]
fracs = [0.9]
# fracs = [0.44, 0.89]


MD22_INFO = {
    "N": [int(np.floor(N * 0.9)) for _, N, _, in md22_info],
    "D": [D for _, _, D in md22_info],
}
KZZ = {}
all_bins = {}

models = ["soft-gp", "svi-gp", "sv-gp"]

# models = ["sv-gp"]
for seed in seeds:
    for model in models:
        for num_inducing in num_inducings:
            for frac in fracs:
                xs = []
                ts = []
                K_zzs = []
                bins1 = []
                bins2 = []
                bins3 = []
                bins4 = []
                bins5 = []
                bins6 = []
                for dataset, _, _ in md22_info:
                    try:
                        xs += [float(md22_dict[(dataset, seed, num_inducing, frac, model)]["test_rmse"][49])]
                        ts += [float(np.array(md22_dict[(dataset, seed, num_inducing, frac, model)]["epoch_time"][49]).mean())]
                        K_zzs += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz"][i] for i in range(5)]
                        bins1 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_0.0"][49]]
                        bins2 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_1e-20"][49]]
                        bins3 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_1e-10"][49]]
                        bins4 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_1e-05"][49]]
                        bins5 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_0.01"][49]]
                        bins6 += [md22_dict[(dataset, seed, num_inducing, frac, model)]["K_zz_bin_0.5"][49]]
                    except Exception as e:
                        xs += [np.nan]
                        ts += [np.nan]
                        bins1 += [np.nan]
                        bins2 += [np.nan]
                        bins3 += [np.nan]
                        bins4 += [np.nan]
                        bins5 += [np.nan]
                        bins6 += [np.nan]
                        print("Exception", e, model, dataset)

                MD22_INFO[f"{model}-{num_inducing}-{frac}-{seed}"] = xs
                MD22_INFO[f"time-{model}-{num_inducing}-{frac}-{seed}"] = ts
                all_bins[f"0.0-{model}-{num_inducing}-{frac}-{seed}"] = bins1
                all_bins[f"1e-20-{model}-{num_inducing}-{frac}-{seed}"] = bins2
                all_bins[f"1e-10-{model}-{num_inducing}-{frac}-{seed}"] = bins3
                all_bins[f"1e-05-{model}-{num_inducing}-{frac}-{seed}"] = bins4
                all_bins[f"0.01-{model}-{num_inducing}-{frac}-{seed}"] = bins5
                all_bins[f"0.5-{model}-{num_inducing}-{frac}-{seed}"] = bins6
                KZZ[f"kzz-{model}-{num_inducing}-{frac}-{seed}"] = K_zzs
df = pd.DataFrame(data=MD22_INFO)
df.index = [name.capitalize().replace("_", "-") for name, _, _ in md22_info]
df

Exception ('Ac-Ala3-NHMe', 6535, 1024, 0.9, 'soft-gp') soft-gp Ac-Ala3-NHMe
Exception ('DHA', 6535, 1024, 0.9, 'soft-gp') soft-gp DHA
Exception ('stachyose', 6535, 1024, 0.9, 'soft-gp') soft-gp stachyose
Exception ('AT-AT', 6535, 1024, 0.9, 'soft-gp') soft-gp AT-AT
Exception ('AT-AT-CG-CG', 6535, 1024, 0.9, 'soft-gp') soft-gp AT-AT-CG-CG
Exception ('Ac-Ala3-NHMe', 6535, 1024, 0.9, 'svi-gp') svi-gp Ac-Ala3-NHMe
Exception ('DHA', 6535, 1024, 0.9, 'svi-gp') svi-gp DHA
Exception ('stachyose', 6535, 1024, 0.9, 'svi-gp') svi-gp stachyose
Exception ('AT-AT', 6535, 1024, 0.9, 'svi-gp') svi-gp AT-AT
Exception ('AT-AT-CG-CG', 6535, 1024, 0.9, 'svi-gp') svi-gp AT-AT-CG-CG
Exception ('Ac-Ala3-NHMe', 6535, 1024, 0.9, 'sv-gp') sv-gp Ac-Ala3-NHMe
Exception ('DHA', 6535, 1024, 0.9, 'sv-gp') sv-gp DHA
Exception ('stachyose', 6535, 1024, 0.9, 'sv-gp') sv-gp stachyose
Exception ('AT-AT', 6535, 1024, 0.9, 'sv-gp') sv-gp AT-AT
Exception ('AT-AT-CG-CG', 6535, 1024, 0.9, 'sv-gp') sv-gp AT-AT-CG-CG
Exception 

Unnamed: 0,N,D,soft-gp-512-0.9-6535,time-soft-gp-512-0.9-6535,soft-gp-1024-0.9-6535,time-soft-gp-1024-0.9-6535,svi-gp-512-0.9-6535,time-svi-gp-512-0.9-6535,svi-gp-1024-0.9-6535,time-svi-gp-1024-0.9-6535,...,soft-gp-1024-0.9-92357,time-soft-gp-1024-0.9-92357,svi-gp-512-0.9-92357,time-svi-gp-512-0.9-92357,svi-gp-1024-0.9-92357,time-svi-gp-1024-0.9-92357,sv-gp-512-0.9-92357,time-sv-gp-512-0.9-92357,sv-gp-1024-0.9-92357,time-sv-gp-1024-0.9-92357
Ac-ala3-nhme,76598,126,0.668285,1.204293,,,0.886015,2.613808,,,...,,,0.893591,1.882314,,,0.866946,0.019085,,
Dha,62777,168,0.596839,1.277841,,,0.901041,1.571823,,,...,,,0.915041,1.630764,,,0.898777,0.017134,,
Stachyose,24544,261,0.367465,0.531578,,,0.703209,0.625723,,,...,,,0.709013,0.550848,,,0.680496,0.014004,,
At-at,18000,354,0.457886,0.28333,,,0.693333,0.459321,,,...,,,0.719648,0.454159,,,0.606635,0.012916,,
At-at-cg-cg,9137,354,0.394379,0.203767,,,0.608052,0.203198,,,...,,,0.680726,0.204729,,,0.629547,0.01227,,


In [10]:
def pm_var(df, model):
    m = df[[f'{model}-6535', f'{model}-8830', f'{model}-92357']].mean(axis=1).round(3).astype(str)
    v = df[[f'{model}-6535', f'{model}-8830', f'{model}-92357']].std(axis=1).round(3).astype(str).apply(lambda x: f" $\pm$ {x}")
    return (m + v).apply(lambda x: x.replace("nan $\pm$ nan", "-"))
df_rmse = pd.DataFrame()
df_rmse[['N', 'D']] = df[['N', 'D']]
for model in models:
    for num_inducing in num_inducings:
        for frac in fracs:
            df_rmse[f'{model}-{num_inducing}-{frac}'] = pm_var(df, f'{model}-{num_inducing}-{frac}')

df_rmse = df_rmse.sort_values(by=['D'], ascending=[True])
print("RMSE")
df_rmse

RMSE


Unnamed: 0,N,D,soft-gp-512-0.9,soft-gp-1024-0.9,svi-gp-512-0.9,svi-gp-1024-0.9,sv-gp-512-0.9,sv-gp-1024-0.9
Ac-ala3-nhme,76598,126,0.669 $\pm$ 0.006,-,0.887 $\pm$ 0.006,-,0.861 $\pm$ 0.006,-
Dha,62777,168,0.592 $\pm$ 0.005,-,0.914 $\pm$ 0.012,-,0.899 $\pm$ 0.012,-
Stachyose,24544,261,0.372 $\pm$ 0.004,-,0.708 $\pm$ 0.004,-,0.68 $\pm$ 0.003,-
At-at,18000,354,0.459 $\pm$ 0.012,-,0.706 $\pm$ 0.013,-,0.595 $\pm$ 0.011,-
At-at-cg-cg,9137,354,0.398 $\pm$ 0.006,-,0.634 $\pm$ 0.04,-,0.595 $\pm$ 0.03,-


In [11]:
df2 = df_rmse[["N", "D", "soft-gp-512-0.9", "sv-gp-512-0.9", "svi-gp-512-0.9"]]
df2 = df2.sort_values(by=['D'], ascending=[True])
print("RMSE")
df2

RMSE


Unnamed: 0,N,D,soft-gp-512-0.9,sv-gp-512-0.9,svi-gp-512-0.9
Ac-ala3-nhme,76598,126,0.669 $\pm$ 0.006,0.861 $\pm$ 0.006,0.887 $\pm$ 0.006
Dha,62777,168,0.592 $\pm$ 0.005,0.899 $\pm$ 0.012,0.914 $\pm$ 0.012
Stachyose,24544,261,0.372 $\pm$ 0.004,0.68 $\pm$ 0.003,0.708 $\pm$ 0.004
At-at,18000,354,0.459 $\pm$ 0.012,0.595 $\pm$ 0.011,0.706 $\pm$ 0.013
At-at-cg-cg,9137,354,0.398 $\pm$ 0.006,0.595 $\pm$ 0.03,0.634 $\pm$ 0.04


In [12]:
latex_table = df2.to_latex(
    index=True,
    escape=False,
    float_format="{:0.3f}".format,
)
print(latex_table)

\begin{tabular}{lrrlll}
\toprule
 & N & D & soft-gp-512-0.9 & sv-gp-512-0.9 & svi-gp-512-0.9 \\
\midrule
Ac-ala3-nhme & 76598 & 126 & 0.669 $\pm$ 0.006 & 0.861 $\pm$ 0.006 & 0.887 $\pm$ 0.006 \\
Dha & 62777 & 168 & 0.592 $\pm$ 0.005 & 0.899 $\pm$ 0.012 & 0.914 $\pm$ 0.012 \\
Stachyose & 24544 & 261 & 0.372 $\pm$ 0.004 & 0.68 $\pm$ 0.003 & 0.708 $\pm$ 0.004 \\
At-at & 18000 & 354 & 0.459 $\pm$ 0.012 & 0.595 $\pm$ 0.011 & 0.706 $\pm$ 0.013 \\
At-at-cg-cg & 9137 & 354 & 0.398 $\pm$ 0.006 & 0.595 $\pm$ 0.03 & 0.634 $\pm$ 0.04 \\
\bottomrule
\end{tabular}

