In [297]:
import numpy as np
import pandas as pd
import wandb
from datetime import datetime
from scipy import stats

import warnings
warnings.filterwarnings('ignore')

api = wandb.Api()

In [356]:
def make_table_from_metric(
    metric,
    results,
    val_metric=None,
    ci=0.95,
    latex=False,
    bold=True,
    drop_nans=False,
    show_group=False,
    select_best=True,
    pm=True,
    uncertainty='half_ci'
):
    if val_metric is None:
        val_metric = metric

    alpha = (1 - ci) / 2

    if drop_nans:
        results = results[results[metric].notna()]
        results = results[results[val_metric].notna()]

    def half_ci(group):
        data = group.to_numpy()
        sem = stats.sem(data)
        t2 = stats.t.ppf(1 - alpha, len(data) - 1) - stats.t.ppf(alpha, len(data) - 1)
        return sem * (t2 / 2)
        # return np.std(data)

    def lower_ci(group):
        data = group.to_numpy()
        sem = stats.sem(data)
        mean = data.mean()
        t = stats.t.ppf(alpha, len(data) - 1)
        return mean + sem * t

    def upper_ci(group):
        data = group.to_numpy()
        sem = stats.sem(data)
        mean = data.mean()
        t = stats.t.ppf(1 - alpha, len(data) - 1)
        return mean + sem * t

    def count(group):
        data = group.to_numpy()
        return np.prod(data.shape)
    

    results = (
        results.groupby(by=["group", "method", "dataset"])
        .agg(
            {
                metric: ["mean", "std", "sem", lower_ci, upper_ci, half_ci, count],
                val_metric: [
                    "mean",
                    "std",
                    "sem",
                    lower_ci,
                    upper_ci,
                    half_ci,
                    count,
                ],
            }
        )
        .reset_index()
    )

    if select_best:
        group_max_idx = (
            results.groupby(by=["method", "dataset"]).transform(max)[val_metric]["mean"]
            == results[val_metric]["mean"]
        )
        table = results[group_max_idx]
    else:
        table = results

    # table = table[table["dataset"].isin(["Earthquake", "Fire", "Flood", "Volcano"])]

    if latex:

        def format_result(row):
            if pm:
                return (
                    f"{{{row[metric]['mean']:0.2f}_{{\pm {row[metric][uncertainty]:0.2f}}}}}"
                )
            else:
                return f"{{{row[metric]['mean']:0.2f}}}"


        def bold_result(row):
            return "\\bm" + row["result"] if row["bold"].any() else row["result"]

    else:

        def format_result(row):
            if pm:
                return f"{row[metric]['mean']:0.2f} ± {row[metric][uncertainty]:0.2f}"
            else:
                return f"{row[metric]['mean']:0.2f}"

        def bold_result(row):
            return "* " + row["result"] if row["bold"].any() else row["result"]

    table["group_max"] = table.groupby(by=["dataset"]).transform(max)[metric]["mean"]
    table["group_max"] = table.apply(
        lambda row: table.index[table[metric]["mean"] == row["group_max"].squeeze()][0],
        axis=1,
    )
    table["bold"] = table.apply(
        lambda row: (
            table.loc[row["group_max"], (metric, "mean")].squeeze()
            < row[metric]["upper_ci"]
        )
        or (
            row[metric]["mean"]
            > table.loc[row["group_max"], (metric, "lower_ci")].squeeze()
        ),
        axis=1,
    )

    table["result"] = table.apply(format_result, axis=1)
    if bold:
        table["result"] = table.apply(bold_result, axis=1)

    if latex:
        table["result"] = table.apply(lambda row: "$" + row["result"] + "$", axis=1)

    table["count"] = table[(metric, "count")]

    return table
    cols = (
        ["method", "dataset", "group"] if show_group else ["method", "dataset", "count"]
    )
    table_flat = table[cols].pivot(index="method", columns="dataset")

    return table_flat


In [364]:
runs = api.runs(
    "emilem/equiv-stochastic-diffusion-processes",
    filters={
        "createdAt": {"$gte": "2023-04-20T00:00:00.000Z"}
        # 'config.name': 'fire's
    },
)

summary_list, config_list, name_list = [], [], []

rows = []

