## Inference of Neocortical Layer 5 Pyramidal Cell

In [None]:
import os
path1 = os.getcwd()
# change the current working directory to path
path2 = '/home/pedro/repos/lfi-models/lfimodels/l5pc/'
os.chdir(path2)

In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.l5pc.utils as utils
import lfimodels.l5pc.l5pc_model as l5pc_model
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle

from lfimodels.l5pc.L5Pyramidal import L5Pyramidal
from lfimodels.l5pc.L5PCStats import L5PCStats
from lfimodels.l5pc.L5PCStatsMoments import L5PCStatsMoments
from delfi.utils.viz import plot_pdf

%matplotlib inline

In [None]:
import json
feature_configs = json.load(open('config/features.json'))
# print(sorted(feature_configs.items()))
efel_feature_name_ls = []
for protocol_name, locations in sorted(feature_configs.items()):
    print(protocol_name)
    for location, features in sorted(locations.items()):
        print(location)
        for efel_feature_name, meanstd in sorted(features.items()):
            efel_feature_name_ls.append(efel_feature_name)
            print(efel_feature_name)

## retrieve observed data

In [None]:
true_params, labels_params = utils.obs_params()

seed = None

obs = utils.syn_obs_data(params=true_params)
obs_stats = utils.syn_obs_stats(data=obs,params=true_params,summary_stats=2)

## generate instances of prior, model, and summary stats

In [None]:
n_processes = 10

if n_processes>1:
    seeds_model = np.arange(1,n_processes+1,1)
    m = []
    for i in range(n_processes):
        m.append(L5Pyramidal(seed=seeds_model[i]))
    p = utils.prior(true_params=true_params,labels_params=labels_params,seed=seed)
#     s = L5PCStats(seed=seed)
    s = L5PCStatsMoments(seed=seed)
    g = dg.MPGenerator(models=m, prior=p, summary=s)
else:
    seed = None
    m = L5Pyramidal(seed=seed)
    p = utils.prior(true_params=true_params,labels_params=labels_params,seed=seed)
#     s = L5PCStats(seed=seed)
    s = L5PCStatsMoments(seed=seed)
    g = dg.Default(model=m, prior=p, summary=s)

In [None]:
def plot_responses(responses,color):
    fig, axes = plt.subplots(len(responses), figsize=(10,10))
    for index, (resp_name, response) in enumerate(sorted(responses.items())):
        axes[index].plot(response['time'], response['voltage'], color = color, label=resp_name)
        axes[index].set_title(resp_name)
    fig.tight_layout()
    fig.show()

In [None]:
plot_responses(obs['data'],color = 'b')

In [None]:
obs_stats[0]

## SNPE-C

In [None]:
# SNPE parameters

seed = None

# training schedule
n_train=80000
n_rounds=1

# fitting setup
minibatch=100
epochs=1000

# network setup
n_hiddens=[50,50]
reg_lambda=0.01

# convenience
pilot_samples=80000
svi=False
verbose=True
prior_norm=True

# SNPE-C parameters
n_null = minibatch-1

# MAF parameters
mode='random' # ordering of variables for MADEs
n_mades = 5 # number of MADES
act_fun = 'tanh'
batch_norm = False # batch-normalization currently not supported
train_on_all = True # now supported feature


