# Synaptic stimulation protocols tests

In [None]:
# Enable interactive plots with backend 'notebook'
%matplotlib notebook
# %matplotlib inline
# Enable plotting of figures after exceptions
import bgcellmodels.common.jupyterutil as jupyterutil
jupyterutil.notebook_show_figs_after_exception()

# print code version (hash of checked out version)
!git log -1 --format="%H"

# print date and time of script execution
import os, datetime
gillies_model_dir = '../../../GilliesWillshaw'
os.chdir(gillies_model_dir)
print("\nNotebook executed at {} in following directory:\n{}".format(
        datetime.datetime.now(), os.getcwd()))

## Import optimization module

In [None]:
# Python standard library
import pickle, pprint
pp = pprint.PrettyPrinter(indent=2)
from bgcellmodels.common import logutils

# BluePyOpt
import bluepyopt.ephys as ephys

# Custom BluePyOpt modules
from cersei_cellmodel import StnCellReduced
from optimize.bpop_protocols_stn import BpopProtocolWrapper
from optimize.bpop_analysis_stn import (
    run_proto_responses, plot_proto_responses, 
    save_proto_responses, load_proto_responses,
    plot_responses
)

import optimize.bpop_features_stn as features_stn
# %load_ext autoreload
# %autoreload 1
# %aimport optimize.bpop_features_stn as features_stn

# Physiology parameters
from evalmodel.cellpopdata import StnModel
from evalmodel.proto_common import StimProtocol
SP = StimProtocol


In [None]:
# Adjust verbosity of loggers
logutils.setLogLevel('quiet', ['marasco', 'folding', 'redops', 
                               'bluepyopt.ephys.parameters', 'bluepyopt.ephys.recordings'])

# Full Model
## Create Protocols

In [None]:
# Protocols to use for optimisation
opt_proto = SP.SYN_BACKGROUND_LOW
proto_kwargs = { # SETPARAM: extra keyword arguments for validation protocol
    'impl_proto': opt_proto,
    'base_seed': 8,
    'num_syn_gpe': 12,
}
stimprotos_wrappers = {
    SP.SYN_BACKGROUND_LOW: BpopProtocolWrapper.make(opt_proto, **proto_kwargs)
}
proto_wrappers = stimprotos_wrappers.values()
opt_stim_protocols = stimprotos_wrappers.keys()
ephys_protos = [p.ephys_protocol for p in proto_wrappers]

# Collect al frozen mechanisms and parameters required for protocols to work
proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params(proto_wrappers)

## (Optional) Inspect model

In [None]:
cell_model = StnCellReduced(
                reduction_method = None,
                name		= 'StnGillies',
                mechs		= proto_mechs,
                params		= proto_params)

# Instantiate model
nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)
param_values = {}

# Instantiate cell model and stimulation protocol
cell_model.freeze(param_values)
ephys_protos[0].pre_model_instantiate(cell_model=cell_model, sim=nrnsim)
cell_model.instantiate(sim=nrnsim)
ephys_protos[0].post_model_instantiate(cell_model=cell_model, sim=nrnsim)

# Start NEURON GUI
from neuron import gui
# Now you can inspect the model in NEURON GUI: 
# Tools > ModelView > Soma > Point Processes > GABAsyn & GLUsyn

## Run Protocols

In [None]:
# Run protocols using full model to get responses
full_model = StnCellReduced(
                reduction_method = None,
                name		= 'StnGillies',
                mechs		= proto_mechs,
                params		= proto_params)

nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

# Simulate protocols
full_responses = {}
for e_proto in ephys_protos:
    
    # Make sure recording functions are executes
    e_proto.record_contained_traces = True
    
    full_responses[e_proto.name] = e_proto.run(
                                        cell_model		= full_model, 
                                        param_values	= {},
                                        sim				= nrnsim,
                                        isolate			= False)

In [None]:
# Plot results
plot_proto_responses(full_responses)
for proto in ephys_protos:
    proto.plot_contained_traces()

## Calculate Feature Targets

In [None]:
# Make EFEL feature objects
stimprotos_feats = features_stn.make_opt_features(proto_wrappers)

# Calculate target values from full model responses
features_stn.calc_feature_targets(stimprotos_feats, full_responses)

# Reduced Model
## Make Protocols

In [None]:
# Protocols to use for optimisation
# same as full model
# opt_stim_protocols = [SP.SYN_BACKGROUND_HIGH]

# Make all protocol data
# proto_wrappers = [BpopProtocolWrapper.make(p) for p in opt_stim_protocols]
# ephys_protos = [p.ephys_protocol for p in proto_wrappers]

# Collect al frozen mechanisms and parameters required for protocols to work
# proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params(proto_wrappers)

# Protocols to use for optimisation
stimprotos_wrappers = {
    SP.SYN_BACKGROUND_LOW: BpopProtocolWrapper.make(opt_proto, **proto_kwargs)
}
proto_wrappers = stimprotos_wrappers.values()
opt_stim_protocols = stimprotos_wrappers.keys()
ephys_protos = [p.ephys_protocol for p in proto_wrappers]

# Collect al frozen mechanisms and parameters required for protocols to work
proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params(proto_wrappers)

## Run Reduced Model

In [None]:
# Create reduced model and get parameters
red_model = StnCellReduced(
                reduction_method='BushSejnowski',
                name='StnFolded',
                mechs=proto_mechs,
                params=proto_params)

nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

