# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils
plot_utils.set_rcParams()

In [None]:
fig_num = os.getcwd().split('/')[-1][3:5]
print(fig_num)

# Get data

In [None]:
Vs_opt = {25: [300, 600], 40: [150, 300]}
Vs_val = {25: [100, 200, 300, 400, 500, 600], 40: [50, 100, 150, 200, 250, 300]}

## Experimental data

In [None]:
folder = os.path.join('..', '..', 'step3_optimize_COMSOL_params')
pp_folder = os.path.join(os.path.join(folder, 'data_preprocessed'))

In [None]:
os.listdir(pp_folder)

In [None]:
ex_currents = data_utils.load_var(os.path.join(pp_folder, 'raw_currents.pkl'))
cur_time = data_utils.load_var(os.path.join(pp_folder, 'cur_time.pkl'))
V_amps = data_utils.load_var(os.path.join(pp_folder, 'V_amps.pkl'))

V_ames_sinus_fits_params = data_utils.load_var(
    os.path.join(pp_folder, 'V_ames_sinus_fits_params.pkl'))
EDL_phase_total = data_utils.load_var(
    os.path.join(pp_folder, 'EDL_phase_total.pkl'))
currents_fit_sin_params = data_utils.load_var(
    os.path.join(pp_folder, 'raw_currents_sinus_fits_params.pkl'))
absZ_est = data_utils.load_var(os.path.join(pp_folder, 'absZ_est.pkl'))

RC_params = data_utils.load_var(os.path.join(pp_folder, 'RC_params.pkl'))

In [None]:
target = data_utils.load_var(os.path.join(
    folder, 'data_validation/target.pkl'))
I_retina_raw = data_utils.load_var(os.path.join(
    folder, 'data_validation/I_retina_validation.pkl'))

I_retina = {}
for f, Vs in Vs_val.items():
    I_retina[f] = {}
    for Vidx, V in enumerate(Vs):
        time = I_retina_raw[f][V]['Time'] - 2/f
        current = I_retina_raw[f][V]['Current']
        idx = np.logical_and(time >= 0, time < 1/f)

        I_retina[f][V] = pd.DataFrame({
            'Time': time[idx],
            'Current': current[idx],
        })

## Get optimization data

In [None]:
os.listdir(os.path.join(folder, 'optim_data'))

In [None]:
folders_inf = {
    1: "optim_data/optimize_CR_step1_submission2/",
    2: "optim_data/optimize_CR_step2_submission2/",
}

In [None]:
sample_files = {}
for i, folder_delfi in folders_inf.items():
    sample_files[i] = sorted(os.listdir(os.path.join(folder, folder_delfi + "samples/")))
sample_files

In [None]:
samples_list = {}
for i, folder_delfi in folders_inf.items():
    samples_list[i] = [data_utils.load_var(os.path.join(folder, folder_delfi + 'samples/' + sample_file))
                       for sample_file in sample_files[i]]
    
eps         = {}
sig         = {}
losses      = {}
losses_sort = {}

for i, samples_i in samples_list.items():

    eps[i]         = np.concatenate([samples_ii['params']['epsilon_retina'] for samples_ii in samples_i])
    sig[i]         = np.concatenate([samples_ii['params']['sigma_retina']   for samples_ii in samples_i])
    losses[i]      = np.concatenate([samples_ii['loss']['total']            for samples_ii in samples_i])
    losses_sort[i] = np.argsort(losses[i])
    
del samples_list

In [None]:
p_unit = {
    'epsilon_retina': 1e6,
    'sigma_retina':   0.1,
}

# Export data

In [None]:
data_utils.make_dir('source_data')

In [None]:
for f in [25, 40]:
    currents_exdf = pd.DataFrame()
    currents_exdf['Time (ms)'] = 1e3*cur_time[f]
    
    for V in Vs_val[f]:
        currents_exdf['i_ames/uA for V0(v_stim)=' +str(V)+'mV'] = ex_currents['wo'][f][str(V) + " mV"]*1e6

    for V in Vs_val[f]:
        currents_exdf['i_retina/uA for V0(v_stim)=' +str(V)+'mV'] = ex_currents['w'][f][str(V) + " mV"]*1e6

    currents_exdf.to_csv('source_data/recorded_currents_' + str(f) + 'Hz.csv', index=False, float_format="%.6f")

