In [None]:
try:
    import statsmodels.api as sm
    import bambi
    import numba
except:
    !pip install statsmodels bambi blackjax jax numpyro nutpie numba pyarrow

In [7]:
import sys
import os

import pandas as pd
import numpy as np

import arviz as az
import matplotlib.pyplot as plt
import statsmodels.api as sm

from tqdm.auto import tqdm

import contextlib
import os
import sys

from typing import cast

In [8]:
sys.path.append('../')

In [9]:
from brmspy import brms, prior
import bambi as bmb

In [10]:
SEED = 42

In [11]:
@contextlib.contextmanager
def silence():
    with open(os.devnull, "w") as devnull:
        old_out, old_err = sys.stdout, sys.stderr
        sys.stdout, sys.stderr = devnull, devnull
        try:
            yield
        finally:
            sys.stdout, sys.stderr = old_out, old_err

In [12]:
data = cast(pd.DataFrame, bmb.load_data("adults"))
data.info()
data.head()

categorical_cols = data.columns[data.dtypes == object].tolist()
for col in categorical_cols:
    data[col] = data[col].astype("category")

data['y'] = (data['income'] == '>50K').astype(int)
data["gid"] = np.repeat(np.arange(len(data) // 50 + 1), 50)[:len(data)]

data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32561 entries, 0 to 32560
Data columns (total 5 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   income   32561 non-null  object
 1   sex      32561 non-null  object
 2   race     32561 non-null  object
 3   age      32561 non-null  int64 
 4   hs_week  32561 non-null  int64 
dtypes: int64(2), object(3)
memory usage: 1.2+ MB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 32561 entries, 0 to 32560
Data columns (total 7 columns):
 #   Column   Non-Null Count  Dtype   
---  ------   --------------  -----   
 0   income   32561 non-null  category
 1   sex      32561 non-null  category
 2   race     32561 non-null  category
 3   age      32561 non-null  int64   
 4   hs_week  32561 non-null  int64   
 5   y        32561 non-null  int64   
 6   gid      32561 non-null  int64   
dtypes: category(3), int64(4)
memory usage: 1.1 MB


In [None]:
formulas = [
    # Baseline GLM
    {
        "brms": "y ~ sex + race + scale(age) + scale(hs_week)",
        "bambi": "y ~ sex + race + scale(age) + scale(hs_week)",
    },

    # Random intercept for factor
    {
        "brms": "y ~ sex + scale(age) + scale(hs_week) + (1 | race)",
        "bambi": "y ~ sex + scale(age) + scale(hs_week) + (1 | race)",
    },

    # Random intercept + random slope for age (gid = cluster/group id)
    {
        "brms": "y ~ sex + race + scale(age) + scale(hs_week) + (1 + scale(age) | gid)",
        "bambi": "y ~ sex + race + scale(age) + scale(hs_week) + (1 + scale(age) | gid)",
    },

    # splines
    {
        "brms":  "y ~ sex + race +  s(age, k=8, bs='bs')",
        "bambi": "y ~ sex + race + bs(age, 7, intercept=True)"
    },

    # 1D Gaussian process
    {
        "brms":  "y ~ sex + race   + gp(hs_week, k=16, c=1.5, cov='exp_quad', scale=TRUE, iso=TRUE)",
        "bambi": "y ~ sex + race + hsgp(hs_week, m=16, c=1.5, cov='ExpQuad',  scale=True, iso=True)"
    },

    # Polynomial nonlinearity (no splines/GP)
    #"y ~ sex + race + poly(age, 3) + scale(hs_week)",
]

In [95]:
bambi_backends = [
    "blackjax",
    "numpyro",
    "nutpie"
]
brms_backends = ["cmdstanr"]

In [96]:
import time
if 'timing_rows' not in globals():
    timing_rows = []
N = 1

In [97]:
def row_exists(library, backend, formula):
    def normalize(f):
        if isinstance(f, dict):
            # defensive: avoid KeyError if 'brms' missing
            return f.get('brms')
        return f

    target_formula = normalize(formula)

    for r in timing_rows:
        if r['library'] != library or r['backend'] != backend:
            continue

        if normalize(r['formula']) == target_formula:
            return True

    return False

In [98]:
for formula in tqdm(formulas):
    used_formula = formula['brms']
    
    for backend in brms_backends:
        key = f"brms_{backend}_{used_formula}"
        print(key)
        row_base = {
            'library': 'brms',
            'backend': backend,
            'formula': formula,
        }
        if row_exists(**row_base):
            continue

        try:
            for i in range(N):
                
                #with silence():
                start = time.perf_counter()
                brms_model = brms.fit(
                    formula=used_formula,
                    data=data,
                    chains=4,
                    cores=4,
                    warmup=1000,
                    iter=2000,
                    backend=backend,
                    silent=1,
                    sample=True
                )
                e = time.perf_counter() - start
                timing_rows.append({
                    **row_base,
                    'seconds': e,
                })
        except Exception as e:
            print(e)
            continue

  0%|          | 0/4 [00:00<?, ?it/s]

brms_cmdstanr_y ~ sex + race + scale(age) + scale(hs_week)
brms_cmdstanr_y ~ sex + scale(age) + scale(hs_week) + (1 | race)
Fitting model with brms (backend: cmdstanr)...


R callback write-console: Model executable is up to date!
  
R callback write-console: Start sampling
  


Running MCMC with 4 parallel chains...

Chain 1 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 2 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 3 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 4 Iteration:    1 / 2000 [  0%]  (Warmup) 
Chain 2 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 1 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 3 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 4 Iteration:  100 / 2000 [  5%]  (Warmup) 
Chain 2 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 1 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 4 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 3 Iteration:  200 / 2000 [ 10%]  (Warmup) 
Chain 2 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 1 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 4 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 2 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 3 Iteration:  300 / 2000 [ 15%]  (Warmup) 
Chain 1 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 4 Iteration:  400 / 2000 [ 20%]  (Warmup) 
Chain 3 Iteration:  400 / 200


  


brms_cmdstanr_y ~ sex + race +  s(age, k=8, bs='bs')
brms_cmdstanr_y ~ sex + race   + gp(hs_week, k=16, c=1.5, cov='exp_quad', scale=TRUE, iso=TRUE)


In [99]:
import re

In [None]:
for formula in tqdm(formulas):
    used_formula = formula['bambi']
    used_formula = re.sub(r" {2,}", " ", used_formula)
    
    for backend in bambi_backends:
        key = f"bambi_{backend}_{used_formula}"
        print(key)
        row_base = {
            'library': 'bambi',
            'backend': backend,
            'formula': formula,
        }


        if row_exists(**row_base):
            continue
        
        try:
            for i in range(N):
                with silence():
                    start = time.perf_counter()
                    bmb_model = bmb.Model(used_formula, data)
                    bmb_fitted = bmb_model.fit(
                        tune=1000, draws=1000,
                        random_seed=SEED, progressbar=False,
                        inference_method=backend,
                        chains=4,
                        cores=4,
                        #progressbar=False
                    )
                    e = time.perf_counter() - start
                    timing_rows.append({
                        **row_base,
                        'seconds': e,
                    })
        except Exception as e:
            print(e)
            continue

  0%|          | 0/4 [00:00<?, ?it/s]

bambi_blackjax_y ~ sex + race + scale(age) + scale(hs_week)
skipping {'library': 'bambi', 'backend': 'blackjax', 'formula': {'brms': 'y ~ sex + race + scale(age) + scale(hs_week)', 'bambi': 'y ~ sex + race + scale(age) + scale(hs_week)'}}
bambi_numpyro_y ~ sex + race + scale(age) + scale(hs_week)
skipping {'library': 'bambi', 'backend': 'numpyro', 'formula': {'brms': 'y ~ sex + race + scale(age) + scale(hs_week)', 'bambi': 'y ~ sex + race + scale(age) + scale(hs_week)'}}
bambi_nutpie_y ~ sex + race + scale(age) + scale(hs_week)
skipping {'library': 'bambi', 'backend': 'nutpie', 'formula': {'brms': 'y ~ sex + race + scale(age) + scale(hs_week)', 'bambi': 'y ~ sex + race + scale(age) + scale(hs_week)'}}
bambi_blackjax_y ~ sex + scale(age) + scale(hs_week) + (1 | race)
skipping {'library': 'bambi', 'backend': 'blackjax', 'formula': {'brms': 'y ~ sex + scale(age) + scale(hs_week) + (1 | race)', 'bambi': 'y ~ sex + scale(age) + scale(hs_week) + (1 | race)'}}
bambi_numpyro_y ~ sex + scale(ag

There were 58 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


bambi_numpyro_y ~ sex + race + hsgp(hs_week, m=16, c=1.5, cov='ExpQuad', scale=True, iso=True)


There were 54 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


bambi_nutpie_y ~ sex + race + hsgp(hs_week, m=16, c=1.5, cov='ExpQuad', scale=True, iso=True)
' m=16'


In [116]:
def parse_row(r):
    if isinstance(r['formula'], dict):
        r['formula'] = r['formula']['brms']
    return {
        **r
    }


In [117]:
df_timing = pd.DataFrame([parse_row(r) for r in timing_rows]).groupby(['library', 'backend', 'formula']).agg({'seconds': 'mean'})

In [118]:
df_timing

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,seconds
library,backend,formula,Unnamed: 3_level_1
bambi,blackjax,"y ~ sex + race + gp(hs_week, k=16, c=1.5, cov='exp_quad', scale=TRUE, iso=TRUE)",734.470137
bambi,blackjax,"y ~ sex + race + s(age, k=8, bs='bs')",428.794206
bambi,blackjax,y ~ sex + race + scale(age) + scale(hs_week),9.241696
bambi,blackjax,y ~ sex + scale(age) + scale(hs_week) + (1 | race),71.597975
bambi,numpyro,"y ~ sex + race + gp(hs_week, k=16, c=1.5, cov='exp_quad', scale=TRUE, iso=TRUE)",777.371472
bambi,numpyro,"y ~ sex + race + s(age, k=8, bs='bs')",439.739242
bambi,numpyro,y ~ sex + race + scale(age) + scale(hs_week),9.523644
bambi,numpyro,y ~ sex + scale(age) + scale(hs_week) + (1 | race),67.536308
bambi,nutpie,y ~ sex + race + scale(age) + scale(hs_week),7.741469
bambi,nutpie,y ~ sex + scale(age) + scale(hs_week) + (1 | race),28.55236


In [119]:
if len(df_timing) > 1:
    df_timing.to_csv("debug_performance_timings.csv")

# Visualisation

In [120]:
df_vis = pd.read_csv("debug_performance_timings.csv")

In [121]:
df_pivot = (
    df_vis
    .pivot_table(
        index="formula",
        columns=["library", "backend"],
        values="seconds"
    )
)

winner_col = df_pivot.idxmin(axis=1)
df_pivot["winner"] = winner_col

df_pivot = df_pivot.set_index("winner", append=True)
df_pivot = df_pivot.reorder_levels(["winner", "formula"])

df_pivot.style.background_gradient(
    cmap="RdYlGn_r",
    axis=1
).format("{:.2f}", na_rep="")

Unnamed: 0_level_0,library,bambi,bambi,bambi,brms
Unnamed: 0_level_1,backend,blackjax,numpyro,nutpie,cmdstanr
winner,formula,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
"('brms', 'cmdstanr')","y ~ sex + race + gp(hs_week, k=16, c=1.5, cov='exp_quad', scale=TRUE, iso=TRUE)",734.47,777.37,,494.41
"('bambi', 'blackjax')","y ~ sex + race + s(age, k=8, bs='bs')",428.79,439.74,,473.35
"('bambi', 'nutpie')",y ~ sex + race + scale(age) + scale(hs_week),9.24,9.52,7.74,20.07
"('bambi', 'nutpie')",y ~ sex + scale(age) + scale(hs_week) + (1 | race),71.6,67.54,28.55,216.24
