# Inference for the OFF-BC

- [Define Target and Stimulus](#Target-and-Stimulus)
- [Create the BC model](#Cell)
- [Select loss function and parameters](#Optimizer)
- [Run inference](#Inference)
- Plots inference results:
    - [Plot results](#Plot-results)
    - [Posterior](#Posterior)
    - [Best sample(s)](#Best-sample(s))

## <font color='red'> Select mode: full_inference  / load_only / test</font>

- *full_inference*
    - Runs the whole inference.
    - Takes a long time
- *load_only*
    - Will not generate new samples, but loads the data generated for the paper.
- *test*
    - Runs the whole inference, but with fewer samples.
    - Illustrates how the inference works, without spending to much CPU power and time.
    - However, it might lead to problems, because too few samples are generated leading to bad inference.
    - Don't use these results in subsequent steps.

In [None]:
#inference_mode = 'test'
#inference_mode = 'full_inference'
inference_mode = 'load_only'

# Imports

In [None]:
import importlib

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Target and Stimulus

In [None]:
stim_t_rng = (1, 32)

In [None]:
# Load experimental data
data_folder = os.path.join('..', 'ExperimentalData', 'PreprocessedData')
target_dF_F = pd.read_csv(os.path.join(data_folder, 'Franke2017_Release_BC3a_Strychnine.csv'))
stimulus    = pd.read_csv(os.path.join(data_folder, 'Franke2017_stimulus_time_and_amp_corrected.csv'))

In [None]:
fig, axs = plt.subplots(2,1,figsize=(12,4),subplot_kw=dict(xlim=stim_t_rng))
stimulus.plot(x='Time', ax=axs[0])
target_dF_F.plot(x='Time', ax=axs[1])
plt.show()

# Cell

In [None]:
bp_type = 'CBC3a'
predur = 10

# Load parameters.
params_default = data_utils.load_var(os.path.join('cell_params', bp_type+'_cell_params_default.pkl'))
params_unit = data_utils.load_var(os.path.join('cell_params', bp_type+'_cell_params_unit.pkl'))

optimize_cpl_dict = data_utils.load_var(os.path.join("cell_params", bp_type+"_optimize_cpl_dict.pkl"))
final_cpl_dict = data_utils.load_var(os.path.join("cell_params", bp_type+"_final_cpl_dict.pkl"))

params_default.update(optimize_cpl_dict)

## Create model

In [None]:
import retsim_cells

cell = retsim_cells.CBC(
    bp_type = bp_type,
    predur=predur, t_rng=(1.9,2.3),
    params_default=params_default, params_unit=params_unit,
    stimulus=stimulus, stim_type='Light',
    cone_densfile       = 'dens_cone_optimized_submission2.n',
    bp_densfile         = 'dens_CBC3a_optimize_OFF.n',
    nval_file           = 'nval_optimize_CBCs.n',
    chanparams_file     = 'chanparams_CBC3a_optimize_OFF.n',
    expt_file_list      = ['optimize_OFF', 'test_stability_Vclamp_OFF'],
    expt_base_file_list = ['retsim_files/expt_CBC_base.cc', 'retsim_files/expt_test_stability_Vclamp_base.cc'],
    retsim_path=os.path.abspath(os.path.join('..', 'NeuronC', 'models', 'retsim')) + '/'
)

In [None]:
import plot_cell_morph

cones_connect_to_nodes = [686, 1037, 828, 950, 879]
plot_cell_morph.plot_3D_cell(morph_data=cell.morph_data, node_list=cones_connect_to_nodes)

In [None]:
# Create c++ file.
cell.create_retsim_expt_file(verbose=False, off2cone_nodes=cones_connect_to_nodes)
# Compile c++ file.
!(cd {cell.retsim_path} && make)

In [None]:
cell.update_cpl(**optimize_cpl_dict)
cell.init_retsim(verbose=False)

plot_cell_morph.plot_2D_cell(
    morph_data=cell.morph_data, comp_data=cell.comp_data,
    node_list=cones_connect_to_nodes, plot_connections=True
)

cell.init_retsim(plot=True, verbose=False)

In [None]:
cell.update_cpl(**final_cpl_dict)

plot_cell_morph.plot_2D_cell(
    morph_data=cell.morph_data, comp_data=cell.comp_data, node_list=cones_connect_to_nodes
)
                
cell.init_retsim(verbose=False)

# Test model

This step can be skipped.

In [None]:
cell.update_t_rng((1.95, 2.15))
cell.rec_type = 'optimize'

In [None]:
# Use opt cpl
cell.update_cpl(**optimize_cpl_dict)
%time rec_data1, rec_time1, rec_stim1 = cell.run(plot=True, verbose=False)

In [None]:
# Use final cpl
cell.update_cpl(**final_cpl_dict)
%time rec_data2, rec_time2, rec_stim2 = cell.run(plot=True, verbose=False)

In [None]:
# Compare. Does not have to be equal, but should be relatively close.
plt.figure(1,(12,3))
plt.subplot(121)
plt.plot(rec_time1, rec_data1['BC Vm Soma'], c='red')
plt.plot(rec_time2, rec_data2['BC Vm Soma'], c='blue', ls='--')

plt.subplot(122)
h1 = plt.plot(rec_time1, rec_data1['rate BC'], c='red', label='optimize cpl')
h2 = plt.plot(rec_time2, rec_data2['rate BC'], c='blue', ls='--', label='final cpl')
plt.legend([h1[0], h2[0]], [h1[0].get_label(), h2[0].get_label()])
plt.show()

## Test cones

Set optimization folder to cone folder, will simulate the cone and compare output to desired output.

In [None]:
import retsim_cell_tests

cone_post_data_folder = os.path.join(
    '..', 'step1a_optimize_cones', 'optim_data', 'optimize_cone_submission2', 'post_data'
)

retsim_cell_tests.test_cones(
    cell, os.path.join(cone_post_data_folder, 'final_model_output.pkl'), t_rng=(1,2.5)
);

## Test if parameters are used in retsim

In [None]:
import retsim_params_test

cell.predur = 2
cell.update_t_rng((4.9, 5.2))
cell.rec_type = 'optimize'
cell.update_cpl(**optimize_cpl_dict)

all_equal_params, all_close_params = retsim_params_test.test_if_params_are_used(
    cell=cell, params=params_default, Vm_name='BC Vm Soma', rate_name='rate BC',
)

In [None]:
assert len(all_equal_params) == 0
assert len(all_close_params) == 0

# Optimizer

In [None]:
if inference_mode == 'load_only':
    output_folder = 'optimize_OFF_submission2'
elif inference_mode == 'test' or inference_mode == 'full_inference':
    output_folder = 'optimize_OFF'
else:
    raise NotImplementedError()
    
print('Inference:', inference_mode, '--> Folder:', output_folder)

## Save number of compartments.

In [None]:
n_cpl_dict = {}

cell.update_cpl(**final_cpl_dict)
n_cpl_dict['final'] = cell.n_bc_comps

cell.update_cpl(**optimize_cpl_dict)
n_cpl_dict['optimize'] = cell.n_bc_comps

data_utils.save_var(n_cpl_dict, os.path.join('optim_data', output_folder, 'n_cpl_dict.pkl'))

## Parameters

In [None]:
# Load optimize paramters.
opt_params_default = data_utils.load_var(os.path.join('cell_params', bp_type+'_opt_params_default.pkl'))
opt_params_range   = data_utils.load_var(os.path.join('cell_params', bp_type+'_opt_params_range.pkl'))

data_utils.save_var(opt_params_default, os.path.join('optim_data', output_folder, 'opt_params_default.pkl'))
data_utils.save_var(optimize_cpl_dict, os.path.join('optim_data', output_folder, "optimize_cpl_dict.pkl"))

data_utils.save_var(final_cpl_dict, os.path.join('optim_data', output_folder, "final_cpl_dict.pkl"))

In [None]:
import param_funcs
        
params = param_funcs.Parameters(p_range=opt_params_range, p_default=opt_params_default)
data_utils.save_var(params, os.path.join('optim_data', output_folder, 'params.pkl'))

In [None]:
params.plot(opt_bounds=(0,1))

## Optimizer

In [None]:
import retsim_test_2nd_eq
test = retsim_test_2nd_eq.test_class()

In [None]:
import optim_funcs

# Set cell parameter in case they were changed for testing.
cell.predur = predur
cell.update_cpl(**optimize_cpl_dict)
cell.set_stim(stimulus)

optim = optim_funcs.Optimizer(
    cell=cell, params=params,
    t_rng=stim_t_rng, timeout=60*60*30,
    output_folder=output_folder,
    raw_data_labels       = ['rate BC', 'BC Vm Soma'],
    raw2model_data_labels = {'rate BC': 'rate', 'BC Vm Soma': 'Vm'},
    expt_idx=0, expt_test_idx=1, test=test,
)

optim.init_rec_data(allow_loading=True, force_loading=False, verbose=True)

## Loss

In [None]:
import loss_funcs
importlib.reload(loss_funcs);

loss = loss_funcs.LossOptimizeCell(
    target=target_dF_F, rec_time=optim.get_rec_time(), t_drop=0.5+optim.get_t_rng()[0],
    loss_params='BC OFF no rate limit', absolute=False, mode='gauss'
)
optim.loss = loss

data_utils.save_var(loss, os.path.join('optim_data', optim.output_folder, 'loss.pkl'))

In [None]:
loss.plot_loss_params()

In [None]:
loss_output = loss.calc_loss(optim.rec_data['Data'], plot=True, verbose=True);

## Prior

In [None]:
from TruncatedNormal import TruncatedNormal

lower = np.zeros(len(params.p_names))
upper = np.ones(len(params.p_names))

mean  = [optim.params.sim_param2opt_param(opt_params_default[param], param) for param in params.p_names]
std   = [0.3                                                                for param in params.p_names]

prior = TruncatedNormal(m=np.array(mean), S=np.diag(np.array(std)**2), lower=lower, upper=upper)

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[],
    lbs=np.full(params.p_N, -0.5), ubs=np.full(params.p_N, 1.5),
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=False)

# Inference

In [None]:
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"

import gpu_test
assert gpu_test.run(verbose=False)

In [None]:
pseudo_obs_dim = np.argmax(list(optim.model_output2dict({}, 0, rec_data=optim.rec_data['Data'])['loss'].keys()) =='iGluSnFR')
print(pseudo_obs_dim)

In [None]:
if inference_mode=='test':
    n_samples_per_round = 40
    max_rounds = 2
    gen_minibatch = 20
    print('WARNING: Test mode selected. Results will differ from paper data!')
else:
    n_samples_per_round = 2000
    max_rounds = 4
    gen_minibatch = 200
    
print(n_samples_per_round, '*', max_rounds, 'samples')

In [None]:
import delfi_funcs

delfi_optim = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=30,
    gen_minibatch=gen_minibatch, scalar_loss=False,
    post_as_truncated_normal=True,
)

if not(inference_mode=='load_only'):
    delfi_optim.init_SNPE(
        verbose                 = False,
        pseudo_obs_dim          = pseudo_obs_dim,
        pseudo_obs_n            = 1,
        kernel_bandwidth        = 0.25,
        kernel_bandwidth_perc   = 20,
        pseudo_obs_use_all_data = False,
        n_components            = 1,
    )

    delfi_optim.run_SNPE(
        max_duration_minutes  = 60*24,
        max_rounds            = max_rounds,
        n_samples_per_round   = n_samples_per_round,
        continue_optimization = continue_optimization,
        load_init_tds         = load_init_tds,
    )

# Plot results

## Load data

In [None]:
tds                  = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'tds.pkl'))
inf_snpes            = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'inf_snpes.pkl'))
sample_distributions = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'sample_distributions.pkl'))
logs                 = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'logs.pkl'))
pseudo_obs           = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'pseudo_obs.pkl'))
kernel_bandwidths    = data_utils.load_var(os.path.join(delfi_optim.snpe_folder, 'kernel_bandwidths.pkl'))

