# Initialise

In [None]:
# print date and time of script execution
import datetime
print("\nNotebook executed at at {} in following directory:".format(datetime.datetime.now()))
%cd /home/luye/workspace/bgcellmodels/GilliesWillshaw/

# print code version (hash of checked out version)
print("\nCode version info:")
!git log -1 # --format="%H"

In [None]:
# Import plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white")
import numpy as np
import pandas as pd
%matplotlib notebook

# Bokeh for interactive plots
from bokeh.io import push_notebook, output_notebook, show as bokeh_show
from bokeh.plotting import figure as bokeh_figure
output_notebook()

# Import our analysis modules
%load_ext autoreload
%autoreload 1
%aimport optimize.bpop_analysis_stn
%aimport optimize.bpop_analysis_pop

resp_analysis = optimize.bpop_analysis_stn
pop_analysis = optimize.bpop_analysis_pop

import pprint
pp = pprint.PrettyPrinter(indent=2)

## Load data

In [None]:
# 100 individuals, 100 generations IBEA dataset
checkpoint_files = [
    # ONE-SHOT OPTIMISATIONS
    ['IBEA_100gen',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_2/opt_checkpoints_cdf893c2.pkl',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_2/opt_checkpoints_cdf893c2_settings_withparamsobjs.pkl'],
    ['NSGA2_100gen',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_1/opt_checkpoints_3210b868.pkl',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171103_1/opt_checkpoints_3210b868_settings.pkl'],
    ['IBEA_100gen_BACKGROUND',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171106_1/opt_checkpoints_6b30ea0c.pkl',
     '/home/luye/cloudstore_m/simdata/marasco_folding/optimization_run_20171106_1/opt_checkpoints_6b30ea0c_settings.pkl'],
    ['IBEA_100gen_BACKGROUND_2',
     '/home/luye/Documents/optimization_run_20171108_1/opt_checkpoints_aeee0fa8.pkl',
     '/home/luye/Documents/optimization_run_20171108_1/opt_checkpoints_aeee0fa8_settings_withparamsobjs.pkl'],
    # INCREMENTAL OPTIMISATIONS
    ['INCR_REB_MINSYN',
     '/home/luye/Documents/optimization_run_20171105_1/opt_checkpoints_28d1c4f6.pkl',
     '/home/luye/Documents/optimization_run_20171105_1/opt_checkpoints_28d1c4f6_settings.pkl'],
    ['INCR_BACKGROUND',
     '/home/luye/Documents/optimization_run_20171106_2/opt_checkpoints_7b14ca54.pkl',
     '/home/luye/Documents/optimization_run_20171106_2/opt_checkpoints_7b14ca54_settings.pkl'],
]

opt_data = pd.DataFrame(checkpoint_files, columns=['name', 'checkpoints_file', 'settings_file'])

In [None]:
# Choose optimisation to analyse
opt_name = 'IBEA_100gen' # SETPARAM: optimisation run to analyse

# from ipywidgets import interact, interactive, fixed, interact_manual
# import ipywidgets as widgets
# def f(x):
#     global opt_name
#     opt_name = x
# interact(f, x=opt_data['name'])

In [None]:
idx = opt_data.index[opt_data['name'] == opt_name][0] # addres by value in field
cp_file = opt_data['checkpoints_file'][idx]
settings_file = opt_data['settings_file'][idx]

print("Analysing data for {} from files:\n{}\n{}".format(opt_name, cp_file, settings_file))

In [None]:
# Load logs file
import cPickle as pickle

# Old pickling method
# with open(checkpoints_file, 'r') as f:
#     checkpoint = pickle.load(f)
#     # old_param_names = pickle.load(f)

# New pickling method
with open(cp_file, "rb") as f:
    while True:
        try:
            checkpoint = pickle.load(f)
        except EOFError:
            break

# Get variables
hof = checkpoint['halloffame']
log = checkpoint['logbook']
pareto_front = checkpoint['paretofront']

In [None]:
# Load settings file
with open(settings_file, 'r') as f:
    opt_settings = pickle.load(f)
    
pp.pprint(opt_settings)

# Validation protocol responses

- make a random validation protocol
- set up models and params like in opt notebook
- plot all responses like in l5pc_analysis.py
- calculate scores
- select based on scores
    + write to pickle file: selected params so we can instantiate and run bagged opt