In [None]:
for f in [25, 40]:
    V_ames_exdf = pd.DataFrame(index=['V0(v_stim)=' +str(V)+'mV' for V in Vs_val[f]])
    
    V_ames_exdf['V0(v_ames)/mV'] = [V_ames_sinus_fits_params[f][V][0] for V in Vs_val[f]]
    V_ames_exdf['phi(v_ames)/degree'] = [V_ames_sinus_fits_params[f][V][1] for V in Vs_val[f]]
    
    V_ames_exdf.to_csv('source_data/v_ames_' + str(f) + 'Hz.csv', index=True, float_format="%.4f")

In [None]:
for f in [25, 40]:
    currents_exdf = pd.DataFrame(index=['V0(v_stim)=' +str(V)+'mV' for V in Vs_val[f]])
    
    currents_exdf['I0(i_ames)/uA'] = [currents_fit_sin_params['wo'][f][V][0]*1e6 for V in Vs_val[f]]
    currents_exdf['phi(i_ames)/degree'] = [currents_fit_sin_params['wo'][f][V][1] for V in Vs_val[f]]
    
    currents_exdf['I0(i_retina)/uA'] = [currents_fit_sin_params['w'][f][V][0]*1e6 for V in Vs_val[f]]
    currents_exdf['phi(i_retina)/degree'] = [currents_fit_sin_params['w'][f][V][1] for V in Vs_val[f]]
    
    filename = 'source_data/fits_to_recorded_currents_' + str(f) + 'Hz.csv'
    currents_exdf.to_csv(filename, index=True, float_format="%.4f")

In [None]:
for f in [25, 40]:
    RC_params_exdf = pd.DataFrame(index=['V0(v_stim)=' +str(V)+'mV' for V in Vs_val[f]])
    
    RC_params_exdf['R_e/kOhm'] = [RC_params[f][V]['R']/1e3 for V in Vs_val[f]]
    RC_params_exdf['C_e/nF'] = [RC_params[f][V]['C']*1e9 for V in Vs_val[f]]
    
    RC_params_exdf.to_csv('source_data/RC_params_' + str(f) + 'Hz.csv', index=True, float_format="%.3f")

In [None]:
pd.DataFrame({
    'sigma_retina/(S/m)': np.concatenate([sig[1], sig[2]]) * p_unit['sigma_retina'], 
    'epsilon_retina': np.concatenate([eps[1], eps[2]]) * p_unit['epsilon_retina'],
    'discrepancy': np.concatenate([losses[1], losses[2]]),
}, index=["1_Log_Round1"]*50+["2_Log_Round2"]*50+["3_Lin_Round1"]*50+["4_Lin_Round2"]*50,
).to_csv('source_data/Samples_ElectricalParams.csv', float_format="%.6f")

In [None]:
for f in [25, 40]:
    currents_sim_exdf = {}

    for V in Vs_val[f]:
        currents_sim_exdf['Time (ms) for V0(v_stim)=' +str(V)+'mV'] =\
            pd.Series(1e3*I_retina[f][V]['Time'].values)
        currents_sim_exdf['i_retina/uA for V0(v_stim)=' +str(V)+'mV'] =\
            pd.Series(1e6*I_retina[f][V]['Current'].values)

    currents_sim_exdf = pd.DataFrame(currents_sim_exdf)
        
    currents_sim_exdf.to_csv('source_data/simulated_i_retina_' + str(f) + 'Hz.csv',
                             index=False, float_format="%.6f")

# Plotting function

In [None]:
from PIL import Image

c1 = Image.open('circuits/circuit1.png')
c2 = Image.open('circuits/circuit2.png')
c3 = Image.open('circuits/circuit3.png')

fig, axs = plt.subplots(1,3,figsize=(12,4))
for ax, c in zip(axs, [c1, c2, c3]):
    ax.imshow(c)
    ax.axis('off')

## Plot circuits.

