# Group-level comparison across confidence SNR

This notebook simulates 20 participants (100 trials each) across 5 levels of type-2 noise
(lower noise = higher confidence SNR) and compares four group-level approaches:

- MLE per-subject mean
- MLE pooled counts (per condition)
- Pooled Bayesian (per condition)
- Hierarchical Bayesian (rm1way, condition offsets)


In [1]:
import os
os.environ.setdefault("PYTENSOR_FLAGS", "compiledir=/tmp/pytensor")
os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")

import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt

from metadpy.utils import type2_SDT_simuation, ratings2df, responseSimulation, trials2counts
from metadpy.mle import metad
from metadpy.bayesian import hmetad, hmetad_pooled


In [2]:
np.random.seed(123)

n_subjects = 20
n_trials = 500
n_ratings = 4
d = 1.0
c = 0.0

noise_levels = np.array([1.0, 0.8, 0.6, 0.4, 0.2])
snr_levels = 1.0 / noise_levels

num_samples = 400
num_chains = 2
tune = 400
target_accept = 0.9
random_seed = 123

calib_trials = 200000
mc_reps = 200


In [3]:
n_levels = len(noise_levels)
nR_S1_all = np.zeros((n_subjects, n_levels, 2 * n_ratings))
nR_S2_all = np.zeros_like(nR_S1_all)
frames = []

for subj in range(n_subjects):
    for idx, noise in enumerate(noise_levels):
        nR_S1, nR_S2 = type2_SDT_simuation(
            d=d,
            noise=float(noise),
            c=c,
            nRatings=n_ratings,
            nTrials=n_trials,
        )
        nR_S1 = nR_S1.astype(int)
        nR_S2 = nR_S2.astype(int)
        nR_S1_all[subj, idx, :] = nR_S1
        nR_S2_all[subj, idx, :] = nR_S2
        df = ratings2df(nR_S1, nR_S2)
        df["Subject"] = subj
        df["Condition"] = snr_levels[idx]
        frames.append(df)

sim_df = pd.concat(frames, ignore_index=True)
sim_df.head()


In [4]:
mle_df = metad(
    data=sim_df,
    nRatings=n_ratings,
    stimuli="Stimuli",
    accuracy="Accuracy",
    confidence="Confidence",
    subject="Subject",
    within="Condition",
    padding=False,
    verbose=0,
)

mle_group = (
    mle_df.groupby("Condition")["m_ratio"]
    .agg(mean="mean", sem=lambda x: x.std(ddof=1) / np.sqrt(x.count()))
    .reset_index()
)
mle_group["lower"] = mle_group["mean"] - 1.96 * mle_group["sem"]
mle_group["upper"] = mle_group["mean"] + 1.96 * mle_group["sem"]
mle_group["method"] = "MLE (per-subject mean)"


  dx = [(x0[i] + h[i]) - x0[i] for i in range(n)]
  self.f = self.J.dot(self.x)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  c_ineq + s))
  return diag_elements*vec
  results_df = pd.concat([results_df, results])


In [5]:
pooled_rows = []
for idx, snr in enumerate(snr_levels):
    nR_S1_pool = nR_S1_all[:, idx, :].sum(axis=0)
    nR_S2_pool = nR_S2_all[:, idx, :].sum(axis=0)
    res = metad(
        nR_S1=nR_S1_pool,
        nR_S2=nR_S2_pool,
        nRatings=n_ratings,
        padding=False,
        verbose=0,
    )
    pooled_rows.append(
        {
            "Condition": float(snr),
            "mean": float(res["m_ratio"].iloc[0]),
            "lower": float(res["m_ratio"].iloc[0]),
            "upper": float(res["m_ratio"].iloc[0]),
            "method": "MLE (pooled counts)",
        }
    )

pooled_mle = pd.DataFrame(pooled_rows)


In [6]:
def summarize_posterior(da, hdi_prob=0.94):
    if hasattr(da, "to_array"):
        da = da.to_array()
    mean = da.mean(dim=("chain", "draw")).values
    hdi = az.hdi(da, hdi_prob=hdi_prob)
    if hasattr(hdi, "to_array"):
        hdi = hdi.to_array()
    lower = hdi.sel(hdi="lower").values
    upper = hdi.sel(hdi="higher").values
    return np.squeeze(mean), np.squeeze(lower), np.squeeze(upper)

pooled_models = hmetad_pooled(
    data=sim_df,
    nRatings=n_ratings,
    stimuli="Stimuli",
    accuracy="Accuracy",
    confidence="Confidence",
    within="Condition",
    sample_model=True,
    num_samples=num_samples,
    num_chains=num_chains,
    tune=tune,
    target_accept=target_accept,
    random_seed=random_seed,
    progressbar=False,
    output="model",
)

pooled_bayes_rows = []
for cond, (model, trace) in pooled_models.items():
    mratio_post = trace.posterior["meta_d"] / trace.posterior["d1"]
    mean, lower, upper = summarize_posterior(mratio_post)
    pooled_bayes_rows.append(
        {
            "Condition": float(cond),
            "mean": float(mean),
            "lower": float(lower),
            "upper": float(upper),
            "method": "Bayes (pooled)",
        }
    )

pooled_bayes = pd.DataFrame(pooled_bayes_rows)


