# Step 5: Model fitting

In [None]:
import os    
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import dill as pickle
import matplotlib.pyplot as plt
import numpy as np
import os
import theano.tensor as tt
import time

from delfi.neuralnet.NeuralNet import NeuralNet
from delfi.neuralnet.Trainer import Trainer

import sys; sys.path.append('../')
from common import col, svg, plot_pdf

from model.ChannelOmni import ChannelOmni
from model.ChannelOmniStats import ChannelOmniStats as ChannelStats

In [None]:
params = []
stats = []
for root, dirs, files in os.walk("./data/", topdown=False):
    for name in files:
        filename = os.path.join(root, name)
        if filename[-len('theta.npy'):] == 'theta.npy':
            params += [np.load(filename)]
        if filename[-len('stats.npy'):] == 'stats.npy':
            stats += [np.load(filename)] 

params = np.asarray(params).reshape(-1, 8)
stats = np.asarray(stats).reshape(-1, 55)

In [None]:
prior_lims= np.array([
    [0, 1],
    [-10., 10.],
    [-120., 120.],
    [0., 2000],
    [0., 0.5],
    [0, 0.05],
    [0., 0.5],
    [0, 0.05]
])

m = ChannelOmni(third_exp_model=False)
p = dd.Uniform(lower=prior_lims[:,0], upper=prior_lims[:,1])
s = ChannelStats()
g = dg.Default(model=m, prior=p, summary=s)

res = infer.APT(
    g,
    pilot_samples=(params, stats), 
    n_hiddens=[250,250],
    seed=101, 
    n_mades=5,
    prior_norm=True,
    impute_missing=False,
    density='maf', 
    obs=stats[0,:],
    verbose=True
)

In [None]:
log, train_data, posterior = res.run(
    n_train=params.shape[0], 
    n_rounds=1, 
    minibatch=100,
    epochs=1000,
    silent_fail=False,
    proposal='mog',
    val_frac=0.05,
    patience=30,
    monitor_every=1,
    verbose=True
)

In [None]:
!mkdir -p results/net_maf
pickle.dump(res.network, open('results/net_maf/net.pkl', 'wb'))
pickle.dump(res, open('results/net_maf/res.pkl', 'wb'))

## Plot mode samples and calculate correlations

In [None]:
import dill as pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns
import sys; sys.path.append('../')

from box import Box
from common import plot_pdf, samples_nd, col, svg
from support_files.pickle_macos import pickle_load
from tqdm import tqdm

%matplotlib inline

# support
dats = pickle.load(open('support_files/manualfit_params.pkl', 'rb'))  # curve fit values
mats = Box(pickle_load('support_files/pow1_mats_comp.pkl'))

# model
protocols = ['ap', 'act', 'inact', 'deact', 'ramp']

In [None]:
# Mode sample
import cma

prior_lims = np.vstack((p.lower, p.upper)).T

def mode_sample(posterior_, init=(prior_lims[:,1] + prior_lims[:,0])/2, rate=0.01):
    es = cma.CMAEvolutionStrategy(
        init,
        rate,
        {'scaling_of_variables':(prior_lims[:,1]-prior_lims[:,0]),
         'bounds': [list(prior_lims[:,0]), list(prior_lims[:,1])]
        })
    es.optimize(lambda x: -1. * posterior_.eval(x) )
    es.result_pretty()
    return es.best.x

In [None]:
# Inference
errs = []
sams = 1

N_chans = mats['ap']['data'].shape[0]

idxs = []
ccs = []
l2s = []

