# Parameter inference in models of decision making

Welcome to week 3 of module 1! 

As discussed in the lecture, a central challenge in mechanistic modeling is to identify parameters that are in agreement with model and data using Bayesian inference. This is especially challenging in mechanistic models, since the likelihood function is commonly intractable. 

After working with a complex mechanistic model for the last two weeks, this week we will turn to simpler mechanistic model already discussed in the lecture - the drift diffusion model. We chose this simple mechanistic model because it is very fast in generating data (especially compared to the Wong & Wang model from last week). This will enable you to try out different methods for Bayesian parameter inference yourself: A classical approach called rejection ABC and a more modern one, which is based on density estimation.

## A simple mechanistic model of decision making: The Drift diffusion model (DDM)

We will use the DDM as a simple example:

The DDM simulates a perceptual decision making process in a [two alternative forced choice experiment (2AFC)](https://en.wikipedia.org/wiki/Two-alternative_forced_choice) with a single scalar variable $x$. For example, the the task for the subject could be to report the direction of movement of a cloud of points in which a certain amount of points move left or right. 

You can think of $x$ as the sensory evidence towards one or the other choice: it starts at a neutral position, e.g., at zero, and moves up or down following a mean drift $\mu$ (the first parameter of the model) plus some noise (external and internal noise) $\sigma$ (the second parameter of the model) and a decision towards one or the other side is made when a pre-defined decision boundary is hit.

In a way it is performing a random walk with a drift towards one or the other choice ([more details](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process)), that is why it is called drift ($\mu$) and diffusion ($\sigma$) model. There are more reasons for this name, the interested reader is referred to [here](https://en.wikipedia.org/wiki/Fokker%E2%80%93Planck_equation).

Formally, the drift diffusion model is defined through the following equation: 
$$ 
\begin{align}
\text{d}X = \mu \text{dt} + \sigma \text{dW}
\end{align}
$$
where $X$ is the decision variable, $\mu$ the drift, and $\sigma$ the diffusion parameter (effectively scaling the variance of the noise coming from the [Wiener process](https://en.wikipedia.org/wiki/Wiener_process) W). This equation is a stochastic differential equation (SDE) with no closed-form solution. Therefore, we must use a SDE solver to generate data: given a drift $\mu$ and a diffusion parameter $\sigma$ we integrate the equation for a given number of time steps and obtain the trace of $X$. Then the reaction time is given by the time when $X$ crossed the predefined decision boundary and the decision is given by the sign of the $X$ at that time. 

- You can find a short introduction to the DDM model in the Neural Dynamics book (16.4.2): https://neuronaldynamics.epfl.ch/online/Ch16.S4.html Note that they use a slightly different notation. 
- There are three additional packages you need to install in your conda environment for this exercise: 
  - To simulate the DDM, please install `sdeint`, e.g., with `pip install sdeint`
  - For neural network training, install pytorch, following [this guide](https://pytorch.org/get-started/) and selecting your platform, `Conda`, and `None` for the `CUDA` option. We also recommend getting familiar with the syntax of PyTorch, which is very close to NumPy, e.g., by following this tutorial: https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html
  - [tqdm](https://github.com/tqdm/tqdm) for progress bars: `conda install tqdm`

## Exercise 1: Understand the model.

To complete the exercises in this notebook we provide you with several function needed to generate data with the DDM. It is important that you understand what is going on in all of these functions.

**Overall goal**: Understand the code, how the model generates data and how summary statistics are calculated. 

**Hint**: The functions all depend on each other. It can be helpful to execute each function with the inputs generated from other function in order to understand what they are doing. Have a look at exercises 1 b) and c) for an example.

a) Read the code carefully and write [docstrings](https://www.datacamp.com/community/tutorials/docstrings-python)  and additional explanatory inline comments for all the functions. You can use any format for the docstring you like, [here is a style suggestion](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). The docstrings should document what goes into each function and what is returned.

In [None]:
# Make sure to install the required packages so that all imports are available.
import numpy as np
import sdeint
import matplotlib.pyplot as plt
%matplotlib inline
import torch
torch.set_default_tensor_type(torch.FloatTensor)  # Calculate on CPU rather than GPU per default
import tqdm

#### Functions for the DDM

To run inference with the model, we need two functions, prior (the prior over parameters) and simulator (the model): 

In [None]:
def prior(num_samples=1):
    dist = torch.distributions.Uniform(low=-2.0, high=2.0).expand([1])
    return dist.sample((num_samples,))


def simulator(parameters, num_trials=100, sigma=0.2, T=5, y0=0, a=1.0, return_traces=False, nbins=20, dt=0.01):
    
    if type(parameters) == float:
        parameters = torch.tensor([parameters])
    if parameters.ndim == 1:
        parameters = parameters.reshape(1, -1)
        
    # number of parameters to simulate
    num_parameters = parameters.shape[0]
    
    x_traces = []
    stats = []

    # run simulation one by one as this is faster than batching in sdeint.
    for idx_param in range(num_parameters): 
        
        num_simulations = num_trials
        mu = parameters[idx_param].numpy()

        # this repeats each element in mu num_trials times.
        mu_expanded = np.repeat(mu, num_trials)

        tspan = np.linspace(0.0, T, int(T/dt))
        y0_expanded = y0 * np.ones(num_simulations)

        def f(x, t):
            return np.eye(num_simulations).dot(mu_expanded)

        def G(x,t):
            return sigma * np.eye(num_simulations)
    
        x_traces_param = sdeint.itoEuler(f, G, y0_expanded, tspan).T

        rts, decisions = find_rts_and_decisions(x_traces_param, 1, num_trials, tspan, a=a)
        stats_param = calculate_histogram_stats(rts, decisions, T=T, nbins=nbins)
        
        # collect result as torch tensors (to be able to use PyTorch Neural Nets later on)
        x_traces.append(torch.as_tensor(x_traces_param, dtype=torch.float32))
        stats.append(torch.as_tensor(stats_param, dtype=torch.float32))
    
    x_traces = torch.cat(x_traces)
    stats =  torch.cat(stats)
    
    if return_traces: 
        return x_traces, stats
    else:
        return stats


def find_rts_and_decisions(x_traces, num_parameters, num_trials, tspan, a=1.0): 
    
    assert x_traces.shape[0] == num_parameters * num_trials
    
    rts = []
    decisions = []
    
    for param_idx in range(num_parameters):
        
        # get traces for this param.
        x_traces_param = x_traces[param_idx * num_trials: (param_idx + 1)*num_trials, ]
            
        # find crossing of decision threshold
        rows, _ = ((abs(x_traces_param) >= a)).nonzero()
        # the unique rows are the ones with decisions. 
        unique_rows = np.unique(rows)
        
        # find undecisive trials
        undecided_idx = [i for i in range(num_trials) if i not in unique_rows]
        # enforce decision in last time bin. 
        x_traces_param[undecided_idx, -1] = a * np.sign(x_traces_param[undecided_idx, -1])
        
        # now search again for decision threshold, all trials are decisive now.
        rows, cols = ((abs(x_traces_param) >= a)).nonzero()
        # get first indices for every trial
        unique_rows, trial_idx = np.unique(rows, return_index=True)
        
        # get first decision time bin idx for every trial
        decision_idx = cols[trial_idx]
        
        # find proportion of up decisions by looking at sign at decision idx.
        decisions_param = np.sign(x_traces_param[unique_rows, decision_idx])
        rt_param = tspan[decision_idx]

        # mark the direction of the decision by the multiplying the rt with the sign.
        rts.append(rt_param)
        decisions.append(decisions_param)
        
    return np.array(rts), np.array(decisions)


def calculate_histogram_stats(rts, decisions, T, nbins=20): 
    # mark rts with decision direction sign 
    sign_rts = rts * decisions

    # fixed bins
    bins = np.linspace(-T, T, nbins)

    counts = []
    for trial_idx in range(sign_rts.shape[0]): 
        # count for every trial
        trial_count, *_ = np.histogram(sign_rts[trial_idx, :], bins=bins)
        counts.append(trial_count)
    
    return np.array(counts)

## Exercise 1 continued:

Look at the two cells below. What is happening here? 

b) Describe what is happening in the first cell below.

c) Describe the figure in the second cell below. Change the title and labels of the figure accordingly.

In [None]:
num_trials, T, dt = 50, 5, 0.01
x_traces, _ = simulator(torch.tensor([[-0.3], [0.3]]), num_trials=num_trials, T=T, dt=dt, return_traces=True)

In [None]:
plt.figure(figsize=[18, 5])
tspan = np.linspace(0, T, int(T/dt))

plt.plot(tspan, x_traces[:num_trials, :].numpy().T, c="C0")
plt.plot(tspan, x_traces[num_trials:, :].numpy().T, c="C1")
plt.title("Title", fontsize=20)
plt.ylabel('y label', fontsize=20)
plt.xlabel('x label [unit]', fontsize=20);

c) We defined two functions, `prior` and `simulator`. The prior returns draws from a uniform prior in [-2.0, +2.0] over parameters. Look at the three lines of code in the cell below. What is happening here? What does the output of the simulator represent?

In [None]:
param = prior(1)
stats = simulator(param)
stats

## Exercise 2: Parameter inference with rejection ABC

In the lecture and the tutorial you learned about rejection ABC, the root of all simulation-based inference algorithms (more details in the lecture slides or [here](https://en.wikipedia.org/wiki/Approximate_Bayesian_computation#The_ABC_rejection_algorithm). Here, you will implement this algorithm and use it do inference over the drift parameter $\mu$ of the DDM, given observed data $s_o$. 

**Overall goal**: Obtain 100 samples from the approximate posterior using rejection ABC.

a) Write a function called `rejection_abc` with the function signature as in the cell below. The function takes as arguments a prior, simulator, observed_data, a simulation budget, a distance function and a quantile. The quantile is used to select the parameters according to the top quantile of the sorted distances. We provide the arguments and the return of the function, and the distance function. Make sure to write docstrings and comments in your function. (**Hint**: you do not need loops to implement this algorithm, you can pass all parameters at once to the simulator, or to the distance function.)

b) Test your function with small simulation budgets to save time. Later, for the final inference, we suggest a budget of around 10000 simulations (this can take a couple of minutes to run). 

c) Run inference using the observed data and obtain 100 samples from the posterior (quantile=0.01 for a budget of 10000 samples). 

