In [None]:
import delfi.distribution as dd

import matplotlib as mpl
import numpy as np
import pandas as pd
import time

from delfi.generator import Default
from delfi.utils.viz import plot_pdf
from lfimodels.channelomics.Channel import Channel
from lfimodels.channelomics.ChannelStats import ChannelStats
from matplotlib import pyplot as plt
%matplotlib inline

GT = {'k': np.array([9, 25, 0.02, 0.002]), \
      'na': np.array([-35, 9, 0.182, 0.124, -50, -75, 5, -65, 6.2, 0.0091, 0.024])}
LP = {'k': ['qa','tha','Ra','Rb'], 'na': ['tha','qa','Ra','Rb','thi1','thi2','qi','thinf','qinf','Rg','Rd']}

channel_type = 'k' # alternative: 'na'

gt = GT[channel_type]
labels_params = LP[channel_type]
lims = np.sort(np.concatenate((0.5*gt.reshape(-1,1),1.5*gt.reshape(-1,1)),axis=1))
n_params = len(gt)

m = Channel(channel_type=channel_type, n_params=n_params)
p = dd.Uniform(lower=lims[:,0], upper=lims[:,1])
s = ChannelStats(channel_type=channel_type)
g = Default(model=m, prior=p, summary=s)

out = g.gen(1)

In [None]:
from delfi.inference import Basic

inf_basic = Basic(generator=g, n_components=2, n_hiddens=[25, 25])
log, train_data = inf_basic.run(n_train=2500)

## 13.5h runtime, 2500 simulations

In [None]:
import delfi.utils.io as io
import pickle

io.save(inf_basic, 'run_1.pkl')
pickle.dump(log, open('run_1_log.pkl', 'wb'))

In [None]:
xo = m.gen(gt.reshape(1,-1))
xo_stats = s.calc(xo[0])
posterior = inf_basic.predict(xo_stats)

In [None]:
plot_pdf(posterior, 
         gt=gt, 
         ticks=False,
         labels_params=labels_params,
         lims=lims);