In [None]:
def plot_circuit(ax, c, im_d=10):
    ax.imshow(c, origin='lower')
    ax.axis('off')
    ax.set_xlim((-im_d, c.size[0]+im_d))
    ax.set_ylim((c.size[1]+im_d, -im_d))
    
    if c == c1:
        ax.text(c.size[0]*0.21, c.size[1]*0.32, r'$v_{ames}$', va='center', ha='right')  
        ax.text(c.size[0]*0.33, c.size[1]*0.77, r'$i_{ames}$', va='center', ha='right')
        
        ax.text(c.size[0]*(0.47-0.08), c.size[1]*0.35, r'$\epsilon_{ames}$', va='center', ha='left')
        ax.text(c.size[0]*(0.83-0.08), c.size[1]*0.35, r'$\sigma_{ames}$',  va='center', ha='right')
    
    elif c == c2:
        ax.text(c.size[0]*0.35, c.size[1]*0.35, r'$R_e$', va='bottom', ha='right')
        ax.text(c.size[0]*0.19,  c.size[1]*0.35, r'$C_e$', va='bottom', ha='right')
        
        ax.text(c.size[0]*0.23, c.size[1]*0.77, r'$i_{ames}$', va='bottom', ha='right')  
        ax.text(c.size[0]*0.5, c.size[1]*0.58, r'$v_{stim}$ - $v_{ames}$', va='top', ha='left')
    
    elif c == c3:
        ax.text(c.size[0]*0.31, c.size[1]*0.43, r'$R_e$', va='bottom', ha='right')
        ax.text(c.size[0]*0.17, c.size[1]*0.43, r'$C_e$', va='bottom', ha='right')
        
        ax.text(c.size[0]*0.2, c.size[1]*0.83, r'$v_{stim}$', va='bottom', ha='right')  
        ax.text(c.size[0]*0.82, c.size[1]*0.83, r'$i_{retina}$', va='bottom', ha='right')
        
        ax.text(c.size[0]*0.47, c.size[1]*0.65, r'$\epsilon_{retina}$', va='center', ha='left')
        ax.text(c.size[0]*0.83, c.size[1]*0.65, r'$\sigma_{retina}$',  va='center', ha='right')
        
        ax.text(c.size[0]*0.47, c.size[1]*0.35, r'$\epsilon_{ames}$', va='center', ha='left')
        ax.text(c.size[0]*0.83, c.size[1]*0.35, r'$\sigma_{ames}$',  va='center', ha='right')

In [None]:
fig, axs = plt.subplots(1,3,figsize=(8,2))
for c, ax in zip([c1, c2, c3], axs):
    plot_circuit(ax, c)

## Plot traces

In [None]:
def sin(A, phi, time, f):
    return A*np.sin(2*np.pi*time*f+phi/180*np.pi)

In [None]:
line_kw = dict(lw=0.9, alpha=1.0)

In [None]:
def make_trace_axis(axs):
    for ax in axs:
        ax.set_xlabel('Time (ms)', labelpad=-5)
        
    axs[0].set_xticks([0,40])
    axs[1].set_xticks([0,25])

In [None]:
from matplotlib import cm
mapper = cm.get_cmap('viridis_r', 8)
    
def V2color(V, f):
    assert V in Vs_val[f]
    return mapper(np.argmax(np.array(Vs_val[f]) == V))

### Plot stim voltage

In [None]:
def plot_voltage(axs):
    
    assert axs.size == 2
    
    for ax in axs: ax.set_title(r"$v_{stim}$")
    
    for ax, (f, Vs) in zip(axs, Vs_opt.items()):
        for V in Vs:
            ax.plot(1e3*cur_time[f], sin(V, 0.0, cur_time[f], f), **line_kw, color=V2color(V, f))  
    
    axs[0].set_ylabel('V (mV)')
    make_trace_axis(axs)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,2))
plot_voltage(axs)

### Plot measured currents.

In [None]:
def plot_measured_I(axs):
    for ax in axs: ax.set_title(r"$i_{ames}$")
    for ax, (f, Vs) in zip(axs, Vs_opt.items()):
        for V in Vs:
            ax.plot(1e3*cur_time[f], ex_currents['wo'][f][str(V) + " mV"]*1e6, **line_kw, color=V2color(V, f))
            ax.plot(
                1e3*cur_time[f],
                sin(currents_fit_sin_params['wo'][f][V][0], currents_fit_sin_params['wo'][f][V][1], cur_time[f], f)*1e6,
                'k--', alpha=0.8, lw=line_kw['lw']
            ) 

    axs[0].set_ylabel(r'I ($\mu$A)')
    make_trace_axis(axs)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,2))
plot_measured_I(axs)

### Plot V_ames

In [None]:
def plot_V_ames(axs):
    for ax in axs: ax.set_title(r"$v_{ames}$")
    for ax, (f, Vs) in zip(axs, Vs_opt.items()):
        for V in Vs:
            ax.plot(
                1e3*cur_time[f],
                sin(V_ames_sinus_fits_params[f][V][0], V_ames_sinus_fits_params[f][V][1], cur_time[f], f),
                 **line_kw, color=V2color(V, f)
            )
    axs[0].set_ylabel(r'V (mV)')
    make_trace_axis(axs)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,2))
plot_V_ames(axs)

## Plot others

### Plot RC params

In [None]:
RC_plot_info = {
    'R': {'unit': 1e-3, 'title': r'$R_e$', 'label': r'R (k$\Omega$)'},
    'C': {'unit': 1e9,  'title': r'$C_e$', 'label': 'C (nF)'},
}

