# Synaptic stimulation protocols tests

In [None]:
# Enable interactive plots with backend 'notebook'
# %matplotlib notebook
%matplotlib inline

# print code version (hash of checked out version)
!git log -1 --format="%H"

# print date and time of script execution
import datetime
print("\nNotebook executed at {} in following directory:".format(datetime.datetime.now()))
%cd ..

## Import optimization module

In [None]:
# Python standard library
import pickle, pprint
pp = pprint.PrettyPrinter(indent=2)

# Distributed logging
from common import logutils

# BluePyOpt
import bluepyopt.ephys as ephys

# Custom BluePyOpt modules
from optimize.bpop_cellmodels import StnFullModel, StnReducedModel
from optimize.bpop_protocols_stn import BpopProtocolWrapper
from optimize.bpop_analysis_stn import (
    run_proto_responses, plot_proto_responses, 
    save_proto_responses, load_proto_responses,
    plot_responses
)

import optimize.bpop_features_stn as features_stn
# %load_ext autoreload
# %autoreload 1
# %aimport optimize.bpop_features_stn as features_stn

# Physiology parameters
from evalmodel.cellpopdata import StnModel
from evalmodel.proto_common import StimProtocol

CLAMP_PLATEAU = StimProtocol.CLAMP_PLATEAU
CLAMP_REBOUND = StimProtocol.CLAMP_REBOUND
MIN_SYN_BURST = StimProtocol.MIN_SYN_BURST
SYN_BACKGROUND_HIGH = StimProtocol.SYN_BACKGROUND_HIGH

In [None]:
# Adjust verbosity of loggers
logutils.setLogLevel('quiet', ['marasco', 'folding', 'redops', 'bluepyopt.ephys.parameters'])

# Full Model
## Create Protocols

In [None]:
# Protocols to use for optimisation
opt_stim_protocols = [SYN_BACKGROUND_HIGH]

# Make all protocol data
proto_wrappers = [BpopProtocolWrapper.make(p) for p in opt_stim_protocols]
ephys_protos = [p.ephys_protocol for p in proto_wrappers]

# Collect al frozen mechanisms and parameters required for protocols to work
proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params(proto_wrappers)

## Run Protocols

In [None]:
# Run protocols using full model to get responses
full_model = StnFullModel(
                name		= 'StnGillies',
                mechs		= proto_mechs,
                params		= proto_params)

nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

# Simulate protocols
full_responses = {}
for e_proto in ephys_protos:
    
    # Make sure recording functions are executes
    e_proto.record_contained_traces = True
    
    full_responses[e_proto.name] = e_proto.run(
                                        cell_model		= full_model, 
                                        param_values	= {},
                                        sim				= nrnsim,
                                        isolate			= True)

In [None]:
# Plot results
plot_proto_responses(full_responses)
for proto in ephys_protos:
    proto.plot_contained_traces()

## Calculate Feature Targets

In [None]:
# Make EFEL feature objects
stimprotos_feats = features_stn.make_opt_features(proto_wrappers)

# Calculate target values from full model responses
features_stn.calc_feature_targets(stimprotos_feats, full_responses)

# Reduced Model
## Make Protocols

In [None]:
# Protocols to use for optimisation
# same as full model
# opt_stim_protocols = [SYN_BACKGROUND_HIGH]

# Make all protocol data
proto_wrappers = [BpopProtocolWrapper.make(p) for p in opt_stim_protocols]
ephys_protos = [p.ephys_protocol for p in proto_wrappers]

# Collect al frozen mechanisms and parameters required for protocols to work
proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params(proto_wrappers)

## Run Reduced Model

In [None]:
# Create reduced model and get parameters
red_model = StnReducedModel(
                name		= 'StnFolded',
                fold_method	= 'marasco',
                num_passes	= 7,
                mechs		= proto_mechs,
                params		= proto_params)

nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

# Simulate protocols
red_responses = {}
for e_proto in ephys_protos:
    
    # Make sure recording functions are executes
    e_proto.record_contained_traces = True
    
    red_responses[e_proto.name] = e_proto.run(
                                        cell_model		= red_model, 
                                        param_values	= {},
                                        sim				= nrnsim,
                                        isolate			= True)

In [None]:
# Plot results
plot_proto_responses(red_responses)
for proto in ephys_protos:
    proto.plot_contained_traces()

## Calculate Feature Distances

In [None]:
# Adjust exp_std in efeature references
for stimproto, featdict in stimprotos_feats.iteritems():
    for efeat, weight in featdict.values():

        # NOTE: score = distance = sum(feat[i] - exp_mean) / N / exp_std  => so exp_std determines weight
        score = efeat.calculate_score(red_responses[stimproto.name]) # exp_std is 1.0, so score will be numerator
#         efeat.exp_std = score / weight # divide numerator so it has desired weight

        print('Calculates {} score: {}'.format(efeat.name, score))

## Calculate PSTH

In [None]:
%load_ext autoreload
%autoreload 1
%aimport common.analysis
import numpy as np
import matplotlib.pyplot as plt
import efel; efel.reset()

resp_dict = red_responses[opt_stim_protocols[0].name]
TVresp = resp_dict.items()[0][1]
stim_start, stim_end = 300.0, 1800.0

efel_trace = {
    'T': TVresp['time'],
    'V': TVresp['voltage'],
    'stim_start': [stim_start],
    'stim_end': [stim_end],
}

# Get spike times using eFEL
efel_feat = 'peak_time'
feat_vals = efel.getFeatureValues(
    [efel_trace],
    [efel_feat],
    raise_warnings = True
)
resp_spike_times = feat_vals[0][efel_feat]
print(type(resp_spike_times), resp_spike_times)

# Compute psth/rates

bin_width = 50.0
min_spk = 2

psth1 = common.analysis.nrn_sum_psth(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width).as_numpy()

rates1 = common.analysis.nrn_avg_rate_adaptive(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width,
                minsum=min_spk).as_numpy()

print(psth1)
print(rates1)
print('Exptected num bins = (tstop-tstart)/binwidth + 2 = {}'.format(int((stim_end-stim_start)/bin_width) + 2))
print('Got num bins: {}'.format(psth.size))


psth2 = common.analysis.numpy_sum_psth(
                [resp_spike_times], 
                stim_start, stim_end,
                binwidth=bin_width)

rates2 = common.analysis.numpy_avg_rate_simple(
                [resp_spike_times], 
                stim_start, stim_end,
                bin_width)

print(psth2)
print(rates2)

In [None]:
# Plot psth/rates
plt.figure()
plt.plot(stim_start + np.arange(0, psth.size)*(bin_width/2), psth)

plt.figure()
plt.plot(stim_start + np.arange(0, rates.size)*(bin_width/2), rates)