# Imports

In [None]:
import importlib
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import itertools

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils
import plot_utils

In [None]:
plot_utils.set_rcParams()

In [None]:
fig_num = os.getcwd().split('/')[-1][3:5]
print(fig_num)

# Get data.

In [None]:
data_folder = os.path.join('..', '..', 'step3b_thresholds')
[filename for filename in os.listdir(data_folder) if 'bc' in filename]

In [None]:
bc_data_folder = f'{data_folder}/bc_data_submission2'
assert os.path.isdir(bc_data_folder)
os.listdir(bc_data_folder)

In [None]:
AxA_list = ['1x1', '2x2', '4x4', '10x10']
j_list = data_utils.load_var(f'{bc_data_folder}/run_j_list.pkl')

## Cell positions

In [None]:
dxdy_list = data_utils.load_var(f'{bc_data_folder}/dxdy_list.pkl')
dist_list = list(np.sqrt(np.sum(dxdy_list**2,axis=1)))

N_cells = len(dxdy_list)

plt.figure(1,(3,3))
plt.plot(dxdy_list[:,0], dxdy_list[:,1], '*')
plt.show()

## Current data

In [None]:
current_data = {
    "I": data_utils.load_var(os.path.join(data_folder, 'data_preprocessed', 'from_raw_I.pkl')),
    "X": data_utils.load_var(os.path.join(data_folder, 'data_preprocessed', 'from_raw_X.pkl')),
}

In [None]:
cell2rrp = {
    'ON': data_utils.load_var(f'{bc_data_folder}/ON_rrps.pkl'),
    'OFF': data_utils.load_var(f'{bc_data_folder}/OFF_rrps.pkl'),
}

## Get GC data.

These numbers are from the NMI, I do not have the raw data.
For comparison look at Corna et al.

In [None]:
gc_thresh = {}
gc_thresh['1x1'] = {}
gc_thresh['1x1']['mean'] = 0.52
gc_thresh['1x1']['std']  = 0.2

gc_thresh['2x2'] = {}
gc_thresh['2x2']['mean'] = 0.054
gc_thresh['2x2']['std']  = 0.020

gc_thresh['4x4'] = {}
gc_thresh['4x4']['mean'] = 0.025
gc_thresh['4x4']['std']  = 0.020

gc_thresh['10x10'] = {}
gc_thresh['10x10']['mean'] = 0.017
gc_thresh['10x10']['std']  = 0.011

## Get simulated BC data

In [None]:
def plot_sim_list(sim_list):
    
    plot_jidxs = [0, 1, 3, 4]
    plot_cellidxs = [0, 1, 4]
    
    fig, axs = plt.subplots(1,len(plot_jidxs),figsize=(10,1), sharex=True)
    
    for ax, jidx in zip(axs, plot_jidxs):
        ax.set_title('j'+str(jidx))
        for p_idx, sim_list_cell_params in enumerate(sim_list[jidx]):
            for cellidx in plot_cellidxs:
                rec_time = sim_list_cell_params[cellidx][1]
                mean_rate = sim_list_cell_params[cellidx][0]['rate BC'].mean(axis=1)
                ax.plot(rec_time, mean_rate, label='_', c='C'+str(p_idx), alpha=1-(cellidx)/len(plot_cellidxs))
        #ax.legend(loc='upper right')
    plt.show()

In [None]:
bc_ves_release = {cell: {AxA: {} for AxA in AxA_list} for cell in ['ON', 'OFF']}

for cell, AxA in itertools.product(['ON', 'OFF'], AxA_list):
    filename = os.path.join(bc_data_folder, f'sim_{cell}_{AxA}.pkl')
    sim_list = data_utils.load_var(filename)

    print(cell, AxA, filename)
    print('N currents:', len(sim_list))
    print('N cell parameters:', len(sim_list[0]))
    print('N cells:', len(sim_list[0][0]))
    
    plot_sim_list(sim_list)

    assert len(sim_list) == len(j_list)
    assert np.unique([len(sim_list_i) for sim_list_i in sim_list]).size == 1
    assert np.all([[len(sim_list_ii) == len(dxdy_list) for sim_list_ii in sim_list_i]
                   for sim_list_i in sim_list])

    n_cell_params = len(sim_list[0])

    bc_ves_release[cell][AxA] = np.full((len(j_list), len(dxdy_list), n_cell_params), np.nan)
    for ji, pi, ci in itertools.product(range(len(j_list)), range(n_cell_params), range(len(dxdy_list))):
        rec_time = sim_list[ji][pi][ci][1]
        mean_rate = sim_list[ji][pi][ci][0]['rate BC'].mean(axis=1)
        total_mean_release = np.mean(mean_rate)*(rec_time[-1]-rec_time[0])
        bc_ves_release[cell][AxA][ji, ci, pi] = total_mean_release / cell2rrp[cell][pi]
                