In [None]:
# Distributed logging
from bgcellmodels.common import logutils

# BluePyOpt
import bluepyopt.ephys as ephys

# Custom BluePyOpt modules
from optimize.bpop_cellmodels import StnFullModel, StnReducedModel
from optimize.bpop_protocols_stn import BpopProtocolWrapper
import optimize.bpop_analysis_stn as resp_analysis
import optimize.bpop_features_stn as features_stn

# Physiology parameters
from bgcellmodels.cellpopdata import StnModel
from evalmodel.proto_common import StimProtocol

CLAMP_PLATEAU = StimProtocol.CLAMP_PLATEAU
CLAMP_REBOUND = StimProtocol.CLAMP_REBOUND
MIN_SYN_BURST = StimProtocol.MIN_SYN_BURST
SYN_BACKGROUND_HIGH = StimProtocol.SYN_BACKGROUND_HIGH
VALIDATION = StimProtocol.SYN_BACKGROUND_LOW

# Adjust verbosity of loggers
logutils.setLogLevel('quiet', ['marasco', 'folding', 'redops', 'stn_protos', 'bpop_ext',
                               'bluepyopt.ephys.parameters', 'bluepyopt.ephys.simulators',
                               'bluepyopt.ephys.efeatures', 'bluepyopt.ephys.recordings'])

## Choose validation protocol

In [None]:
# First add optimised protocols
validation_protos = [StimProtocol.from_str(proto) for proto in opt_settings['protos_feats'].keys()]

# Add validation protocol
val_proto = VALIDATION
validation_protos += [val_proto] # SETPARAM: choose the validation protocol
opt_protos = [sp for sp in validation_protos if sp!=val_proto]

# SETPARAM: copy possible protocol parameters from optimisation notebook
proto_kwargs = {
    CLAMP_REBOUND: {},
    MIN_SYN_BURST: {},
    VALIDATION: {
        'impl_proto': val_proto,
        'base_seed': 11,
    }
}

# Make all protocol data
stimprotos_wrappers = {
    stim_proto: BpopProtocolWrapper.make(stim_proto, **proto_kwargs[stim_proto]) 
        for stim_proto in validation_protos
}

# Collect al frozen mechanisms and parameters required for protocols to work
all_proto_mechs, all_proto_params = BpopProtocolWrapper.all_mechs_params(stimprotos_wrappers.values())

## Run on full model

In [None]:
# Run protocols using full model to get responses
full_model = StnFullModel(
                name		= 'StnGillies',
                mechs		= all_proto_mechs,
                params		= all_proto_params)

# Set up simulation
nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

# full_responses = wrapper.ephys_protocol.run(
#                         cell_model		= full_model, 
#                         param_values	= {},
#                         sim				= nrnsim,
#                         isolate			= True)

full_resp_dict = resp_analysis.run_proto_responses(
                        full_model,
                        [p.ephys_protocol for p in stimprotos_wrappers.values()])

# Put in same format as evaluator.evaluate(individual)
full_responses = {}
for stimproto, responses in full_resp_dict.iteritems():
    full_responses.update(responses)

In [None]:
# Plot full model responses
resp_analysis.plot_proto_responses(full_resp_dict)

## Make objectives

Make features + weights to put into singleton objectives.

In [None]:
# 1. Optimised objectives ######################################################
opt_objectives = opt_settings['objectives_ordered']
if isinstance(opt_objectives, dict):
    opt_features = [o.features[0] for o in opt_objectives.values()]
else:
    opt_features = [o.features[0] for o in opt_objectives]

# 2. Validation objectives #####################################################

# Make EFEL feature objects = dict(StimProtocol : dict(feature_name : tuple(efeature, weight)))
stimprotos_feats = features_stn.make_opt_features([stimprotos_wrappers[p] for p in [val_proto]])

# Convert to list(EFeature) and list(weight)
val_features, val_weights = features_stn.all_features_weights(stimprotos_feats.values())

# Calculate targets only for validation objectives (already saved in exp_mean for optimised objectives)
features_stn.calc_feature_targets(stimprotos_feats, full_resp_dict)

# 3. Merge them ################################################################
all_features = opt_features + val_features

