# Imports and Defaults

In [1]:
import os

import numpy as np
import polars as pl
import seaborn as sns

In [2]:
sns.set_theme(style="darkgrid")
posterior_name = "funnel10"
data_dir = os.path.join("..", "..", "data", posterior_name)

# Load Data

In [3]:
summary = pl.read_parquet(os.path.join(data_dir, "summary.parquet"))
summary.describe()

statistic,chain,damping,max_proposals,metric,probabilistic,reduction_factor,sampler_type,step_count_method,step_size,step_size_factor,hparams,num_nans,max_se1,max_se2,p0_se1,p1_se1,p2_se1,p3_se1,p4_se1,p5_se1,p6_se1,p7_se1,p8_se1,p9_se1,p0_se2,p1_se2,p2_se2,p3_se2,p4_se2,p5_se2,p6_se2,p7_se2,p8_se2,p9_se2,step_count,step_count_factor
str,f64,f64,f64,f64,f64,f64,str,str,f64,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",4405.0,4400.0,4400.0,4400.0,4400.0,4400.0,"""4405""","""4000""",0.0,4400.0,"""4405""",4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,4405.0,0.0,400.0
"""null_count""",0.0,5.0,5.0,5.0,5.0,5.0,"""0""","""405""",4405.0,5.0,"""0""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4405.0,4005.0
"""mean""",2.0,0.265455,3.5,1.0,0.0,7.5,,,,3.3,,37592000.0,7.635495,500592.927888,0.8186,1.212679,4.273447,3.522798,0.995649,0.999372,1.134992,1.894176,2.183241,5.270488,15.672466,50478.534945,180589.853603,244501.754969,47085.170651,55906.439687,36900.317117,89425.727235,98213.218906,427794.841883,,0.9
"""std""",1.414374,0.290811,1.118161,0.0,,5.362512,,,,2.821668,,227880000.0,238.104835,22733000.0,3.786864,23.705812,202.916272,132.475016,14.574173,27.918949,35.526589,64.040809,62.398871,233.043133,140.80181,1596000.0,7922400.0,10718000.0,1544600.0,2182600.0,1174300.0,4693200.0,3676700.0,21759000.0,,0.0
"""min""",0.0,0.01,2.0,1.0,0.0,2.0,"""drghmc""","""const_step_cou…",,0.5,"""damping=0.01__…",0.0,0.008626,657.352949,1.0737e-08,3.1117e-10,3.6385e-10,1.5907e-08,2.1731e-09,2.4738e-09,3.3318e-09,1.3619e-13,7.366e-11,6.381e-10,4.678e-07,0.05311,0.001128,0.015753,0.000114,0.019178,0.001389,0.009555,0.00167,0.164942,,0.9
"""25%""",1.0,0.05,3.0,1.0,,4.0,,,,1.0,,0.0,0.118217,6152.419332,0.015954,0.003052,0.004537,0.004778,0.003571,0.003397,0.003836,0.003248,0.003786,0.005072,0.60869,2215.552114,3163.650654,2179.371173,2424.042337,2697.949156,1678.59147,1879.494617,3545.852294,5782.851454,,0.9
"""50%""",2.0,0.1,4.0,1.0,,8.0,,,,2.0,,0.0,0.267004,7960.521565,0.074054,0.019896,0.02597,0.031057,0.021644,0.019958,0.022435,0.017359,0.021224,0.029593,2.241711,3463.102847,4733.005016,3444.387183,3719.115977,4086.503437,2769.913836,3007.977455,5119.953783,7775.538563,,0.9
"""75%""",3.0,0.5,4.0,1.0,,8.0,,,,5.0,,0.0,0.820601,9338.017748,0.242269,0.093202,0.095382,0.142627,0.083586,0.078885,0.090421,0.08403,0.089133,0.108441,6.272388,4371.021976,5824.459017,4410.553605,4695.904149,5156.207261,3678.760114,3892.501688,6271.703371,9132.736529,,0.9
"""max""",4.0,1.0,5.0,1.0,0.0,16.0,"""nuts""","""const_traj_len…",,8.0,"""sampler_type=n…",5814300000.0,14904.869886,1406900000.0,90.782996,1055.416498,13024.938125,7294.765753,498.794175,1727.926927,2228.841747,3953.501377,2884.438543,14904.869886,7230.685296,86146000.0,392800000.0,538640000.0,87763000.0,117420000.0,54094000.0,308730000.0,195670000.0,1406900000.0,,0.9


# Squared Error Across Chains

WLOG consider a model parameter with true value $\theta$ and an estimate $\hat{\theta} = \frac{1}{n} \sum_{i=1}^n \theta_i$, with $\theta_i$ s drawn from our sampler. The squared error of our estimate is then:

$$ \text{SE} = (\hat{\theta} - \theta)^2 $$

We compute the squared error across different gradient evaluations to see how the squared error changes as we generate more draws:

$$ \text{SE}_i = (\hat{\theta}_i - \theta)^2 $$

where $\hat{\theta}_i = \frac{1}{i} \sum_{j=1}^i \theta_j$

To compute squared error *across* chains we cannot simply average the squared error because it is not a linear calculation. Instead we concatonate chains into one long series of draws and then compute the squared error.

For $C$ chains, construct $\hat{\theta}_i$ as follows:

