# Inference for the Cone

- [Define Target and Stimulus](#Target-and-Stimulus)
- [Create the cone model](#Cell)
- [Select loss function and parameters](#Optimizer)
- [Run inference](#Inference)
- Plots inference results:
    - [Plot results](#Plot-results)
    - [Posterior](#Posterior)
    - [Best sample(s)](#Best-sample(s))

## Select mode: full_inference  / load_only / test

- *full_inference*
    - Runs the whole inference.
    - Takes a long time
- *load_only*
    - Will not generate new samples, but loads the data generated for the paper.
- *test*
    - Runs the whole inference, but with fewer samples.
    - Illustrates how the inference works, without spending to much CPU power and time.
    - However, it might lead to problems, because too few samples are generated leading to bad inference.
    - Don't use these results in subsequent steps.

In [None]:
#inference_mode = 'test'
#inference_mode = 'full_inference'
inference_mode = 'load_only'

# Imports

In [None]:
import importlib

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Target and Stimulus

In [None]:
predur = 5.0

In [None]:
stim_t_rng = (1, 32)

In [None]:
# Load experimental data
data_folder = os.path.join('..', 'step0_preprocess_iGluSnFR_data', 'data_preprocessed')
target_dF_F = pd.read_csv(os.path.join(data_folder, 'ConeData_ReleaseMeanData.csv'))
stimulus    = pd.read_csv(os.path.join(data_folder, 'ConeData_stimulus_time_and_amp_corrected.csv'))

In [None]:
fig, axs = plt.subplots(2,1,figsize=(12,4),subplot_kw=dict(xlim=stim_t_rng))
stimulus.plot(x='Time', ax=axs[0])
target_dF_F.plot(x='Time', ax=axs[1])
plt.show()

# Cell

## Create model

In [None]:
params_default = data_utils.load_var(os.path.join('cell_params', 'cone_cell_params_default.pkl'))
params_unit = data_utils.load_var(os.path.join('cell_params', 'cone_cell_params_unit.pkl'))

In [None]:
import retsim_cells
importlib.reload(retsim_cells)

# Create cell.
cell = retsim_cells.Cone(
    predur=predur, t_rng=(1.9,2.2),
    params_default=params_default, params_unit=params_unit,
    stimulus=stimulus, stim_type='Light',
    cone_densfile       = 'dens_cone_optimize_cone.n',
    nval_file           = 'nval_cone_optimize_cone.n',
    chanparams_file     = 'chanparams_cone_optimize_cone.n',
    expt_file_list      = ['optimize_cones'],
    expt_base_file_list = [os.path.join('retsim_files', 'expt_optimize_cones.cc')],
    retsim_path=os.path.abspath(os.path.join('..', 'neuronc', 'models', 'retsim')) + '/'
)

In [None]:
cell.create_retsim_expt_file(verbose=False) # Create c++ file.
!(cd {cell.retsim_path} && make) # Compile c++ file.

In [None]:
cell.init_retsim(verbose=True)

# Tests

Skip this step if you trust the model.

In [None]:
cell.rec_type = 'test'
%time rec_data, rec_time, rec_stim = cell.run(plot=True, verbose=True)

## Test number of compartements

In [None]:
print('Should be the same')
for c_rm in [1, 10, 20, 40, 100]:
    cell.init_retsim(sim_params={'c_rm': c_rm})
    print('\t',cell.comp_data['dia'].values)

for c_ri in [1, 100, 200]:
    cell.init_retsim(sim_params={'c_ri': c_ri})
    print('\t',cell.comp_data['dia'].values)
    
print('Should be the same size')
for c_cm in [0.9, 1, 2]:
    cell.init_retsim(sim_params={'c_cm': c_cm})
    print('\t',cell.comp_data['dia'].values)
    
print('Should not be the same')
for cpl_axon in [0.002, 1]:
    cell.init_retsim(sim_params={'cpl_axon': cpl_axon})
    print('\t',cell.comp_data['dia'].values)
    
cell.init_retsim()
print('Default:')
print('\t',cell.comp_data['dia'].values)

## Test parameters in model

Note that channels might not have an influence for the given voltage range (e.g. if the stay below threshold).
Also effect might be small for some parameters, maybe they become more important when other parameters changes the overall dynamics, so don't discard them if you don't know yet.

In [None]:
import retsim_params_test
importlib.reload(retsim_params_test);

cell.update_t_rng((1.95, 2.15))
cell.rec_type = 'optimize'

all_equal_params, all_close_params = retsim_params_test.test_if_params_are_used(
    cell=cell, params=params_default, Vm_name='Vm Soma', rate_name='rate Cone'
)

In [None]:
assert len(all_equal_params) == 0
assert len(all_close_params) == 0

# Optimizer

In [None]:
if inference_mode == 'load_only':
    output_folder = 'optimize_cone_submission2'
elif inference_mode in ['test', 'full_inference']:
    output_folder = 'optimize_cone'
else:
    raise NotImplementedError()
    
print('Inference:', inference_mode, '--> Folder:', output_folder)

## Parameters

In [None]:
# Load optimize paramters.
opt_params_default = data_utils.load_var(os.path.join('cell_params', 'cone_opt_params_default.pkl'))
opt_params_range   = data_utils.load_var(os.path.join('cell_params', 'cone_opt_params_range.pkl'))

In [None]:
import param_funcs
importlib.reload(param_funcs);

params = param_funcs.Parameters(p_range=opt_params_range, p_default=opt_params_default)

In [None]:
params.plot()

## Optimizer

In [None]:
import optim_funcs
importlib.reload(optim_funcs)

cell.predur = predur

optim = optim_funcs.Optimizer(
    cell=cell, params=params, t_rng=stim_t_rng, output_folder=output_folder,
    raw_data_labels       = ['rate Cone', 'Vm Soma'],
    raw2model_data_labels = {'rate Cone': 'rate', 'Vm Soma': 'Vm'},
)

%time optim.init_rec_data(allow_loading=True, force_loading=False, verbose=True)

## Loss

In [None]:
import loss_funcs
importlib.reload(loss_funcs);

loss = loss_funcs.LossOptimizeCell(
    target=target_dF_F, rec_time=optim.get_rec_time(), t_drop=0.5+optim.get_t_rng()[0],
    loss_params='Cone eq and range', absolute=False, mode='gauss'
)
optim.loss = loss

In [None]:
loss.plot_loss_params()

In [None]:
loss_output = loss.calc_loss(optim.rec_data['Data'], plot=True, verbose=True);

## Prior

In [None]:
from TruncatedNormal import TruncatedNormal

lower = np.zeros(len(params.p_names))
upper = np.ones(len(params.p_names))

mean  = [optim.params.sim_param2opt_param(opt_params_default[param], param) for param in params.p_names]
std   = [0.3                                                                for param in params.p_names]

prior = TruncatedNormal(m=np.array(mean), S=np.diag(np.array(std)**2), lower=lower, upper=upper)

In [None]:
import plot_sampling_dists

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[],
    lbs=np.full(params.p_N, -0.5), ubs=np.full(params.p_N, 1.5)
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=False)

## Save optimization parameters

In [None]:
data_utils.save_var(opt_params_default, os.path.join('optim_data', output_folder, 'opt_params_default.pkl'))
data_utils.save_var(opt_params_range, os.path.join('optim_data', output_folder, 'opt_params_range.pkl'))
data_utils.save_var(loss, os.path.join('optim_data', optim.output_folder, 'loss.pkl'))
data_utils.save_var(params, os.path.join('optim_data', optim.output_folder, 'params.pkl'))

# Inference

In [None]:
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"

import gpu_test
assert gpu_test.run(verbose=False)

In [None]:
pseudo_obs_dim = np.argmax(list(optim.model_output2dict({}, 0, rec_data=optim.rec_data['Data'])['loss'].keys()) =='iGluSnFR')
print(pseudo_obs_dim)

In [None]:
if inference_mode=='test':
    n_samples_per_round = 40
    max_rounds = 2
    gen_minibatch = 20
    print('WARNING: Test mode selected. Results will differ from paper data!')
else:
    n_samples_per_round = 2000
    max_rounds = 4
    gen_minibatch = 200
    
print(n_samples_per_round, '*', max_rounds, 'samples')

In [None]:
import delfi_funcs

delfi_optim = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=30,
    gen_minibatch=gen_minibatch, scalar_loss=False,
    post_as_truncated_normal=True,
)

if not(inference_mode=='load_only'):
    delfi_optim.init_SNPE(
        verbose                 = False,
        pseudo_obs_dim          = pseudo_obs_dim,
        pseudo_obs_n            = 1,
        kernel_bandwidth        = 0.25,
        kernel_bandwidth_perc   = 20,
        pseudo_obs_use_all_data = False,
        n_components            = 1,
    )

    delfi_optim.run_SNPE(
        max_duration_minutes  = 60*24,
        max_rounds            = max_rounds,
        n_samples_per_round   = n_samples_per_round,
        continue_optimization = continue_optimization,
        load_init_tds         = load_init_tds,
    )

# Plot results

## Load data

In [None]:
tds                  = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'tds.pkl'))
inf_snpes            = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'inf_snpes.pkl'))
sample_distributions = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'sample_distributions.pkl'))
logs                 = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'logs.pkl'))
pseudo_obs           = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'pseudo_obs.pkl'))
kernel_bandwidths    = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'kernel_bandwidths.pkl'))