# Get training & validation errors

In [None]:
# Make free parameters (locations etc, not values)
opt_param_names = opt_settings['opt_param_names'] # same order as in individuals
opt_free_params = opt_settings['free_params'].values()

# Create reduced model and get parameters
red_model = StnReducedModel(
                name		= 'StnFolded',
                fold_method	= 'marasco',
                num_passes	= 7,
                mechs		= all_proto_mechs,
                params		= all_proto_params + opt_free_params)

# Make validation objectives calculators
all_objectives = [ephys.objectives.SingletonObjective(f.name, f) for f in all_features]
fitcalc = ephys.objectivescalculators.ObjectivesCalculator(all_objectives)

# Make evaluator to evaluate model using objective calculator
opt_ephys_protos = {s.name: w.ephys_protocol for s,w in stimprotos_wrappers.iteritems()}

nrnsim = ephys.simulators.NrnSimulator(dt=0.025, cvode_active=False)

from optimize.bpop_extensions import CellEvaluatorCaching
evaluator = CellEvaluatorCaching(
                    cell_model			= red_model,
                    param_names			= opt_param_names, # fitted parameters (same order as saved individuals!)
                    fitness_protocols	= opt_ephys_protos,
                    fitness_calculator	= fitcalc,
                    sim					= nrnsim,
                    isolate_protocols	= True)

# Save responses that are evaluated
import os.path
folder, tail = os.path.split(cp_file)
newtail = 'response_ind_'
evaluator.set_responses_filename(folder, newtail)
ind_resp_file_prefix = os.path.join(folder, newtail)

In [None]:
from datetime import datetime
from ipyparallel import Client
import socket

# Create a connection to the server
rc = Client() # if profile specified: searches JSON file in ~/.ipython/profile_name
print('Using ipyparallel with %d engines' % len(rc))

# Create a view of all the workers (ipengines)
lview = rc.load_balanced_view()
host_names = lview.apply_sync(socket.gethostname) # run gethostname on all ipengines
if isinstance(host_names, str):
    host_names = [host_names]

def map_func(func, *sequences):
    """
    @param sequences    sequences of matching length with sequence i
                        containing the i-th argument for func
    """
    start_time = datetime.now()
    ret = lview.map_sync(func, *sequences)
    return ret

In [None]:
# Evaluate validation protocols using all individuals
eval_inds = hof
all_ind_scores = map_func(evaluator.evaluate_with_lists, eval_inds, range(len(eval_inds)))

In [None]:
# Evaluate the non-optimised model (all default params)
unopt_ind = [opt_settings['free_params'][p].value for p in opt_param_names]
unopt_ind
unopt_scores = map_func(evaluator.evaluate_with_lists, [unopt_ind], ['unopt'])

In [None]:
# Shape individuals into matrix with each individual's scores represented by a row
ind_scores_as_rows = np.array(all_ind_scores + unopt_scores)
obj_names = [o.name for o in all_objectives]
ind_scores_data = pd.DataFrame(ind_scores_as_rows, columns=obj_names)

# Select scores on validation protocol only
val_obj_names = [o.name for o in all_objectives if val_proto.name in o.name]
ind_val_data = ind_scores_data[val_obj_names]

# Select scores on optimized protocols only
opt_obj_names = [o.name for o in all_objectives if not o.name in val_obj_names]
ind_opt_data = ind_scores_data[opt_obj_names]

print("\nObjectives used in optimisation:")
pp.pprint(opt_obj_names)
print("\nObjectives used in validation:")
pp.pprint(val_obj_names)

print("\nScores for first individuals in hall of fame:")
print(ind_opt_data.iloc[0:10])

## Plot validation scores

In [None]:
# Calculate training & validation errors (scores)
sum_opt_scores = ind_opt_data.sum(axis=1)
sum_val_scores = ind_val_data.sum(axis=1)
i_nofail = np.where(sum_val_scores < 250.0)[0]


xdata = sum_opt_scores.index[i_nofail]
ydata = sum_opt_scores.values[i_nofail]

