# Initialise

In [None]:
# print date and time of script execution
import datetime
print("\nNotebook executed at at {} in following directory:".format(datetime.datetime.now()))
%cd /home/luye/workspace/bgcellmodels/GilliesWillshaw/

# print code version (hash of checked out version)
print("\nCode version info:")
!git log -1 # --format="%H"

In [None]:
# Import plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white")
import numpy as np
import pandas as pd
%matplotlib notebook

# Bokeh for interactive plots
from bokeh.io import push_notebook, output_notebook, show as bokeh_show
from bokeh.plotting import figure as bokeh_figure
output_notebook()

# Import our analysis modules
%load_ext autoreload
%autoreload 1
%aimport optimize.bpop_analysis_stn
%aimport optimize.bpop_analysis_pop

resp_analysis = optimize.bpop_analysis_stn
pop_analysis = optimize.bpop_analysis_pop

import pprint
pp = pprint.PrettyPrinter(indent=2)

## Load data

In [None]:
# 100 individuals, 100 generations IBEA dataset
checkpoint_files = [
    ['IBEA_100gen',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_2/opt_checkpoints_cdf893c2.pkl',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_2/opt_checkpoints_cdf893c2_settings_withparams.pkl'],
    ['NSGA2_100gen',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_1/opt_checkpoints_3210b868.pkl',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_1/opt_checkpoints_3210b868_settings.pkl']
]

opt_data = pd.DataFrame(checkpoint_files, columns=['name', 'checkpoints_file', 'settings_file'])

In [None]:
# Choose optimisation to analyse
# opt_data['checkpoints_file'][0] # address by row index
opt_name = 'IBEA_100gen'
idx = opt_data.index[opt_data['name'] == opt_name][0] # addres by value in field

cp_file = opt_data['checkpoints_file'][idx]
settings_file = opt_data['settings_file'][idx]

print("Analysing data from files:\n{}\n{}".format(cp_file, settings_file))

In [None]:
# Load logs file
import cPickle as pickle

# Old pickling method
# with open(checkpoints_file, 'r') as f:
#     checkpoint = pickle.load(f)
#     # old_param_names = pickle.load(f)

# New pickling method
with open(cp_file, "rb") as f:
    while True:
        try:
            checkpoint = pickle.load(f)
        except EOFError:
            break

# Get variables
hof = checkpoint['halloffame']
log = checkpoint['logbook']
pareto_front = checkpoint['paretofront']

In [None]:
# Load settings file
with open(settings_file, 'r') as f:
    opt_settings = pickle.load(f)
    
# pp.pprint(opt_settings)

# Validation protocol responses

- make a random validation protocol
- set up models and params like in opt notebook
- plot all responses like in l5pc_analysis.py
- calculate scores
- select based on scores
    + write to pickle file: selected params so we can instantiate and run bagged opt

In [None]:
# Distributed logging
from common import logutils

# BluePyOpt
import bluepyopt.ephys as ephys

# Custom BluePyOpt modules
from optimize.bpop_cellmodels import StnFullModel, StnReducedModel
from optimize.bpop_protocols_stn import BpopProtocolWrapper
import optimize.bpop_analysis_stn as resp_analysis
import optimize.bpop_features_stn as features_stn

# Physiology parameters
from evalmodel.cellpopdata import StnModel
from evalmodel.proto_common import StimProtocol

CLAMP_PLATEAU = StimProtocol.CLAMP_PLATEAU
CLAMP_REBOUND = StimProtocol.CLAMP_REBOUND
MIN_SYN_BURST = StimProtocol.MIN_SYN_BURST
SYN_BACKGROUND_HIGH = StimProtocol.SYN_BACKGROUND_HIGH

# Adjust verbosity of loggers
logutils.setLogLevel('quiet', ['marasco', 'folding', 'redops', 'stn_protos', 'bpop_ext',
                               'bluepyopt.ephys.parameters', 'bluepyopt.ephys.simulators',
                               'bluepyopt.ephys.efeatures', 'bluepyopt.ephys.recordings'])

## Choose validation protocol

In [None]:
# Protocols to use for optimisation
validation_proto = SYN_BACKGROUND_HIGH

# Collect al frozen mechanisms and parameters required for protocols to work
proto_wrapper = BpopProtocolWrapper.make(validation_proto)
proto_mechs, proto_params = BpopProtocolWrapper.all_mechs_params([proto_wrapper])

## Run on full model

In [None]:
# Run protocols using full model to get responses
full_model = StnFullModel(
                name		= 'StnGillies',
                mechs		= proto_mechs,
                params		= proto_params)

# Set up simulation
nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)
proto_wrapper.ephys_protocol.record_contained_traces = True

full_responses = proto_wrapper.ephys_protocol.run(
                            cell_model		= full_model, 
                            param_values	= {},
                            sim				= nrnsim,
                            isolate			= True)

# Plot results
# resp_analysis.plot_responses(full_responses)

In [None]:
# Make EFEL feature objects
stimprotos_feats = features_stn.make_opt_features([proto_wrapper])
# returns: dict(StimProtocol : dict(feature_name : tuple(efeature, weight)))

# Calculate target values from full model responses
full_responses_dict = {proto_wrapper.ephys_protocol.name: full_responses}
features_stn.calc_feature_targets(stimprotos_feats, full_responses_dict)

