In [14]:
from scalable_gps.wandb_utils import load_runs_from_regex
import numpy as np
import os

def get_splits(dataset):
    if dataset == '3droad':
        return [0, 1, 2, 4]
    elif dataset == 'houseelectric':
        return [0, 1, 2]
    else:
        return [0, 1, 2, 3, 4]

datasets = ['pol',
            'elevators',
            'bike',
            # 'kin40k',
            'protein',
            'keggdirected',
            '3droad',
            'song',
            'buzz',
            'houseelectric']

models = ['sgd', 'cg', 'precondcg']#, 'vi']

config_keys = ['model_name', 'dataset_config.split', 'override_noise_scale']
metric_keys = ['wall_clock_time', 'normalised_test_rmse']

rmse_dict_path = "./table_rmse.npy"

if os.path.isfile(rmse_dict_path):
    rmse_dict = np.load(rmse_dict_path, allow_pickle=True).item()
else:
    rmse_dict = dict()

for dataset in datasets:
    if dataset not in rmse_dict.keys():
        rmse_dict[dataset] = dict()

    splits = get_splits(dataset)
    split_regex = f"{splits}".replace(", ", "|")
    n_splits = len(splits)

    for model in models:
        if model in rmse_dict[dataset].keys():
            print(f"rmse for {dataset}, {model} already exists")
            continue
        
        rmse_dict[dataset][model] = dict()

        for metric in metric_keys:
            rmse_dict[dataset][model][metric] = np.zeros((n_splits, 2))
        
        regex = f"^final_{dataset}_{model}_{split_regex}.*"

        print(f"Downloading results for {dataset}, {model}")
        for metric in metric_keys:
            configs_and_metrics = load_runs_from_regex(regex, config_keys=config_keys, metric_keys=[metric])

            for (configs, metrics) in configs_and_metrics:
                split = splits.index(configs['dataset_config.split'])
                assert model == configs['model_name']
                # print(dataset, split, model)
                print(metrics)
                idx = 0 if configs['override_noise_scale'] == -1 else 1
                rmse_dict[dataset][model][metric][split, idx] = metrics[metric][-1]
        np.save(rmse_dict_path, rmse_dict)


rmse for pol, sgd already exists
rmse for pol, cg already exists
rmse for pol, precondcg already exists
rmse for elevators, sgd already exists
rmse for elevators, cg already exists
rmse for elevators, precondcg already exists
rmse for bike, sgd already exists
rmse for bike, cg already exists
rmse for bike, precondcg already exists
rmse for protein, sgd already exists
rmse for protein, cg already exists
rmse for protein, precondcg already exists
rmse for keggdirected, sgd already exists
rmse for keggdirected, cg already exists
rmse for keggdirected, precondcg already exists
rmse for 3droad, sgd already exists
rmse for 3droad, cg already exists
rmse for 3droad, precondcg already exists
rmse for song, sgd already exists
rmse for song, cg already exists
rmse for song, precondcg already exists
rmse for buzz, sgd already exists
rmse for buzz, cg already exists
rmse for buzz, precondcg already exists
rmse for houseelectric, sgd already exists
rmse for houseelectric, cg already exists
rmse for

In [3]:
from scalable_gps.wandb_utils import load_runs_from_regex
import numpy as np
import os

datasets = ['pol',
            'elevators',
            'bike',
            # 'kin40k',
            'protein',
            'keggdirected',
            '3droad',
            'song',
            'buzz',
            'houseelectric']

ns = [15000, 16599, 17379, 45730, 48827, 434874, 515345, 583250, 2049280]

ds = [26, 18, 17, 9, 20, 3, 90, 77, 11]
regression_table_filepath = "./regression_table.tex"
with open(regression_table_filepath, 'w') as table:
    table.write("\\begin{table}[]\n")
    table.write("\\centering\n")
    table.write("\\renewcommand{\\arraystretch}{1.5}\n")
    table.write("\\setlength\\tabcolsep{2pt}\n")
    table.write("\\resizebox{\\textwidth}{!}{%\n")
    table.write("\\begin{tabular}{@{}cccclclcllclclclclclcl@{}}\n")
    table.write("\\toprule\n")
    table.write("\\multicolumn{2}{c}{\\multirow{2}{*}{}} &")
    for dataset in datasets:
        table.write("\n\\multicolumn{2}{c}{\\texttt{" + dataset + "}} &")
    table.write("\\\\ \\cmidrule(l){3-21} &\n")
    
    table.write("\\multicolumn{2}{c}{} &\n")
    
    for n, d in zip(ns, ds):
        table.write("\\multicolumn{2}{l}{$n$=$" + n "$, $d$=$" + d + "$} &\n")
    table.write("\\\\ \\midrule")
    
    

SyntaxError: closing parenthesis '}' does not match opening parenthesis '(' (3192435809.py, line 36)

In [None]:
\multicolumn{2}{c}{} &
  \multicolumn{2}{l}{$n$=$15k$, $d$=$26$} &
  \multicolumn{2}{l}{n=15k, d=91} &
  \multicolumn{3}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} &
  \multicolumn{2}{l}{n=15k, d=9} \\ \midrule