In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import keras
os.environ["KERAS_BACKEND"] = "tensorflow"
import bayesflow as bf
import pickle
import EZ2

In [None]:
def prior():
  params = {}

  # Drift rates v toward left and right responses
  params['vL'] = np.random.uniform(0.1, 6.0)
  params['vR'] = np.random.uniform(0.1, 6.0)

  # Boundary separation a
  params['a'] = np.random.uniform(0.3, 4.0)

  # Relative starting point z
  params['z'] = np.random.uniform(0.1, 0.9)

  # Non-decision times ter for left and right responses (in seconds)
  params['terL'] = np.random.uniform(0.1, 1.0)
  params['terR'] = np.random.uniform(0.1, 1.0)

  return params

In [None]:
def forward_model_ez(vL, vR, a, z, terL, terR):
    z_abs = z * a # convert relative z to absolute z

    mrtR = EZ2.cmrt(vR, z_abs, a, s=1) + terR
    vrtR = EZ2.cvrt(vR, z_abs, a, s=1)
    peR = EZ2.pe(vR, z_abs, a, s=1)

    mrtL = EZ2.cmrt(vL, a - z_abs, a, s=1) + terL
    vrtL = EZ2.cvrt(vL, a - z_abs, a, s=1)
    peL = EZ2.pe(vL, a - z_abs, a, s=1)

    return {
        'mrtL': mrtL,
        'vrtL': vrtL,
        'peL':  peL,
        'mrtR': mrtR,
        'vrtR': vrtR,
        'peR':  peR
    }

In [None]:
simulator = bf.make_simulator([prior, forward_model_ez])

In [None]:
par_names = ['vL', 'vR', 'a', 'z', 'terL', 'terR']
data_names = ['mrtL', 'vrtL', 'peL', 'mrtR', 'vrtR', 'peR']

adapter = (
    bf.adapters.Adapter()
    .keep(par_names + data_names)
    .to_array()
    .convert_dtype("float64", "float32")
    .concatenate(par_names, into="inference_variables")
    .concatenate(data_names, into="summary_variables")
)

In [None]:
from keras import Model, Input
from keras.layers import Layer

# Define a simple identity summary network:
# Since the training data already consists of summary statistics,
# we pass them via an identity network.
class IdentitySummaryNet(Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs):
        return inputs

    def compute_metrics(self, inputs, stage=None):
        return {"outputs": self(inputs)}

    def compute_output_shape(self, input_shape):
        return input_shape

summary_net = IdentitySummaryNet()

from bayesflow.networks import CouplingFlow
from bayesflow.workflows import BasicWorkflow

flow = CouplingFlow(
    num_coupling_layers=6,
    hidden_units=[128, 128],
    coupling_type="spline",
    batch_norm=True,
    dropout=0.05,
    tail_bound=5.0
)

wf = BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    summary_network=summary_net,
    inference_network=flow,
    inference_variables=["inference_variables"],
    summary_variables=["summary_variables"],
    standardize=["summary_variables"]
)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, BackupAndRestore

es = EarlyStopping(
    monitor="loss",
    min_delta=0.001,
    patience=20,
    restore_best_weights=True
)

callbacks = [
    BackupAndRestore(backup_dir="./ez_backup"),
    ModelCheckpoint("ez_model_ckpt.keras", monitor="loss", save_best_only=True),
    es
]

In [None]:
history = wf.fit_online(
    epochs = 2000,
    num_batches_per_epoch = 200,
    batch_size = 64,
    callbacks = [ckpt, es, csv_logger]
)

In [None]:
f = bf.diagnostics.plots.loss(history)

In [None]:
num_samples = 1000

# Simulate validation data (unseen during training)
val_sims = simulator.sample(200)

# Obtain num_samples samples of the parameter posterior for every validation dataset
post_draws = wf.sample(conditions=val_sims, num_samples=num_samples)

f = bf.diagnostics.plots.recovery(
    estimates=post_draws,
    targets=val_sims,
    variable_names=par_names
)

In [None]:
from pathlib import Path
out = Path.cwd() / "standard_model.keras"
out.parent.mkdir(exist_ok=True, parents=True)
wf.approximator.save(out)