# Step 3: Manual fits

We will do manual curve fits to the submodel using steady-state and activation curves for pow1 models. The curves were extracted by Chaitanya (using pyNeuroML) and are in a pickle in `support_files/manualfit_curves_complete_df.pkl`. 

This notebook does manual parameter fitting. 

In [None]:
import pickle 
import numpy as np

from scipy.optimize import curve_fit

In [None]:
def sigmoid(x,a,b):
    return 1/(1+np.exp(-a*x+b))

def tau_fun(x,a,b,c,d,e,f):
    y = (x - a)
    return np.exp(b)/(np.exp(-(np.log(c)*y+np.log(d)*y**2)) + np.exp(np.log(e)*y+np.log(f)*y**2))

tau_fit_fcn = tau_fun

In [None]:
# most models have 61 data points
df = pickle.load(open('./support_files/manualfit_curves_complete_df.pkl','rb'))

kvactonly = pickle.load(open('./support_files/kvact_only.pkl', 'rb'))
pow1 = kvactonly['pow1']
pow4 = kvactonly['pow4']

dats = {}

infs = []
taus = []
names = []

for i in range(len(df)):
    row = df.iloc[i]
    if row['Name'].split('/')[2] == 'K' and row['Name'].split('/')[3] in pow1:       
        name = row['Name'].split('/')[3]
        if name not in dats:
            dats[name] = {}
        
        trace = row['Trace'][0].reshape(-1)
        curve_type = row['Name'][-7:-4]
        V = np.linspace(-150,150,61)
        
        try:
            if curve_type == 'inf':
                popt_inf_ls, _ = curve_fit(sigmoid, V, trace)
                dats[name]['inf'] = [popt_inf_ls]
                
            else:
                p0 = [V[np.argmax(trace)], np.log(np.max(trace)), np.exp(0.5), np.exp(0), np.exp(0.5), np.exp(0)]
                
                popt_tau_ls, _ = curve_fit(tau_fun, V, trace, p0=p0)        
                dats[name]['tau'] = [popt_tau_ls]

        except:
            print('err processing')
            print(row)
            taus += [popt_tau_ls]
            continue

Store the results to a pickle.

In [None]:
# dats is a dict (keys are channel names) of dicts (keys are inf/tau)
pickle.dump(dats, open('./support_files/manualfit_params.pkl', 'wb'))

Manual fits successful for

In [None]:
len(dats.keys()) - np.sum([len(dats[k].keys()) != 2 for k in dats.keys()])

channel models. 

## Simulating responses using manually fitted parameters

We simulate manually fitted parameters using the OmniModel and compute correlation coefficients between simulated and observed traces. Plots are generated for each of those simulations. 

In [None]:
import dill as pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from box import Box
from delfi.distribution import TransformedNormal
from delfi.utils.viz import plot_pdf
from model.ChannelOmni import ChannelOmni
from model.ChannelOmniStats import ChannelOmniStats as ChannelStats
from tqdm import tqdm_notebook as tqdm
from support_files.pickle_macos import pickle_load

import sys; sys.path.append('../')
from common import col, svg, samples_nd

%matplotlib inline

In [None]:
# model
m = ChannelOmni()
s = ChannelStats()

In [None]:
mats = Box(pickle_load('./support_files/pow1_mats_comp_lfs.pkl'))

In [None]:
N_chans = mats['ap']['data'].shape[0]  # 372

In [None]:
N_chans

In [None]:
dataset = []

from tqdm import tqdm_notebook as tqdm

