In [None]:
import delfi.distribution as dd
import matplotlib as mpl
import numpy as np
import pandas as pd
import time

from delfi.generator import Default
from delfi.utils.viz import plot_pdf
from lfimodels.channelomics.ChannelSingle import ChannelSingle
from lfimodels.channelomics.ChannelSuper import ChannelSuper
from lfimodels.channelomics.ChannelStats import ChannelStats
from lfimodels.channelomics.ChannelMPGenerator import ChannelMPGenerator
from matplotlib import pyplot as plt

%matplotlib inline

In [None]:
GT = {'kd': np.array([4, -63, 0.032, 15, 5, 0.5, 10, 40]),
      'kslow': np.array([1, 35, 10, 3.3, 20])}
LP = {'kd': ['power', 'vt', 'Ra', 'tha', 'sa', 'Rb', 'thb', 'sb'],
      'kslow': ['power', 'v_shift', 'v_scale', 'tau_scale', 'tau_shift']}

E_channel = {'kd': -90.0, 'kslow': -90.0}
fact_inward = {'kd': 1, 'kslow': 1}

channel_type = 'kd'
cython = True

gt = GT[channel_type]

n_params = len(gt)
labels_params = LP[channel_type]
lims = np.sort(np.concatenate((0.3 * gt.reshape(-1,1), 1.3 * gt.reshape(-1,1)), axis=1))

m = ChannelSingle(channel_type=channel_type, n_params=n_params, cython=cython)
p = dd.Uniform(lower=lims[:,0], upper=lims[:,1])
s = ChannelStats(channel_type=channel_type)
g = Default(model=m, prior=p, summary=s)


# generate observed data
n_params_obs = len(gt)
m_obs = ChannelSingle(channel_type=channel_type, n_params=n_params_obs, cython=cython)
xo = m_obs.gen(gt.reshape(1,-1))
xo_stats = s.calc(xo[0])

In [None]:
from delfi.inference import SNPE

svi = False
prior_norm = True
inf_snpe = SNPE(g, obs=xo_stats, pilot_samples=100, n_components=2, n_hiddens=[50,50], svi=svi, prior_norm=prior_norm)
log, train_data, posterior = inf_snpe.run(n_train=1000, n_rounds=1)

In [None]:
posterior = inf_snpe.predict(xo_stats)

In [None]:
plot_pdf(posterior, gt=gt, labels_params=LP['kd'],
         lims=lims,figsize=(18,15));

In [None]:
mean, S = posterior.calc_mean_and_cov()
# m = posterior.xs[0].m
# S = posterior.xs[0].S

prot = ['v_act','v_inact','v_deact','v_ap','v_ramp']
num_protocols = len(prot)

num_samp = 10

# sampling at contour of 1 covariance away from mean (if samples from outside the prior box, contour is at prior box)
x_samp = np.random.randn(n_params,num_samp)
x_samp = np.divide(x_samp,np.linalg.norm(x_samp,axis=0))
x_samp = (np.dot(S,x_samp)).T+mean

# # sample from posterior
# x_samp = posterior.gen(n_samples=num_samp)

# correct for samples outside the prior box
x_samp = np.maximum(x_samp,lims[:,0])
x_samp = np.minimum(x_samp,lims[:,1])

params = np.concatenate((np.array([mean]),x_samp))

fig = plt.figure(figsize = (20,10+num_samp*5))

for i in range(1+num_samp):
    x = m.gen_single(params[i,:])
    for p in range(num_protocols):
        I = x[prot[p]]['data']
        t = x[prot[p]]['time']
        num_levels = len(I[:,0])
        cm1 = mpl.cm.viridis
        col1 = [cm1(1.*k/num_levels) for k in range(num_levels)]
        
        for j in range(num_levels):
            if i==0:
                plt.subplot(2+num_samp,num_protocols,p+1)
                plt.plot(t, xo[0][0][prot[p]]['data'][j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('observation')
                
                plt.subplot(2+num_samp,num_protocols,num_protocols+p+1)
                plt.plot(t, I[j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('mode')
            else:
                plt.subplot(2+num_samp,num_protocols,(i+1)*num_protocols+p+1)
                plt.plot(t, I[j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('sample '+str(num_samp-i+1))