del sim_list

# Plot

In [None]:
ABC_space = '        '

## Plot params

In [None]:
j_plot_list = j_list.copy()
romans = ["i", "ii", "iii", "iv"]

### Define colormapping

In [None]:
from matplotlib import cm as plt_cm
from matplotlib import colors as plt_colors

cmap   = plt_cm.viridis
norm   = plt_colors.Normalize(vmin=-2, vmax=N_cells, clip=False)
mapper = plt_cm.ScalarMappable(norm=norm, cmap=cmap)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,1))
for i in range(N_cells):
    ax.plot(i, 0, marker='o', ms='10', c=mapper.to_rgba(i))

In [None]:
cmap_I   = plt_cm.Reds
norm_I   = plt_colors.Normalize(vmin=-3, vmax=len(j_plot_list), clip=False)
mapper_I = plt_cm.ScalarMappable(norm=norm_I, cmap=cmap_I)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,1))
for i in range(len(j_plot_list)):
    ax.plot(i, 0, marker='o', ms='10', c=mapper_I.to_rgba(i))

## Plot helper functions

### Plot setup

In [None]:
def plot_setup(ax, AxA, colidx):
    
    A = int(AxA[AxA.find('x')+1:])
    
    if AxA == '1x1':
        el_x = np.array([0])
    else:
        el_x = np.arange(-70*(A-1)/2, 70*(A-1)/2+1, 70)
    
    xx, yy = np.meshgrid(el_x, el_x)
    ax.plot(xx, yy, markersize=2.9, markeredgecolor=(0.8, 0, 0, 0.0), markerfacecolor=(0.8, 0, 0, 0.7),
               marker='o', label='electrodes', zorder=30, ls='None')
    
    for i, (dxdy, dist) in enumerate(zip(dxdy_list, dist_list)):
        ax.plot(
            dxdy[0], dxdy[1], ls='None', marker='+', c=mapper.to_rgba(i),
            label='cells' if i == 0 else None, markersize=8, markeredgewidth=1.2,
            clip_on=False, zorder=20-i, alpha=1
        )
    
    ax.set_aspect('equal')

    ax.set_xlim(-500, 500)
    ax.set_ylim(-500, 500)
    
    if colidx != 0:
        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.spines['left'].set_visible(False)
    else:
        ax.set_ylabel(r"y ($\mu$m)")
        ax.set_yticks((-400, 0, 400))
        
    ax.set_xlabel(r"x ($\mu$m)")
            
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    
    ax.spines['left'].set_bounds(-400, 400)
    ax.spines['bottom'].set_bounds(-400, 400)
    
    ax.set_xticks((-400, 0, 400))
    

In [None]:
fig, axs = plt.subplots(1,4,figsize=(6.7,1.))

for colidx, (ax, AxA) in enumerate(zip(axs, AxA_list)):
    plot_setup(ax=ax, AxA=AxA, colidx=colidx)

### Plot currents.

In [None]:
currents_cols = ['I'+str(j)+ " [A]" for j in j_plot_list]

def plot_currents(ax, AxA, colidx):
    
    ax.set_xlim(0,10)
    for ci, currents_col in enumerate(currents_cols):
        ax.plot(
            current_data['I'][AxA]['Time [s]']*1e3, 1e6*current_data['I'][AxA][currents_col],
            c=mapper_I.to_rgba(ci), lw=1, zorder=-ci
        )
    
    ax.set_xlabel(r'Time (ms)')
    if colidx == 0: ax.set_ylabel(r'Current ($\mu$A)')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_bounds(0, 10)
    
    I_max = 1e6*np.max(np.abs(current_data['I'][AxA].iloc[:,np.asarray(j_plot_list)+1].values), axis=None)
    tickmax = (I_max - I_max % 5)
    ax.set_yticks([-tickmax, 0, tickmax])
    
    ax.spines['left'].set_bounds(-I_max, I_max)
    ax.set_ylim(-I_max*1.2, I_max*1.2)
    ax.set_xticks([0,5,10])