for idx_chan in tqdm(range(N_chans)):
    try:
        name_gt = mats['ap']['names'][idx_chan]

        trace_gt = {
            'v_act':   {'data' : mats['act']['data'][idx_chan,   6:, 1:].T},
            'v_inact': {'data' : mats['inact']['data'][idx_chan, 6:, 1:].T},
            'v_deact': {'data' : mats['deact']['data'][idx_chan, 6:, 1:].T},
            'v_ap':    {'data' : mats['ap']['data'][idx_chan,    6:, 1:].T},
            'v_ramp':  {'data' : mats['ramp']['data'][idx_chan,  6:, 1:].T},
        }

        stats_gt = s.calc([trace_gt])

        stats_gt_norm = stats_gt
        posterior_normed_gt = res.predict(stats_gt_norm.astype(np.float32))

        samx = mode_sample(posterior_normed_gt).reshape(1,-1)

        params_sam = np.asarray(samx).reshape(-1)
        trace_sam = m.gen_single(params_sam)
        stats_sam = s.calc([trace_sam])

        # concat'ed timeseries
        trace_sam_concat = np.concatenate((
            trace_sam['v_act']['data'].reshape(-1),
            trace_sam['v_inact']['data'].reshape(-1),
            trace_sam['v_deact']['data'].reshape(-1),
            trace_sam['v_ap']['data'].reshape(-1),
            trace_sam['v_ramp']['data'].reshape(-1)
        ))
        trace_gt_concat = np.concatenate((
            trace_gt['v_act']['data'].reshape(-1),
            trace_gt['v_inact']['data'].reshape(-1),
            trace_gt['v_deact']['data'].reshape(-1),
            trace_gt['v_ap']['data'].reshape(-1),
            trace_gt['v_ramp']['data'].reshape(-1)
        ))

        # calculate corrcoef
        corrcoef = np.corrcoef(trace_sam_concat, trace_gt_concat)[0,1]
        ccs.append(corrcoef)

        l2 = np.linalg.norm(stats_gt-stats_sam)
        l2s.append(l2)

        idxs.append(idx_chan)


        protocols = ['ap', 'act', 'inact', 'deact', 'ramp']

        with mpl.rc_context(fname='../.matplotlibrc'):
            f = 0.6
            plt.figure(figsize=(f*20/2.54, f*5/2.54))

            for p, protocol in enumerate(protocols):
                if protocol == 'ap':
                    ds = 10
                else:
                    ds = 100

                plt.subplot(2, 5, p+1)

                if p == 0:
                    plt.gca().set_title('Channel {} · cc={:.3f}'.format(name_gt, corrcoef),
                                        loc='left', pad=15,
                                        fontdict={'fontsize': 8})

                N = trace_gt['v_' + protocol]['data'].shape[0]

                if N == 1:
                    cm = sns.light_palette(col['GT'], N, reverse=True)
                    plt.gca().set_prop_cycle('color', cm)
                else:
                    cm = sns.light_palette(col['GT'], N)
                    plt.gca().set_prop_cycle('color',cm)

                plt.plot(trace_sam['v_' + protocol]['time'][::ds],
                         trace_gt['v_' + protocol]['data'].T[::ds],
                         linewidth=1.);  # usually 1.5

                plt.xticks([])
                plt.yticks([])

                sns.despine(left=True, bottom=True, offset=5)

            for p, protocol in enumerate(protocols):
                if protocol == 'ap':
                    ds = 10
                else:
                    ds = 100

                plt.subplot(2, 5, p+6)

                N = trace_gt['v_' + protocol]['data'].shape[0]

                if N == 1:
                    cm = sns.light_palette(col['CONSISTENT1'], N, reverse=True)
                    plt.gca().set_prop_cycle('color',cm)
                else:
                    cm = sns.light_palette(col['CONSISTENT1'], N)
                    plt.gca().set_prop_cycle('color',cm)

                plt.plot(trace_sam['v_' + protocol]['time'][::ds],
                         trace_sam['v_' + protocol]['data'].T[::ds],
                         linewidth=1., alpha=1.0);  # usually 1.5

                plt.xticks([])
                plt.yticks([])
                sns.despine(left=True, bottom=True, offset=5)

                plt.plot([0., 100.],[-0.1, -0.1], color='k', linewidth=2)
                plt.text(0.0, -0.4, '100ms', fontsize=8)


            import os

            !mkdir -p results/net_maf
            !mkdir -p results/net_maf/svg
            
            plt.savefig('./results/net_maf/svg/{}.svg'.format(idx_chan))
            plt.close()
    except:
        pass

In [None]:
pickle.dump(ccs, open('./results/net_maf/ccs.pkl', 'wb'))
pickle.dump(l2s, open('results/net_maf/l2s.pkl', 'wb'))
pickle.dump(idxs, open('results/net_maf/idxs.pkl', 'wb'))