In [1]:
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed

import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Download the data

In [24]:
from collections import defaultdict


api = wandb.Api()
target_names = ["many_well"]  # many_well
dims = [64]

algorithm_names = [
    "pis",
    "dds",
    "lv",
    "tb",
    # "gfniw",
    "gfnsmc",
    # "gfnbuf",
    # "gfnrbuf",
    # "gfnlbuf",
    # "gfnuiwbuf",
    "gfnpiwbuf",
    "gfnsmcbuf",
    "gfnsmcpiwbuf",
]

wandb_tag_filter = {
    "$all": ["final", "main_w_lp"],
    "$nin": ["hidden", "legacy"],
}

wandb_filter = {
    "tags": wandb_tag_filter,
    "config.algorithm_model_use_lp": True,
}

runs = defaultdict(dict)

# first key is target_name-dim
for target_name, dim in zip(target_names, dims):
    wandb_filter["config.target_name"] = target_name
    wandb_filter["config.target_dim"] = dim

    for algorithm_name in algorithm_names:
        wandb_filter["config.wandb_name"] = algorithm_name
        runs[f"{target_name}-{dim}d"][algorithm_name] = api.runs(f"sanghyeok-choi/sampling_bench", filters=wandb_filter)
        print(f"Number of runs in {target_name}-{dim}d/{algorithm_name}: {len(runs[f'{target_name}-{dim}d'][algorithm_name])}")


Number of runs in many_well-64d/pis: 5
Number of runs in many_well-64d/dds: 5
Number of runs in many_well-64d/lv: 5
Number of runs in many_well-64d/tb: 5
Number of runs in many_well-64d/gfnsmc: 5
Number of runs in many_well-64d/gfnpiwbuf: 5
Number of runs in many_well-64d/gfnsmcbuf: 5
Number of runs in many_well-64d/gfnsmcpiwbuf: 5


In [25]:
### Prepare dataframes

# Prepare metrics (columns)
metrics = ["KL/eubo", "KL/elbo", "logZ/reverse", "discrepancies/sd", "discrepancies/mmd"]
metrics_std = [f"{m}_std" for m in metrics]

# make a dataframe with group_keys as multi-index and metrics as columns
metrics_dfs = {}
for target_name_dim in runs.keys():
    # row is algorithm_name
    # column is metrics_columns
    metrics_dfs[target_name_dim] = pd.DataFrame(columns=metrics + metrics_std) 

In [26]:

def process_group_key(alg_name, alg_runs, metrics, timesteps, n_seeds):
    # Save to numpy array first
    metrics_arr = np.zeros((n_seeds, len(metrics)))

    for i, run in enumerate(alg_runs):
        # Fetch last 5 metrics
        last_5_df = run.history(samples=timesteps[-1] + 1, keys=metrics)
        last_5_df.set_index("_step", inplace=True)
        last_5_values = last_5_df.loc[timesteps, metrics].values
        metrics_arr[i] = last_5_values.mean(axis=0)

    return {
        'key': alg_name,
        'mean': metrics_arr.mean(axis=0),  # average over 5 seeds
        'std': metrics_arr.std(axis=0),  # std over 5 seeds
    }


last_5_timesteps = [38400, 38800, 39200, 39600, 39999]  # to be averaged
n_seeds = 5

for target_name_dim in runs.keys():
    print(f"Downloading runs for {target_name_dim}")

    target_runs = runs[target_name_dim]
    keys = target_runs.keys()
    with ThreadPoolExecutor(max_workers=min(32, len(runs[target_name_dim]))) as executor:
        futures = {
            executor.submit(
                process_group_key, 
                key,
                target_runs[key],
                metrics,
                last_5_timesteps,
                n_seeds,
            ): key for key in keys
        }

        for future in as_completed(futures):
            result = future.result()
            key = result['key']

            # Store results in dataframes
            metrics_dfs[target_name_dim].loc[key, metrics] = result['mean']
            metrics_dfs[target_name_dim].loc[key, metrics_std] = result['std']

Downloading runs for many_well-64d


KeyError: "None of [Index([38400, 38800, 39200, 39600, 39999], dtype='int64', name='_step')] are in the [index]"

### Main Results

In [23]:

def print_latex_table(
    target_names: list[str],
    target_names_to_display_name: dict[str, str],
    metrics: list[str],
    metric_names: list[str],
    final_metrics_dfs: dict[str, pd.DataFrame],
):
    # header = f"{'Algorithm': <150}"
    # for target_name in target_names:
    #     name, dim = target_name.replace('_', '').split('-')
    #     for metric_name in metric_names:
    #         col = f"{name.upper()}({dim}) {metric_name}"
    #         header += f" & {col: <24}"
    # header += "\\\\"
    # print(header)

    # Print latex table with column: target_names[0]-elbo & target_names[0]-eubo & target_names[1]-elbo & energy_names[1]-eubo & ...
    indices = final_metrics_dfs[target_names[0]].index
    for df in final_metrics_dfs.values():
        if not np.all(df.index == indices):
            raise ValueError("All dataframes must have the same index")

    out = f"{'Algorithm': <20}"
    for target_name in target_names:
        name, dim = target_name.split('-')
        for metric_name in metric_names:
            col = f"{target_names_to_display_name[target_name]} {metric_name}"
            out += f" & {col: <24}"
    out += "\n"

    for idx in indices:
        out += f"{idx: <20}\n"
        for target_name in target_names:
            temp_df = final_metrics_dfs[target_name].loc[idx]
            temp_df = final_metrics_dfs[target_name].loc[idx]
            for metric in metrics:
                val, std = temp_df[metric], temp_df[f"{metric}_std"]
                if val < -1e5 or val > 1e5:
                    # Convert to scientific notation
                    record = f"{val:.2e}\scriptsize$\pm${std:0.2e} \n"
                else:
                    if metric == "discrepancies/mmd":
                        record = f"{val:0.3f}\scriptsize$\pm${std:0.3f} \n"
                    else:
                        record = f"{val:0.2f}\scriptsize$\pm${std:0.2f} \n"
                out += f"\t& {record: <24}"
        out += "\\\\"
        out += "\n"
    print(out)
    

analysis_name = "main"
lp = [False]

metrics = ["discrepancies/mmd", "discrepancies/sd"]
metric_names = ["MMD", "Sinkhorn"]

run_name_to_display_name = OrderedDict({
    "pis": "PIS",
    "dds": "DDS",
    "lv": "LV",
    "tb": "TB",
    "gfnsmc": "\hspace\{2pt\}+ SMC",
    # "gfnbuf": "\hspace\{2pt\}+ SMC",
    # "gfnrbuf": "\hspace\{2pt\}+ R-Buf",
    # "gfnlbuf": "\hspace\{2pt\}+ L-Buf",
    # "gfnuiwbuf": "\hspace\{2pt\}+ UIW-Buf",
    "gfnpiwbuf": "\hspace\{2pt\}+ IW-Buf",
    # "gfnsmcbuf": "\hspace\{2pt\}+ SMC + Buf",
    "gfnsmcpiwbuf": "\hspace\{2pt\}+ SMC + IW-Buf",
})

target_names_to_display_name = {
    "funnel-10d": "Funnel (10d)",
    "planar_robot_4goal-10d": "Robot4 (10d)",
    "student_t_mixture-50d": "MoS (50d)",
    "gaussian_mixture40-50d": "GMM40 (50d)",
    "many_well-64d": "ManyWell (64d)",
}

new_dfs = {key: df.copy() for key, df in metrics_dfs.items()}
for key, df in new_dfs.items():
    df.index = df.index.map(run_name_to_display_name)
    # Filter out keys that are not in run_name_to_display_name
    df = df[df.index.isin(list(run_name_to_display_name.values()))]
    df = df.reindex(list(run_name_to_display_name.values()))
    new_dfs[key] = df

# save_final_metrics_dfs_to_csv(energy_names, metrics, metric_names, new_df, analysis_name)
print_latex_table(
    list(runs.keys()),
    target_names_to_display_name,
    metrics,
    metric_names,
    new_dfs,
)


Algorithm            & Robot4 (10d) MMD         & Robot4 (10d) Sinkhorn    & MoS (50d) MMD            & MoS (50d) Sinkhorn       & GMM40 (50d) MMD          & GMM40 (50d) Sinkhorn    
PIS                 
	& 0.593\scriptsize$\pm$0.026 
	& 3106.73\scriptsize$\pm$684.05 
	& 0.366\scriptsize$\pm$0.025 
	& 2324.69\scriptsize$\pm$81.66 
	& 0.110\scriptsize$\pm$0.002 
	& 16425.73\scriptsize$\pm$261.31 
\\
DDS                 
	& 1.381\scriptsize$\pm$0.004 
	& 6.22e+05\scriptsize$\pm$8.01e+05 
	& 0.245\scriptsize$\pm$0.016 
	& 2170.75\scriptsize$\pm$24.29 
	& 0.050\scriptsize$\pm$0.001 
	& 6882.66\scriptsize$\pm$125.25 
\\
LV                  
	& 0.422\scriptsize$\pm$0.002 
	& 1.71\scriptsize$\pm$0.01 
	& 0.350\scriptsize$\pm$0.007 
	& 2175.86\scriptsize$\pm$16.75 
	& 0.036\scriptsize$\pm$0.000 
	& 3952.22\scriptsize$\pm$97.14 
\\
TB                  
	& 0.424\scriptsize$\pm$0.001 
	& 1.72\scriptsize$\pm$0.01 
	& 0.315\scriptsize$\pm$0.023 
	& 2128.50\scriptsize$\pm$64.94 
	& 0.036\scriptsize$