In [None]:
fig, axs = plt.subplots(1,4,figsize=(6.7,6.7/4))

for colidx, (ax, AxA) in enumerate(zip(axs, AxA_list)):
    plot_currents(ax=ax, AxA=AxA, colidx=colidx)

##### Export data

In [None]:
data_utils.make_dir('source_data')

In [None]:
for AxA in AxA_list:
    current_exdf = pd.DataFrame({'Time/ms': current_data['I'][AxA]['Time [s]'].values*1e3})

    for ci, currents_col in enumerate(currents_cols):
        current_exdf['I'+str(ci+1)+'/uA'] = current_data['I'][AxA][currents_col]*1e6

    current_exdf.to_csv('source_data/Stimulation_currents_'+AxA+'.csv', float_format='%.6f', index=False)

In [None]:
fig, axs = plt.subplots(1,4,figsize=(12,3))
for ax, AxA in zip(axs, AxA_list):
    current_exdf = pd.read_csv('source_data/Stimulation_currents_'+AxA+'.csv')
    current_exdf.plot(x='Time/ms', ax=ax, title=AxA)

### Plot response

In [None]:
def plot_response(ax, AxA, colidx, cell, ABC='', isbottompanel=True):
    
    x_data = np.array([0.1*current_data['X'][AxA]['I' + str(j)] for j in j_plot_list])
    idx_sort = np.argsort(x_data)
    x_data = x_data[idx_sort]
    
    for ci, cell_dist in enumerate(dist_list):
        
        y_data = bc_ves_release[cell][AxA][:, ci, :]
        y_data = y_data[idx_sort, :]
        
        ax.plot(
            x_data, np.mean(y_data, axis=1),  ls='-', marker='+', label='_',
            c=mapper.to_rgba(ci), zorder=-ci, alpha=0.3
        )
        
        ax.errorbar(
            x_data, np.mean(y_data, axis=1), yerr=np.std(y_data, axis=1),
            ls='None', marker='o', label="{:.2g} um".format(cell_dist),
            c=mapper.to_rgba(ci), zorder=-ci, lw=1, markersize=4, clip_on=False,
        )
        
    ax.set_ylim(0,None)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
 
    # plot mean gc threshold and std
    ax.axvline(gc_thresh[AxA]['mean'], c='dimgray', linestyle='--', zorder=20)
    ax.fill_between(
        [gc_thresh[AxA]['mean'] - gc_thresh[AxA]['std'], gc_thresh[AxA]['mean'] + gc_thresh[AxA]['std']],
        [0, 0], [ax.get_ylim()[1], ax.get_ylim()[1]],
        facecolor='dimgray', alpha=0.5, zorder=-30
    )
    
    if not isbottompanel:
        ax.spines['bottom'].set_visible(False)
        ax.set_xticks([])
    else:
        ax.set_xlabel(r'Charge dens. (mC/cm²)')
        ax.spines['bottom'].set_position(("axes", -0.1))
        
    if colidx == 0:
        ax.set_ylabel(r'Release / $v^{max}_{RRP}$')
    else:
        ax.set_yticks([])
        ax.spines['left'].set_visible(False)

In [None]:
fig, axs = plt.subplots(1,4,figsize=(6.7,6.7/4))
for colidx, (ax, AxA) in enumerate(zip(axs, AxA_list)):
    plot_response(ax=ax, AxA=AxA, cell='OFF', colidx=colidx)

fig, axs = plt.subplots(1,4,figsize=(6.7,6.7/4))
for colidx, (ax, AxA) in enumerate(zip(axs, AxA_list)):
    plot_response(ax=ax, AxA=AxA, cell='ON', colidx=colidx)

##### Export data

In [None]:
columns = ['Distance/um'] + ['Release/RRP for I'+str(i+1) for i in range(len(j_plot_list))]
index = np.tile(np.array(['Cell'+str(i+1) for i in range(n_cell_params)]), len(dist_list))