# Split prior and posteriors.
prior = sample_distributions[0]
posteriors = sample_distributions[1:]

In [None]:
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    concat_traces=True, list_traces=False, return_sort_idx=True,
    return_n_samples=True, verbose=True
)

d_min_idx = d_sort_index[0]
print('\nd_min = {:.5f}'.format(samples['loss']['total'][d_min_idx]))

## Plot training data

In [None]:
import plot_obs_and_bw

plot_obs_and_bw.plot(
    [pseudo_obs_i[0,pseudo_obs_dim] for pseudo_obs_i in pseudo_obs],
    [kernel_bandwidth_i[pseudo_obs_dim] for kernel_bandwidth_i in kernel_bandwidths],
)
plot_obs_and_bw.plot_logs(logs)

In [None]:
import plot_iws
plot_iws.plot_iws(tds, pseudo_obs_dim, pseudo_obs=None, kernel_bandwidths=None)

In [None]:
import plot_samples
plot_samples.plot_execution_time(samples, lines=n_samples)

In [None]:
plot_samples.plot_loss_rounds(samples, n_samples)

## Plot samples

In [None]:
plot_samples.plot_best_samples(samples=samples, time=optim.get_rec_time(), loss=optim.loss, n=1)
plot_samples.plot_best_samples(samples=samples, time=optim.get_rec_time(), loss=optim.loss, n=20)

