## MNLE inference with varying numbers of trials

This notebook presents are short summary of the code needed to train MNLE on the DDM and then perform inference with MCMC. 

The MNLE code itself and training and mcmc utils can be found in `mnle_utils.py` in this folder. 

If you have any questions please create an issue in the repository.

In [1]:
import pickle
import sbibm
import sbi

from mnle_utils import BernoulliMN, MNLE, train_choice_net

# Plotting settings
plt.style.use('plotting_settings.mplstyle')
# Colorblind color palette
colors = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']

In [4]:
# Loading DDM simulator and prior from sbibm framework.
task = sbibm.get_task("ddm")
prior = task.get_prior_dist()
simulator = task.get_simulator()

In [6]:
# Generate training data (local Julia installation required) or load from disk.
julia_available = False

if julia_available:
    N = 100000
    theta = prior.sample((N,))
    x = simulator(theta)
else:
    with open("ddm_training_data.p", "rb") as fh:
        theta, x, *_ = pickle.load(fh).values()
        
# The DDM simulator returns choices encoded as sign of the reaction times, decode:
rts = abs(x)
choices = torch.ones_like(x)
choices[x < 0] = 0
# Concatenate theta and choices for conditional flow training below.
theta_and_choices = torch.cat((theta, choices), dim=1)

FileNotFoundError: [Errno 2] No such file or directory: 'ddm_training_data.p'

## Hyperparameters

In [9]:
num_hidden_layers = 3
num_hidden_units = 10
validation_fraction = 0.1
stop_after_epochs = 20
training_batch_size = 100

# for neural spline flow
use_log_rts = True
num_transforms = 2
num_bins = 5
base_distribution = "gaussian"
tails = "linear"
tail_bound = 10
tail_bound_eps = 1e-7


## Train separate likelihood estimators for choices and reaction times

In [None]:
# train choice net.
choice_net, vallp = train_choice_net(
    theta,
    choices,
    # set up NN to learn Bernoulli probs over choices.
    net=BernoulliMN(
        n_hidden_layers=num_hidden_layers, n_hidden_units=num_hidden_units
    ),
    validation_fraction=validation_fraction,
    stop_after_epochs=stop_after_epochs,
    batch_size=training_batch_size,
)

## train flow using sbi routines.
# construct the density estimator.
density_estimator_fun = likelihood_nn(
    model="nsf",
    num_transforms=num_transforms,
    hidden_features=num_hidden_units,
    num_bins=num_bins,
    base_distribution=base_distribution,
    tails=tails,
    tail_bound=tail_bound,
    tail_bound_eps=tail_bound_eps,
    num_hidden_spline_context_layers=num_hidden_layers,
)

# set up sbi training object.
inference_method = inference.SNLE(
    density_estimator=density_estimator_fun,
    prior=prior,
)
# append data and train
inference_method = inference_method.append_simulations(
    theta=theta_and_choices,
    x=torch.log(rts) if use_log_rts else rts,
    from_round=0,
)
rt_flow = inference_method.train(
    training_batch_size=training_batch_size,
    show_train_summary=False,
    stop_after_epochs=stop_after_epochs,
)

mnle = MNLE(choice_net, rt_flow, use_log_rts=use_log_rts)

## Visualize learned likelihood estimate

In [None]:
l_lower_bound = 1e-7
test_theta = prior.sample((1,))
test_data = np.arange(-5, 5, 1000)
# Separate rts and choices.
rts = abs(test_data)
cs = torch.ones_like(test_data)
cs[x < 0] = 0


analytical_likelihoods = torch.tensor([task.get_log_likelihood(test_theta, 
                                                               test_rt.reshape(-1, 1), 
                                                               l_lower_bound=l_lower_bound) 
                                       for test_rt in test_rts])

mnle_likelihoods = torch.tensor([mnle.log_prob(r.reshape(-1, 1), 
                                               c.reshape(-1, 1), 
                                               theta_o)
                                 for r, c in zip(rs, cs)])

In [None]:
plt.figure(figsize=15, 5)
plt.plot(test_data, analytical_likelihoods.exp(), label="Analytical L", c=colors[0]);
plt.plot(test_data, mnle_likelihoods.exp(), label="MNLE", ls="-", c=colors[2]);
plt.ylabel(r"$L(x | \theta)$");
plt.xlabel("reaction time [s]")
plt.legend()

## Run inference with MCMC

In [None]:
# Load observations with different numbers of trials from DDM benchmark
xo1 = task.get_observation(1)
xo10 = task.get_observation(101)
xo100 = task.get_observation(201)
num_samples = 1000

potential_fun = mnle.get_potential_fn(data=xo1, prior=prior, transforms=task.get_transforms())
mcmc_parameters = dict(num_chains=10, warmup_steps=100, thin=10, init_strategy="prior")
posterior_samples = run_mcmc(prior, potential_fun, mcmc_parameters, num_samples)

## Visualize posteriors

In [None]:
obs = 1  # change to 101 or 201 depending on observation chosen above.
reference_posterior_samples = task.get_reference_posterior_samples(obs)

fig, ax1 = pairplot([reference_posterior_samples, posterior_samples],
         points=sbibm.get_task("ddm").get_true_parameters(obs), 
         limits=[[-2, 2], [0.5, 2.0], [.3, .7], [.2, 1.8]], 
         ticks = [[-2, 2], [0.5, 2.0], [.3, .7], [.2, 1.8]], 
         samples_colors=colors[:2], 
         diag="kde",
         upper="contour",
         kde_offdiag=dict(bw_method="scott", bins=50),
         contour_upper=dict(levels=[0.1], percentile=False),
         points_offdiag=dict(marker="+", markersize=10), 
         points_colors=["k"], 
         labels=[r"$v$", r"$a$", r"$w$", r"$\tau$"])

plt.sca(ax1[0, 0])
plt.legend(["Reference", "MNLE", r"Ground truth $\theta$"], 
           bbox_to_anchor=(-.1, -2.2), 
           loc=2)