for cell in ['OFF', 'ON']:
    for AxA in AxA_list:
        cell_data_ex = np.full((len(dist_list)*n_cell_params, 1+len(j_plot_list)), np.nan)
        cell_data_ex[:,0] = np.repeat(np.array(dist_list), n_cell_params)

        for i in range(len(j_plot_list)):
            cell_data_ex[:,i+1] = (bc_ves_release[cell][AxA][i,:,:] * cell2rrp[cell]).flatten()

        cell_data_exdf = pd.DataFrame(cell_data_ex, columns=columns, index=index)

        for i in range(len(j_plot_list)):
            for ci, cell_dist in enumerate(dist_list):
                assert np.allclose(
                    bc_ves_release[cell][AxA][i,ci,:] * cell2rrp[cell],
                    cell_data_exdf['Release/RRP for I'+str(i+1)][cell_data_exdf['Distance/um'] == cell_dist].values
                )

        cell_data_exdf.to_csv('source_data/BC_response_'+cell+'_'+AxA+'.csv', float_format='%.6f')

##### Test

In [None]:
for cell in ['OFF', 'ON']:
    print(cell)
    fig, axs = plt.subplots(1,4,figsize=(6.7,6.7/4), sharey=True)
    for ax, AxA in zip(axs, AxA_list):
        cell_data_exdf = pd.read_csv('source_data/BC_response_'+cell+'_'+AxA+'.csv')
        cell_data_exdf.plot(x='Distance/um', marker='.', title=AxA, ax=ax, legend=False)
        if AxA=='10x10':
            ax.legend(bbox_to_anchor=(1,1))
    plt.show()

##### Export

In [None]:
pd.DataFrame(cell2rrp, index=['Cell'+str(i+1) for i in range(5)]).to_csv('source_data/RRP_sizes.csv', float_format="%.4f")

In [None]:
gc_thresh[AxA]

gc_tresh_df = pd.DataFrame()
for AxA in AxA_list:
    gc_tresh_df.loc[AxA, 'Threshold_mean/(mC/cm^2)'] = gc_thresh[AxA]['mean']
    gc_tresh_df.loc[AxA, 'Threshold_std/(mC/cm^2)'] = gc_thresh[AxA]['std']
    
gc_tresh_df.to_csv('source_data/GC_thresholds.csv', float_format='%.3f')
gc_tresh_df

# Make figure

In [None]:
# Plot.
fig, axs = plt.subplots(
    ncols=len(AxA_list), nrows=4, figsize=(5.6, 5.6), squeeze=False,
    gridspec_kw={'height_ratios': [0.85, 1.25, 0.85, 0.85]}
)
        
for colidx, (ax, AxA) in enumerate(zip(axs[0,:], AxA_list)):
    plot_currents(ax=ax, AxA=AxA, colidx=colidx)

for colidx, (ax, AxA) in enumerate(zip(axs[1,:], AxA_list)):
    plot_setup(ax=ax, AxA=AxA, colidx=colidx)
    
for colidx, (ax, AxA) in enumerate(zip(axs[2,:], AxA_list)):
    plot_response(ax=ax, AxA=AxA, cell='OFF', colidx=colidx, isbottompanel=False, ABC='C')  
    
for colidx, (ax, AxA) in enumerate(zip(axs[3,:], AxA_list)):
    plot_response(ax=ax, AxA=AxA, cell='ON', colidx=colidx, isbottompanel=True, ABC='D')  
            
fig.align_ylabels()

for row_idx, row_title in enumerate('ABCD'):
    axs[row_idx, 0].set_title(row_title + ABC_space + romans[0], loc='left', horizontalalignment='right', fontweight="bold")
    for col_idx in np.arange(1, len(AxA_list)):
        axs[row_idx, col_idx].set_title(romans[col_idx], loc='left', horizontalalignment='right', fontweight="bold")

for ax, AxA in zip(axs[0,:], AxA_list):
    ax.set_title(AxA +'\n')

plt.tight_layout(pad=1, w_pad=0, h_pad=0.3)
    
for ax in axs[2,:]:
    box = np.array(ax.get_position().bounds)
    box[1] -= 0.02
    ax.set_position(box)
    
for ax in axs[0,:]:
    box = np.array(ax.get_position().bounds)
    box[1] += 0.02
    box[3] -= 0.02
    ax.set_position(box)

fig.align_ylabels(axs[:,0])
    
plt.savefig(os.path.join('..', '_figures', f'fig{fig_num}_thresholds_ensemble.pdf'))