In [None]:
import plot_peaks

xlims = [optim.get_t_rng()] + [(1.5, 2.5), (4,6), (10,14), (13.5, 17), (12, 13), (22, 25), (27, 31)]

trace_peaks = plot_peaks.compare_peaks_in_traces(
    trace_list=[loss.target, loss.rate2best_iGluSnFR_trace(samples['rate'][d_min_idx])[0]],
    time_list=optim.loss.target_time,
    color_list=['r', 'b'],
    label_list=['target', 'fit'],
    params_dict_list=[{'height_pos': 0.2, 'prom': 0.16}, {'height_pos': 0.1, 'prom': 0.05}],
    xlims=xlims,
    base_trace_i=0,
    ignore_rec_times=[(15, 24)], # Ignore noisy parts.
)

# Posterior

In [None]:
post_data_folder = os.path.join('optim_data', optim.output_folder, 'post_data')
data_utils.make_dir(post_data_folder)

In [None]:
final_posterior = posteriors[-1]

## Plot Posteriors

In [None]:
import plot_sampling_dists

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=delfi_optim.prior, posterior_list=posteriors,
    lbs=prior.lower, ubs=prior.upper
)

PP.plot_sampling_dists_1D(
    opt_x=True, params=None, plot_peak_lines=False, figsize=(12,8),
    opt_samples=np.concatenate([tds_i[0] for tds_i in tds])[d_sort_index[:10],:]
)