Initializing NUTS using jitter+adapt_diag...
ERROR (pytensor.graph.rewriting.basic): SequentialGraphRewriter apply <pytensor.tensor.rewriting.elemwise.FusionOptimizer object at 0x11db08950>
ERROR (pytensor.graph.rewriting.basic): Traceback:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/Users/yifei/anaconda3/envs/metadpy-dev/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 289, in apply
    sub_prof = rewriter.apply(fgraph)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yifei/anaconda3/envs/metadpy-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 886, in apply
    scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yifei/anaconda3/envs/metadpy-dev/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py", line 538, in elemwise_to_scalar
    scalar_inputs = [replacement[inp]

In [7]:
hb_model, hb_trace = hmetad(
    data=sim_df,
    nRatings=n_ratings,
    stimuli="Stimuli",
    accuracy="Accuracy",
    confidence="Confidence",
    subject="Subject",
    within="Condition",
    sample_model=True,
    num_samples=num_samples,
    num_chains=num_chains,
    tune=tune,
    target_accept=target_accept,
    random_seed=random_seed,
    progressbar=False,
)

group_meta = hb_trace.posterior["mu_meta_d"] + hb_trace.posterior["condition_offset"]
group_mratio = group_meta / hb_trace.posterior["mu_d1"]
hb_mean, hb_lower, hb_upper = summarize_posterior(group_mratio)

hb_rows = []
for cond, mean, lower, upper in zip(snr_levels, hb_mean, hb_lower, hb_upper):
    hb_rows.append(
        {
            "Condition": float(cond),
            "mean": float(mean),
            "lower": float(lower),
            "upper": float(upper),
            "method": "Bayes (hierarchical)",
        }
    )

hb_summary = pd.DataFrame(hb_rows)


                                                                                
                              Step      Grad      Sampli…                       
  Progre…   Draws   Diverg…   size      evals     Speed     Elapsed   Remaini…  
 ────────────────────────────────────────────────────────────────────────────── 
  ━━━━━━━   1400    400       0.007     3         19.81     0:01:10   0:00:00   
                                                  draws/s                       
  ━━━━━━━   1400    1         0.022     255       9.83      0:02:22   0:00:00   
                                                  draws/s                       
                                                                                


In [8]:
summary = pd.concat([mle_group, pooled_mle, pooled_bayes, hb_summary], ignore_index=True)
summary = summary.sort_values(["method", "Condition"])

fig, ax = plt.subplots(figsize=(8, 4))
methods = ["MLE (per-subject mean)", "MLE (pooled counts)", "Bayes (pooled)", "Bayes (hierarchical)"]
markers = {
    "MLE (per-subject mean)": "o",
    "MLE (pooled counts)": "s",
    "Bayes (pooled)": "D",
    "Bayes (hierarchical)": "^",
}

for method in methods:
    sub = summary[summary["method"] == method].sort_values("Condition")
    yerr = np.vstack([sub["mean"] - sub["lower"], sub["upper"] - sub["mean"]])
    ax.errorbar(
        sub["Condition"],
        sub["mean"],
        yerr=yerr,
        label=method,
        marker=markers[method],
        capsize=3,
    )

ax.set_xlabel("Confidence SNR (1 / noise)")
ax.set_ylabel("Group m_ratio")
ax.set_title("Group-level meta-d'/d' across type-2 noise levels")
ax.legend(frameon=False)
fig.savefig("notebooks/group_level_methods_comparison.png", dpi=150, bbox_inches="tight")
plt.show()




In [9]:
effective_mratio = []
for noise in noise_levels:
    nR_S1_truth, nR_S2_truth = type2_SDT_simuation(
        d=d,
        noise=float(noise),
        c=c,
        nRatings=n_ratings,
        nTrials=calib_trials,
    )
    nR_S1_truth = nR_S1_truth.astype(int)
    nR_S2_truth = nR_S2_truth.astype(int)
    calib_res = metad(
        nR_S1=nR_S1_truth,
        nR_S2=nR_S2_truth,
        nRatings=n_ratings,
        padding=True,
        verbose=0,
    )
    effective_mratio.append(float(calib_res["m_ratio"].iloc[0]))

effective_mratio = np.array(effective_mratio)

truth_rows = []
for snr, mratio in zip(snr_levels, effective_mratio):
    mc_vals = []
    for rep in range(mc_reps):
        sim_truth = responseSimulation(
            d=d,
            metad=mratio * d,
            c=c,
            nRatings=n_ratings,
            nTrials=n_trials,
            nSubjects=n_subjects,
        )
        nR_S1_pool, nR_S2_pool = trials2counts(
            data=sim_truth,
            stimuli="Stimuli",
            accuracy="Accuracy",
            confidence="Confidence",
            nRatings=n_ratings,
            padding=True,
        )
        res = metad(
            nR_S1=nR_S1_pool,
            nR_S2=nR_S2_pool,
            nRatings=n_ratings,
            padding=True,
            verbose=0,
        )
        val = float(res["m_ratio"].iloc[0])
        if np.isfinite(val):
            mc_vals.append(val)
    truth_rows.append(
        {
            "Condition": float(snr),
            "truth_mratio": float(np.mean(mc_vals)),
        }
    )

truth_df = pd.DataFrame(truth_rows)
summary_with_truth = summary.merge(truth_df, on="Condition", how="left")
summary_with_truth["diff"] = summary_with_truth["mean"] - summary_with_truth["truth_mratio"]
summary_with_truth["diff_lower"] = summary_with_truth["lower"] - summary_with_truth["truth_mratio"]
summary_with_truth["diff_upper"] = summary_with_truth["upper"] - summary_with_truth["truth_mratio"]
summary_with_truth


In [10]:
fig, ax = plt.subplots(figsize=(8, 4))
for method in methods:
    sub = summary_with_truth[summary_with_truth["method"] == method].sort_values("Condition")
    yerr = np.vstack([sub["diff"] - sub["diff_lower"], sub["diff_upper"] - sub["diff"]])
    ax.errorbar(
        sub["Condition"],
        sub["diff"],
        yerr=yerr,
        label=method,
        marker=markers[method],
        capsize=3,
    )

ax.axhline(0, color="black", linewidth=1, alpha=0.5)
ax.set_xlabel("Confidence SNR (1 / noise)")
ax.set_ylabel("Estimate - ground truth (m_ratio)")
ax.set_title("Bias relative to ground truth across type-2 noise levels")
ax.legend(frameon=False)
fig.savefig("notebooks/group_level_methods_comparison_bias.png", dpi=150, bbox_inches="tight")
plt.show()




## Notes
- Lower noise (higher SNR) should yield higher estimated m_ratio across methods.
- The hierarchical model typically provides more stable group-level trends
  when per-subject estimates are noisy (low SNR / few trials).
- Sampling settings are intentionally light for speed; increase draws for final runs.


## Low-trial scenario (fewer trials per subject)

This section repeats the comparison with fewer trials per subject to highlight
the advantage of hierarchical Bayesian shrinkage under noisy individual estimates.


In [11]:
n_trials_low = 100
n_subjects_low = n_subjects
padding_low = True

n_levels = len(noise_levels)
nR_S1_low = np.zeros((n_subjects_low, n_levels, 2 * n_ratings))
nR_S2_low = np.zeros_like(nR_S1_low)
frames_low = []

for subj in range(n_subjects_low):
    for idx, noise in enumerate(noise_levels):
        nR_S1, nR_S2 = type2_SDT_simuation(
            d=d,
            noise=float(noise),
            c=c,
            nRatings=n_ratings,
            nTrials=n_trials_low,
        )
        nR_S1 = nR_S1.astype(int)
        nR_S2 = nR_S2.astype(int)
        nR_S1_low[subj, idx, :] = nR_S1
        nR_S2_low[subj, idx, :] = nR_S2
        df = ratings2df(nR_S1, nR_S2)
        df["Subject"] = subj
        df["Condition"] = snr_levels[idx]
        frames_low.append(df)

sim_df_low = pd.concat(frames_low, ignore_index=True)
sim_df_low.head()


In [12]:
mle_rows_low = []
subjects_low = pd.unique(sim_df_low["Subject"])
conditions_low = pd.unique(sim_df_low["Condition"])
for sub in subjects_low:
    for cond in conditions_low:
        subset = sim_df_low[(sim_df_low["Subject"] == sub) & (sim_df_low["Condition"] == cond)]
        nR_S1_sub, nR_S2_sub = trials2counts(
            data=subset,
            stimuli="Stimuli",
            accuracy="Accuracy",
            confidence="Confidence",
            nRatings=n_ratings,
            padding=padding_low,
        )
        try:
            res = metad(
                nR_S1=nR_S1_sub,
                nR_S2=nR_S2_sub,
                nRatings=n_ratings,
                padding=False,
                verbose=0,
            )
            res = res.assign(Subject=sub, Condition=cond)
        except Exception:
            res = pd.DataFrame(
                {
                    "dprime": [np.nan],
                    "meta_d": [np.nan],
                    "m_ratio": [np.nan],
                    "m_diff": [np.nan],
                    "Subject": [sub],
                    "Condition": [cond],
                }
            )
        mle_rows_low.append(res)

mle_df_low = pd.concat(mle_rows_low, ignore_index=True)

mle_group_low = (
    mle_df_low.groupby("Condition")["m_ratio"]
    .agg(mean="mean", sem=lambda x: x.std(ddof=1) / np.sqrt(x.count()))
    .reset_index()
)
mle_group_low["lower"] = mle_group_low["mean"] - 1.96 * mle_group_low["sem"]
mle_group_low["upper"] = mle_group_low["mean"] + 1.96 * mle_group_low["sem"]
mle_group_low["method"] = "MLE (per-subject mean)"


In [13]:
pooled_rows_low = []
for idx, snr in enumerate(snr_levels):
    nR_S1_pool = nR_S1_low[:, idx, :].sum(axis=0)
    nR_S2_pool = nR_S2_low[:, idx, :].sum(axis=0)
    try:
        res = metad(
            nR_S1=nR_S1_pool,
            nR_S2=nR_S2_pool,
            nRatings=n_ratings,
            padding=padding_low,
            verbose=0,
        )
        mean_val = float(res["m_ratio"].iloc[0])
    except Exception:
        mean_val = np.nan
    pooled_rows_low.append(
        {
            "Condition": float(snr),
            "mean": mean_val,
            "lower": mean_val,
            "upper": mean_val,
            "method": "MLE (pooled counts)",
        }
    )

pooled_mle_low = pd.DataFrame(pooled_rows_low)


In [14]:
initvals_low = {
    "c1": 0.0,
    "d1": 1.0,
    "meta_d": 1.0,
    "cS1_hn": np.full(n_ratings - 1, 0.5),
    "cS2_hn": np.full(n_ratings - 1, 0.5),
}

try:
    pooled_models_low = hmetad_pooled(
        data=sim_df_low,
        nRatings=n_ratings,
        stimuli="Stimuli",
        accuracy="Accuracy",
        confidence="Confidence",
        within="Condition",
        padding=padding_low,
        sample_model=True,
        num_samples=num_samples,
        num_chains=1,
        tune=tune,
        target_accept=target_accept,
        random_seed=random_seed,
        progressbar=False,
        output="model",
        init="adapt_diag",
        initvals=initvals_low,
        jitter_max_retries=0,
        cores=1,
    )

    pooled_bayes_rows_low = []
    for cond, (model, trace) in pooled_models_low.items():
        mratio_post = trace.posterior["meta_d"] / trace.posterior["d1"]
        mean, lower, upper = summarize_posterior(mratio_post)
        pooled_bayes_rows_low.append(
            {
                "Condition": float(cond),
                "mean": float(mean),
                "lower": float(lower),
                "upper": float(upper),
                "method": "Bayes (pooled)",
            }
        )

    pooled_bayes_low = pd.DataFrame(pooled_bayes_rows_low)
except Exception:
    pooled_bayes_low = pd.DataFrame(
        {
            "Condition": snr_levels.astype(float),
            "mean": np.nan,
            "lower": np.nan,
            "upper": np.nan,
            "method": "Bayes (pooled)",
        }
    )


In [15]:
hb_model_low, hb_trace_low = hmetad(
    data=sim_df_low,
    nRatings=n_ratings,
    stimuli="Stimuli",
    accuracy="Accuracy",
    confidence="Confidence",
    subject="Subject",
    within="Condition",
    padding=padding_low,
    sample_model=True,
    num_samples=num_samples,
    num_chains=num_chains,
    tune=tune,
    target_accept=target_accept,
    random_seed=random_seed,
    progressbar=False,
)

group_meta_low = hb_trace_low.posterior["mu_meta_d"] + hb_trace_low.posterior["condition_offset"]
group_mratio_low = group_meta_low / hb_trace_low.posterior["mu_d1"]
hb_mean_low, hb_lower_low, hb_upper_low = summarize_posterior(group_mratio_low)

hb_rows_low = []
for cond, mean, lower, upper in zip(snr_levels, hb_mean_low, hb_lower_low, hb_upper_low):
    hb_rows_low.append(
        {
            "Condition": float(cond),
            "mean": float(mean),
            "lower": float(lower),
            "upper": float(upper),
            "method": "Bayes (hierarchical)",
        }
    )

hb_summary_low = pd.DataFrame(hb_rows_low)


SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'mu_d1': array(-0.12020242), 'sigma_d1_log__': array(0.10291319), 'd1': array([[-0.42676494, -0.47243314, -0.2790231 , -0.10287953, -0.4502611 ],
       [-0.18607105, -0.77220957, -0.11944165, -0.17509695,  0.83033442],
       [-0.12229742, -0.51434947, -1.03117288,  0.72560764,  0.85613197],
       [-0.16940658, -0.0641974 , -0.70056302,  0.39876904, -0.10460456],
       [-0.69952706,  0.74035956,  0.18874597, -0.34404821,  0.62120573],
       [ 0.11523019,  0.54782616, -0.54868389, -0.19100467, -0.19460093],
       [-1.0952814 ,  0.35572742, -0.25001411,  0.27746568, -0.96097275],
       [-0.89380012, -0.29982289, -0.3535083 , -0.07675784, -0.96379132],
       [ 0.86903144,  0.10349025,  0.35848654, -0.8487804 ,  0.00193005],
       [-0.39389313, -0.62123198,  0.19445306, -0.32966312, -0.61925221],
       [ 0.81567693,  0.60100236, -0.54218276, -0.31493611,  0.28382398],
       [-0.07558377, -0.92050364,  0.50345236,  0.54358072, -0.77633891],
       [ 0.73673788,  0.11384026,  0.28001775,  0.00663031,  0.78788095],
       [-1.08227899,  0.21014344,  0.77867508, -0.63547993,  0.54065587],
       [-0.6168622 ,  0.28135834,  0.23331462,  0.39245985, -0.66071539],
       [ 0.656924  ,  0.35777586,  0.84821459, -0.87141374, -0.36569395],
       [-0.38602853, -1.11562025, -1.10055788,  0.78207111,  0.05867159],
       [-0.28162306, -1.04187009,  0.49417851,  0.41422476, -0.06643469],
       [-0.80054636, -0.5790098 ,  0.49092572,  0.74623221,  0.05355142],
       [-1.04599058,  0.1118164 ,  0.71667673, -1.08498777,  0.33192297]]), 'mu_c1': array(-0.08150676), 'sigma_c1_log__': array(-0.78045794), 'c1': array([[-0.79126686,  0.64489063, -0.96093728,  0.01448246,  0.54762627],
       [ 0.1200497 , -0.64624387, -0.66195155, -0.94082867, -0.36909859],
       [-1.0678165 , -0.28488921,  0.78530195, -0.55691758, -0.37222104],
       [-0.74507823, -0.60360864, -0.62430917,  0.63360368, -0.06584033],
       [-0.31640423,  0.80077652, -0.49925883, -0.53842279, -0.09794889],
       [-0.44693086, -0.20511269, -0.47494298, -0.14604402, -0.31994003],
       [ 0.48115201, -0.62484883,  0.59039413,  0.11037339, -0.81053186],
       [ 0.08992065,  0.13259192,  0.13011534, -0.02336079,  0.7707327 ],
       [-0.76011225,  0.76670613, -0.72564688, -0.47108303, -0.57421988],
       [-0.48719373, -0.64342236,  0.31270582, -0.7054362 , -0.69070566],
       [-0.72795611, -0.4671311 , -1.01332345, -0.16221177,  0.79451246],
       [-0.49768913, -0.30897634, -1.01303673,  0.49641932, -0.73084576],
       [-1.02386185, -0.11752069,  0.51485668,  0.68569552, -0.50646106],
       [-0.96911703,  0.56482451, -0.08655838, -0.8754314 ,  0.47507308],
       [-0.26822983,  0.18520215, -0.76301554,  0.33683526,  0.60603962],
       [-1.04001048,  0.52766716,  0.19636119, -0.5297796 , -0.25037574],
       [-0.05176268, -0.56799117, -0.11038595,  0.39948477, -1.01359022],
       [-1.0154746 ,  0.91541571, -0.43956876, -0.28225313,  0.32118957],
       [-0.50371402, -0.08921941, -0.66109224, -0.6543417 , -0.8220932 ],
       [-0.87120076,  0.69128012, -0.54607637,  0.1344137 ,  0.59499951]]), 'mu_meta_d': array(0.33704139), 'sigma_meta_d_log__': array(0.4896409), 'condition_offset_raw': array([-0.64673298,  0.24590522, -0.76221138, -0.55097013,  0.81194894]), 'meta_d': array([[ 0.231567  ,  0.13408993,  0.40570414,  0.07315806,  2.12151076],
       [ 0.74362127,  0.48064513, -1.22313962,  0.06923801,  1.18106203],
       [-0.71646526, -0.12164159,  0.10103707,  0.54650914,  0.41782817],
       [ 0.49333097, -0.02640675, -0.15896202,  0.72362567,  1.75806734],
       [ 0.48069895,  1.00924435, -1.10134591,  0.72821772,  0.91634603],
       [-0.22471118,  0.92205637, -0.52041942, -0.62118771,  1.22776933],
       [ 0.20576651,  1.41399844, -1.22800007, -0.51573014,  0.55531216],
       [-0.46901197,  0.12032874, -0.228032  , -0.64166266,  1.20404838],
       [ 0.03513162,  0.30715551, -0.12563093, -0.90667419,  0.49604409],
       [-0.08610721,  1.13532007, -0.61162705, -0.70004676,  1.52229054],
       [-0.32955034,  1.74235719, -0.27776982, -0.57494164,  1.69345108],
       [ 0.4838903 ,  0.74704914, -1.16160954, -0.31458794,  1.97427899],
       [-0.3093764 ,  1.21802907,  0.16651124, -0.22125455,  1.86471581],
       [-0.45113588,  1.39214339, -0.20294691,  0.72011344,  1.66745551],
       [ 0.18807813,  0.82347082, -0.31185724, -0.05985667,  1.35295662],
       [ 0.82151721,  0.89047128, -0.19559443, -0.23483639,  2.12436103],
       [ 0.68665684,  0.5181407 , -0.07954605, -0.16676602,  0.51635636],
       [ 0.75511555,  1.08964779,  0.42381269,  0.4849316 ,  0.6894075 ],
       [ 0.53404124,  0.34723021,  0.22130342,  0.05723246,  0.71506328],
       [ 0.7950206 ,  1.51546539, -0.05601881, -0.51066984,  1.7677972 ]]), 'cS1_hn_log__': array([[[ 0.51717868, -0.25375894,  0.68947006],
        [-0.85976022, -0.43425619, -0.91153015],
        [ 0.53402856,  0.2125602 , -0.78081866],
        [-0.51078527, -0.37732263, -0.99875683],
        [ 0.87788296, -0.60959447, -0.94086678]],

       [[ 0.19374248, -0.50314207, -0.77944489],
        [ 0.90487951, -0.60958645,  0.81455171],
        [ 0.24862346, -0.55897172,  0.78187192],
        [ 0.35445544, -0.32507202, -0.55579906],
        [ 0.36718197, -0.20063423, -0.4814253 ]],

       [[ 0.76913025, -0.90319348,  0.78172442],
        [-0.75722937, -0.66441525,  0.3837943 ],
        [-0.81104795, -0.1217847 ,  0.40819451],
        [ 0.76847535, -0.09367238, -0.39326715],
        [ 0.7500062 ,  0.22457957, -0.8623508 ]],

       [[-0.07467212,  0.97348677, -0.23045091],
        [-0.05937395,  0.24378578,  0.94078953],
        [-0.59614151, -0.27007707,  0.48209489],
        [-0.78187388,  0.09225218, -0.97300692],
        [-0.98677659,  0.42107417,  0.43693793]],

       [[-0.06155859,  0.86540996,  0.54066376],
        [-0.91876829, -0.98084993,  0.27628879],
        [-0.06254894,  0.79172487,  0.93845661],
        [ 0.21806513, -0.8727984 , -0.92475572],
        [ 0.64621207,  0.44290613, -0.73639298]],

       [[ 0.4138628 ,  0.91975097, -0.66224609],
        [ 0.02502639,  0.68924428, -0.84842468],
        [-0.87607787, -0.81636635,  0.79243038],
        [ 0.51424779, -0.25393616, -0.7094427 ],
        [ 0.70767252, -0.30302118,  0.87613389]],

       [[ 0.41450771,  0.4797352 ,  0.71663198],
        [-0.26481571, -0.54273706, -0.13183637],
        [ 0.80124535, -0.81760492, -0.84460112],
        [-0.23931685,  0.04882757, -0.18305058],
        [-0.8273003 ,  0.72694326,  0.84394579]],

       [[-0.79801842,  0.13695138, -0.54443462],
        [-0.20227226, -0.35005262, -0.21529137],
        [-0.09551046, -0.26167325, -0.14201092],
        [ 0.15809761, -0.18266572, -0.06954113],
        [-0.33655411, -0.95187343,  0.27961597]],

       [[ 0.02140124,  0.00543867,  0.73227281],
        [ 0.80083898, -0.58348465, -0.66466929],
        [ 0.69073869, -0.16817609,  0.66527711],
        [ 0.08005729, -0.89023506,  0.38823977],
        [-0.58846667,  0.2848531 , -0.58623478]],

       [[-0.53646794,  0.66104969,  0.34747469],
        [ 0.16285697,  0.49466082, -0.5529614 ],
        [-0.56323398, -0.96747192, -0.18576511],
        [ 0.3883785 ,  0.13988547, -0.22261981],
        [ 0.37832148, -0.14244202,  0.24830874]],

       [[-0.3182968 ,  0.45327953, -0.98417489],
        [ 0.50829866, -0.09732722,  0.34966443],
        [ 0.42356612, -0.35323322, -0.44475418],
        [ 0.76587544,  0.60953463,  0.04248326],
        [ 0.1967324 , -0.24468959, -0.74996378]],

       [[ 0.75994549,  0.95152953,  0.75915546],
        [-0.1669717 ,  0.27329527,  0.12080269],
        [-0.26757655,  0.47685487, -0.30450108],
        [ 0.26311679, -0.46242166,  0.84805561],
        [ 0.22339395,  0.67171873, -0.83861292]],

       [[-0.9839439 , -0.41405666, -0.99428918],
        [ 0.14996969,  0.46601583,  0.7935609 ],
        [ 0.26417994, -0.97605869,  0.68084008],
        [-0.33288066,  0.32966381,  0.78029228],
        [ 0.06947038,  0.15964048, -0.08045663]],

       [[-0.06684787,  0.92286115,  0.58171034],
        [ 0.88120927,  0.33829956, -0.13201883],
        [ 0.85663368, -0.01343116, -0.53947141],
        [-0.57945604,  0.70359235,  0.77782595],
        [ 0.83750099, -0.50726547, -0.01205957]],

       [[-0.97018979,  0.95686909,  0.73909673],
        [ 0.07647175, -0.19902978, -0.88494483],
        [ 0.22159784,  0.99739737,  0.68098979],
        [-0.08677807,  0.15806507,  0.19376193],
        [-0.96455155,  0.77147212,  0.52007124]],

       [[-0.76058713, -0.08818221,  0.92316504],
        [-0.53078737,  0.79764598,  0.32456365],
        [ 0.82808507, -0.59828962,  0.31601739],
        [ 0.19889662,  0.28084078,  0.90394111],
        [-0.3914067 , -0.92086791, -0.53997555]],

       [[-0.71357576, -0.83643316,  0.87033467],
        [-0.2354967 , -0.85493599,  0.50130529],
        [ 0.071005  ,  0.04276896, -0.31132649],
        [-0.41194575, -0.28861263,  0.64934809],
        [-0.5007465 , -0.30520457,  0.65030942]],

       [[-0.48442099,  0.99730046,  0.45227079],
        [-0.80478681, -0.66060227,  0.9640569 ],
        [-0.40721808, -0.65381967, -0.93274801],
        [ 0.21512027,  0.66603201, -0.39533187],
        [-0.5462721 , -0.4119909 , -0.41870648]],

       [[-0.72987402, -0.39344214,  0.87151135],
        [ 0.69919678,  0.3738528 ,  0.05178205],
        [ 0.55563689,  0.53015513,  0.12372975],
        [-0.00530384, -0.94577299, -0.76118103],
        [-0.6038747 ,  0.45674843,  0.71120282]],

       [[-0.27208289,  0.58054302, -0.31020689],
        [-0.85644525, -0.50350828,  0.55377986],
        [-0.44263926,  0.67086878, -0.97557581],
        [-0.25763163,  0.6590971 , -0.48121907],
        [-0.720391  ,  0.40570763,  0.40628193]]]), 'cS2_hn_log__': array([[[ 0.89202842, -0.34397873, -0.18881438],
        [-0.45398429,  0.61475278,  0.0189285 ],
        [ 0.34075085,  0.60598185,  0.12561371],
        [ 0.32384164, -0.7479509 , -0.96637048],
        [-0.90896579,  0.65718938,  0.74197947]],

       [[-0.1145236 ,  0.94625102,  0.13329269],
        [-0.42184613,  0.26137906, -0.07204344],
        [ 0.22471433, -0.42125159,  0.8983852 ],
        [ 0.62223558, -0.54592935, -0.59882646],
        [ 0.77162214, -0.29948305,  0.51315269]],

       [[ 0.61063872,  0.4438244 ,  0.39168955],
        [ 0.12193155, -0.41673584, -0.53490565],
        [-0.67084898,  0.95288712, -0.22463703],
        [-0.61280437, -0.8475317 ,  0.39004677],
        [-0.7657975 , -0.8744204 ,  0.68358893]],

       [[ 0.93032781, -0.4064417 ,  0.73168369],
        [ 0.12983857,  0.46565042, -0.08381385],
        [-0.28174873,  0.73171207,  0.45322102],
        [ 0.48513649,  0.39350561, -0.32359483],
        [-0.61867336, -0.75400248,  0.10217243]],

       [[-0.87730259, -0.85718866, -0.0355113 ],
        [-0.89604026, -0.41655796, -0.96183526],
        [ 0.85683433, -0.43435358,  0.41898985],
        [-0.69005288,  0.25510429, -0.82675906],
        [-0.5135225 , -0.15638321, -0.01260582]],

       [[ 0.10379008,  0.29225401, -0.5946698 ],
        [ 0.42470436, -0.90087423,  0.68919703],
        [ 0.9514079 , -0.21449985,  0.5043796 ],
        [-0.52383682, -0.88017823, -0.36449815],
        [-0.38175939, -0.58769252,  0.25464269]],

       [[ 0.72464465, -0.3346732 ,  0.7042992 ],
        [-0.7626376 , -0.2663362 , -0.57057686],
        [-0.88739177, -0.44920063, -0.17209387],
        [-0.3568623 ,  0.51185919, -0.09986633],
        [-0.49857045, -0.99152259,  0.52776458]],

       [[-0.44695981,  0.93082801, -0.23174429],
        [ 0.89267593,  0.55462375, -0.5021282 ],
        [ 0.89985542, -0.78474215,  0.37215182],
        [-0.74260432, -0.40922237,  0.8494274 ],
        [ 0.10578043,  0.77756467,  0.96372968]],

       [[-0.8103314 ,  0.32474805,  0.48540617],
        [ 0.42244286,  0.09028272,  0.24662216],
        [-0.61149475,  0.90115938,  0.3236237 ],
        [-0.26432614,  0.90243228, -0.14583202],
        [ 0.80098743, -0.99876903, -0.50075835]],

       [[ 0.66972008,  0.16459021,  0.10196985],
        [ 0.2130512 ,  0.82322914,  0.39533194],
        [ 0.11484619, -0.34627897, -0.5288459 ],
        [ 0.62530286,  0.9923042 , -0.30224995],
        [-0.27511032, -0.75406878, -0.39509431]],

       [[ 0.2186081 ,  0.24826595,  0.72741577],
        [ 0.25358646,  0.44822589,  0.69903533],
        [-0.52637327, -0.02515031,  0.19833494],
        [ 0.12070856, -0.10573947, -0.15438643],
        [-0.55476868, -0.23502974, -0.81446486]],

       [[ 0.92020237, -0.95896954, -0.9492946 ],
        [-0.20961011, -0.66069116, -0.7455996 ],
        [ 0.87983372,  0.7570718 ,  0.08543895],
        [-0.79281014, -0.84790257,  0.86806865],
        [-0.68999495, -0.91383864,  0.34075375]],

       [[ 0.3100884 , -0.53484827, -0.21989245],
        [ 0.15828634, -0.36875642, -0.46994264],
        [-0.09970053,  0.18148514, -0.09079875],
        [-0.12852389,  0.52027152,  0.99538526],
        [-0.46024732, -0.0549999 ,  0.27539407]],

       [[-0.89921361, -0.45144179,  0.7530228 ],
        [-0.81874585,  0.583342  , -0.75137567],
        [ 0.7354828 ,  0.04066658, -0.79475253],
        [-0.5045665 ,  0.82701141,  0.97414424],
        [-0.77025616,  0.2657258 , -0.59543423]],

       [[-0.77353318,  0.44435306, -0.97912743],
        [-0.78189484,  0.46923128, -0.06106454],
        [ 0.15436734,  0.82084372, -0.25465398],
        [ 0.36770289, -0.04372776,  0.23961174],
        [ 0.52793116, -0.75484551, -0.73692791]],

       [[-0.81698734,  0.52556803,  0.43718977],
        [ 0.07040013,  0.08524615,  0.78132547],
        [-0.12348076,  0.91335295, -0.86823869],
        [ 0.48340196,  0.55383399, -0.01666316],
        [ 0.33169329, -0.28367585,  0.87616486]],

       [[-0.10224197,  0.46823122,  0.74484368],
        [ 0.86080202,  0.91950463, -0.71634674],
        [-0.16804781,  0.19367331, -0.46768866],
        [ 0.21995176, -0.93820073, -0.6347216 ],
        [ 0.98548702,  0.88004774,  0.93079373]],

       [[ 0.34812969, -0.04871295,  0.44462982],
        [-0.44187664,  0.62055588, -0.52097393],
        [-0.51343478,  0.71020096,  0.06950624],
        [ 0.81651469,  0.37860778, -0.47223691],
        [ 0.03985006,  0.81150637,  0.00258419]],

       [[ 0.254715  ,  0.94546758, -0.13365157],
        [ 0.22352677, -0.54269785, -0.94629483],
        [-0.96035907, -0.52128461, -0.92941125],
        [-0.79675926, -0.71035268, -0.68606377],
        [ 0.61634162, -0.62307385,  0.0629763 ]],

       [[ 0.89316184,  0.51442253,  0.4276674 ],
        [-0.28765026,  0.99036898, -0.76334477],
        [ 0.33360978, -0.44018035,  0.49516502],
        [-0.9314654 , -0.75872115,  0.56314406],
        [ 0.36627858,  0.95932099, -0.62451662]]])}