for run in runs:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)

    # config = {"config/" + k: v for k, v in run.config.items() if not k.startswith("_")}
    config = {k: v for k, v in run.config.items() if not k.startswith("_")}

    # .name is the human-readable name of the run.
    name_list.append(run.name)

    rows.append(
        {
            "group": run.group,
            **run.summary._json_dict,
            **config,
        }
    )

runs_df = pd.DataFrame(rows)

In [365]:

def make_method(row):
    if "MultiOutputAttentionModel" in row["net"]:
        if "RBFVec" in row["kernel"]:
            return "NDP (SE)"
        elif "WhiteVec" in row["kernel"]:
            return "NDP (White)"
        else:
            raise
    elif "TransformerModule" in row["net"]:
        if row["net/attention"] == True:
            return "SE(3)-Transformer"
        else:
            return "Tensor Field"

# runs_df = runs_df.dropna()

runs_df["kernel"] = runs_df["sde/limiting_kernel/_target_"]
runs_df["net"] = runs_df["net/_target_"]
runs_df = runs_df[~runs_df["kernel"].isna()]
runs_df["method"] = runs_df.apply(make_method, axis=1)
runs_df["dataset"] = runs_df["data/kernel/_target_"].replace(
    {
        "neural_diffusion_processes.kernels.RBFDivFree": "Div-free",
        "neural_diffusion_processes.kernels.RBFCurlFree": "Curl-free",
        "neural_diffusion_processes.kernels.RBFVec": "SE",
    }
)
runs_df["lengthscale"] = runs_df["kernel/params/lengthscale"]
runs_df["variance"] = runs_df["kernel/params/variance"]
runs_df["noise"] = runs_df["kernel/noise"]
runs_df["beta1"] = runs_df["beta_schedule/beta1"]
runs_df["std_trick"] = runs_df["sde/std_trick"]
runs_df["residual_trick"] = runs_df["sde/residual_trick"]
runs_df["is_score_preconditioned"] = runs_df["sde/is_score_preconditioned"]
runs_df["n_points"] = runs_df["data/n_points"].apply(lambda x: str(x))


def query(data_frame, query_string):
    if query_string == "all":
        return data_frame
    return data_frame.query(query_string)


criteria = [
    # "`name` == 'context'",
    "`lengthscale` == 1.",
    "`variance` == 1.",
    "`beta1` == 15",
    # "`noise` == 0.1",
    # "`optim/n_steps` == 100000",
    # "(`data/n_train` == 80000. | `data/num_samples_train` == 80000.)",
    "`n_points` == '[25, 648]'",
]
criteria = ["all"] if criteria == [] else criteria
runs_df = query(runs_df, " & ".join(criteria))

In [370]:
import matplotlib
import matplotlib.pyplot as plt

plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = ["Latin Modern Roman"]
plt.rcParams.update({"font.size": 10.95})
pw = '397.48499pt'
lw = pw

ci = 0.95
alpha = (1 - ci) / 2

def half_ci(group):
    data = group.to_numpy()
    sem = stats.sem(data)
    t2 = stats.t.ppf(1 - alpha, len(data) - 1) - stats.t.ppf(alpha, len(data) - 1)
    return sem * (t2 / 2)
    # return np.std(data)

metric = "prior_logp"
metric = "cond_logp"
val_metric = metric
results = runs_df.query("`dataset` == 'Div-free' & `data/n_train` >= 50")

results = (
        results.groupby(by=["method", "dataset", "data/n_train"])
        .agg(
            {
                metric: ["mean", "std", "sem", half_ci, "count"],
                val_metric: ["mean", "std", "sem", half_ci, "count"],
            }
        )
        .reset_index()
    )

fig, ax = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(pw, pw / 4))
fig.subplots_adjust(wspace=0, hspace=0.0)
# methods = ["NDP (White)", "Tensor Field", "SE(3)-Transformer"]
methods = ["NDP (White)", "Tensor Field"]
for i, method in enumerate(methods):
    idx = results["method"] == method
    x = results[idx]["data/n_train"]
    y = results[idx][metric]["mean"]
    # y_err = results[idx][metric]["sem"]
    y_err = results[idx][metric]["half_ci"]
    print(y_err)
    ax.errorbar(x, y, y_err, color=f"C{i}", label=method)
    ax.set_xscale('log')
    # ax.set_yscale('log')
    ax.set_ylim(0.5, .8)
    # ax.set_ylim(-100, .8)