fig, ax = plt.subplots()
sns.barplot(x=xdata, y=ydata)
# ax.bar(sum_opt_scores.index, sum_opt_scores.values)
ax.set_title('Sum of optimisation errors')
ax.set_ylabel('$\sum w_i e_i$')
ax.set_yscale('log')
# ax.set_ylim((int(min(sum_opt_scores.values)), max(sum_opt_scores.values)))
# yticks = np.arange(int(min(sum_opt_scores.values)), max(sum_opt_scores.values), .1)
# ax.set_yticks(yticks, minor=True)
# ax.tick_params(reset=True, axis='y', which='minor', direction='in', width=2.0)
# import matplotlib.ticker as ticker
# ax.yaxis.set_minor_formatter(ticker.FormatStrFormatter('%0.1f'))
# ysubticks = np.arange(2.1, 9, .1)
# ax.set_yscale('log', ysubs=ysubticks)



# Plot sum of validation errors

xdata = sum_val_scores.index[i_nofail]
ydata = sum_val_scores.values[i_nofail]

fig, ax = plt.subplots()
sns.barplot(x=xdata, y=ydata)
ax.set_title('Sum of validation errors')
ax.set_ylabel('$\sum w_i e_i$')
# ax.set_yticks(np.arange(1, 250, 1))
ax.set_yscale('log')
ax.set_ylim((5, 35))

In [None]:
# Plot each validation error for each individual
f, all_ax = plt.subplots(len(val_obj_names), 1, sharex=True)
for i, ax in enumerate(all_ax):
    xdata = ind_val_data.index[i_nofail]
    ydata = ind_val_data[ind_val_data.columns[i]].values[i_nofail]
    sns.barplot(x=xdata, y=ydata, ax=ax)
    # ax.bar(ind_val_data.index, ydata)
    ax.set_title(ind_val_data.columns[i])
    ax.set_ylabel('$w_i e_i$')
    ax.set_ylim((0, 15))

In [None]:
# Select individuals that perform best on validaton protocol
target_obj = val_obj_names[1] # SETPARAM: objective to sort on
sorted_index = np.argsort(ind_val_data[target_obj].values)
best_ids_validated = sorted_index

print("Objective used for ranking: {}".format(target_obj))
print('\nIndices of best ranked individuals:\n{}'.format(sorted_index))


In [None]:
# Load responses of best individuals
num_best = len(hof)
ind_responses = []
for ind_idx in range(len(hof)):
    
    # Load saved responses
    resp_file = ind_resp_file_prefix + str(ind_idx)
    with open(resp_file, 'rb') as f:
        responses = pickle.load(f)
    
    ind_responses.append(responses)

In [None]:
# Add response of default model
with open(ind_resp_file_prefix+'unopt', 'rb') as f:
    ind_responses.append(pickle.load(f))

## Compare raster plots

In [None]:
# Get spike times
import efel

def get_peaktimes(tvresp, proto):
    """
    Function to extract peak times from TimeVoltageResponse
    """
    efel.reset()
    efel.setThreshold(-20.0) # eFEL default value
    
    # Prepare trace
    efel_trace = {
        'T': tvresp['time'],
        'V': tvresp['voltage'],
        'stim_start': [proto.response_interval[0]],
        'stim_end': [proto.response_interval[1]],
    }

    # Calculate spike times from response
    values = efel.getFeatureValues([efel_trace], ['peak_time'], raise_warnings=True)
    return values[0]['peak_time']

In [None]:
# Get spike times for full model
val_response = full_resp_dict[val_proto.name]
val_wrapper = stimprotos_wrappers[val_proto]
full_spiketimes = get_peaktimes(val_response.items()[0][1], val_wrapper)

# Get spike times for all individuals
ind_spiketimes = []
for i_hof in range(len(hof) + 1):
    
    # Load saved responses
#     resp_file = ind_resp_file_prefix + str(i_hof)
#     with open(resp_file, 'rb') as f:
#         responses = pickle.load(f)
        
    # Get previously loaded response
    responses = ind_responses[i_hof]
    
    # Calculate spike times
    val_response = next((v for k,v in responses.items() if val_proto.name in k))
    peak_times = get_peaktimes(val_response, val_wrapper)
    ind_spiketimes.append(peak_times)

In [None]:
from bgcellmodels.common import analysis
import collections

