In [None]:
import delfi.distribution as dd

import matplotlib as mpl
import numpy as np
import pandas as pd
import time

from delfi.generator import Default
from delfi.utils.viz import plot_pdf
from lfimodels.channelomics.Channel import Channel
from lfimodels.channelomics.ChannelStats import ChannelStats
from matplotlib import pyplot as plt
%matplotlib inline

GT = {'k': np.array([9, 25, 0.02, 0.002]),
      'na': np.array([-35, 9, 0.182, 0.124, -50, -75, 5, -65, 6.2, 0.0091, 0.024]),
      'ca': np.array([24.6, 11.3, -.031, 37.1, 2.5, 12.6, 18.9, 420]),
      'ih': np.array([.0015, .02, -87.7, -51.7, -.155, .144, 0.0067, .014, -94.2, \
                      -35.5, -.075, .144, 3.086, 4.486e-05, 80, 8.94, 1e-05, 1, 0])}
LP = {'k': ['qa','tha','Ra','Rb'],
      'na': ['tha','qa','Ra','Rb','thi1','thi2','qi','thinf','qinf','Rg','Rd'],
      'ca': ['p1','p2','p3','p4','p5','p6','p7','tau_z'],
      'ih':['a0','b0','ah','bh','ac','bc','aa0','ba0','aah','bah','aac','bac', \
            'kon','koff','b','bf','ai','gca','shift']}

channel_type = 'k'

gt = GT[channel_type]
labels_params = LP[channel_type]
lims = np.sort(np.concatenate((0.5*gt.reshape(-1,1),1.5*gt.reshape(-1,1)),axis=1))
n_params = len(gt)
cython = True

m = Channel(channel_type=channel_type, n_params=n_params, cython=cython)
p = dd.Uniform(lower=lims[:,0], upper=lims[:,1])
s = ChannelStats(channel_type=channel_type)
g = Default(model=m, prior=p, summary=s)

xo = m.gen(gt.reshape(1,-1))
xo_stats = s.calc(xo[0])

In [None]:
from delfi.inference import SNPE

svi = False
prior_norm = True
inf_snpe = SNPE(g, obs=xo_stats, n_components=2, n_hiddens=[25,25],svi=svi,prior_norm=prior_norm)
log, train_data, posterior = inf_snpe.run(n_train=2500, n_rounds=1)

In [None]:
import delfi.utils.io as io
import pickle

io.save(inf_snpe, 'run_1_k.pkl')
pickle.dump(log, open('run_1_log_k.pkl', 'wb'))

In [None]:
posterior = inf_snpe.predict(xo_stats)

In [None]:
plot_pdf(posterior, 
         gt=gt, 
         ticks=True,
         labels_params=labels_params,
         lims=lims);

In [None]:
mean, S = posterior.calc_mean_and_cov()
# m = posterior.xs[0].m
# S = posterior.xs[0].S

prot = ['v_act','v_inact','v_deact','v_ap','v_ramp']
num_protocols = len(prot)

num_samp = 10

# sampling at contour of 1 covariance away from mean (if samples from outside the prior box, contour is at prior box)
x_samp = np.random.randn(n_params,num_samp)
x_samp = np.divide(x_samp,np.linalg.norm(x_samp,axis=0))
x_samp = (np.dot(S,x_samp)).T+mean

# # sample from posterior
# x_samp = posterior.gen(n_samples=num_samp)

# correct for samples outside the prior box
x_samp = np.maximum(x_samp,lims[:,0])
x_samp = np.minimum(x_samp,lims[:,1])

params = np.concatenate((np.array([mean]),x_samp))

fig = plt.figure(figsize = (20,10+num_samp*5))

for i in range(1+num_samp):
    x = m.gen_single(params[i,:])
    for p in range(num_protocols):
        I_K = x[prot[p]]['data']
        t = x[prot[p]]['time']
        num_levels = len(I_K[:,0])
        cm1 = mpl.cm.viridis
        col1 = [cm1(1.*k/num_levels) for k in range(num_levels)]
        
        for j in range(num_levels):
            if i==0:
                plt.subplot(2+num_samp,num_protocols,p+1)
                plt.plot(t, xo[0][0][prot[p]]['data'][j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('observation')
                
                plt.subplot(2+num_samp,num_protocols,num_protocols+p+1)
                plt.plot(t, I_K[j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('mode')
            else:
                plt.subplot(2+num_samp,num_protocols,(i+1)*num_protocols+p+1)
                plt.plot(t, I_K[j,], color = col1[j], lw=2)
                plt.xlabel('time (ms)')
                plt.ylabel('norm. current')
                plt.title('sample '+str(num_samp-i+1))