# Split prior and posteriors.
prior = sample_distributions[0]
posteriors = sample_distributions[1:]

In [None]:
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    concat_traces=True, list_traces=False, return_sort_idx=True,
    return_n_samples=True, verbose=True
)

d_min_idx = d_sort_index[0]
print('\nd_min = {:.5f}'.format(samples['loss']['total'][d_min_idx]))

## Plot training data

In [None]:
import plot_obs_and_bw
importlib.reload(plot_obs_and_bw)

plot_obs_and_bw.plot(
    [pseudo_obs_i[0,pseudo_obs_dim] for pseudo_obs_i in pseudo_obs],
    [kernel_bandwidth_i[pseudo_obs_dim] for kernel_bandwidth_i in kernel_bandwidths],
)
plot_obs_and_bw.plot_logs(logs)

In [None]:
import plot_iws
importlib.reload(plot_iws)

plot_iws.plot_iws(tds, pseudo_obs_dim, pseudo_obs=None, kernel_bandwidths=None)

In [None]:
import plot_samples
importlib.reload(plot_samples);

plot_samples.plot_execution_time(samples, lines=n_samples)

In [None]:
import plot_samples
importlib.reload(plot_samples);

plot_samples.plot_loss_rounds(samples, n_samples, equal_x=True)

## Plot samples

