In [1]:
import pandas as pd

In [19]:
def preprocess(name, csvfile):
    df = pd.read_csv(csvfile, index_col=0)
    tfrmt = lambda x: x if pd.isna(x) else f"{x // 3600:02.0f}:{x // 60 % 60:02.0f}:{x % 60:02.0f}"
    df['duration'] = df['time'].apply(tfrmt)
    df.columns = [f"{name}_{x}" for x in df.columns]
    return df

In [20]:
pyro_mixed = preprocess("pyro_mixed", "201104_1633_pyro_mixed.csv")
numpyro_mixed = preprocess("numpyro_mixed", "201113_0434_numpyro_mixed.csv")
numpyro_comprehensive = preprocess("numpyro_comprehensive", "201113_0515_numpyro_comprehensive.csv")
numpyro_generative = preprocess("numpyro_generative", "201118_1505_numpyro_generative.csv")
stan = preprocess("stan", "201114_0927_stan_mixed.csv")

In [21]:
results = pd.concat([pyro_mixed, numpyro_mixed, numpyro_comprehensive, numpyro_generative, stan], axis=1)
results = results[results.stan_status != "mismatch"]
results['example'] = results.index.map(lambda x: x.split("-")[1])
results['data'] = results.index.map(lambda x: x.split("-")[0])
results = results.sort_values(by='example')
results['speedup'] = results.stan_time / results.numpyro_mixed_time
results['speedup'] = results.speedup.apply(lambda x: f"{x:2.1f}")

In [22]:
summary = results[["example", "data", "stan_time", "stan_duration", "numpyro_mixed_time", "speedup"]]
print(summary.to_markdown(index=False))

| example                   | data                 |   stan_time | stan_duration   |   numpyro_mixed_time |     speedup |
|:--------------------------|:---------------------|------------:|:----------------|---------------------:|------------:|
| arK                       | arK                  |    55.5189  | 00:00:56        |             11.8642  |   4.67952   |
| dogs                      | dogs                 |   115.64    | 00:01:56        |             41.6432  |   2.77693   |
| dogs_log                  | dogs                 |    86.5672  | 00:01:27        |             42.0574  |   2.05831   |
| earn_height               | earnings             |    79.0459  | 00:01:19        |             11.7242  |   6.74213   |
| eight_schools_centered    | eight_schools        |    12.0772  | 00:00:12        |              6.8801  |   1.75538   |
| eight_schools_noncentered | eight_schools        |     1.95722 | 00:00:02        |             47.4012  |   0.0412905 |
| garch11               

In [6]:
len(summary)

32

In [23]:
all = results[["example", "data", "stan_status", "stan_duration","pyro_mixed_status", "pyro_mixed_duration", "numpyro_mixed_status", "numpyro_mixed_duration", "numpyro_comprehensive_status", "numpyro_comprehensive_duration","numpyro_generative_status", "numpyro_generative_duration",
"speedup"]]

import re
tex_summary = all.replace("success", "smark").replace("mismatch", "mmark").replace("error", "emark").replace("NaN", "").replace("nan","")
print(re.sub(
    pattern=r"([sem]mark)",
    repl="\\\\\\1",
    string = tex_summary.to_latex(index=False)
))


\begin{tabular}{llllllllllllr}
\toprule
                   example &                  data & stan\_status & stan\_duration & pyro\_mixed\_status & pyro\_mixed\_duration & numpyro\_mixed\_status & numpyro\_mixed\_duration & numpyro\_comprehensive\_status & numpyro\_comprehensive\_duration & numpyro\_generative\_status & numpyro\_generative\_duration &    speedup \\
\midrule
                       arK &                   arK &       \smark &      00:00:56 &             \smark &            20:49:07 &                \mmark &               00:00:12 &                        \mmark &                       00:00:06 &                     \mmark &                    00:00:12 &   4.679518 \\
                      dogs &                  dogs &       \smark &      00:01:56 &             \smark &            15:50:05 &                \mmark &               00:00:42 &                        \emark &                            NaN &                     \mmark &                    00:00:42 &   2.776931

In [8]:
speedup = results["stan_time"] / results["numpyro_mixed_time"]

In [9]:
speedup

arK-arK                                         4.679518
dogs-dogs                                       2.776931
dogs-dogs_log                                   2.058313
earnings-earn_height                            6.742132
eight_schools-eight_schools_centered            1.755376
eight_schools-eight_schools_noncentered         0.041291
garch-garch11                                        NaN
gp_pois_regr-gp_regr                            0.219060
bball_drive_event_0-hmm_drive_0                      NaN
hmm_example-hmm_example                              NaN
kidiq-kidscore_interaction                     10.128142
kidiq_with_mom_work-kidscore_interaction_c2     1.702784
kidiq_with_mom_work-kidscore_mom_work           2.106571
kidiq-kidscore_momhs                            0.937358
kidiq-kidscore_momhsiq                          3.845383
kidiq-kidscore_momiq                            1.909499
kilpisjarvi_mod-kilpisjarvi                     3.918809
earnings-log10earn_height      

In [10]:
import scipy.stats

scipy.stats.gmean(speedup.dropna())

3.6953204111892295

In [11]:
speedup.std()

8.619813843161419

In [12]:
stan_cache = preprocess("stan_cache", "201114_0927_stan_mixed.csv")
stan_compile = preprocess("stan_compile", "201118_0506_stan_mixed_compile.csv")
numpyro_cache = preprocess("numpyro_cache", "201118_0759_numpyro_mixed.csv")
numpyro_compile = preprocess("numpyro_compile", "201118_0858_numpyro_mixed_compile.csv")
numpyro_compile_dune = preprocess("numpyro_compile_dune", "201118_0723_numpyro_mixed_compile.csv")

In [13]:
compilation = pd.concat([stan_cache, stan_compile, numpyro_cache, numpyro_compile, numpyro_compile_dune], axis=1)
compilation = compilation[compilation.stan_cache_status != "mismatch"]

In [14]:
compilation["stan_compilation"] = compilation.stan_compile_time - compilation.stan_cache_time
compilation["numpyro_compilation"] = compilation.numpyro_compile_time - compilation.numpyro_cache_time
compilation["numpyro_compilation_dune"] = compilation.numpyro_compile_dune_time - compilation.numpyro_cache_time

In [15]:
compilation[["stan_compilation", "numpyro_compilation", "numpyro_compilation_dune"]].mean()

stan_compilation            10.986828
numpyro_compilation          0.450057
numpyro_compilation_dune     0.715769
dtype: float64

In [16]:
compilation[["stan_compilation", "numpyro_compilation", "numpyro_compilation_dune"]].std(axis=0)

stan_compilation            8.470852
numpyro_compilation         0.925645
numpyro_compilation_dune    0.497894
dtype: float64