d) Make a pretty figure with title, axis labels and good fontsizes etc showing the histogram of accepted posterior samples. Plot the ground truth parameter $\mu_o$ as vertical line. Are the posterior samples close to the ground truth? 

### The observed data: 

In [None]:
s_o = torch.tensor([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 13., 22., 20., 14.,
         11.,  7.,  4.,  4.,  4.]])
mu_o = torch.tensor([[0.4588]])

In [None]:
def l2_distance(observation, simulated_data):
    return torch.norm((observation - simulated_data), dim=-1)

def rejection_abc(prior, simulator, observed_data, num_simulations, distance_fun, q):
    
    # Your code for rejection ABC.
    # Hint: Note that prior, simulator and distance_fun all are callable functions, 
    # while observed_data is a tensor, num_simulations is a integer and q a float.
    # Also note that most functions like "sort, "argsort" you might know from numpy are 
    # available in PyTorch as well.
    
    return samples  # return accepted parameters

In [None]:
# c) run inference using your rejection abc function. 

# samples = rejection_abc(...)

In [None]:
# d) plot the results.

## Exercise 3: Parameter inference with Conditional Density Estimation

Rejection ABC, as the name says, is based on rejecting samples, which can be very inefficient if the data is high dimensional. Alternative methods using neural networks for density estimation have been developed in recent years, starting with: https://papers.nips.cc/paper/6084-fast-free-inference-of-simulation-models-with-bayesian-conditional-density-estimation