$$ \hat{\theta}_i = \frac{1}{i} \sum_{j=1}^i \sum_{k=1}^C \theta_{jk} $$

where $\theta_{jk}$ is the $j$ th draw from the $k$ th chain.

In [43]:
def parse_summary_row(summary_row):
    important_keys = set([
        "hparams",
        "chain",
        "damping",
        "max_proposals", 
        "metric",
        "probabilistic",
        "reduction_factor",
        "sampler_type",
        "step_count_method",
        "step_size",
        "step_size_factor",
        "num_nans",
        "step_count",
        "step_count_factor",
    ])
    
    hparam_dict = dict()
    for key in important_keys:
        hparam_dict[key] = summary_row[key]
    return hparam_dict

def get_dir_name(row_dict):
    if row_dict["sampler_type"] == "nuts":
        return f"burn_in=0__chain={row_dict['chain']}__gradient_budget=1000000__metric={row_dict['metric']}__sampler_type={row_dict['sampler_type']}__seed=1234"
    
    elif row_dict["sampler_type"] == "drghmc":
        return f"burn_in=0__chain={row_dict['chain']}__damping={row_dict['damping']}__gradient_budget=1000000__max_proposals={row_dict['max_proposals']}__metric={row_dict['metric']}__probabilistic={row_dict['probabilistic']}__reduction_factor={row_dict['reduction_factor']}__sampler_type={row_dict['sampler_type']}__seed=1234__step_count_method={row_dict['step_count_method']}__step_size={row_dict['step_size']}__step_size_factor={row_dict['step_size_factor']}"
    
    elif row_dict["sampler_type"] == "drhmc":
        return f"burn_in=0__chain={row_dict['chain']}__damping={int(row_dict['damping'])}__gradient_budget=1000000__max_proposals={row_dict['max_proposals']}__metric={row_dict['metric']}__probabilistic={row_dict['probabilistic']}__reduction_factor={row_dict['reduction_factor']}__sampler_type={row_dict['sampler_type']}__seed=1234__step_count={row_dict['step_count']}__step_count_factor={row_dict['step_count_factor']}__step_size={row_dict['step_size']}__step_size_factor={int(row_dict['step_size_factor'])}"
    
    else:
        raise ValueError(f"Unknown sampler type: {row_dict['sampler_type']}")
    
def parse_history(hparams_dict, history_keys):
    path = os.path.join(data_dir, get_dir_name(hparams_dict), "history.npz")
    history = np.load(path)
    
    rows = []
    if "draws" in history_keys:
        hparams_cp = hparams_dict.copy()
        for draw in history["draws"]:
            for idx, param in enumerate(draw):
                hparams_cp[f"p{idx}"] = param
                rows.append(hparams_cp)
        history_keys.remove("draws")
        
    for k in history_keys:
        hparams_cp = hparams_dict.copy()
        for v in history[k]:
            hparams_cp[k] = v
            rows.append(hparams_cp)
        
    return rows

def parse_summary(summary, history_keys=None, metric_keys=None):
    rows = []
    for summary_row in summary.rows(named=True):
        hparams_dict = parse_summary_row(summary_row)
        
        if history_keys:
            history_dict = parse_history(hparams_dict, history_keys)
            rows.extend(history_dict)
            
        # if metric_keys:
            # metric_dict = parse_metrics(hparams, metrics_keys)
            # rows.extend(metric_dict)
        break
    
    return pl.DataFrame(rows)

In [44]:
filtered = summary.filter(pl.col("sampler_type") == "drhmc")
df = parse_summary(filtered, history_keys=["draws", "acceptance"], metric_keys=None)
df.describe()

statistic,probabilistic,hparams,damping,reduction_factor,metric,max_proposals,step_count_method,step_size,step_count_factor,step_size_factor,num_nans,step_count,chain,sampler_type,p0,p1,p2,p3,p4,p5,p6,p7,p8,p9
str,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",36916.0,"""36916""",36916.0,36916.0,36916.0,36916.0,0.0,0.0,36916.0,36916.0,36916.0,0.0,36916.0,"""36916""",33560.0,33560.0,33560.0,33560.0,33560.0,33560.0,33560.0,33560.0,33560.0,33560.0
"""null_count""",0.0,"""0""",0.0,0.0,0.0,0.0,36916.0,36916.0,0.0,0.0,0.0,36916.0,0.0,"""0""",3356.0,3356.0,3356.0,3356.0,3356.0,3356.0,3356.0,3356.0,3356.0,3356.0
"""mean""",0.0,,1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,,-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
"""std""",,,0.0,0.0,0.0,0.0,,,2.2205e-16,0.0,0.0,,0.0,,0.0,1.1102e-16,1.1102e-16,2.2205e-16,2.7756e-17,0.0,5.5512e-17,6.939e-18,1.1102e-16,0.0
"""min""",0.0,"""damping=1__max…",1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,"""drhmc""",-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
"""25%""",,,1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,,-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
"""50%""",,,1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,,-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
"""75%""",,,1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,,-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
"""max""",0.0,"""damping=1__max…",1.0,8.0,1.0,2.0,,,0.9,5.0,162161.0,,0.0,"""drhmc""",-0.136604,-0.59991,-0.935525,1.900947,-0.194565,0.41155,-0.264543,-0.041403,-0.973988,-0.125921
