In [14]:
import hydra
import torch
from ddm_stride.pipeline.evaluate import load_experimental_data
from ddm_stride.pipeline.infer import build_posterior, build_prior
from ddm_stride.utils.data_names import *

import warnings
warnings.filterwarnings('ignore')

### Load data

Open your `config/task` file. You should already have specified the `experimental_data_path` during the simulation phase. The subsequent cell will read in the data again, in case you want to make changes to the data or plot it. 

In [2]:
with hydra.initialize(config_path='../config'):
    cfg = hydra.compose(config_name='config')

experimental_data = load_experimental_data(cfg)
experimental_data

Unnamed: 0,monkey,rt,coh,correct,choice
0,1,0.355,0.512,1.0,0.0
1,1,0.359,0.256,1.0,1.0
2,1,0.525,0.128,1.0,1.0
3,1,0.332,0.512,1.0,1.0
4,1,0.302,0.032,0.0,0.0
...,...,...,...,...,...
6144,2,0.627,0.032,1.0,1.0
6145,2,0.581,0.256,1.0,1.0
6146,2,0.293,0.512,1.0,1.0
6147,2,0.373,0.128,1.0,0.0


### Plotting the experimental data and the posterior

The `group_by` configuration in `config/task` allows you to group results by experimental conditions. If you leave the configuration empty, the posterior will marginalize over all experimental conditions, i.e.  
$P(\theta | x) = \sum_{\pi} P(\theta | x, \pi)$ with parameters $\theta$, data $x$ and experimental conditions $\pi$. By specifying at least one experimental condition, the posterior will be computed for each combination of the specified experimental conditions separately.

Example:  
The experimental data specifies three levels of task difficulty via the experimental condition `coh`. The subsequent plot shows an example for `evaluate/pdf_and_posterior.png` when defining `group_by: coh`. The plots on the left visualize the experimental data as well as the potential function $P(x| \theta, \pi) \cdot P(\theta)$. The title of the plot indicates the experimental condition that the data has been grouped by. The right side of the plot shows the posterior for each parameter $\theta$.

TODO: plot

The following plot shows an example of marginalizing the task difficulty out.  

TODO: plot

The best parameters (for each group) are saved in `best_thetas.json` with the `diagnose` subfolder. TODO: confidence interval

### Posterior predictive check

In order to determine how well different posterior samples approximate the experimental data, a posterior predictive check is performed. For each plot one posterior sample that has been inferred 
from the experimental data is used to simulate observations. Posterior samples should effect simulations similar to the experimental data, especially if the posterior sample has a high probability within the posterior distribution. 

TODO: plot

### Run evaluate step

In [3]:
dir = '../results/${result_folder}'

In [None]:
%run ../ddm_stride/run.py run_evaluate=True    
# show output, interpretation/recommendations

### Access the posterior

The subsequent cell allows to access the posterior. `ddm_stride/sbi_extensions/mcmc.py` provides functions for sampling from the posterior or computing  $\log P(x| \theta, \pi) \cdot P(\theta)$.

In [15]:
with hydra.initialize(config_path='../config'):
    cfg = hydra.compose(config_name='config')

# Specify one configuration of experimental conditions, if available
exp_cond = []

# Filter out data containing the experimental conditions specified above
if exp_cond: 
    exp_cond_data = experimental_data.loc[experimental_data[:, get_experimental_condition_names(cfg)] == exp_cond]
    # x contains the observations
    x = exp_cond_data.loc[:, get_observation_names(cfg)].values
else:
    # x contains the observations
    x = experimental_data.loc[:, get_observation_names(cfg)].values

posterior = build_posterior(cfg)

In [18]:
# Sample from the posterior
# If you want to draw many samples, increase the number of chains and workers
samples = posterior.sample((10,), x=torch.Tensor(x), exp_cond=torch.Tensor(exp_cond), num_chains=1, num_workers=1)
print(f"samples: {samples}")

Tuning bracket width...: 100%|██████████| 50/50 [01:35<00:00,  1.90s/it]s]
Generating samples: 100%|██████████| 100/100 [03:11<00:00,  1.92s/it]
Generating samples: 100%|██████████| 10/10 [00:18<00:00,  1.85s/it]
Running 1 MCMC chains in 1 batches.: 100%|██████████| 1/1 [05:05<00:00, 305.60s/it]

samples: tensor([[-0.8296,  1.7672,  0.6999],
        [-0.8365,  1.7620,  0.6999],
        [-0.8289,  1.7657,  0.6995],
        [-0.8049,  1.7589,  0.6999],
        [-0.8288,  1.7598,  0.6999],
        [-0.8342,  1.7658,  0.6999],
        [-0.8191,  1.7641,  0.6999],
        [-0.8167,  1.7548,  0.6998],
        [-0.8193,  1.7537,  0.7000],
        [-0.8214,  1.7635,  0.6999]])





In [10]:
# Sample a parameter from the prior
theta = build_prior(cfg, device='cpu').sample((2,))
print(f"theta: {theta}")

# Compute the log probability
log_prob = posterior.log_prob(theta, x=torch.Tensor(x), exp_cond=torch.Tensor(exp_cond))
potential = posterior.potential(theta, x=torch.Tensor(x), exp_cond=torch.Tensor(exp_cond))

print(f"log_prob: {log_prob}, \npotential: {potential}")

theta: tensor([[-0.7571,  1.2192,  0.3646],
        [-0.7012,  1.1473,  0.3808]])
log_prob: tensor([[-13791.8652, -14407.0352]]), 
potential: tensor([[-13791.8652, -14407.0352]])
