# Simulation Based Inference For NeuroScience: The BOLD signal

## Table Of Content:
* [Setup](#set-up)
* [Train a Density Estimator](#density-estimator)
* [Loading Simulation Results](#simulation)
* [Training the Neural Network](#training)
* [Validating Results](#validation)
* [Plotting and Saving the Figures](#figures)

## Setup<a class="anchor" id="set-up"></a>

First import the needed libraries (and set current working directory if needed).

In [None]:
# If needed, before starting, change the current working directory by uncommenting and inserting the right path:
import os
# os.chdir("/home/coder/projects/lorenz_sbi")

# General libraries:
import numpy as np
import pylab as plt
import argparse 
import torch 

from sbi.inference import SNPE, SNLE, SNRE

# For plotting:
from sbi.analysis import pairplot, conditional_pairplot
from utils import marginal_correlation

# Functions:
from train import train

# Get the path of the current working directory: 
cwd_path = os.getcwd()

## Train a density estimator<a class="anchor" id="density-estimator"></a>

In [None]:
# Train a density estimator.
# If desired, change the number of threads to a higher number (by default it's 1).
parser = argparse.ArgumentParser(description="Train a density estimator.")
parser.add_argument("--data", type=str, default=cwd_path + "/data/X.npy", help="Path to the data file.")
parser.add_argument("--method", type=str, default="SNPE", help="Inference method.")
parser.add_argument("--density_estimator", type=str, default="maf", help="Density estimator.")
parser.add_argument("--num_threads", type=int, default=1, help="Number of threads.")
parser.add_argument("--device", type=str, default="cpu", help="Device.")

args, unknown = parser.parse_known_args()

## Creating X

In [None]:
# Read in beta files (input to our NN).
# First read in 1 file:
betas = np.load(cwd_path + "/data/Betas_01.npy")

# Append the rest of the beta batches to this file:
for number in range(2, 11):
    #print(number)
    if number <= 9:
        beta_files = np.load(cwd_path + "/data/Betas_0" + str(number) + ".npy")
        betas = np.concatenate((betas, beta_files))
    else:
        beta_files = np.load(cwd_path + "/data/Betas_10.npy")
        betas = np.concatenate((betas, beta_files))

# Read in our parameters.
total_combinations = np.load(cwd_path + "/data/total_combinations.npy")


# Check whether the number of rows match between our parameters and beta values.
print("Does the length match between the parameters and the beta value (returns True if this is true)?", len(total_combinations) == len(betas))

# Concatenate both together:
X = np.concatenate((total_combinations, betas), axis = 1)

print("The shape of X is: ", X.shape)
print("The first row/simulation in X is: ", "\n", X[0])

np.save(cwd_path + "/data/X.npy", X)

## Loading Simulation Results<a class="anchor" id="simulation"></a>

In [7]:
X = np.load(args.data, allow_pickle=True)
print(X.shape[0])
print(X[0])

# Split this matrix into train and test data.

# Take the first n rows of X for X_train (and all the columns of that row). The other rows will be assigned to X_test.
# X_train = X[:n, :]
# X_test = X[n:, :]

# As a small test only the last simulation is used as test data.
X_train = X[:11519, :]
X_test = X[11519:, :]

60000
[5.91164975 5.91920401 5.96471115 0.13882868 0.11480275 0.1577574
 0.11480275 0.10333426 0.11946977 0.1577574  0.11946977 0.19903269
 1.         0.95849652 0.94904747 0.95849652 1.         0.83305496
 0.94904747 0.83305496 1.         8.50382875 5.87806224 7.00083841]
[[5.91164975 5.91920401 5.96471115 0.13882868 0.11480275 0.1577574
  0.11480275 0.10333426 0.11946977 0.1577574  0.11946977 0.19903269
  1.         0.95849652 0.94904747 0.95849652 1.         0.83305496
  0.94904747 0.83305496 1.         8.50382875 5.87806224 7.00083841]]


## Training the Neural Network <a class="anchor" id="training"></a>

There is a pretrained posterior present in the models map. If you want to retrain it, uncomment the block of code below before running it.

In [None]:
# # Establish how many simulations are in the data 
# num_simulations = X_train.shape[0]

# # Seperate simulation parameters and summary statistics
# params = X_train[:, -4:]
# stats  = X_train[:, :-4]

# # When working with Torch, the matrix has to be parsed to a Torch object 
# theta = torch.from_numpy(params).float()
# x = torch.from_numpy(stats).float()

# # Train the posterior with all the arguments needed 
# posterior = train(num_simulations,
#                     x,
#                     theta,
#                     num_threads         = args.num_threads,
#                     method              = args.method,
#                     device              = args.device,
#                     density_estimator   = args.density_estimator
#                     )

# # Save posterior (intermediate result).
# torch.save(posterior, cwd_path + "/models/posterior.pt")

## Validating Results <a class="anchor" id="validation"></a>

In [None]:
# Known values (for testing the model); split X_test into the parameters and summary statistics.
# obs_x = "" # Get default stats
# obs_theta = "" # Get default params

# Testing workflow.
obs_x = X_test[0, -4:] # Get default stats
obs_theta = X_test[0, :-4] # Get default params

print(X_test)
print(obs_x)
print(obs_theta)

# Load posterior.
posterior = torch.load(cwd_path + "/models/posterior.pt")

num_samples = 100000

posterior.set_default_x(obs_x)
posterior_samples = posterior.sample((num_samples,))

## Plotting and Saving the Figures <a class="anchor" id="figures"></a>

In [None]:
plt.figure()
fig, ax = pairplot(samples=posterior_samples, labels=[r"$NU_1$", r"$NU_2$", r"$NU_3$", r"$NU_4$", r"$NU_5$", r"$NU_6$", 
                   r"$NU_7$", r"$NU_8$"], figsize=(20, 20))
plt.show()

# Save figure.
plt.savefig(cwd_path + "/png/pairplot.png")

plt.figure()
fig, ax = marginal_correlation(samples=posterior_samples, labels=[r"$NU_1$", r"$NU_2$", r"$NU_3$", r"$NU_4$", r"$NU_5$", r"$NU_6$", 
                               r"$NU_7$", r"$NU_8$"], figsize=(10, 10))
plt.show()

#Save figure.
plt.savefig(cwd_path + "/png/marginal_correlation.png")