Logp initial evaluation results:
{'mu_d1': np.float64(-0.93), 'sigma_d1': np.float64(-0.74), 'd1': np.float64(-115.97), 'mu_c1': np.float64(-0.92), 'sigma_c1': np.float64(-1.11), 'c1': np.float64(-91.65), 'mu_meta_d': np.float64(-0.98), 'sigma_meta_d': np.float64(-1.07), 'condition_offset_raw': np.float64(-5.61), 'meta_d': np.float64(-146.3), 'cS1_hn': np.float64(-344.11), 'cS2_hn': np.float64(-349.53), 'H': np.float64(-533.95), 'FA': np.float64(-948.16), 'CR_counts': np.float64(-inf), 'FA_counts': np.float64(-inf), 'M_counts': np.float64(-inf), 'H_counts': np.float64(-inf)}
You can call `model.debug()` for more details.

In [None]:
summary_low = pd.concat([mle_group_low, pooled_mle_low, pooled_bayes_low, hb_summary_low], ignore_index=True)
summary_low = summary_low.sort_values(["method", "Condition"])

fig, ax = plt.subplots(figsize=(8, 4))
for method in methods:
    sub = summary_low[summary_low["method"] == method].sort_values("Condition")
    yerr = np.vstack([sub["mean"] - sub["lower"], sub["upper"] - sub["mean"]])
    ax.errorbar(
        sub["Condition"],
        sub["mean"],
        yerr=yerr,
        label=method,
        marker=markers[method],
        capsize=3,
    )