if train_on_all:
    epochs = [epochs//(r+1) for r in range(n_rounds)]

In [None]:
# control MAF seed
rng = np.random
rng.seed(seed)

# inference object
res = infer.SNPEC(g,
                 obs=obs_stats,
                 n_hiddens=n_hiddens,
                 seed=seed,
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 n_mades=n_mades, # providing this argument triggers usage of MAFs (vs. MDNs)
                 act_fun=act_fun,
                 mode=mode,
                 rng=rng,
                 batch_norm=batch_norm,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
log, train_data, posterior = res.run(
                    n_train=n_train,
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds,
                    train_on_all=train_on_all,
                    minibatch=minibatch,
                    silent_fail=False,
                    verbose=True,
                    epochs=epochs)

In [None]:
# change the current working directory back to notebook path
os.chdir(path1)

In [None]:
if svi:
    svi_flag = '_svi'
else:
    svi_flag = '_nosvi'

save_path = './results/'
filename1 = 'l5pc_run_1_param20'+\
svi_flag+'_rounds'+str(n_rounds)+'_nsims'+str(n_train*n_rounds)+'_snpec_lognorm.pkl'
io.save_pkl((log, train_data, posterior),save_path+filename1)

In [None]:
save_path = './results/'
filename1 = 'l5pc_run_1_param20'+\
svi_flag+'_rounds'+str(n_rounds)+'_nsims'+str(n_train*n_rounds)+'_snpec_lognorm.pkl'
log, train_data, posterior = io.load_pkl(save_path+filename1)

## analyse results

In [None]:
fig = plt.figure(figsize=(15,10))

col_min = 1
num_colors = 2+col_min
cm1 = mpl.cm.Blues
col1 = [cm1(1.*i/num_colors) for i in range(col_min,num_colors)]

for i in range(res.round):
    plt.subplot(2,1,i+1)
    plt.plot(log[i]['loss'], color=col1[i], lw=2)
    plt.xlabel('iteration')
    plt.ylabel('loss')
    plt.title('round'+str(i+1));

In [None]:
# param_names_bounds = [(str(param.name),param.bounds) for param in l5pc_model.define_original_parameters()
#                       if not param.frozen]
# param_names_bounds.sort(key=lambda x: labels_params.index(x[0]))
# param_bounds = [param_names_bounds[i][1] for i in range(len(param_names_bounds))]
# param_bounds = np.array(param_bounds)
# param_bounds[[5,11,13,18],0] -= [0.0005,20,0.0005,20]
# param_bounds[:,1] += 10
# param_bounds = np.log(param_bounds+1e-4)

# param_bounds_min = param_bounds[:,0]
# param_bounds_max = param_bounds[:,1]

# prior_lims = param_bounds*1


prior_min = res.generator.prior.mean-3*res.generator.prior.std
prior_max = res.generator.prior.mean+3*res.generator.prior.std

prior_lims = np.concatenate((prior_min.reshape(-1,1),
                                 prior_max.reshape(-1,1)),
                                axis=1)

for i in range(res.round):
    fig = plt.figure(figsize=(15,15))
    posterior[i].ndim = len(true_params)
    plot_pdf(p, lims=prior_lims, samples=posterior[i].gen(1000).T, gt=true_params,figsize=(15,15))
    plt.show()
    plt.close()

In [None]:
posterior = posterior[n_rounds-1]

In [None]:
fig = plt.figure(figsize=(15,15))
# plot_pdf(p, lims=prior_lims, samples=posterior.gen(1000).T, gt=true_params,figsize=(15,15),
#          labels_params=labels_params, ticks=True)
plot_pdf(p, lims=prior_lims, samples=posterior.gen(1000).T, gt=true_params,figsize=(15,15))
plt.show()
plt.close()

In [None]:
def plot_responses(responses,color):
    fig, axes = plt.subplots(len(responses), figsize=(10,10))
    for index, (resp_name, response) in enumerate(sorted(responses.items())):
        axes[index].plot(response['time'], response['voltage'], color = color, label=resp_name)
        axes[index].set_title(resp_name)
    fig.tight_layout()
    fig.show()
    
def plot_all_responses(responses,colors):
    fig, axes = plt.subplots(len(responses[0]), figsize=(10,10))
    for i in range(len(responses)):
        for index, (resp_name, response) in enumerate(sorted(responses[i].items())):
            axes[index].plot(response['time'], response['voltage'], color = colors[i], label=resp_name)
            axes[index].set_title(resp_name)
    fig.tight_layout()
    fig.show()

In [None]:
# parameter set with closest summary statistic to observed data
train_stats = train_data[n_rounds-1][1]
obs_zt = (obs_stats-res.stats_mean)/res.stats_std
# obs_zt[0,res.stats_std==0] = (obs_stats[0,res.stats_std==0]-\
#                               res.stats_mean[res.stats_std==0])/res.stats_std[res.stats_std==0]

dist_train_data = np.linalg.norm((train_stats[:,res.stats_std!=0]-obs_zt[0][res.stats_std!=0]),axis=1)
dist_argsort = np.argsort(dist_train_data)
x_closest = train_data[n_rounds-1][0][dist_argsort[0],:]*res.params_std+res.params_mean

In [None]:
x_closest-true_params

In [None]:
num_samp = 10

# # sampling at contour of 1 covariance away from mean
# # (if samples from outside the prior box, contour is at prior box)
# x_samp = np.random.randn(n_params,num_samp)
# x_samp = np.divide(x_samp,np.linalg.norm(x_samp,axis=0))
# x_samp = (np.dot(S,x_samp)).T+mn_post

# sample from posterior
x_samp = posterior.gen(n_samples=num_samp)

# reject samples outside the prior box
ind = (x_samp > prior_min) & (x_samp < prior_max)
x_samp = x_samp[np.prod(ind,axis=1)==1]

num_samp = len(x_samp[:,0])

num_samp

In [None]:
# change the current working directory to path
path2 = '/home/pedro/repos/lfi-models/lfimodels/l5pc/'
os.chdir(path2)

fig = plt.figure()

mn_post, S = posterior.calc_mean_and_cov()
# mn_post = posterior.xs[0].m
# S = posterior.xs[0].S
n_params = len(mn_post)
y_obs = obs['data']
# t = obs['time']
# duration = np.max(t)

COL = {}
COL['GT']   = (35/255,86/255,167/255)
COL['SNPE'] = (0, 174/255,239/255)
COL['CLOSEST'] = (1,0,0)

num_samp = 4

# # sampling at contour of 1 covariance away from mean
# # (if samples from outside the prior box, contour is at prior box)
# x_samp = np.random.randn(n_params,num_samp)
# x_samp = np.divide(x_samp,np.linalg.norm(x_samp,axis=0))
# x_samp = (np.dot(S,x_samp)).T+mn_post

# sample from posterior
x_samp = posterior.gen(n_samples=num_samp)

# # reject samples outside the prior box
# ind = (x_samp > prior_min) & (x_samp < prior_max)
# x_samp = x_samp[np.prod(ind,axis=1)==1]

num_samp = len(x_samp[:,0])
num_colors = num_samp+1
cm1 = mpl.cm.Oranges
col1 = [cm1(1.*i/num_colors) for i in range(num_colors)]
col1.append(COL['SNPE'])
col1.append(COL['CLOSEST'])
col1.append(COL['GT'])

params = np.concatenate((x_samp,np.array([mn_post,x_closest])))

V = []
for i in range(2+num_samp):
    m = L5Pyramidal(seed=230+i)
    x = m.gen_single(params[i,:])
    V.append(x['data'])

V.append(y_obs)

# plotting simulation
plot_all_responses(V,col1[1:])
plt.xlabel('time (ms)')
plt.ylabel('voltage (mV)');

## analysis of summary statistics

In [None]:
# labels_sum_stats = ['sp_t','r_pot','r_pot_std','m1','m2','m3','m4']*3+['sp_t','r_pot','r_pot_std','m1','m2']*3
labels_sum_stats = ['sp_t1','r_pot1','r_pot_std1','m11','m21','m31','m41']+\
                   ['sp_t2','r_pot2','r_pot_std2','m12','m22','m32','m42']+\
                   ['sp_t3','r_pot3','r_pot_std3','m13','m23','m33','m43']+\
                   ['sp_t4','r_pot4','r_pot_std4','m14','m24']+\
                   ['sp_t5','r_pot5','r_pot_std5','m15','m25']+\
                   ['sp_t6','r_pot6','r_pot_std6','m16','m26']

n_summary_stats = len(obs_stats[0])

sum_stats_post = res.generator.summary.calc([m.gen_single(mn_post)])[0]

fig = plt.figure(figsize=(20,15))
ax = plt.subplot(321)
plt.plot(obs_stats[0], color=COL['GT'], lw=2, label='observation')
plt.plot(sum_stats_post, color=COL['SNPE'], lw=2, label='mode')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
# ax.set_xticklabels(labels_sum_stats)
plt.ylabel('feature value')
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.2, 1), loc='upper right')