In [None]:
import matplotlib.patheffects as path_effects

colordata_f = {25: 'k', 40: 'dimgray'}

def plot_RC_params(ax, param):
    data = {}
    for f in [25, 40]:
        data[f] = [RC_params[f][V][param]*RC_plot_info[param]['unit'] for V in Vs_val[f]]
    
    for f, xdata in Vs_val.items():
        ax.plot(xdata[-1], data[f][-1], marker='d',  color=V2color(Vs_val[f][-1], f), clip_on=False, markersize=6, markeredgewidth=0)
        ax.plot(xdata[2],  data[f][2],  marker='d',  color=V2color(Vs_val[f][2], f), clip_on=False, markersize=6, markeredgewidth=0)
        ax.plot(xdata,     data[f],     '.-', c=colordata_f[f], clip_on=False, markersize=4, lw=1.0, label=str(f)+ ' Hz')

    ax.set_xlabel(r'V (mV)', labelpad=-5)
    ax.set_title(RC_plot_info[param]['title'])
    ax.set_ylabel(RC_plot_info[param]['label'])
    ax.set_xticks([50, 600])

In [None]:
def plot_RC(axs):
    plot_RC_params(axs[0], param='R')
    plot_RC_params(axs[1], param='C')
    axs[1].legend(loc='lower right', frameon=False, borderpad=0.0, labelspacing=0.01, handlelength=0.8,
                  handletextpad=0.1, bbox_to_anchor=(1.1,-0.2))

In [None]:
fig, axs = plt.subplots(1,2,figsize=(3.3,1))
plot_RC(axs)

### Plot optimized parameters

In [None]:
opt_params_step1_kw = dict(marker='.', ls='None', markeredgewidth=0.0, c='dimgray', markersize=4, clip_on=False)
opt_params_step2_kw = dict(marker='.', ls='None', markeredgewidth=0.0, c='k',       markersize=4, clip_on=False, zorder=50)
opt_params_best_kw  = dict(marker='.', ls='None', markeredgewidth=0.0, c='r',       markersize=6, clip_on=False, zorder=100)

In [None]:
def plot_retina_param(ax, param):
    ax.set(xscale='log', yscale='log')
    
    data = eps if 'epsilon' in param else sig
    
    ax.plot(data[1]*p_unit[param], losses[1], **opt_params_step1_kw, label='Log.')
    ax.plot(data[2]*p_unit[param], losses[2], **opt_params_step2_kw, label='Lin.')
    ax.plot(data[2][losses_sort[2][0]]*p_unit[param], losses[2][losses_sort[2][0]], **opt_params_best_kw, label='Best')

In [None]:
from matplotlib import ticker

def plot_epsilon_and_sigma(axs):
    plot_retina_param(axs[0], param='epsilon_retina')
    axs[0].set_title(r'$\epsilon_{retina}$')
    axs[0].set_ylabel(r'Discrepancy')
    axs[0].set_xlabel(r'$\epsilon$', labelpad=-10)
    axs[0].set_xticks([1e5, 1e7])
    axs[0].get_xaxis().set_major_formatter(ticker.LogFormatterSciNotation())
    axs[0].set_yticks([1e-2, 1e0])
    axs[0].get_xaxis().set_major_formatter(ticker.LogFormatterSciNotation())
    
    
    plot_retina_param(axs[1], param='sigma_retina')
    axs[1].set_title(r'$\sigma_{retina}$')
    axs[1].set_xlabel(r'$\sigma$ (S/m)', labelpad=-10)
    axs[1].set_xticks([1e-2, 1e0])
    axs[1].get_xaxis().set_major_formatter(ticker.LogFormatterSciNotation())
    axs[1].set_yticks([1e-2, 1e0])
    axs[1].get_xaxis().set_major_formatter(ticker.LogFormatterSciNotation())
    
    axs[1].legend(loc='lower left', frameon=False, borderpad=0, borderaxespad=0,
                  labelspacing=0.01, handlelength=0.5, handletextpad=0.1, bbox_to_anchor=(0,-0.2))

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,2))
plot_epsilon_and_sigma(axs)

### Plot I retina