ax.set_xlabel("Confidence SNR (1 / noise)")
ax.set_ylabel("Group m_ratio")
ax.set_title("Group-level meta-d'/d' with low trials per subject")
ax.legend(frameon=False)
fig.savefig("notebooks/group_level_methods_comparison_low_trials.png", dpi=150, bbox_inches="tight")
plt.show()


In [None]:
truth_rows_low = []
for snr, mratio in zip(snr_levels, effective_mratio):
    mc_vals = []
    for rep in range(mc_reps):
        sim_truth = responseSimulation(
            d=d,
            metad=mratio * d,
            c=c,
            nRatings=n_ratings,
            nTrials=n_trials_low,
            nSubjects=n_subjects_low,
        )
        nR_S1_pool, nR_S2_pool = trials2counts(
            data=sim_truth,
            stimuli="Stimuli",
            accuracy="Accuracy",
            confidence="Confidence",
            nRatings=n_ratings,
            padding=True,
        )
        try:
            res = metad(
                nR_S1=nR_S1_pool,
                nR_S2=nR_S2_pool,
                nRatings=n_ratings,
                padding=True,
                verbose=0,
            )
            val = float(res["m_ratio"].iloc[0])
        except Exception:
            val = np.nan
        if np.isfinite(val):
            mc_vals.append(val)
    truth_rows_low.append(
        {
            "Condition": float(snr),
            "truth_mratio": float(np.mean(mc_vals)) if mc_vals else np.nan,
        }
    )