fig.legend()
fig.show()

results

TypeError: unsupported operand type(s) for /: 'str' and 'int'

In [360]:
# Manually adding the CNP results!

columns = ['group', 'method', "config/seed", 'dataset', 'cond_logp']
rows = [
    ["ConvCNP", "ConvCNP", 1, "Curl-free", -1.7677894592285157],
    ["ConvCNP", "ConvCNP", 2, "Curl-free", -1.7930784861246745],
    ["ConvCNP", "ConvCNP", 3, "Curl-free", -1.7732168833414714],
    ["ConvCNP", "ConvCNP", 4, "Curl-free", -1.7641785939534504],
    ["ConvCNP", "ConvCNP", 5, "Curl-free", -1.7661763509114583],
    ["ConvCNP", "ConvCNP", 1, "Div-free", -1.7574120839436849],
    ["ConvCNP", "ConvCNP", 2, "Div-free", -1.7526724497477213],
    ["ConvCNP", "ConvCNP", 3, "Div-free", -1.7614798227945963],
    ["ConvCNP", "ConvCNP", 4, "Div-free", -1.7596895853678385],
    ["ConvCNP", "ConvCNP", 5, "Div-free", -1.7613946278889974],
    ["ConvCNP", "ConvCNP", 1, "SE", -1.7010875701904298],
    ["ConvCNP", "ConvCNP", 2, "SE", -1.7113407135009766],
    ["ConvCNP", "ConvCNP", 3, "SE", -1.7199483235677084],
    ["ConvCNP", "ConvCNP", 4, "SE", -1.7112767537434896],
    ["ConvCNP", "ConvCNP", 5, "SE", -1.6989725748697917],
    ["C4", "SteerCNP", 1, "Curl-free", -1.5672126770019532],
    ["C4", "SteerCNP", 2, "Curl-free", -1.5728504180908203],
    ["C4", "SteerCNP", 3, "Curl-free", -1.5713305155436197],
    ["C4", "SteerCNP", 4, "Curl-free", -1.5724315643310547],
    ["C4", "SteerCNP", 5, "Curl-free", -1.5712619781494142],
    ["C4", "SteerCNP", 1, "Div-free", -1.5733726501464844],
    ["C4", "SteerCNP", 2, "Div-free", -1.5614351908365884],
    ["C4", "SteerCNP", 3, "Div-free", -1.574195353190104],
    ["C4", "SteerCNP", 4, "Div-free", -1.5771334330240885],
    ["C4", "SteerCNP", 1, "SE", -1.6106020609537761],
    ["C4", "SteerCNP", 2, "SE", -1.6136614481608074],
    ["C4", "SteerCNP", 3, "SE", -1.6125297546386719],
    ["C4", "SteerCNP", 4, "SE", -1.6106975555419922],
    ["C4", "SteerCNP", 5, "SE", -1.6165941874186198],
    ["GP", "GP", 1, "Curl-free", 0.6598717212677002],
    ["GP", "GP", 2, "Curl-free", 0.6598289966583252],
    ["GP", "GP", 3, "Curl-free", 0.6598330497741699],
    ["GP", "GP", 4, "Curl-free", 0.6598613739013672],
    ["GP", "GP", 5, "Curl-free", 0.6598188877105713],
    ["GP", "GP", 1, "Div-free", 0.6602359294891358],
    ["GP", "GP", 2, "Div-free", 0.6602742195129394],
    ["GP", "GP", 3, "Div-free", 0.6602205753326416],
    ["GP", "GP", 4, "Div-free", 0.6602634906768798],
    ["GP", "GP", 5, "Div-free", 0.6602737426757812],
    ["GP", "GP", 1, "SE", 0.5573758125305176],
    ["GP", "GP", 2, "SE", 0.5573612689971924],
    ["GP", "GP", 3, "SE", 0.5573762893676758],
    ["GP", "GP", 4, "SE", 0.5574379444122315],
    ["GP", "GP", 5, "SE", 0.5574140071868896],
    ["GP", "GP (diag.)", 1, "Curl-free", -1.4716421127319337],
    ["GP", "GP (diag.)", 2, "Curl-free", -1.4714653968811036],
    ["GP", "GP (diag.)", 3, "Curl-free", -1.4716914176940918],
    ["GP", "GP (diag.)", 4, "Curl-free", -1.471421241760254],
    ["GP", "GP (diag.)", 5, "Curl-free", -1.471923542022705],
    ["GP", "GP (diag.)", 1, "Div-free", -1.466759204864502],
    ["GP", "GP (diag.)", 2, "Div-free", -1.4690001487731934],
    ["GP", "GP (diag.)", 3, "Div-free", -1.4697052001953126],
    ["GP", "GP (diag.)", 4, "Div-free", -1.4679842948913575],
    ["GP", "GP (diag.)", 5, "Div-free", -1.4686187744140624],
    ["GP", "GP (diag.)", 1, "SE", -1.559639072418213],
    ["GP", "GP (diag.)", 2, "SE", -1.5616438865661622],
    ["GP", "GP (diag.)", 3, "SE", -1.5610824584960938],
    ["GP", "GP (diag.)", 4, "SE", -1.5596550941467284],
    ["GP", "GP (diag.)", 5, "SE", -1.5608834266662597],
    ]