# Simulate protocols
red_responses = {}
for e_proto in ephys_protos:
    
    # Make sure recording functions are executes
    e_proto.record_contained_traces = True
    
    # NOTE: isolate=False only if model not previously build
    red_responses[e_proto.name] = e_proto.run(
                                        cell_model		= red_model, 
                                        param_values	= {},
                                        sim				= nrnsim,
                                        isolate			= False)

In [None]:
# Plot results
plot_proto_responses(red_responses)
for proto in ephys_protos:
    proto.plot_contained_traces()

## Calculate Feature Distances

In [None]:
# Adjust exp_std in efeature references
for stimproto, featdict in stimprotos_feats.iteritems():
    for efeat, weight in featdict.values():

        # NOTE: score = distance = sum(feat[i] - exp_mean) / N / exp_std  => so exp_std determines weight
        score = efeat.calculate_score(red_responses[stimproto.name]) # exp_std is 1.0, so score will be numerator
#         efeat.exp_std = score / weight # divide numerator so it has desired weight

        print('Calculates {} score: {}'.format(efeat.name, score))

## Calculate ISI Voltage distance

In [None]:
from optimize.efeatures_fast_numba import calc_ISI_voltage_distance_dt_equal
import efel

def get_ISI_Vdist(tvresp1, tvresp2, proto):
    """
    Function to extract peak times from TimeVoltageResponse
    """
    feat_vals = []
    efel_traces = []
    
    for tvresp in [tvresp1, tvresp2]:
        # Prepare trace
        efel_trace = {
            'T': tvresp['time'],
            'V': tvresp['voltage'],
            'stim_start': [proto.response_interval[0]],
            'stim_end': [proto.response_interval[1]],
        }
        efel_traces.append(efel_trace)

        # Calculate required features
        efel_feats = ['AP_begin_indices', 'AP_end_indices']
        feat_values = efel.getFeatureValues(
            [efel_trace],
            efel_feats,
            raise_warnings=True
        )
        feat_vals.append(feat_values)

    # Compute distance function
    tar_AP_begin    = feat_vals[0][0]['AP_begin_indices']
    tar_AP_end      = feat_vals[0][0]['AP_end_indices']
    tar_Vm          = efel_traces[0]['V'].values
    tar_dt          = efel_traces[0]['T'][1] - efel_traces[0]['T'][0]

    cur_AP_begin    = feat_vals[1][0]['AP_begin_indices']
    cur_AP_end      = feat_vals[1][0]['AP_end_indices']
    cur_Vm          = efel_traces[1]['V'].values # pandas.Series to numpy.ndarray
    cur_dt          = efel_traces[1]['T'][1] - efel_traces[1]['T'][0]

    dt_equal = abs(tar_dt-cur_dt) <= 0.00001
    if not dt_equal:
        raise Exception("ISI voltage distance only implemented for traces calculated with equal time step (dt_old={}, dt_new={}).".format(tar_dt, cur_dt))

    if not all([np.issubdtype(v.dtype, int) for v in tar_AP_begin, tar_AP_end, cur_AP_begin, cur_AP_end]):
        logger.warning("Calculation of AP indices failed")
        efel.reset()
        return float('NaN')

    return calc_ISI_voltage_distance_dt_equal(
                            tar_Vm, cur_Vm, 
                            tar_AP_begin, cur_AP_begin,
                            tar_AP_end, cur_AP_end,
                            proto.response_interval[0], proto.response_interval[1], tar_dt)

resp1 = full_responses.items()[0][1]
resp2 = red_responses.items()[0][1]
dist = get_ISI_Vdist(resp1.items()[0][1], resp2.items()[0][1], proto_wrappers[0])
print dist

## Calculate PSTH

In [None]:
%load_ext autoreload
%autoreload 1
%aimport bgcellmodels.common.analysis
import numpy as np
import matplotlib.pyplot as plt
import efel; efel.reset()

resp_dict = red_responses[opt_stim_protocols[0].name]
TVresp = resp_dict.items()[0][1]
stim_start, stim_end = 300.0, 1800.0

efel_trace = {
    'T': TVresp['time'],
    'V': TVresp['voltage'],
    'stim_start': [stim_start],
    'stim_end': [stim_end],
}

# Get spike times using eFEL
efel_feat = 'peak_time'
feat_vals = efel.getFeatureValues(
    [efel_trace],
    [efel_feat],
    raise_warnings = True
)
resp_spike_times = feat_vals[0][efel_feat]
print(type(resp_spike_times), resp_spike_times)

# Compute psth/rates

bin_width = 50.0
min_spk = 2

psth1 = common.analysis.nrn_sum_psth(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width).as_numpy()

rates1 = common.analysis.nrn_avg_rate_adaptive(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width,
                minsum=min_spk).as_numpy()

print(psth1)
print(rates1)
print('Exptected num bins = (tstop-tstart)/binwidth + 2 = {}'.format(int((stim_end-stim_start)/bin_width) + 2))
print('Got num bins: {}'.format(psth1.size))


psth2 = common.analysis.numpy_sum_psth(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width)

rates2 = common.analysis.numpy_avg_rate_simple(
                [resp_spike_times], 
                stim_start, stim_end,
                bin_width)

print(psth2)
print(rates2)

In [None]:
# Plot psth/rates
plt.figure()
plt.plot(stim_start + np.arange(0, psth.size)*(bin_width/2), psth)

plt.figure()
plt.plot(stim_start + np.arange(0, rates.size)*(bin_width/2), rates)