# Simulation Based Inference For NeuroScience: The BOLD signal

## Table Of Content:
* [Setup](#set-up)
* [Train a Density Estimator](#density-estimator)
* [Creating X](#creating-X)
* [Loading Simulation Results](#simulation)
* [Training the Neural Network](#training)
* [Validating Results & Saving Plots](#validation)

## Setup<a class="anchor" id="set-up"></a>

First import the needed libraries (and set current working directory if needed).

In [None]:
# If needed, before starting, change the current working directory by uncommenting and inserting the right path:
import os
# os.chdir("/home/coder/projects/lorenz_sbi")

# General libraries:
import numpy as np
import pylab as plt
import argparse 
import torch 

from sbi.inference import SNPE, SNLE, SNRE

# For plotting:
from sbi.analysis import pairplot, conditional_pairplot
from utils import marginal_correlation

# Functions:
from train import train

# Get the path of the current working directory: 
cwd_path = os.getcwd()

## Train a density estimator<a class="anchor" id="density-estimator"></a>

In [None]:
# Train a density estimator.
# If desired, change the number of threads to a higher number (by default it's 1).
parser = argparse.ArgumentParser(description="Train a density estimator.")
parser.add_argument("--data", type=str, default=cwd_path + "/data/X_train_01.npy", help="Path to the data file.")
parser.add_argument("--method", type=str, default="SNPE", help="Inference method.")
parser.add_argument("--density_estimator", type=str, default="maf", help="Density estimator.")
parser.add_argument("--num_threads", type=int, default=1, help="Number of threads.")
parser.add_argument("--device", type=str, default="cpu", help="Device.")

args, unknown = parser.parse_known_args()

## Creating X<a class="anchor" id="creating-X"></a>

In [None]:
# Read in beta files (input to our NN).
# First read in 1 file:
betas = np.load(cwd_path + "/data/Betas_01.npy")

# Append the rest of the beta batches to this file:
for number in range(2, 11):
    #print(number)
    if number <= 9:
        beta_files = np.load(cwd_path + "/data/Betas_0" + str(number) + ".npy")
        betas = np.concatenate((betas, beta_files))
    else:
        beta_files = np.load(cwd_path + "/data/Betas_10.npy")
        betas = np.concatenate((betas, beta_files))

# Read in our parameters.
total_combinations = np.load(cwd_path + "/data/total_combinations.npy")


# Check whether the number of rows match between our parameters and beta values.
print("Does the length match between the parameters and the beta value (returns True if this is true)?", len(total_combinations) == len(betas))

# Concatenate both together:
X = np.concatenate((total_combinations, betas), axis = 1)

print("The shape of X is: ", X.shape)
print("The first row/simulation in X is: ", "\n", X[0])

np.save(cwd_path + "/data/X.npy", X)

## Loading Simulation Results<a class="anchor" id="simulation"></a>

As shuffle is not deterministic (till need to include a seed as np.random.seed() does not work as intended), we will save our X_train and X_test instead.

In [7]:
X = np.load(cwd_path + "/data/X.npy")
print(X.shape)
print(X[0])

# # Split this matrix into train and test data.

# # Shuffle the X to make sure different combinations end up in the train and test sets.
# np.random.shuffle(X)

# # Take 10% of the data for test (there are 11520 simulations in X, 10% is 1152 which goes to test, the rest goes to X_train).
# X_train = X[:10368, :]
# X_test = X[10368:, :]

# # Save this train and test set.
# np.save(cwd_path + "/data/X_train_01.npy", X_train)
# np.save(cwd_path + "/data/X_test_01.npy", X_test)

60000
[5.91164975 5.91920401 5.96471115 0.13882868 0.11480275 0.1577574
 0.11480275 0.10333426 0.11946977 0.1577574  0.11946977 0.19903269
 1.         0.95849652 0.94904747 0.95849652 1.         0.83305496
 0.94904747 0.83305496 1.         8.50382875 5.87806224 7.00083841]
[[5.91164975 5.91920401 5.96471115 0.13882868 0.11480275 0.1577574
  0.11480275 0.10333426 0.11946977 0.1577574  0.11946977 0.19903269
  1.         0.95849652 0.94904747 0.95849652 1.         0.83305496
  0.94904747 0.83305496 1.         8.50382875 5.87806224 7.00083841]]


## Training the Neural Network <a class="anchor" id="training"></a>

There is a pretrained posterior present in the models map. If you want to retrain it, uncomment the block of code below before running it.

In [None]:
X_train_01 = np.load(args.data, allow_pickle=True)
X_test_01 = np.load(cwd_path + "/data/X_test_01.npy")

# # Establish how many simulations are in the data 
# num_simulations = X_train_01.shape[0]

# # Seperate simulation parameters and summary statistics
# params = X_train_01[:, -4:]
# stats  = X_train_01[:, :-4]

# # When working with Torch, the matrix has to be parsed to a Torch object 
# theta = torch.from_numpy(params).float()
# x = torch.from_numpy(stats).float()

# # Train the posterior with all the arguments needed 
# posterior = train(num_simulations,
#                     x,
#                     theta,
#                     num_threads         = args.num_threads,
#                     method              = args.method,
#                     device              = args.device,
#                     density_estimator   = args.density_estimator
#                     )

# # Save posterior (intermediate result).
# torch.save(posterior, cwd_path + "/models/posterior_01.pt")

## Validating Results & Saving Plots <a class="anchor" id="validation"></a>

*Warning: when this code is run 1152 pairplots and correlation plots will be saved in the png-map! (for every simulation 2 plots and we have 1152 simulations as test data).*

To-do: change the for-loop to display the plot in a different way (don't use pairplots) & incorporate the error between the actual value and the inferred parameter values.

In [None]:
# # Iterate through all simulations, using the beta values see how well the original parameters can be inferred (all the plots are saved).

# for number in range(0, 1153):
#     obs_x = X_test_01[number, -4:] # Get default stats
#     obs_theta = X_test_01[number, :-4] # Get default params

#     posterior = torch.load(cwd_path + "/models/posterior_01.pt") # Load posterior

#     num_samples = 100000

#     posterior.set_default_x(obs_x)
#     posterior_samples = posterior.sample((num_samples,))

#     # Plotting
#     plt.figure()
#     fig, ax = pairplot(samples=posterior_samples, labels=[r"$NU_1$", r"$NU_2$", r"$NU_3$", r"$NU_4$", r"$NU_5$", r"$NU_6$", 
#                        r"$NU_7$", r"$NU_8$"], figsize=(20, 20))

#     plt.savefig(cwd_path + "/plot_images/pairplot_01_" + str(number) + ".png") # Save figure
#     plt.close()

#     plt.figure()
#     fig, ax = marginal_correlation(samples=posterior_samples, labels=[r"$NU_1$", r"$NU_2$", r"$NU_3$", r"$NU_4$", r"$NU_5$", r"$NU_6$", 
#                                    r"$NU_7$", r"$NU_8$"], figsize=(10, 10))

#     plt.savefig(cwd_path + "/plot_images/marginal_correlation_01_" + str(number) + ".png") # Save figure
#     plt.close()