In [None]:
import plot_samples
importlib.reload(plot_samples);

plot_samples.plot_best_samples(samples=samples, time=optim.get_rec_time(), loss=optim.loss, n=1)
plot_samples.plot_best_samples(samples=samples, time=optim.get_rec_time(), loss=optim.loss, n=20)

In [None]:
import plot_peaks
importlib.reload(plot_peaks);

xlims = [optim.get_t_rng()] + [(1.5, 2.5), (4,6), (10,14), (13.5, 17), (12, 13), (22, 25), (27, 31)]

trace_list = [loss.target, loss.rate2best_iGluSnFR_trace(samples['rate'][d_min_idx])[0]]
time_list = optim.loss.target_time
label_list = ['target', 'fit']
params_dict_list = [
    {'height_pos': 0.06, 'height_neg': 0.01, 'prom': 0.12},
    {'height_pos': 0.01, 'prom': 0.1},
    {'height_pos': 0.01, 'prom': 0.03},
]

trace_peaks = plot_peaks.compare_peaks_in_traces(
    trace_list=trace_list,
    time_list=time_list,
    plot_single=False,
    plot_hist=True,
    plot=True,
    params_dict_list=params_dict_list,
    color_list=['r', 'b'],
    label_list=label_list,
    xlims=xlims,
    base_trace_i=0,
    ignore_rec_times=[(15, 24)], # Ignore noisy parts.
)