for idx_chan in tqdm(range(N_chans)):
    datum = {}
    datum['idx_mats'] = idx_chan
    
    name_gt = mats['ap']['names'][idx_chan]   
    datum['name'] = name_gt

    try:
        params_manual_inf = dats[name_gt + '.mod']['inf'][0]
        params_manual_tau = dats[name_gt + '.mod']['tau'][0]
        params_manual = np.hstack((params_manual_inf, params_manual_tau)).reshape(-1)
        datum['params_manual'] = params_manual
    except:
        print('no params found for {}'.format(name_gt))
        continue
        
    try:
        trace_manual = m.gen_single(params_manual)
        #datum['traces_manual'] = trace_manual

        # gt trace simulated with neuron
        # note that we are introducing an offset to make length of data match
        # using 6: for now, alternatively could do :-6
        trace_gt = {
            'v_act':   {'data' : mats['act']['data'][idx_chan,   6:, 1:].T},
            'v_inact': {'data' : mats['inact']['data'][idx_chan, 6:, 1:].T},
            'v_deact': {'data' : mats['deact']['data'][idx_chan, 6:, 1:].T},
            'v_ap':    {'data' : mats['ap']['data'][idx_chan,    6:, 1:].T},
            'v_ramp':  {'data' : mats['ramp']['data'][idx_chan,  6:, 1:].T},
        }
        #datum['traces_gt'] = trace_gt

        # concat'ed timeseries
        trace_manual_concat = np.concatenate((
            trace_manual['v_act']['data'].reshape(-1),
            trace_manual['v_inact']['data'].reshape(-1),
            trace_manual['v_deact']['data'].reshape(-1),
            trace_manual['v_ap']['data'].reshape(-1),
            trace_manual['v_ramp']['data'].reshape(-1)
        ))
        trace_gt_concat = np.concatenate((
            trace_gt['v_act']['data'].reshape(-1),
            trace_gt['v_inact']['data'].reshape(-1),
            trace_gt['v_deact']['data'].reshape(-1),
            trace_gt['v_ap']['data'].reshape(-1),
            trace_gt['v_ramp']['data'].reshape(-1)
        ))

        # calculate corrcoef
        corrcoef = np.corrcoef(trace_manual_concat, trace_gt_concat)[0,1]
        datum['cc_manual'] = corrcoef

        # calculate L2
        stats_gt = s.calc([trace_gt])
        stats_manual = s.calc([trace_manual])
        l2 = np.linalg.norm(stats_gt-stats_manual)
        datum['l2_manual'] = l2

        dataset.append(datum)
    
    except:
        print('error with : {}'.format(idx_chan))
        continue


In [None]:
# save dataset
pickle.dump(dataset, open('./results/manual_fits_lfs.pkl', 'wb'))

In [None]:
# range of manual params
params_manual = []
ccs = []
for i in range(len(dataset)):
    if dataset[i]['cc_manual'] > 0.9:
        params_manual.append(dataset[i]['params_manual'])

In [None]:
np.array(params_manual).mean(axis=0)

In [None]:
for i in range(len(dataset)):
    if dataset[i]['cc_manual'] > 0.9:

        # plot 
        protocols = ['ap', 'act', 'inact', 'deact', 'ramp']

        with mpl.rc_context(fname='../.matplotlibrc'):
            plt.figure(figsize=(20/2.54, 5/2.54))

            for p, protocol in enumerate(protocols):    
                if protocol == 'ap':
                    ds = 10
                else:
                    ds = 100

                plt.subplot(2, 5, p+1)

                name_gt = dataset[i]['name']
                corrcoef = dataset[i]['cc_manual']
                
                if p == 0:
                    plt.gca().set_title('Channel {} · cc={:.5f}'.format(name_gt, corrcoef), 
                                        loc='left', pad=15,
                                        fontdict={'fontsize': 10})  # · $L_2$-error {:.2f}

                trace_gt = dataset[i]['traces_gt']
                trace_manual = dataset[i]['traces_manual']
                
                N = trace_gt['v_' + protocol]['data'].shape[0]
                if N == 1:
                    plt.gca().set_prop_cycle('color',[plt.cm.Blues_r(i) for i in np.linspace(0., 1, N)])
                else:
                    plt.gca().set_prop_cycle('color',[plt.cm.Blues(i) for i in np.linspace(0.3, 1, N)])

                plt.plot(trace_manual['v_' + protocol]['time'][::ds], 
                         trace_gt['v_' + protocol]['data'].T[::ds], 
                         linewidth=1.);  # usually 1.5

                #plt.xlim([0, mat[-1,0]])
                plt.xticks([])

                plt.yticks([])

                sns.despine(left=True, bottom=True, offset=5)

            for p, protocol in enumerate(protocols):    
                if protocol == 'ap':
                    ds = 10
                else:
                    ds = 100

                plt.subplot(2, 5, p+6)

                N = trace_manual['v_' + protocol]['data'].shape[0]
                if N == 1:
                    plt.gca().set_prop_cycle('color',[plt.cm.Greys_r(i) for i in np.linspace(0., 1, N)])
                else:
                    plt.gca().set_prop_cycle('color',[plt.cm.Greys(i) for i in np.linspace(0.3, 1, N)])

                plt.plot(trace_manual['v_' + protocol]['time'][::ds], 
                         trace_manual['v_' + protocol]['data'].T[::ds], 
                         linewidth=1., alpha=1.0);  # usually 1.5

                #plt.xlim([0, mat[-1,0]])
                plt.xticks([])

                plt.yticks([])

                sns.despine(left=True, bottom=True, offset=5)
                #plt.axis('off')

                plt.plot([0., 100.],[-0.1, -0.1], color='k', linewidth=2)
                plt.text(0.0, -0.4, '100ms', fontsize=8)

            !mkdir -p results
            !mkdir -p results/manual_fit/
            !mkdir -p results/manual_fit/svg
            
            plt.savefig('./results/manual_fit/svg/{}.svg'.format(i))
            plt.close()