In [1]:
import numpy as np
import torch

from sbi.inference.snle import SNLE_A
from sbi.utils.torchutils import BoxUniform
from fslm.utils import OptimisedPrior, includes_nan
from fslm.experiment_helper import SimpleDB
from fslm4expdata import CalibratedLikelihoodEstimator, BiasEstimator, CalibratedPrior
import pickle
import subprocess
from fslm.metrics import sample_kl
from tqdm import tqdm
import warnings
from sbi.analysis.plot import pairplot
import matplotlib.pyplot as plt
import pandas as pd

# import sys
# sys.path.append('../code')

ephys_helper not found. Required for HH simulation.
ephys_helper not found. Required for HH simulation.


In [2]:
def perturb(x, eps=1, rng=0):
    np.random.seed(rng)
    x_clean = x[~np.any(np.isnan(x), axis=1)]
    x_clean = x_clean[~np.any(np.isinf(x_clean), axis=1)]
        
    noise_stats=np.random.multivariate_normal(
        np.zeros((x.shape[1])),
        eps*np.diag(x_clean.std(axis=0)),
        size=x.shape[0]
    )
    return x + noise_stats.astype(np.float32)

def optimise_base_prior_samples(sample_shape, bias_log_prob, base_prior_samples, seed=0):
    n_samples = torch.Size(sample_shape).numel()
    n = 0
    samples = []
    if seed != None: torch.manual_seed(seed); np.random.seed(seed)
    # rejection sampling | could be replaced by sbi's rejection_sample
    # return rejection_sample(self, self.base_prior, num_samples)

    N = len(base_prior_samples)
    while n < n_samples:
        base_sample_idxs = torch.randperm(N)
        theta = base_prior_samples[base_sample_idxs]
        
        p_accept = torch.exp(bias_log_prob(theta)).view(-1)
        accepted = p_accept > torch.rand_like(p_accept)
        samples += base_sample_idxs[accepted][:n_samples-n].tolist()
        samples = list(set(samples)) # ensures samples are unique
        n = len(samples)
    return torch.tensor(samples)

In [3]:
db_2d3M = SimpleDB("../data/2d3M")



In [4]:
# define prior
model_param_names = np.array(['C', r'$R_{input}$', r'$\tau$', r'$g_{Nat}$', r'$g_{Na}$', r'$g_{Kd}$', r'$g_{M}$',
                         r'$g_{Kv31}$', r'$g_{L}$', r'$E_{leak}$', r'$\tau_{max}$', 'VT', 'rate_to_SS_factor'])
prior_min = [0.1,  20,  0.1,    0,        0,      0,      0,      0,      0, -130,    50,    -90,   0.1]
prior_max = [15,   1000,   70,   250,     100,      30,    3,     250,     3,  -50,  4000,   -35,    3]
base_prior = BoxUniform(low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))

db_2d3M.write("base_prior", base_prior)

# All 25 degree Celcius mouse motor cortex (M1) electrophysiological data, preprocessed
M1_25degree = pickle.load(open('../data/M1_features.pickle', 'rb'))
Xo = M1_25degree['X_o']

# write observations to database
X_o = {key: list(val.values()) for key, val in Xo.T.to_dict().items()}
db_2d3M.write("X_o", X_o)

# import simulations for training of NaN optimised prior
num_sims = 30_000
theta = torch.from_numpy(np.load("../data/full_batch.npz")['theta'])
X = np.load("../data/full_batch.npz")['stats']

torch.manual_seed(0)
rd_idxs = torch.randperm(len(theta))[:num_sims]
x4opt = torch.from_numpy(X)[rd_idxs]
theta4opt = theta[rd_idxs]

# add noise to simulations and write to database
x2d4opt = torch.from_numpy(perturb(X, rng=0))[rd_idxs]
db_2d3M.write("theta4opt", theta4opt)
db_2d3M.write("x4opt2d", x2d4opt)

In [None]:
## train optimised prior on
# 30k noised data incl. NaNs -> prior_opt
nan_bias = BiasEstimator(input_dim = theta.shape[1], model = "resnet")
nan_bias.z_score_inputs(theta[:10_000])
nan_bias.train(theta4opt, ~includes_nan(x2d4opt))
prior_opt = CalibratedPrior(base_prior, nan_bias)
prior_opt.bias.summarywriter = None # fix to allow pickling
db_2d3M.write("prior", prior_opt)