Here, we coded up a simple version of this approach: A neural network that takes the simulated data as input and regresses this input on the mean and std (in log-space) of a Normal distribution. The Normal distribution is then used to approximate the posterior distribution. 

After training, any observed data point can be passed to the neural network (now a conditional density estimator) and the it will return the parameters for the Normal distribution - for the corresponding posterior. Thus, by giving $s_o$ to the network one can obtain the posterior $p(\mu | s_o)$. 


**Your exercises**: 

a) Read the code carefully to understand what is happening. Consider revisiting a [tutorial on Pytorch](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).

b) Use the `tqdm` package to update the code such that it shows progress bars of the training epochs.

c) Train the network with 10000 simulations. Then get the posterior $p(\mu | s_o)$ by passing the observed data `s_o` from above into the network and use the resulting parameters to instantiate a Normal distribution. Plot the probabilities across the support of the prior (-2. to 2.).

d) Draw samples from the posterior and plot them in a histogram. Compare the histogram to the one you obtained by Rejection ABC. Do they differ? If so, what might be reasons for the difference? 

As always, make sure to add title and labels to your figures.

Below is some PyTorch code training a conditional density estimator for inference:

In [None]:
# Create a TensorDataset which we will use for training the density estimator
# From a TensorDataset, we construct a DataLoader that allows splitting data into shuffled batches
# during training of the NN. For background see PyTorch tutorials on data handling
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn

# You might wanna change the num_samples argument for playing around. 
parameters = prior(num_samples=10000)
data = simulator(parameters)

train_data = TensorDataset(parameters, data)
data_loader = DataLoader(train_data, batch_size=10, shuffle=True)

Next, we build a conditional density estimation neural network, which regresses on the mean and std (in log-space) of a Normal distribution: 

In [None]:
num_hiddens = [50, 50]
input_size = data.shape[1] 

# The NN takes data (observation) as input
# There are 2 hidden layers of 50 units each
# The regression is onto 2 numbers that will represent the parameters of a Normal distribution
network = nn.Sequential(
    nn.Linear(input_size, num_hiddens[0]),
    nn.ReLU(),
    nn.Linear(num_hiddens[0], num_hiddens[1]),
    nn.ReLU(),
    nn.Linear(num_hiddens[1], 2),
    )

# Create an optimizer
optim = torch.optim.Adam(network.parameters(), lr=0.01)

In [None]:
# You need to change this cell for exercise 3b. 
# import tqdm for progress bars
from tqdm import tqdm

# Train the parameters of the NN in minibatches
losses = []
for epoch in range(100):
    for inputs, outputs in data_loader:
        nn_out = network(outputs)
        cond_dist = torch.distributions.Normal(loc=nn_out[:,0], scale=torch.exp(nn_out[:,1]))
        loss = -1. * cond_dist.log_prob(inputs.reshape(-1)).mean()

        optim.zero_grad()
        loss.backward()
        optim.step()

        losses.append(loss.item())

# Plot loss of NN training
plt.figure(18, 5)
plt.title('loss')
plt.plot(losses);

c)

In [13]:
# Your code for exercise 3c:



d)

In [14]:
# Your code for exercise 3d:



## Exercise 4 (Optional) 

If you are interested in trying out more advanced CDE methods, we invite you to have a look at a toolbox we are currently developing in the mackelab, a toolbox for simulation-based inference, called `SBI`: https://github.com/mackelab/sbi 

This exercise is optional. You can have a look at the documentation page at https://www.mackelab.org/sbi/. If you want to proceed, you can clone the repository, follow the installation instructions and try to use the simulator and prior above to run inference on the DDM using `SBI`. 