# Posterior

In [None]:
post_data_folder = os.path.join('optim_data', optim.output_folder, 'post_data')
data_utils.make_dir(post_data_folder)

In [None]:
final_posterior = posteriors[-1]

## Plot posteriors

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=delfi_optim.prior, posterior_list=posteriors,
    lbs=prior.lower, ubs=prior.upper
)

samples_to_plot = np.concatenate([tds_i[0] for tds_i in tds])[d_sort_index[:10],:]

PP.plot_sampling_dists_1D(
    opt_x=False, params=None, plot_peak_lines=False, figsize=(12,8), opt_samples=samples_to_plot
)

## Sample from posterior

In [None]:
import analyze_posterior_utils

if inference_mode=='test':
    post_n_samples = 20
else:
    post_n_samples = 200

post_opt_params = analyze_posterior_utils.get_samples(
    posterior=final_posterior, n_samples=post_n_samples, seed=777,
    plot=True, prior=prior, params=params, plot_opt_x=True,
)

### Run with opt. cpl

In [None]:
# Load or simulate?
load_rec_data_list = (inference_mode=='load_only')

post_model_output_list = analyze_posterior_utils.gen_or_load_samples(
    optim=optim, opt_params=post_opt_params,
    filename=os.path.join(post_data_folder, 'post_model_output_list_optimize_cpl.pkl'),
    load=load_rec_data_list
)

In [None]:
import print_num_failed
importlib.reload(print_num_failed);

post_success_list = print_num_failed.print_num_failed(post_model_output_list)

### Stack all samples

In [None]:
post_samples = optim.stack_model_output_list(post_model_output_list)
all_samples = optim.stack_model_output_list([samples, post_samples])
all_samples_sort_idx = np.argsort(all_samples['loss']['total'])

### Plot

In [None]:
plot_samples.plot_best_samples(post_samples, optim.get_rec_time(), loss=optim.loss, n=1)
plot_samples.plot_best_samples(post_samples, optim.get_rec_time(), loss=optim.loss, n=20)

In [None]:
importlib.reload(plot_samples)

plot_samples.plot_loss_rounds(
    all_samples, n_samples=np.append(n_samples, post_samples['loss']['total'].size+n_samples[-1]), equal_x=True,
)

In [None]:
plt.loglog(np.arange(1,all_samples_sort_idx.size+1), all_samples['loss']['total'][all_samples_sort_idx], '.')
plt.title(str(all_samples_sort_idx[:7]) + '\n' +\
          str(["{:.3f}".format(l) for l in all_samples['loss']['total'][all_samples_sort_idx][:7]]))