cnp_runs_df = pd.DataFrame(rows, columns=columns)

runs_df = pd.concat([runs_df, cnp_runs_df])

In [361]:
results = runs_df.query("`method` == 'NDP (White)' & `dataset` == 'Div-free'")#["cond_logp"]
results = (
        results.groupby(by=["group", "method", "dataset"])
        .agg(
            {
                metric: ["mean", "std", "sem", "count"],
            }
        )
        .reset_index()
    )
results

Unnamed: 0_level_0,group,method,dataset,cond_logp,cond_logp,cond_logp,cond_logp
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,mean,std,sem,count
0,"context_data.n_points=[25,648],data.n_train=80000,data_kernel=divfree,net=mattn",NDP (White),Div-free,0.622912,0.005073,0.002269,5


In [362]:
def format(table, row_names=[]):
    cols = ["method", "dataset", "result"]
    table = table[cols].pivot(index="method", columns="dataset")
    if len(row_names) > 0:
        table = table.reindex(row_names)
    table.index = [f"\\scshape {x}" for x in table.index]

    table = table.droplevel(level=0, axis=1)
    table = table.droplevel(level=0, axis=1)
    table = table[["SE", "Curl-free", "Div-free"]]
    table.columns = [f"\\scshape {x}" for x in table.columns]
    table.columns.name = "\\textsc{Model}"
    table.index.name = None
    return table

metric = "cond_logp"
# metric = "prior_logp"
# metric = "true_logp"


non_gp_idx = runs_df["method"].isin(["GP", "GP (diag.)"])

test_table = make_table_from_metric(
    metric, runs_df[~non_gp_idx], val_metric=metric, drop_nans=False, latex=True, bold=True, show_group=False, select_best=True
)
test_table["method"] = test_table["method"].replace(
    {"NDP (White)": "NDP",
    "Tensor Field": "Equiv NDP"},
)

print(test_table[["group", "result"]])
row_names=["ConvCNP", "SteerCNP", "NDP", "SE(3)-Transformer", "Equiv NDP"]
table = format(test_table, row_names=row_names)
print(table)
gp_test_table = make_table_from_metric(
    metric, runs_df[non_gp_idx], val_metric=metric, drop_nans=False, latex=True, bold=False, show_group=False, select_best=True, pm=True, uncertainty='sem'
)
# # print(gp_test_table)
gp_table = format(gp_test_table, row_names=["GP (diag.)", "GP"])
table = pd.concat([table, gp_table])
# row_names = ["ConvCNP", "SteerCNP", "GP (diag.)", "NDP (White)", "SE(3)-Transformer", "Tensor Field", "GP"]
row_names = ["ConvCNP", "SteerCNP", "GP (diag.)", "NDP", "Equiv NDP", "SE(3)-Transformer", "GP"]
table = table.reindex([f"\\scshape {row}" for row in row_names])

import os
latex_path = "table.tex"
filename = os.path.join(os.getcwd(), latex_path)
print(filename)
table.style.to_latex(
    buf=filename, hrules=True, multicol_align="c", column_format="lrrr"
)
table

len(data) 5
stats.t.ppf(1 - alpha, len(data) - 1) 2.7764451051977987
stats.t.ppf(alpha, len(data) - 1) -2.7764451051977987
t2 5.5528902103955975
sem 0.0009998398116324934


RuntimeError: No active exception to reraise