In [1]:
import pandas as pd
import scipy.stats
import os

In [2]:
def preprocess_average(name, logdir="logs"):
    dfs = []
    files = [f for f in os.listdir(logdir) if f"_{name}_" in f]
    n = len(files)
    for i, f in enumerate(files):
        df = pd.read_csv(f"{logdir}/{f}", index_col=0)
        df.columns = [f"{name}_{i}_{x}" for x in df.columns]
        dfs.append(df)

    df = pd.concat(dfs, axis=1)
    df.loc[df[[f"{name}_{i}_status" for i in range(n)]].eq("success").any(axis=1), f"{name}_status"] = "success"
    df.loc[df[[f"{name}_{i}_status" for i in range(n)]].eq("error").any(axis=1), f"{name}_status"] = "error"
    df[f"{name}_status"] = df[f"{name}_status"].fillna("mismatch")
    
    df[f"{name}_time"] = df[[f"{name}_{i}_time" for i in range(n)]].mean(axis=1)
    df[f"{name}_time_std"] = df[[f"{name}_{i}_time" for i in range(n)]].std(axis=1)

    tfrmt = lambda x: x if pd.isna(x) else f"{x // 3600:02.0f}:{x // 60 % 60:02.0f}:{x % 60:02.0f}"
    df[f"{name}_duration"] = df[f"{name}_time"].apply(tfrmt)

    return df[[f"{name}_status", f"{name}_time", f"{name}_time_std", f"{name}_duration"]]

In [3]:
pyro_comprehensive = preprocess_average("pyro_comprehensive")
numpyro_comprehensive = preprocess_average("numpyro_comprehensive")
numpyro_mixed = preprocess_average("numpyro_mixed")
# numpyro_generative = preprocess_average("numpyro_generative")
stan = preprocess_average("stan")

In [4]:
# mean_res = pd.concat([pyro_comprehensive, numpyro_mixed, numpyro_comprehensive, numpyro_generative, stan], axis=1)
mean_res = pd.concat([stan, numpyro_comprehensive], axis=1)
mean_res['example'] = mean_res.index.map(lambda x: x.split("-")[1])
mean_res['data'] = mean_res.index.map(lambda x: x.split("-")[0])
mean_res = mean_res.sort_values(by='example')

mean_res['speedup'] = (mean_res.stan_time / mean_res.numpyro_comprehensive_time)
speedups = mean_res[mean_res.numpyro_comprehensive_status == "success"]['speedup']
mean_res['speedup'] = speedups

In [5]:
print(f"Stan successes: {len(mean_res[mean_res.stan_status == 'success'])}")
print(f"High relative std: {mean_res[mean_res.stan_time_std / mean_res.stan_time > 1.0].index.tolist()}")
mean_res = mean_res[mean_res.stan_status == "success"]
mean_res = mean_res[mean_res.stan_time_std / mean_res.stan_time < 1.0]
print(f"Valid benchs: {len(mean_res)}")
print(f"Valid speedup: {len(mean_res['speedup'].dropna())}")

Stan successes: 36
High relative std: ['arma-arma11']
Valid benchs: 35
Valid speedup: 1


In [6]:
print(f"average speedup: {scipy.stats.gmean(mean_res.speedup.dropna())}")
print(f"Relative std numpyro: {mean_res.numpyro_comprehensive_time_std.mean() / mean_res.numpyro_comprehensive_time.mean()}")
# print(f"Relative std pyro: {mean_res.pyro_comprehensive_time_std.mean() / mean_res.pyro_comprehensive_time.mean()}")
print(f"Relative std stan: {mean_res.stan_time_std.mean() / mean_res.stan_time.mean()}")

average speedup: 2.1258117964489553
Relative std numpyro: 0.6941714032306678
Relative std stan: 0.04136124286874295


In [7]:
print(mean_res[["stan_duration", "speedup"]].to_markdown())

|                                             | stan_duration   |   speedup |
|:--------------------------------------------|:----------------|----------:|
| mcycle_gp-accel_gp                          | 00:18:23        | nan       |
| arK-arK                                     | 00:00:59        | nan       |
| dogs-dogs                                   | 00:02:02        | nan       |
| dogs-dogs_log                               | 00:01:26        | nan       |
| earnings-earn_height                        | 00:01:19        | nan       |
| eight_schools-eight_schools_centered        | 00:00:02        | nan       |
| eight_schools-eight_schools_noncentered     | 00:00:02        | nan       |
| garch-garch11                               | 00:00:17        | nan       |
| gp_pois_regr-gp_regr                        | 00:00:03        | nan       |
| bball_drive_event_0-hmm_drive_0             | 00:04:09        | nan       |
| hmm_example-hmm_example                     | 00:00:31        

In [None]:
mean_res['speedup'] = mean_res['speedup'].apply(lambda x: f"{x:02.2f}")
print(
    mean_res[
        [
            "example",
            "data",
            "stan_status",
            "stan_duration",
            "pyro_comprehensive_status",
            "pyro_comprehensive_duration",
            "numpyro_comprehensive_status",
            "numpyro_comprehensive_duration",
            "numpyro_mixed_status",
            "numpyro_mixed_duration",
            "numpyro_generative_status",
            "numpyro_generative_duration",
            "speedup",
        ]
    ].to_latex(index=False)
    .replace("success", "\smark")
    .replace("error", "\emark")
    .replace("mismatch", "\mmark")
    .replace("NaN", "")
    .replace("nan", "")
)