plt.show()

### Run with final CPL

In [None]:
# Prepare cell
cell.update_cpl(**final_cpl_dict)
cell.timeout = 100000.
optim.n_parallel = 20

cell.init_retsim()

In [None]:
# Load or simulate?
load_rec_data_list = (inference_mode=='load_only')

post_model_output_list_final_cpl = analyze_posterior_utils.gen_or_load_samples(
    optim=optim, opt_params=post_opt_params,
    filename=os.path.join(post_data_folder, 'post_model_output_list.pkl'),
    load=load_rec_data_list
)

post_samples_final_cpl = optim.stack_model_output_list(post_model_output_list_final_cpl)

In [None]:
post_success_list_final_cpl = print_num_failed.print_num_failed(post_model_output_list_final_cpl)

In [None]:
import plot_samples
importlib.reload(plot_samples);

plot_samples.plot_best_samples(post_samples_final_cpl, optim.get_rec_time(), loss=optim.loss, n=1)
plot_samples.plot_best_samples(post_samples_final_cpl, optim.get_rec_time(), loss=optim.loss, n=20)

### Compare CPLs

In [None]:
import plot_opt_cpl_vs_final_cpl
importlib.reload(plot_opt_cpl_vs_final_cpl);

plot_opt_cpl_vs_final_cpl.plot_post_vs_marg_sample_loss(
    post_loss=post_samples['loss'],
    post_loss_final_cpl=post_samples_final_cpl['loss'],
)

## Summarize posterior samples and save to file

In [None]:
# Plot successful runs.
iGlus = np.full((len(post_success_list), optim.loss.target_time.size), np.nan)
rates = np.full((len(post_success_list), optim.rec_ex_size), np.nan)
Vms   = np.full((len(post_success_list), optim.rec_ex_size), np.nan)

# Get successful traces.
for idx_l, idx_r in enumerate(post_success_list):
    iGlus[idx_l,:] = loss.rate2best_iGluSnFR_trace(post_model_output_list_final_cpl[idx_r]['rate'])[0]
    rates[idx_l,:] = post_model_output_list_final_cpl[idx_r]['rate']
    Vms[idx_l,:]   = post_model_output_list_final_cpl[idx_r]['Vm']

In [None]:
# Save post data.
data_utils.save_var(optim.get_rec_time(), os.path.join(post_data_folder, 'rec_time.pkl'))
data_utils.save_var(iGlus,           os.path.join(post_data_folder, 'iGlus.pkl'))
data_utils.save_var(Vms,             os.path.join(post_data_folder, 'Vms.pkl'))
data_utils.save_var(rates,           os.path.join(post_data_folder, 'rates.pkl'))
data_utils.save_var(final_posterior, os.path.join(post_data_folder, 'distribution.pkl'))
data_utils.save_var([params.opt_params2sim_params(opt_params) for opt_params in post_opt_params],
                                     os.path.join(post_data_folder, 's_params_list.pkl'))

In [None]:
import plot_rates_and_Vm
importlib.reload(plot_rates_and_Vm);

plot_rates_and_Vm.plot_rates_Vms_iGlus(
    iGlus, loss.target, rates, Vms, ts_iGlus=optim.loss.target_time,
    ts_rec=optim.get_rec_time(),
)

# Best sample(s)

## With final CPL

In [None]:
post_best_idx = np.argmin(post_samples_final_cpl['loss']['total'])
post_best_samples_final_cpl = post_model_output_list_final_cpl[post_best_idx].copy()

# Save final model output.
final_model_output = {}
final_model_output['rate']        = post_best_samples_final_cpl['rate']
final_model_output['rate-off']    = post_best_samples_final_cpl['rate'] - post_best_samples_final_cpl['rate'][0]
final_model_output['iGlu']        = loss.rate2best_iGluSnFR_trace(post_best_samples_final_cpl['rate'])[0]
final_model_output['Vm']          = post_best_samples_final_cpl['Vm']
final_model_output['Vm-off']      = post_best_samples_final_cpl['Vm'] - post_best_samples_final_cpl['Vm'][0] 
final_model_output['Time']        = delfi_optim.optim.get_rec_time()
final_model_output['predur']      = predur
final_model_output['t_rng']       = delfi_optim.optim.get_t_rng()
final_model_output['Stimulus']    = stimulus
final_model_output['Target']      = loss.target
final_model_output['Time-Target'] = loss.target_time
final_model_output['params_unit'] = params_unit.copy()
final_model_output['params']      = post_best_samples_final_cpl['params'].copy()
final_model_output['loss']        = post_best_samples_final_cpl['loss'].copy()