## Run on individuals

In [None]:
# Make free parameters (locations etc, not values)
opt_param_names = opt_settings['opt_param_names'] # same order as in individuals
opt_params = opt_settings['free_params'].values()

In [None]:
%%capture

# Create reduced model and get parameters
red_model = StnReducedModel(
                name		= 'StnFolded',
                fold_method	= 'marasco',
                num_passes	= 7,
                mechs		= proto_mechs,
                params		= proto_params + opt_params)

all_ind_responses = []
all_ind_scores = []

for ind in hof:
    
    ind_param_dict = {pname: ind[i] for i,pname in enumerate(opt_param_names)}
    
    # Run with individual's parameters
    ind_responses = proto_wrapper.ephys_protocol.run(
                            cell_model		= red_model, 
                            param_values	= ind_param_dict,
                            sim				= nrnsim,
                            isolate			= True)
    
    all_ind_responses.append(ind_responses)
    # resp_analysis.plot_responses(ind_responses)
                         
    # Calculate feature scores
    # (iterate over dict(StimProtocol : dict(feature_name : tuple(efeature, weight))))
    all_ind_scores.append({})
    for stimproto, featdict in stimprotos_feats.iteritems():
        for efeat, weight in featdict.values():

            # NOTE: score = distance = sum(feat[i] - exp_mean) / N / exp_std  => so exp_std determines weight
            score = efeat.calculate_score(ind_responses)
            dist = score * efeat.exp_std
            
            all_ind_scores[-1][efeat.name] = dist
            print("Score (unweighted) for {} is {}".format(efeat.name, dist))

### Compare voltage responses

In [None]:
def plot_responses(response_dicts):
    """
    Plot response dict for each individual

    @param response_dicts    list(dict<str, TimeVoltageResponse>)
    """

    fig, axes = plt.subplots(len(response_dicts) * len(response_dicts[0]))
    try:
        iter(axes)
    except TypeError:
        axes = [axes]

    for i_ind, responses in enumerate(response_dicts):
        for i_resp, (resp_name, response) in enumerate(sorted(responses.items())):
            axes[i_ind+i_resp].plot(response['time'], response['voltage'], label=resp_name)
            # axes[i_ind+i_resp].set_title(resp_name)

    return fig, axes

In [None]:
f1 = plt.figure()
response = full_responses.items()[0][1]
l1 = plt.plot(response['time'], response['voltage'], color='g')

# Plot individual responses
f2, a2 = plot_responses(all_ind_responses)

### Compare raster plots & PSTH

In [None]:
# Get spike times
import efel
efel.reset()
efel.setThreshold(-20.0) # eFEL default value

def get_peaktimes(tvresp, proto):
    """
    Function to extract peak times from TimeVoltageResponse
    """
    # Prepare trace
    efel_trace = {
        'T': tvresp['time'],
        'V': tvresp['voltage'],
        'stim_start': [proto.response_interval[0]],
        'stim_end': [proto.response_interval[1]],
    }

    # Calculate spike times from response
    values = efel.getFeatureValues([efel_trace], ['peak_time'], raise_warnings=True)
    return values[0]['peak_time']
    

# Get spike times for all individuals
all_ind_spiketimes = []
for i, responses in enumerate(all_ind_responses):
    peak_times = get_peaktimes(responses.items()[0][1], proto_wrapper)
    all_ind_spiketimes.append(peak_times)
    
full_spiketimes = get_peaktimes(full_responses.items()[0][1], proto_wrapper)

In [None]:
from common import analysis
import collections

all_spk = collections.OrderedDict()
# all_spk['orig'] = full_spiketimes
for i,st in enumerate(all_ind_spiketimes):
    all_spk['ind{}'.format(i)] = st

# fig, ax = analysis.plotRaster(all_spk, proto_wrapper.response_interval)
# ax.set_yticklabels(reversed(all_spk.keys()))
# fig.set_figheight(0.2*len(all_spk))

# create X and Y data for scatter plot
spike_labels = list(reversed(all_spk.keys()))
spike_vecs = [all_spk[label] for label in spike_labels]
x_data = np.concatenate(spike_vecs) # X data is concatenated spike times
y_data = np.concatenate([np.zeros_like(vec)+j for j, vec in enumerate(spike_vecs)]) # Y-data is trace IDs

# Filter data within given time interval
timeRange = proto_wrapper.response_interval
mask = (x_data > timeRange[0]) & (x_data < timeRange[1])
x_data = x_data[mask]
y_data = y_data[mask]

# Plot data as scatter plot
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, s=4, c='b', lw=0, marker='.') # marker=',' is thicker
ax.plot(full_spiketimes, np.zeros_like(full_spiketimes)+len(spike_vecs), c='r', lw=0, markersize=2, marker='.')

# Axes
ax.set_xlim(timeRange)
ax.grid(True, axis='x')

plt.yticks(range(len(spike_labels)), spike_labels, rotation='horizontal')
ax.set_xlabel('time (ms)')
ax.set_title('Best individual spike times', loc='center')

fig.subplots_adjust(left=0.15) # Tweak spacing to prevent clipping of tick-labels
fig.set_figheight(0.2*len(all_spk))