all_spk = collections.OrderedDict()
for i_hof in best_ids_validated:
    # Save using index in hof (fall of fame)
    if i_hof >= len(hof):
        label = 'def'
    else:
        label = 'ind{}'.format(i_hof)
    all_spk[label] = ind_spiketimes[i_hof]

# create X and Y data for scatter plot
spike_labels = list(reversed(all_spk.keys()))
spike_vecs = [all_spk[label] for label in spike_labels]
x_data = np.concatenate(spike_vecs) # X data is concatenated spike times
y_data = np.concatenate([np.zeros_like(vec)+j for j, vec in enumerate(spike_vecs)]) # Y-data is trace IDs

# Filter data within given time interval
timeRange = val_wrapper.response_interval
mask = (x_data > timeRange[0]) & (x_data < timeRange[1])
x_data = x_data[mask]
y_data = y_data[mask]

# Plot data as scatter plot
fig, ax = plt.subplots()
ax.scatter(x_data, y_data, s=4, c='b', lw=0, marker='.') # marker=',' is thicker
ax.plot(full_spiketimes, np.zeros_like(full_spiketimes)+len(spike_vecs), c='r', lw=0, markersize=2, marker='.')

# Axes
ax.set_xlim(timeRange)
ax.grid(True, axis='x')

plt.yticks(range(len(spike_labels)), spike_labels, rotation='horizontal')
ax.set_xlabel('time (ms)')
ax.set_title('Validation protocol spike trains', loc='center')

fig.subplots_adjust(left=0.15) # Tweak spacing to prevent clipping of tick-labels
fig.set_figheight(0.2*len(all_spk))


# Inspect best individuals

In [None]:
def plot_responses(ind_responses, full_responses=None, plot_kws=None):
    """
    Plot responses dict returned by an EphysProtocol

    @param	proto_responses		dict {str: responses.TimeVoltageResponse}
    """

    fig, axes = plt.subplots(len(ind_responses))
    try:
        iter(axes)
    except TypeError:
        axes = [axes]

    if plot_kws is None:
        plot_kws = {}
    if full_responses is None:
        full_responses = {}
    
    for index, resp_name in enumerate(sorted(ind_responses.keys())):
        ax = axes[index]
        ax.set_title(resp_name, color='g')
        
        # Mark validation protocol
        if val_proto.name in ax.title.get_text():
            ax.title.set_color('r')
        
        # Plot target responses (lowest z-index)
        if resp_name in full_responses:
            response = full_responses[resp_name]
            axes[index].plot(response['time'], response['voltage'], 'r-', linewidth=1, label='full')
        
        # Plot individual response
        response = ind_responses[resp_name]
        axes[index].plot(response['time'], response['voltage'], label='reduced', **plot_kws)

    fig.tight_layout()

    return fig, axes

## Compare Vm of best optimised

In [None]:
ranks_toplot = range(3)
for rank in ranks_toplot:
    ind_id = rank
    
    print("Hall of fame index: {}".format(ind_id))
    print("Validation ranking index: {}\n".format(np.where(best_ids_validated==ind_id)[0]))
    
    # Print model parameters
    ind_param_dict = {pname: hof[ind_id][i] for i,pname in enumerate(opt_param_names)}
    pp.pprint(ind_param_dict)
    
    # Print errors
    print("\nOptimisation errors:")
    print(ind_opt_data.iloc[ind_id])
    print("\nValidation errors:")
    print(ind_val_data.iloc[ind_id])
    
    # Plot optimised & validation responses
    fig, axes = plot_responses(ind_responses[ind_id], full_responses)

## Compare Vm of best validated

In [None]:
ranks_toplot = range(3)
for val_rank in ranks_toplot:
    ind_id = best_ids_validated[val_rank]
    
    print("Hall of fame index: {}".format(ind_id))
    print("Validation ranking index: {}\n".format(val_rank))
    
    # Print model parameters
    ind_param_dict = {pname: hof[ind_id][i] for i,pname in enumerate(opt_param_names)}
    pp.pprint(ind_param_dict)
    
    # Print errors
    print("\nOptimisation errors:")
    print(ind_opt_data.iloc[ind_id])
    print("\nValidation errors:")
    print(ind_val_data.iloc[ind_id])
    
    # Plot optimised & validation responses
    fig, axes = plot_responses(ind_responses[ind_id], full_responses)