# Preprocess optimized stimulus data

- [Define data to load](#Select-data)
- [Load selected data](#Load-data)
- [Create example stimulus for figure](#Example-stimulus)
- [Store data to folder](#Save-data)
- [Copy stimulated data with removed ion channels](#Get-RM-channels-data)

# Imports

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

In [None]:
import os
import sys

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', '..', '_pythoncode'))
sys.path = [pythoncodepath] + sys.path
import importhelper
importhelper.addfolders2path(pythoncodepath)

In [None]:
import data_utils

# Select data

In [None]:
base_folder = os.path.join('..', '..', 'step4_optimize_stimulus', 'optim_data')
os.listdir(base_folder)

In [None]:
data_folder = os.path.join(base_folder, 'optimize_stimulus_submission2')
os.listdir(data_folder)

# Load data

In [None]:
rec_time = data_utils.load_var(os.path.join(data_folder, 'rec_time.pkl'))
raw_stim_time = data_utils.load_var(os.path.join(data_folder, 'stim_time.pkl'))
predur_stim = data_utils.load_var(os.path.join(data_folder, 'predur_stim.pkl'))

idx_stim_onset = np.argmax(raw_stim_time >= predur_stim)
stim_time = raw_stim_time[idx_stim_onset:] - raw_stim_time[idx_stim_onset]

stimgen = data_utils.load_var(os.path.join(data_folder, 'stim_generator.pkl'))

In [None]:
cells = [('CBC3a', 'OFF'), ('CBC5o', 'ON')]
cell_data = {celltype: {} for cell, celltype in cells}

for cell, celltype in cells:
    snpe_folder = os.path.join(data_folder, 'target_' + cell + '_snpe')
    
    dists_list = data_utils.load_var(os.path.join(snpe_folder, 'sample_distributions.pkl'))
    cell_data[celltype]['prior'] = dists_list[0]
    cell_data[celltype]['post_list'] = dists_list[1:]
    
    post_folder = os.path.join(data_folder, 'post_data_' + cell)
    best_stimuli = data_utils.load_var(os.path.join(post_folder, 'best_stimuli.pkl'))
    sampled_stimuli = data_utils.load_var(os.path.join(post_folder, 'post_sampled_stimuli.pkl'))
    
    cell_data[celltype]['best_stimuli'] = best_stimuli[:,idx_stim_onset:]    
    cell_data[celltype]['post_sampled_stimuli'] = sampled_stimuli[:,idx_stim_onset:]
    
    cell_data[celltype]['rrps'] = data_utils.load_var(os.path.join(data_folder, 'retsim', f'{celltype}_rrps.pkl'))

In [None]:
import delfi_funcs
delfi_optim = delfi_funcs.EmptyDELFI_Optimizer()

In [None]:
def load_samples(samples_folder):
    sample_files = sorted(os.listdir(samples_folder))
    print('All files:')
    print(sample_files)
    
    sample_files = [os.path.join(samples_folder, sample_file) for sample_file in sample_files]
    
    samples, n_samples, d_sort_index = delfi_optim.load_samples(
        files=sample_files, concat_traces=True, list_traces=False,
        return_sort_idx=True, return_n_samples=True,
        verbose=False
    )
    
    return samples, d_sort_index, n_samples

In [None]:
for cell, celltype in cells:
    samples, d_sort_index, n_samples = load_samples(os.path.join(data_folder, 'target_' + cell + '_samples'))
    cell_data[celltype]['samples'] = samples
    cell_data[celltype]['d_sort_index'] = d_sort_index

# Example stimulus

In [None]:
stimgen_params = np.array([-0.5, -0.7, 0.3, 0.5])
assert stimgen_params.size == stimgen.n_params, f'Define exactly {stimgen.n_params} params'

In [None]:
stimgen_stim = stimgen.create_stimulus(params=stimgen_params, plot=True, filename=None)

In [None]:
if stimgen.stim_mode == 'spline':
    stimgen_stim, stim_anchor_points_time, stim_anchor_points_amp =\
        stimgen.create_stimulus_spline(params=stimgen_params, verbose=False)

elif stimgen.stim_mode == 'charge neutral':
    stimgen_stim, stim_anchor_points_time, stim_anchor_points_amp =\
        stimgen.create_stimulus_charge_neutral(params=stimgen_params, verbose=False, var_dur=False)

elif stimgen.stim_mode == 'charge neutral var dur':
    stimgen_stim, stim_anchor_points_time, stim_anchor_points_amp =\
        stimgen.create_stimulus_charge_neutral(params=stimgen_params, verbose=False, var_dur=True)

In [None]:
stim_idx_change = np.argwhere(((stimgen_stim >= 0).astype(int)[1:] - (stimgen_stim < 0).astype(int)[0:-1])==0).flatten()

In [None]:
if stimgen.normalize_stim:
    factor = 1/np.max(np.abs(stimgen_stim))
else:
    factor = 1.

stimgen_stim *= factor*stimgen.stim_mulitplier
stim_anchor_points_amp *= factor*stimgen.stim_mulitplier

In [None]:
t0 = stimgen.stim_time[stimgen.idx_start_stim]

In [None]:
example_stim = {}
example_stim['time'] = stimgen.stim_time[stimgen.idx_start_stim:stimgen.idx_stop_stim] - t0
example_stim['anchor_time'] = stim_anchor_points_time - t0
example_stim['anchor_points'] = stim_anchor_points_amp
example_stim['stim'] = stimgen_stim[stimgen.idx_start_stim:stimgen.idx_stop_stim]
example_stim['idx_change'] = stim_idx_change - stimgen.idx_start_stim

In [None]:
plt.plot(example_stim['time'], example_stim['stim'])
plt.plot(example_stim['anchor_time'], example_stim['anchor_points'], 'x');

# Save data

In [None]:
data_utils.make_dir('data')

In [None]:
data_utils.save_var(stim_time, os.path.join('data', 'stim_time.pkl'))
data_utils.save_var(rec_time, os.path.join('data', 'rec_time.pkl'))

In [None]:
data_utils.save_var(cell_data, os.path.join('data', 'cell_data.pkl'))

In [None]:
data_utils.save_var(example_stim, os.path.join('data', 'example_stim.pkl'))

# Get RM channels data

In [None]:
from shutil import copy

data_utils.make_dir('removed_ion_channels')

files = os.listdir(os.path.join(data_folder, 'removed_ion_channels'))

for file in files:
    copy(os.path.join(data_folder, 'removed_ion_channels', file),
         os.path.join('removed_ion_channels', file))