In [None]:
PP.plot_correlation(final_posterior.S)

## Sample from posterior

In [None]:
import analyze_posterior_utils

if inference_mode=='test':
    post_n_samples = 20
else:
    post_n_samples = 200

post_opt_params = analyze_posterior_utils.get_samples(
    posterior=final_posterior, n_samples=post_n_samples, seed=777,
    plot=True, prior=prior, params=params, plot_opt_x=False,
)

In [None]:
load_rec_data_list = (inference_mode=='load_only')

post_model_output_list = analyze_posterior_utils.gen_or_load_samples(
    optim=optim, opt_params=post_opt_params, load=load_rec_data_list,
    filename=os.path.join(post_data_folder, 'post_model_output_list.pkl'), 
)

assert len(post_model_output_list) == post_n_samples

In [None]:
post_samples = optim.stack_model_output_list(post_model_output_list)
all_samples = optim.stack_model_output_list([samples, post_samples])
all_samples_sort_idx = np.argsort(all_samples['loss']['total'])

### Plot

In [None]:
import print_num_failed
post_success_list = print_num_failed.print_num_failed(post_model_output_list)

In [None]:
plot_samples.plot_best_samples(post_samples, optim.get_rec_time(), loss=optim.loss, n=1)
plot_samples.plot_best_samples(post_samples, optim.get_rec_time(), loss=optim.loss, n=20)

In [None]:
plot_samples.plot_loss_rounds(
    all_samples, n_samples=np.append(n_samples, len(post_model_output_list)+n_samples[-1]), equal_x=True,
)

## Summarize posterior samples and save to file

In [None]:
# Plot successful runs.
iGlus = np.full((len(post_success_list), optim.loss.target_time.size), np.nan)
rates = np.full((len(post_success_list), optim.rec_ex_size), np.nan)
Vms   = np.full((len(post_success_list), optim.rec_ex_size), np.nan)

# Get successful traces.
for idx_l, idx_r in enumerate(post_success_list):
    iGlus[idx_l,:] = loss.rate2best_iGluSnFR_trace(post_model_output_list[idx_r]['rate'])[0]
    rates[idx_l,:] = post_model_output_list[idx_r]['rate']
    Vms[idx_l,:]   = post_model_output_list[idx_r]['Vm']

In [None]:
# Save post data.
data_utils.save_var(optim.get_rec_time(), os.path.join(post_data_folder, 'rec_time.pkl'))
data_utils.save_var(iGlus,           os.path.join(post_data_folder, 'iGlus.pkl'))
data_utils.save_var(Vms,             os.path.join(post_data_folder, 'Vms.pkl'))
data_utils.save_var(rates,           os.path.join(post_data_folder, 'rates.pkl'))
data_utils.save_var(final_posterior, os.path.join(post_data_folder, 'distribution.pkl'))
data_utils.save_var([params.opt_params2sim_params(opt_params) for opt_params in post_opt_params],
                                     os.path.join(post_data_folder, 's_params_list.pkl'))

## Posterior marginals

In [None]:
final_posterior.reseed(1356)

marg_n_samples = 20 if inference_mode=='test' else 200
marginal_o_params_arr = np.empty((marg_n_samples, params.p_N))

for p_idx in range(params.p_N):
    marginal_o_params_arr[:,p_idx] = final_posterior.gen(marg_n_samples)[:,p_idx]