In [None]:
def plot_I_retina(axs):
    for ax in axs: ax.set_title(r"$i_{retina}$")
        
    for ax, (f, Vs) in zip(axs, Vs_val.items()):
        for Vidx, V in enumerate(Vs):
            
            color = V2color(V, f)
            
            ax.plot(1e3*I_retina[f][V]['Time'], 1e6*I_retina[f][V]['Current'],
                    c=color, zorder=5, **line_kw, label='rec' if Vidx==5 else '_')

            time = target[f][V]['Time'] - 2/f
            current = target[f][V]['Current']
            idx = np.logical_and(time >= 0, time < 1/f)
            ax.plot(1e3*time[idx], 1e6*current[idx], '--', c=color, zorder=10, **line_kw, label='fit' if Vidx==5 else '_')

    axs[0].legend(frameon=False, borderpad=0.0, labelspacing=0.01, borderaxespad=0.3,
                  handletextpad=0.3, handlelength=1.2, loc='lower left', bbox_to_anchor=(0,-0.2))
    axs[0].set_ylabel(r'I ($\mu$A)')
    make_trace_axis(axs)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,2))
plot_I_retina(axs)

# Make figure

In [None]:
plt.rcParams['axes.titlepad'] = 0.01

In [None]:
from matplotlib import pyplot as plt
 
fig = plt.figure(figsize=(5.6, 3.7))
 
gs1n = 3
gs1n_yfrac = 0.33
gs1 = fig.add_gridspec(1, gs1n, width_ratios=[0.88,1,1], hspace=0)
gs1axs = [fig.add_subplot(gs1[idx]) for idx in range(gs1n)]
gs1.tight_layout(fig, rect=[None, 1-gs1n_yfrac, None, 1], w_pad=-5, h_pad=0, pad=0)

gs2nx = 5
gs2ny = 3
gs2 = fig.add_gridspec(gs2ny, gs2nx, width_ratios=[1,1,0.01,1,1])
gs2axs = [fig.add_subplot(gs2[idxrow, idxcol]) for idxrow in range(gs2ny) for idxcol in [0,1,3,4]]

plot_utils.move_xaxis_outward(gs2axs)

# Plot circuits.
for c, ax in zip([c1, c2, c3], gs1axs):
    plot_circuit(ax, c)

# Plot traces.
plot_voltage(np.array(gs2axs[0:2]))
plot_measured_I(np.array(gs2axs[4:6]))
plot_V_ames(np.array(gs2axs[8:10]))

# Plot optimization data.
plot_RC(np.array(gs2axs[2:4]))
plot_epsilon_and_sigma(np.array(gs2axs[6:8]))
plot_I_retina(np.array(gs2axs[10:12]))

gs2.tight_layout(fig, rect=[None, 0, None, 1-gs1n_yfrac+0.05], w_pad=0.0, h_pad=0.3)

sns.despine()

fig.align_ylabels([gs2axs[i] for i in np.arange(0,12,4)])
fig.align_ylabels([gs2axs[i] for i in np.arange(1,12,4)])
fig.align_ylabels([gs2axs[i] for i in np.arange(2,12,4)])

abc = 'BECFDG'
for ii, i in enumerate(np.arange(0,12,2)):
    gs2axs[i].set_title(abc[ii] + '          i', loc='left', ha='right', va='bottom', fontweight="bold", pad=0.1)
    gs2axs[i+1].set_title('ii', loc='left', ha='right', va='bottom', fontweight="bold", pad=0.1)
    
x0 = gs2axs[0].get_position().bounds[0]
xw = gs2axs[3].get_position().bounds[0] + gs2axs[3].get_position().bounds[2] - x0
gs1title_axs = [fig.add_axes(
    np.array([x+x0, 0.93, 0.02, 0.02])) for x in np.linspace(0,xw,4)[:3]]
    
for ax, title in zip(gs1title_axs, ['A          i', 'ii', 'iii']):
    ax.axis('off')
    ax.set_title(title, loc='left', ha='right', va='bottom', fontweight="bold")

plt.savefig(f'../_figures/fig{fig_num}_comsol.pdf', dpi=300)
plt.show()

# Export Data for text

In [None]:
for r, losses_i in losses.items():
    print(r, losses_i.min())

In [None]:
best_sigma = sig[2][losses_sort[2][0]]*p_unit['sigma_retina']
best_epsilon = eps[2][losses_sort[2][0]]*p_unit['epsilon_retina']

In [None]:
best_sigma

In [None]:
best_epsilon

In [None]:
from datetime import datetime
text_output = []

text_output.append('%' + str(datetime.now()) + '\n')

text_output.append("\\newcommand\\optimizedSIG{" + "{:.2g}".format(best_sigma) + "}\n")
text_output.append("\\newcommand\\optimizedEPS{" + "{:.2g}".format(best_epsilon) + "}\n")

data_utils.make_dir('text_data')
with open('text_data/optimizedElParams.tex', 'w') as f:
    f.writelines(text_output)

In [None]:
text_output