In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import keras
os.environ["KERAS_BACKEND"] = "tensorflow"
import bayesflow as bf
import pickle
import EZ2

In [None]:
def prior():
  params = {}

  # Drift rates v toward left and right responses
  params['vL'] = np.random.uniform(0.1, 6.0)
  params['vR'] = np.random.uniform(0.1, 6.0)

  # Boundary separation a
  params['a'] = np.random.uniform(0.3, 4.0)

  # Relative starting point z
  params['z'] = np.random.uniform(0.1, 0.9)

  # Non-decision times ter for left and right responses (in seconds)
  params['terL'] = np.random.uniform(0.1, 1.0)
  params['terR'] = np.random.uniform(0.1, 1.0)

  return params

In [None]:
# Helper functions

def _p_bottom_safe(nu, z, a, s=0.1):
    """Bottom-hit probability with a fallback for near-zero nu."""
    if abs(nu) < 1e-12:
        return (a - z) / a
    s2 = s * s
    num = np.exp(-2*a*nu/s2) - np.exp(-2*z*nu/s2)
    den = np.exp(-2*a*nu/s2) - 1.0
    return float(np.clip(num / den, 0.0, 1.0))

def _safe_rddexit(size, nu, z, a, top_boundary):
    """Call EZ2.rddexit and always return a list (even for size==1)."""
    if size <= 0:
        return []
    arr = EZ2.rddexit(size, nu, z, a, top_boundary=top_boundary)
    if np.isscalar(arr):
        return [float(arr)]
    return [float(x) for x in np.asarray(arr).ravel()]

def _sample_times(size, nu, z, a, s=0.1, rng=None):
    """Safer equivalent of rddexitj using robust fallbacks."""
    if rng is None:
        rng = np.random.default_rng()
    p0 = _p_bottom_safe(nu, z, a, s=s)
    n_bottom = rng.binomial(size, p0)
    n_top = size - n_bottom
    et_bottom = _safe_rddexit(n_bottom, nu, z, a, top_boundary=False)
    et_top    = _safe_rddexit(n_top,    nu, z, a, top_boundary=True)
    return et_bottom, et_top

# Forward model

def forward_model_ez2(
    vL, vR, a, z, terL, terR, n_trials=200, rng=None,
    rt_transform="log1p" # "log1p" or "none"
):
    if rng is None:
        rng = np.random.default_rng()

    # scale to s=0.1 used in functions EZ2
    c = 0.1
    vL_ez, vR_ez = float(vL)*c, float(vR)*c
    a_ez = float(a)*c

    # converting relative z to absolute z
    z_abs = float(z) * a_ez
    eps = 1e-9 * a_ez
    z_abs = min(max(z_abs, eps), a_ez - eps)

    nA = n_trials // 2
    nB = n_trials - nA

    # A condition (Left correct): top->Left(0), bottom->Right(1)
    et_b_A, et_t_A = _sample_times(nA, vL_ez, z_abs, a_ez, s=0.1, rng=rng)
    et_b_A = np.asarray(et_b_A, dtype=np.float64)
    et_t_A = np.asarray(et_t_A, dtype=np.float64)
    nAb, nAt = et_b_A.size, et_t_A.size
    dts_A = np.empty(nA, dtype=np.float64); dts_A[:nAb] = et_b_A; dts_A[nAb:] = et_t_A
    choices_A = np.empty(nA, dtype=np.int64); choices_A[:nAb] = 1; choices_A[nAb:] = 0
    correct_A = np.empty(nA, dtype=np.int64); correct_A[:nAb] = 0; correct_A[nAb:] = 1
    stim_A = np.zeros(nA, dtype=np.int64)

    # B condition (Right correct): top->Right(1), bottom->Left(0)
    et_b_B, et_t_B = _sample_times(nB, vR_ez, a_ez - z_abs, a_ez, s=0.1, rng=rng)
    et_b_B = np.asarray(et_b_B, dtype=np.float64)
    et_t_B = np.asarray(et_t_B, dtype=np.float64)
    nBb, nBt = et_b_B.size, et_t_B.size
    dts_B = np.empty(nB, dtype=np.float64); dts_B[:nBb] = et_b_B; dts_B[nBb:] = et_t_B
    choices_B = np.empty(nB, dtype=np.int64); choices_B[:nBb] = 0; choices_B[nBb:] = 1
    correct_B = np.empty(nB, dtype=np.int64); correct_B[:nBb] = 0; correct_B[nBb:] = 1
    stim_B = np.ones(nB, dtype=np.int64)

    dts = np.concatenate([dts_A, dts_B])
    choices = np.concatenate([choices_A, choices_B])
    correct = np.concatenate([correct_A, correct_B])
    stimulus = np.concatenate([stim_A, stim_B])

    perm = rng.permutation(n_trials)
    dts, choices, correct, stimulus = dts[perm], choices[perm], correct[perm], stimulus[perm]

    # Add ter and optional log-transform
    rts = dts + np.where(choices == 0, terL, terR)
    if rt_transform == "log1p":
        rts = np.log1p(rts)

    return {
        "rts": rts.astype(np.float32),
        "choices": choices,
        "stimulus": stimulus,
        "correct": correct,
    }


In [None]:
simulator = bf.make_simulator([prior, forward_model_ez2])

In [None]:
param_names = ['vL', 'vR', 'a', 'z', 'terL', 'terR']
data_names = ['rts', 'stimulus', 'choices']  # Removed 'correct' for redundancy of information
# 'correct' can be derived from 'choices' and 'stimulus' so we it's not needed

adapter = (
    bf.adapters.Adapter()
    .keep(param_names + data_names)
    .to_array()
    .convert_dtype("float64", "float32")
    .expand_dims("rts", axis=-1)
    .expand_dims("choices", axis=-1)
    .expand_dims("stimulus", axis=-1)
    .concatenate(param_names, into="inference_variables")
    .concatenate(data_names, into="summary_variables")
)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, BackupAndRestore

es = EarlyStopping(
    monitor="loss",
    mode="min",
    min_delta=0.01,
    patience=20,
    restore_best_weights=True,
    start_from_epoch=5
)

callbacks = [
    BackupAndRestore(backup_dir="./standard_backup"),
    ModelCheckpoint("standard_model_ckpt.keras", monitor="loss", save_best_only=True),
    es
]

In [None]:
from bayesflow.networks import CouplingFlow, DeepSet
from bayesflow.workflows import BasicWorkflow

summary_net = DeepSet(
    summary_dim=16,
    dropout=0.1
)

flow = CouplingFlow(
    num_coupling_layers=6,
    hidden_units=[128, 128],
    coupling_type="spline",
    batch_norm=True,
    dropout=0.05,
    tail_bound=6.0
)

wf = BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=summary_net,
    inference_network=flow,
    inference_variables=["inference_variables"],
    summary_variables=["summary_variables"]
)


In [None]:
history = wf.fit_online(
    epochs = 2000,
    num_batches_per_epoch = 150,
    batch_size = 64,
    callbacks = callbacks
)

In [None]:
from pathlib import Path

out = Path.cwd() / "standard_model.keras"
out.parent.mkdir(parents=True, exist_ok=True)
wf.approximator.save(out)

In [None]:
f = bf.diagnostics.plots.loss(history)

In [None]:
num_samples = 1000

# Simulate validation data (unseen during training)
val_sims = simulator.sample(200)

# Obtain num_samples samples of the parameter posterior for every validation dataset
post_draws = wf.sample(conditions=val_sims, num_samples=num_samples)

f = bf.diagnostics.plots.recovery(
    estimates=post_draws,
    targets=val_sims,
    variable_names=param_names
)