In [None]:
PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[final_posterior],
    lbs=np.full(params.p_N, -0.5), ubs=np.full(params.p_N, 1.5)
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=True, opt_samples=marginal_o_params_arr)

In [None]:
load_rec_data_list = (inference_mode=='load_only')

marg_data_folder = os.path.join('optim_data', optim.output_folder, 'marginal_post_data')
data_utils.make_dir(marg_data_folder)

marginal_model_output_list = analyze_posterior_utils.gen_or_load_samples(
    optim=optim, opt_params=marginal_o_params_arr, load=load_rec_data_list,
    filename=os.path.join(marg_data_folder, 'rec_data_list_from_marginals.pkl'), 
)

assert len(marginal_model_output_list) == marg_n_samples

# Best sample(s)

## Loss

In [None]:
plt.loglog(np.arange(1,all_samples_sort_idx.size+1), all_samples['loss']['total'][all_samples_sort_idx], '.')
plt.title(str(all_samples_sort_idx[:7]) + '\n' + str(["{:.3f}".format(l) for l in all_samples['loss']['total'][all_samples_sort_idx][:7]]))
plt.show()

## Summarize

In [None]:
# Save final model output.
final_model_output = {}
final_model_output['rate']        = all_samples['rate'][all_samples_sort_idx[0]]
final_model_output['rate-off']    = all_samples['rate'][all_samples_sort_idx[0]] - all_samples['rate'][all_samples_sort_idx[0]][0]
final_model_output['iGlu']        = loss.rate2best_iGluSnFR_trace(final_model_output['rate'])[0]
final_model_output['Vm']          = all_samples['Vm'][all_samples_sort_idx[0]]
final_model_output['Vm-off']      = all_samples['Vm'][all_samples_sort_idx[0]] - all_samples['Vm'][all_samples_sort_idx[0]][0] 
final_model_output['Time']        = delfi_optim.optim.get_rec_time()
final_model_output['predur']      = predur
final_model_output['t_rng']       = delfi_optim.optim.get_t_rng()
final_model_output['Stimulus']    = stimulus
final_model_output['Target']      = loss.target
final_model_output['Time-Target'] = loss.target_time
final_model_output['params_unit'] = params_unit.copy()
final_model_output['params']      = {k: vs[all_samples_sort_idx[0]] for k, vs in all_samples['params'].items()}
final_model_output['loss']        = {k: vs[all_samples_sort_idx[0]] for k, vs in all_samples['loss'].items()}

data_utils.save_var(final_model_output, os.path.join(post_data_folder, 'final_model_output.pkl'))

In [None]:
final_model_output['loss']

In [None]:
final_model_output['params'] 

## Plot traces

In [None]:
import plot_rates_and_Vm

plot_rates_and_Vm.plot_rates_Vms_iGlus(
    iGlus=iGlus, rates=rates, Vms=Vms, target=optim.loss.target,
    ts_iGlus=optim.loss.target_time, ts_rec=optim.get_rec_time(),
    final_model_output=final_model_output,
)

## Plot peak times

In [None]:
import plot_peaks

xlims = [optim.get_t_rng()] + [(1.5, 2.5), (4,6), (10,14), (13.5, 17), (12, 13), (22, 25), (27, 31)]

trace_list = [loss.target, loss.rate2best_iGluSnFR_trace(final_model_output['rate'])[0]]
time_list = optim.loss.target_time
label_list = ['target', 'fit']
params_dict_list = [{'height_pos': 0.2, 'prom': 0.16}, {'height_pos': 0.1, 'prom': 0.05}]

trace_peaks = plot_peaks.compare_peaks_in_traces(
    trace_list=trace_list,
    time_list=time_list,
    plot_single=False,
    plot_hist=True,
    plot=True,
    params_dict_list=params_dict_list,
    color_list=['r', 'b'],
    label_list=label_list,
    xlims=xlims,
    base_trace_i=0,
    ignore_rec_times=[(15, 24)], # Ignore noisy parts.
)