In [None]:
import pandas as pd
import scipy.stats

In [None]:
def preprocess_average(name, files):
    dfs = []
    n = len(files)
    for i, f in enumerate(files):
        df = pd.read_csv(f, index_col=0)
        df.columns = [f"{name}_{i}_{x}" for x in df.columns]
        dfs.append(df)

    df = pd.concat(dfs, axis=1)
    df.loc[df[[f"{name}_{i}_status" for i in range(n)]].eq("success").any(axis=1), f"{name}_status"] = "success"
    df.loc[df[[f"{name}_{i}_status" for i in range(n)]].eq("error").any(axis=1), f"{name}_status"] = "error"
    df[f"{name}_status"] = df[f"{name}_status"].fillna("mismatch")
    
    df[f"{name}_time"] = df[[f"{name}_{i}_time" for i in range(n)]].mean(axis=1)
    df[f"{name}_time_std"] = df[[f"{name}_{i}_time" for i in range(n)]].std(axis=1)

    tfrmt = lambda x: x if pd.isna(x) else f"{x // 3600:02.0f}:{x // 60 % 60:02.0f}:{x % 60:02.0f}"
    df[f"{name}_duration"] = df[f"{name}_time"].apply(tfrmt)

    return df[[f"{name}_status", f"{name}_time", f"{name}_time_std", f"{name}_duration"]]

In [None]:
pyro_comprehensive = preprocess_average(
    "pyro_comprehensive", 
    ["210113_0819_pyro_comprehensive_0.csv",
     "210113_0820_pyro_comprehensive_1.csv",
     "210113_0820_pyro_comprehensive_2.csv", 
     "210113_0818_pyro_comprehensive_3.csv", 
     "210113_0818_pyro_comprehensive_4.csv"])
numpyro_comprehensive = preprocess_average(
    "numpyro_comprehensive", 
    ["210127_0806_numpyro_comprehensive_0.csv",
     "210127_0844_numpyro_comprehensive_1.csv",
     "210127_0923_numpyro_comprehensive_2.csv", 
     "210127_1001_numpyro_comprehensive_3.csv", 
     "210127_1039_numpyro_comprehensive_4.csv"])
numpyro_mixed = preprocess_average(
    "numpyro_mixed",
    ["210112_1425_numpyro_mixed_0.csv",
     "210112_1459_numpyro_mixed_1.csv",
     "210112_1532_numpyro_mixed_2.csv",
     "210112_1605_numpyro_mixed_3.csv",
     "210112_1639_numpyro_mixed_4.csv",])
numpyro_generative = preprocess_average(
    "numpyro_generative", 
    ["210127_0713_numpyro_generative_0.csv",
     "210127_0720_numpyro_generative_1.csv",
     "210127_0728_numpyro_generative_2.csv", 
     "210127_0735_numpyro_generative_3.csv", 
     "210127_0743_numpyro_generative_4.csv"])
stan = preprocess_average(
    "stan", 
    ["210126_0558_stan_0.csv",
     "210126_0857_stan_1.csv",
     "210126_1155_stan_2.csv",
     "210126_1454_stan_3.csv",
     "210126_1752_stan_4.csv"])

In [None]:
mean_res = pd.concat([pyro_comprehensive, numpyro_mixed, numpyro_comprehensive, numpyro_generative, stan], axis=1)
mean_res['example'] = mean_res.index.map(lambda x: x.split("-")[1])
mean_res['data'] = mean_res.index.map(lambda x: x.split("-")[0])
mean_res = mean_res.sort_values(by='example')

mean_res['speedup'] = (mean_res.stan_time / mean_res.numpyro_comprehensive_time)
speedups = mean_res[mean_res.numpyro_comprehensive_status == "success"]['speedup']
mean_res['speedup'] = speedups

In [None]:
print(f"Stan successes: {len(mean_res[mean_res.stan_status == 'success'])}")
print(f"High relative std: {mean_res[mean_res.stan_time_std / mean_res.stan_time > 1.0].index.tolist()}")
mean_res = mean_res[mean_res.stan_status == "success"]
mean_res = mean_res[mean_res.stan_time_std / mean_res.stan_time < 1.0]
print(f"Valid benchs: {len(mean_res)}")
print(f"Valid speedup: {len(mean_res['speedup'].dropna())}")

In [None]:
print(f"average speedup: {scipy.stats.gmean(mean_res.speedup.dropna())}")
print(f"Relative std numpyro: {mean_res.numpyro_comprehensive_time_std.mean() / mean_res.numpyro_comprehensive_time.mean()}")
print(f"Relative std pyro: {mean_res.pyro_comprehensive_time_std.mean() / mean_res.pyro_comprehensive_time.mean()}")
print(f"Relative std stan: {mean_res.stan_time_std.mean() / mean_res.stan_time.mean()}")

In [None]:
print(mean_res[["stan_time", "numpyro_comprehensive_status","numpyro_comprehensive_time","speedup"]].to_markdown())

In [None]:
mean_res['speedup'] = mean_res['speedup'].apply(lambda x: f"{x:02.2f}")
print(
    mean_res[
        [
            "example",
            "data",
            "stan_status",
            "stan_duration",
            "pyro_comprehensive_status",
            "pyro_comprehensive_duration",
            "numpyro_comprehensive_status",
            "numpyro_comprehensive_duration",
            "numpyro_mixed_status",
            "numpyro_mixed_duration",
            "numpyro_generative_status",
            "numpyro_generative_duration",
            "speedup",
        ]
    ].to_latex(index=False)
    .replace("success", "\smark")
    .replace("error", "\emark")
    .replace("mismatch", "\mmark")
    .replace("NaN", "")
    .replace("nan", "")
)