ax = plt.subplot(322)
plt.semilogy(np.abs(obs_stats[0]),color=COL['GT'],linestyle='--', lw=2, label='observation')
plt.semilogy(np.abs(sum_stats_post-obs_stats[0]),color=COL['SNPE'], lw=2, label='mode')
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
# ax.set_xticklabels(labels_sum_stats);
plt.ylabel(r'$f^*$ - f');

ax = plt.subplot(3,2,(3,4))
obs_stats_norm = obs_stats[0]/res.stats_std
obs_stats_norm[res.stats_std==0] = obs_stats[0][res.stats_std==0]/res.stats_mean[res.stats_std==0]
post_stats_norm = sum_stats_post/res.stats_std
post_stats_norm[res.stats_std==0] = sum_stats_post[res.stats_std==0]/res.stats_mean[res.stats_std==0]
closest_stats = res.generator.summary.calc([m.gen_single(x_closest)])[0]
closest_stats_norm = closest_stats/res.stats_std
closest_stats_norm[res.stats_std==0] = closest_stats[res.stats_std==0]/res.stats_mean[res.stats_std==0]
width = 0.3
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),obs_stats_norm,
        width,color=COL['GT'],label='observation')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,post_stats_norm,
        width, color=COL['SNPE'],label='mode')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+2*width,closest_stats_norm,
        width, color=COL['CLOSEST'],label='closest summary statistics')
plt.ylabel(r'$\frac{f}{\sigma_{f \ PRIOR}}$')
ax.set_xlim(-1.5*width,n_summary_stats+width/2)
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
ax.set_xticklabels(labels_sum_stats, rotation = 45, ha='right')
ax.legend(bbox_to_anchor=(1.2, 1), loc='upper right')

ax = plt.subplot(3,2,(5,6))
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats),obs_stats_norm-post_stats_norm,
        width,color=COL['SNPE'],label='mode')
plt.bar(np.linspace(0,n_summary_stats-1,n_summary_stats)+width,obs_stats_norm-closest_stats_norm,
        width,color=COL['CLOSEST'],label='closest summary statistics')
plt.ylabel(r'$\frac{f^*-f}{\sigma_{f \ PRIOR}}$')
ax.set_xlim(-1.5*width,n_summary_stats+width/2)
ax.set_xticks(np.linspace(0,n_summary_stats-1,n_summary_stats))
ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
ax.set_xticklabels(labels_sum_stats, rotation = 45, ha='right');