# Estimate RC params of retina - Inference

- [Load preprocessed experimental data](#Load-preprocessed-data)
- Estimate the electrical parameters of the retina ($\sigma_{retina}$ and $\epsilon_{retina}$). Inference is split in two parts.
    - [Part 1, using logarithmic parameters](#Optimization---Part-1---Logarithmic)
    - [Part 2, using linear parameters](#Optimization---Part-1---Logarithmic)
- [Postprocessing, test and validate results](#Postprocessing)

## <font color='red'> Select mode: full_inference  / load_only / test</font>

- *full_inference*
    - Runs the whole inference. $\Rightarrow$ **COMSOL is required.**
    - Takes a long time
- *load_only*
    - Will not generate new samples, but loads the data generated for the paper.
- *test*
    - Runs the whole inference, but with fewer samples. $\Rightarrow$ **COMSOL is required.**
    - Illustrates how the inference works, without spending to much CPU power and time.
    - However, it might lead to problems, because too few samples are generated leading to bad inference.
    - Don't use these results in subsequent steps.

In [None]:
#inference_mode = 'test'
#inference_mode = 'full_inference'
inference_mode = 'load_only'

# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Load preprocessed data

In [None]:
V_amps          = data_utils.load_var('PreprocessedData/V_amps.pkl')
EDL_phase_total = data_utils.load_var('PreprocessedData/EDL_phase_total.pkl')
fit_sin_params  = data_utils.load_var('PreprocessedData/raw_currents_sinus_fits_params.pkl')
absZ_est        = data_utils.load_var('PreprocessedData/absZ_est.pkl')

In [None]:
def sin_A_phi_time(A, phi, time):
    return A*np.sin(2*np.pi*time*f0+phi/180*np.pi)

In [None]:
# Create dictionary for EDL phase.
EDL_phase_total_dict = {}
for f in [25, 40]:
    EDL_phase_total_dict[f] = {}
    for i_V, V in enumerate(V_amps[f]):
        EDL_phase_total_dict[f][V] = EDL_phase_total[f][i_V]

In [None]:
# Choose Voltage amplitudes to use for optimization.
V_amps_opt = {}
V_amps_opt[25] = [300, 600]
V_amps_opt[40] = [150, 300]

In [None]:
target = {}
for f0 in [25, 40]:
    target[f0] = {}
    for iV0, V0 in enumerate(V_amps_opt[f0]):
        time0 = np.linspace(0, 0.12, 1000)
        current0 = sin_A_phi_time(fit_sin_params["w"][f0][V0][0], fit_sin_params["w"][f0][V0][1], time0)
        target[f0][V0] = pd.DataFrame({'Time': time0, 'Current': current0})

In [None]:
# Plot targets.
plt.figure(figsize=(12,3))
for f0 in [25, 40]:
    for iV0, V0 in enumerate(V_amps_opt[f0]):      
        plt.plot(target[f0][V0]['Time']*1e3, target[f0][V0]['Current']*1e6, label=(str(f0) + ' ' + str(V0)))
        
plt.xlabel('Time [ms]')
plt.ylabel('Current [uA]')
plt.legend()
plt.show()

data_utils.save_var(target, 'PreprocessedData/target.pkl')

# Parameter estimation

In [None]:
if inference_mode == 'load_only':
    output_folder_step1 = 'optimize_CR_step1_submission2'
    output_folder_step2 = 'optimize_CR_step2_submission2'
elif inference_mode in ['test', 'full_inference']:
    output_folder_step1 = 'optimize_CR_step1'
    output_folder_step2 = 'optimize_CR_step2'
else:
    raise NotImplementedError()

## Optimization - Part 1 - Logarithmic

In [None]:
print('Inference:', inference_mode, '--> Folder:', output_folder_step1)

### Params

Define the optimization parameters and add a unit.

In [None]:
import param_funcs
importlib.reload(param_funcs);

p_default = {
    'epsilon_retina': 1,
    'sigma_retina':   1,
}

p_unit = {
    'epsilon_retina': 1e6,
    'sigma_retina':   0.1,
}

params = param_funcs.Parameters(p_default=p_default, p_use_log=list(p_default.keys()))

### Optimizer

Create the optimzer, a helper function for SNPE.

In [None]:
import optimize_COMSOL_params
importlib.reload(optimize_COMSOL_params);

optim = optimize_COMSOL_params.OptimizerCOMSOLparams(
    params = params,
    output_folder = output_folder_step1,
    p_unit = p_unit,
    reset = not(inference_mode=='load_only'),
    V_amps = V_amps_opt,
    EDL_phase_total = EDL_phase_total_dict,
    absZ_est = absZ_est,
    
)

### Loss

In [None]:
import loss_funcs_COMSOL_params
importlib.reload(loss_funcs_COMSOL_params);

loss = loss_funcs_COMSOL_params.Loss(target=target, t_drop=0.04)
optim.loss = loss

### Prior

Define a plot the priors.

In [None]:
from delfi import distribution

mean  = [optim.params.sim_param2opt_param(p_default[param], param) for param in params.p_names]
std   = [2                                                         for param in params.p_names]

prior = distribution.Gaussian(m=np.array(mean), S=np.diag(np.array(std)**2))

In [None]:
plt.figure(figsize=(12,3))
for idx, param in enumerate(params.p_names):
    ax = plt.subplot(1,2,idx+1)
    plt.title(param)
    plt.hist(np.array([list(params.opt_params2sim_params(prior.gen(1).flatten()).values()) for _ in range(1000)])[:,idx]
             * optim.get_unit(param))
    ax.set_yscale('log')
    ax.set_ylim(0.5, None)

### Inference

Run the inference. If *load_only* this step is skipped.

In [None]:
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"

import gpu_test
assert gpu_test.run(verbose=False)

In [None]:
if inference_mode=='test':
    n_samples_per_round = 20
    max_rounds = 2
else:
    n_samples_per_round = 50
    max_rounds = 2
    
print(n_samples_per_round, '*', max_rounds, 'samples')

In [None]:
import delfi_funcs
importlib.reload(delfi_funcs); 

delfi_optim = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=1, gen_minibatch=1,
    post_as_truncated_normal=False
)

if not(inference_mode=='load_only'):
    print('Create new network')
    delfi_optim.init_SNPE(
        verbose               = False,
        pseudo_obs_n          = 1,
        prior_mixin           = 0.0,
        kernel_bandwidth_perc = 25,
        kernel_bandwidth_min  = 0.0,
        use_all_trn_data      = False,
        n_components          = 1,
        loss_failed_sims      = loss.max_loss['total'],
    )
    
    delfi_optim.nn_epochs    = 200
    delfi_optim.nn_minibatch = 8
    
    delfi_optim.run_SNPE(
        max_duration_minutes       = 60*12,
        max_rounds                 = max_rounds,
        n_samples_per_round        = n_samples_per_round,
        continue_optimization      = continue_optimization,
        load_init_tds              = load_init_tds,
    )

### Load data

In [None]:
# Load data.
inf_snpes            = data_utils.load_var(delfi_optim.snpe_folder + '/inf_snpes.pkl')
sample_distributions = data_utils.load_var(delfi_optim.snpe_folder + '/sample_distributions.pkl')
logs                 = data_utils.load_var(delfi_optim.snpe_folder + '/logs.pkl')
tds                  = data_utils.load_var(delfi_optim.snpe_folder + '/tds.pkl')
pseudo_obs           = data_utils.load_var(delfi_optim.snpe_folder + '/pseudo_obs.pkl')
kernel_bandwidths    = data_utils.load_var(delfi_optim.snpe_folder + '/kernel_bandwidths.pkl')
n_samples            = data_utils.load_var(delfi_optim.snpe_folder + '/n_samples.pkl')

# Split prior and posteriors.
prior = sample_distributions[0]
posteriors = sample_distributions[1:]

In [None]:
import plot_obs_and_bw
importlib.reload(plot_obs_and_bw);

plot_obs_and_bw.plot(pseudo_obs, kernel_bandwidths)
plot_obs_and_bw.plot_logs(logs)

### Analyse results

In [None]:
sample_files = sorted(os.listdir(delfi_optim.samples_folder))
print('All files:')
print(sample_files)
assert len(sample_files) == len(tds)

In [None]:
# Load data.
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    files=sample_files, concat_traces=True, list_traces=False,
    return_sort_idx=True, return_n_samples=True,
    verbose=True
)

# Get best indices.
d_min_idx = d_sort_index[0]
n_best_samples = 5

print('\nd_min = '+ str(samples['loss']['total'][d_min_idx]))

## Get prior bounds for Part 2

In [None]:
eps_lb = samples['params']['epsilon_retina'][d_sort_index][0:int(d_sort_index.size*0.1)].min()
eps_ub = samples['params']['epsilon_retina'][d_sort_index][0:int(d_sort_index.size*0.1)].max()
eps_mu = samples['params']['epsilon_retina'][d_sort_index][0]

In [None]:
print('epsilon_retina')
print("lb: {:.2f}".format(eps_lb))
print("ub: {:.2f}".format(eps_ub))
print("bst: {:.2f}".format(eps_mu))

In [None]:
sig_lb = samples['params']['sigma_retina'][d_sort_index][0:int(d_sort_index.size*0.1)].min()
sig_ub = samples['params']['sigma_retina'][d_sort_index][0:int(d_sort_index.size*0.1)].max()
sig_mu = samples['params']['sigma_retina'][d_sort_index][0]

In [None]:
print('epsilon_retina')
print("lb: {:.2f}".format(sig_lb))
print("ub: {:.2f}".format(sig_ub))
print("bst: {:.2f}".format(sig_mu))

In [None]:
i = 0

for iV0 in [0, 1]:
    plt.figure(figsize=(12,4))
    
    plt.text(0.04, 0.5, 'eps_r = {:.4g}'.format(
        samples['params']['epsilon_retina'][d_sort_index[i]]*p_unit['epsilon_retina']))
    plt.text(0.04, -.5, 'sig_r = {:.4g}'.format(
        samples['params']['sigma_retina'][d_sort_index[i]]* p_unit['sigma_retina']))
    
    plt.plot(samples['data'][d_sort_index[i]][0][25][V_amps_opt[25][iV0]]['Time'],\
             samples['data'][d_sort_index[i]][0][25][V_amps_opt[25][iV0]]['Current']*1e6, label='Fit 25')
    
    plt.plot(samples['data'][d_sort_index[i]][0][40][V_amps_opt[40][iV0]]['Time'],\
             samples['data'][d_sort_index[i]][0][40][V_amps_opt[40][iV0]]['Current']*1e6, label='Fit 40')
    
    
    plt.plot(target[25][V_amps_opt[25][iV0]]['Time'],\
             target[25][V_amps_opt[25][iV0]]['Current']*1e6, 'k--', label='Target 25')
    
    plt.plot(target[40][V_amps_opt[40][iV0]]['Time'],\
             target[40][V_amps_opt[40][iV0]]['Current']*1e6, 'r--', label='Target 40')
    
    plt.axvline(loss.t_drop)
    
    
    plt.legend()
    plt.ylabel('Current [uA]')
    plt.xlabel('Time [s]')
    plt.show()

In [None]:
fig, axs = plt.subplots(1,2,figsize=(12,3))
axs[0].loglog(
    samples['params']['sigma_retina'][d_sort_index]* p_unit['sigma_retina'],
    samples['loss']['total'][d_sort_index],
    'k.'
)

axs[1].semilogy(
    samples['loss']['total'][d_sort_index],
    samples['params']['epsilon_retina'][d_sort_index]* p_unit['epsilon_retina'],
    'k.'
)

plt.show()

## Optimization - Part 2 - Linear

### Params

In [None]:
import param_funcs
importlib.reload(param_funcs);

p_default = {
    'epsilon_retina': eps_mu,
    'sigma_retina':   sig_mu,
}

p_unit = {
    'epsilon_retina': 1e6,
    'sigma_retina':   0.1,
}

p_range = {
    'epsilon_retina': (eps_lb, eps_ub),
    'sigma_retina':   (sig_lb, sig_ub),
}

params = param_funcs.Parameters(p_default=p_default, p_range=p_range)

In [None]:
params.plot()

### Optimizer

In [None]:
import optimize_COMSOL_params
importlib.reload(optimize_COMSOL_params);

optim = optimize_COMSOL_params.OptimizerCOMSOLparams(
    params = params,
    output_folder = output_folder_step2,
    p_unit = p_unit,
    reset = False,
    V_amps = V_amps_opt,
    EDL_phase_total = EDL_phase_total_dict,
    absZ_est = absZ_est,
)

### Loss

In [None]:
import loss_funcs_COMSOL_params
importlib.reload(loss_funcs_COMSOL_params);

loss = loss_funcs_COMSOL_params.Loss(target=target, t_drop=0.04)
optim.loss = loss

### Prior

In [None]:
import TruncatedNormal

lower = np.array([0 for param in params.p_names])
upper = np.array([1 for param in params.p_names])

mean  = [optim.params.sim_param2opt_param(p_default[param], param) for param in params.p_names]
std   = [0.3                                                       for param in params.p_names]

prior = TruncatedNormal.TruncatedNormal(m=np.array(mean), S=np.diag(np.array(std)**2), lower=lower, upper=upper)

In [None]:
import plot_sampling_dists
importlib.reload(plot_sampling_dists);

PP = plot_sampling_dists.SamplingDistPlotter(
    params=params, prior=prior, posterior_list=[],
    lbs=np.full(params.p_N, -0.5), ubs=np.full(params.p_N, 1.5),
)
PP.plot_sampling_dists_1D(plot_peak_lines=False, figsize=(12,8), opt_x=False)

### Inference

In [None]:
import delfi_funcs
importlib.reload(delfi_funcs); 

delfi_optim = delfi_funcs.DELFI_Optimizer(
    optim=optim, prior=prior, n_parallel=1, gen_minibatch=1,
    post_as_truncated_normal=False
)

if not(inference_mode=='load_only'):
    delfi_optim.init_SNPE(
        verbose               = False,
        pseudo_obs_n          = 1,
        prior_mixin           = 0.0,
        kernel_bandwidth_perc = 25,
        kernel_bandwidth_min  = 0.0,
        use_all_trn_data      = False,
        n_components          = 1,
        loss_failed_sims      = loss.max_loss['total'],
    )
    
    delfi_optim.nn_epochs    = 200
    delfi_optim.nn_minibatch = 8

    # Run.
    delfi_optim.run_SNPE(
        max_duration_minutes       = 60*12,
        max_rounds                 = max_rounds,
        n_samples_per_round        = n_samples_per_round,
        continue_optimization      = continue_optimization,
        load_init_tds              = load_init_tds,
    )

### Load data

In [None]:
inf_snpes            = data_utils.load_var(delfi_optim.snpe_folder + '/inf_snpes.pkl')
sample_distributions = data_utils.load_var(delfi_optim.snpe_folder + '/sample_distributions.pkl')
logs                 = data_utils.load_var(delfi_optim.snpe_folder + '/logs.pkl')
tds                  = data_utils.load_var(delfi_optim.snpe_folder + '/tds.pkl')
pseudo_obs           = data_utils.load_var(delfi_optim.snpe_folder + '/pseudo_obs.pkl')
kernel_bandwidths    = data_utils.load_var(delfi_optim.snpe_folder + '/kernel_bandwidths.pkl')
n_samples            = data_utils.load_var(delfi_optim.snpe_folder + '/n_samples.pkl')

prior, posteriors = sample_distributions[0], sample_distributions[1:]

In [None]:
import plot_obs_and_bw
importlib.reload(plot_obs_and_bw)

plot_obs_and_bw.plot(pseudo_obs, kernel_bandwidths)
plot_obs_and_bw.plot_logs(logs)

### Analyse results

In [None]:
sample_files = sorted(os.listdir(delfi_optim.samples_folder))
print('All files:')
print(sample_files)
assert len(sample_files) == len(tds)

In [None]:
# Load data.
samples, n_samples, d_sort_index = delfi_optim.load_samples(
    files=sample_files, concat_traces=True, list_traces=False,
    return_sort_idx=True, return_n_samples=True,
    verbose=True
)

# Get best indices.
d_min_idx = d_sort_index[0]
n_best_samples = 5

print('\nd_min = '+ str(samples['loss']['total'][d_min_idx]))

In [None]:
best_sig_r = samples['params']['sigma_retina'][d_min_idx]
print('sigma_retina = {:.4g}'.format(best_sig_r))

In [None]:
best_eps_r = samples['params']['epsilon_retina'][d_min_idx]
print('epsilon_retina = {:.4g}'.format(best_eps_r))

# Postprocessing

## Test

In [None]:
if not(inference_mode=='load_only'):
    optim.create_inputs(sim_params={"sigma_retina": best_sig_r, "epsilon_retina": best_eps_r})
    optim.run_COMSOL()
    
    rec_data = optim.read_outputs()

In [None]:
if not(inference_mode=='load_only'):
    for iV0 in [0, 1]:

        plt.figure(figsize=(12,4))

        plt.plot(rec_data[25][V_amps_opt[25][iV0]]['Time'],\
                 rec_data[25][V_amps_opt[25][iV0]]['Current']*1e6, label='Fit 25')

        plt.plot(rec_data[40][V_amps_opt[40][iV0]]['Time'],\
                 rec_data[40][V_amps_opt[40][iV0]]['Current']*1e6, label='Fit 40')


        plt.plot(target[25][V_amps_opt[25][iV0]]['Time'],\
                 target[25][V_amps_opt[25][iV0]]['Current']*1e6, 'k--', label='Target 25')

        plt.plot(target[40][V_amps_opt[40][iV0]]['Time'],\
                 target[40][V_amps_opt[40][iV0]]['Current']*1e6, 'r--', label='Target 40')

        plt.axvline(loss.t_drop)


        plt.legend()
        plt.ylabel('Current [uA]')
        plt.xlabel('Time [s]')
        plt.show()

## Validate

### <font color='red'> Run COMSOL "step4_flat_w_validate.mph" </font>

In [None]:
if not(inference_mode=='load_only'):
    assert len(os.listdir('COMSOL_output/')) > 0, 'Somethings wrong with the COMSOL output, did you run the COMSOL model?'

In [None]:
optim.V_amps = {25: [100, 200, 300, 400, 500, 600], 40: [50, 100, 150, 200, 250, 300]}

if not(inference_mode=='load_only'):
    rec_data = optim.read_outputs(verbose=True)
else:
    rec_data = data_utils.load_var('ValidationData/I_retina_validation.pkl')

In [None]:
target = {}
for f0 in [25, 40]:
    target[f0] = {}
    for iV0, V0 in enumerate(optim.V_amps[f0]):
        time0 = np.linspace(0, 0.12, 1000)
        current0 = sin_A_phi_time(fit_sin_params["w"][f0][V0][0], fit_sin_params["w"][f0][V0][1], time0)
        target[f0][V0] = pd.DataFrame({'Time': time0, 'Current': current0})

In [None]:
plt.figure(figsize=(12,6))
for fi, f in enumerate([25, 40]):
    plt.subplot(2, 1, fi+1)
    for Vi, V in enumerate(optim.V_amps[f]):
        
        plt.plot(
            rec_data[f][V]['Time'], rec_data[f][V]['Current']*1e6,
            label='Fit @ '+ str(f) + ' Hz, '+ str(V) + ' mV'
        )
        
        plt.plot(
            target[f][V]['Time'], target[f][V]['Current']*1e6,
            'k--', label=None
        )
        
        plt.axvline(loss.t_drop)
        
        plt.legend()
        plt.ylabel('Current [uA]')
        plt.xlabel('Time [s]')
plt.show()

### Save

In [None]:
data_utils.make_dir('ValidationData')
data_utils.save_var(rec_data, 'ValidationData/I_retina_validation.pkl')
data_utils.save_var(target, 'ValidationData/target.pkl')