## prepare training 3M noisy data points for posterior (incl. BiasEstimator)
train_idxs = optimise_base_prior_samples((3_000_000,), prior_opt.bias.log_prob, theta, seed=1)
theta_train = theta[train_idxs]
x2d_train = torch.from_numpy(perturb(X, rng=0))[train_idxs]

db_2d3M.write("x_2dopt", x2d_train)
db_2d3M.write("theta_2dopt", theta_train)

In [26]:
# train posterior and bias estimator on the data
# for s in range(10):
#     subprocess.run(f"python3 train_calibration_estimator.py -r {s} -d ../data/2d3M -t 2dopt", shell=True, check=True)
#     subprocess.run(f"python3 train_nle_inference.py -r {s} -d 2d3M -f ../data/ -t 2dopt", shell=True, check=True)

In [None]:
# combine likelihood and bias estimator to create nan calibrated posterior
trained_inference_obj = db_2d3M.query("inference_2d")
calibration_estimator = db_2d3M.query("calibration_2d")

likelihood_estimator = trained_inference_obj._neural_net
cal_likelihood = CalibratedLikelihoodEstimator(likelihood_estimator, calibration_estimator)
cal_nle_posterior = trained_inference_obj.build_posterior(density_estimator=cal_likelihood, sample_with="mcmc")

# input = prior.sample((1,))
# x_o_test = torch.tensor([X_o["20180918_sample_1"]])[:,:-4]
# assert cal_likelihood.log_prob(x_o_test, input.view(1,1,-1))

db_2d3M.write("posterior_2d3M", cal_nle_posterior)

In [None]:
# sample posterior for all observations
# s = 0
# for i in range(len(X_o)):
#     subprocess.run(f"python3 sample_hh_nle.py -o {i} -n xo_sweep -f ../data/ -d 2d3M -w 1 -t 2dopt -r {s}", shell=True, check=True)

In [6]:
# load npe posterior and database with nle samples
with open("../data/training_schedule_2d.pickle", "rb") as f:
    npe_posterior = pickle.load(f)
nle_sample_db = SimpleDB("../data/2d3M_xo_sweep", "r")

In [None]:
# compute and rank kls for all observations

lkls = {}
rkls = {}

np.random.seed(0)
torch.manual_seed(0)

for cell, xo in tqdm(X_o.items()):
    npe_samples = npe_posterior.sample((1000,), xo[:-4], show_progress_bars=False)
    nle_samples = nle_sample_db.query(f"samples_nle_2dopt_mcmc_{0}_{cell}_all_dims")
    lkls[cell] = sample_kl(nle_samples, npe_samples)
    rkls[cell] = sample_kl(npe_samples, nle_samples)
jsds = {cell:0.5*rkls[cell]+0.5*lkls[cell] for cell in X_o}

sorted_rkls = dict(sorted(rkls.items(), key=lambda item: item[1]))
sorted_lkls = dict(sorted(lkls.items(), key=lambda item: item[1]))
sorted_jsds = dict(sorted(jsds.items(), key=lambda item: item[1]))

cell_idxs = {k:v for k,v in zip(list(X_o), range(len(X_o)))}
sorted_cells = list(sorted_rkls)
sorted_cell_idxs = [cell_idxs[n] for n in sorted_rkls]

sorted_kls = pd.DataFrame(data=np.vstack([sorted_cell_idxs, sorted_cells, list(sorted_rkls.values())]).T, columns=["cell_idx", "cell", "kl"])
sorted_kls.to_csv("../data/2d3M_xo_sweep/sorted_kls.csv", index=False)

In [None]:
# set(list(sorted_rkls)[:50]).difference(set(os.listdir("../data/fslm_top50")))

In [None]:
# run fslm for 50 with lowest kl between npe and nle posteriors
# import os
# my_env = os.environ.copy()
# my_env["LD_LIBRARY_PATH"] = ":/home/jnsbck/Applications/anaconda3/envs/hh_sbi/lib"

# run fslm for 50 observations
# for cell in top50_cells:
#     for seed in range(5):
#     subprocess.run(f"python3 hh_fslm_tree.py -r {seed} -n fslm_top50/20180608_sample_2 -w 1 -d 2d3M/ -p 2d3M/ -f ../data/ -o 20180608_sample_2 -t 2dopt_mcmc_0_full", shell=True, check=True, env=my_env)