data_utils.save_var(final_model_output, os.path.join(post_data_folder, 'final_model_output.pkl'))

In [None]:
import plot_rates_and_Vm
importlib.reload(plot_rates_and_Vm);

plot_rates_and_Vm.plot_rates_Vms_iGlus(
    iGlus=iGlus, rates=rates, Vms=Vms, target=optim.loss.target,
    ts_iGlus=optim.loss.target_time, ts_rec=optim.get_rec_time(),
    final_model_output=final_model_output,
)

In [None]:
import plot_peaks
importlib.reload(plot_peaks);

xlims = [optim.get_t_rng()] + [(1.5, 2.5), (4,6), (10,14), (13.5, 17), (12, 13), (22, 25), (27, 31)]

trace_peaks = plot_peaks.compare_peaks_in_traces(
    trace_list=[loss.target, loss.rate2best_iGluSnFR_trace(final_model_output['rate'])[0]],
    time_list=optim.loss.target_time,
    plot_single=False,
    plot_hist=True,
    plot=True,
    params_dict_list=[{'height_pos': 0.1, 'prom': 0.16}, {'height_pos': 0.1, 'prom': 0.05}],
    color_list=['r', 'b'],
    label_list=['target', 'fit'],
    xlims=xlims,
    base_trace_i=0,
    ignore_rec_times=[(17, 24)], # Ignore noisy parts.
)

## With optimize CPL

In [None]:
post_best_sample = post_model_output_list[np.argmin(post_samples_final_cpl['loss']['total'])].copy()

# Save final model output.
final_model_output_optimize_cpl             = final_model_output.copy()
final_model_output_optimize_cpl['rate']     = post_best_sample['rate']
final_model_output_optimize_cpl['rate-off'] = post_best_sample['rate'] - post_best_sample['rate'][0]
final_model_output_optimize_cpl['iGlu']     = loss.rate2best_iGluSnFR_trace(post_best_sample['rate'])[0]
final_model_output_optimize_cpl['Vm']       = post_best_sample['Vm']
final_model_output_optimize_cpl['Vm-off']   = post_best_sample['Vm'] - post_best_sample['Vm'][0]

data_utils.save_var(final_model_output_optimize_cpl, os.path.join(post_data_folder, 'final_model_output_optimize_cpl.pkl'))

## Check runtime

In [None]:
%%time
cell.update_cpl(**optimize_cpl_dict)
_ = optim.run(sim_params=final_model_output['params'], verbose=True)

In [None]:
%%time
cell.update_cpl(**final_cpl_dict)
_ = optim.run(sim_params=final_model_output['params'], verbose=True)

# Sample from Marginals

In [None]:
cell.update_cpl(**optimize_cpl_dict)
cell.init_retsim(verbose=True)

In [None]:
final_posterior.reseed(1356)

marginal_o_params_arr = np.empty((post_n_samples, params.p_N))

for p_idx in range(params.p_N):
    marginal_o_params_arr[:,p_idx] = final_posterior.gen(post_n_samples)[:,p_idx]

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[final_posterior],
    lbs=np.full(params.p_N, -0.5), ubs=np.full(params.p_N, 1.5)
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=True, opt_samples=marginal_o_params_arr)

In [None]:
load_rec_data_list = (inference_mode=='load_only')

marg_data_folder = os.path.join('optim_data', optim.output_folder, 'marginal_post_data')
data_utils.make_dir(marg_data_folder)

import analyze_posterior_utils
importlib.reload(analyze_posterior_utils);

marginal_model_output_list = analyze_posterior_utils.gen_or_load_samples(
    optim=optim, opt_params=marginal_o_params_arr, load=load_rec_data_list,
    filename=os.path.join(marg_data_folder, 'rec_data_list_from_marginals.pkl'), 
)

assert len(marginal_model_output_list) == post_n_samples

In [None]:
import print_num_failed
importlib.reload(print_num_failed);

print_num_failed.print_num_failed(marginal_model_output_list);

## Compare marginals to full posterior

In [None]:
import plot_post_vs_marg
plot_post_vs_marg.plot_post_vs_marg(post_model_output_list, marginal_model_output_list)