truth_df_low = pd.DataFrame(truth_rows_low)
summary_with_truth_low = summary_low.merge(truth_df_low, on="Condition", how="left")
summary_with_truth_low["diff"] = summary_with_truth_low["mean"] - summary_with_truth_low["truth_mratio"]
summary_with_truth_low["diff_lower"] = summary_with_truth_low["lower"] - summary_with_truth_low["truth_mratio"]
summary_with_truth_low["diff_upper"] = summary_with_truth_low["upper"] - summary_with_truth_low["truth_mratio"]
summary_with_truth_low


In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
for method in methods:
    sub = summary_with_truth_low[summary_with_truth_low["method"] == method].sort_values("Condition")
    yerr = np.vstack([sub["diff"] - sub["diff_lower"], sub["diff_upper"] - sub["diff"]])
    ax.errorbar(
        sub["Condition"],
        sub["diff"],
        yerr=yerr,
        label=method,
        marker=markers[method],
        capsize=3,
    )

ax.axhline(0, color="black", linewidth=1, alpha=0.5)
ax.set_xlabel("Confidence SNR (1 / noise)")
ax.set_ylabel("Estimate - ground truth (m_ratio)")
ax.set_title("Bias relative to ground truth (low trials)")
ax.legend(frameon=False)
fig.savefig("notebooks/group_level_methods_comparison_bias_low_trials.png", dpi=150, bbox_inches="tight")
plt.show()
