# Reconstruct CBC stimulus from recordings

In [None]:
import pandas as pd
import numpy as np
import h5py
from matplotlib import pyplot as plt

from scipy.optimize import minimize

In [None]:
import sys
import os

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path

import importhelper
importhelper.addfolders2path(pythoncodepath)

import plot_peaks
import math_utils
import interpolation_utils

# Load Stimulus

## Load digital stimulus

Load stimulus that was designed digitally.

This stimulus is different from the stimulus that was displayed due to non-linearities and delays in the displaying process.

In [None]:
# Load data.
sub1_no_drug_file   = os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data', 'submission1', 'FrankeEtAl_BCs_2017_v1.mat')
with h5py.File(sub1_no_drug_file, 'r') as sub1_NoDrugsdata_raw:

    # Extract stimulus.
    stimulus_digital = {}
    stimulus_digital['Stim'] = np.array(sub1_NoDrugsdata_raw['chirp_stim']).flatten()
    stimulus_digital['Time'] = np.array(sub1_NoDrugsdata_raw['chirp_stim_time']).flatten()
    stimulus_digital = pd.DataFrame(stimulus_digital)

In [None]:
stimulus_digital.plot(x='Time', figsize=(15,2))
plt.show()

### Load recorded stimulus

This is a noisy recording of the stimulus that was acutally displayed.

- It's too noisy to be used for the models.

- Also the decreasing intensity after ~18s is a low-pass filtering artifact.

In [None]:
# Load recorded stimulus.
stimulus_recorded = pd.read_csv(os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data', 'Franke2017_recorded_stimulus.csv'))

In [None]:
stimulus_recorded.plot(x='Time', figsize=(15,2))
plt.show()

# Correcting the amplitude of the stimulus

In [None]:
# Parametrized sigmoidal gamma correction to be fitted.
def sigmoid(x, x0, k, L, b):
    return L / (1 + np.exp(-k*(x-x0))) + b

def correct_amp_sigmoid(monitor_input, sigmoid_params):
    sigmoid_params = np.asarray(sigmoid_params)
    assert sigmoid_params.size==4
    
    k  = sigmoid_params[0]
    x0 = sigmoid_params[1]
    L  = sigmoid_params[2]
    b  = sigmoid_params[3]
    
    return sigmoid(monitor_input, x0, k, L, b)

In [None]:
test_inputs = np.arange(0,255)

# Plot intesity curve.
plt.figure(figsize=(12,3))
plt.plot(test_inputs, math_utils.normalize(correct_amp_sigmoid(test_inputs, sigmoid_params=[0.03,190,0.9,0.01])))
plt.xlabel('Monitor Input')
plt.ylabel('normalized(Monitor Output)')
plt.show()

# Time correction

Amplitude correction is easier if the time is correct. So let's find the correct timing first.

In [None]:
pre_rng_dt=0.2
post_rng_dt=0.01
pre_rng_mean_dt=0.1

In [None]:
def plot_rng_and_find_step(rng):
    
    # Get indexes in range.
    rng_idxs = (stimulus_recorded['Time'].values >= rng[0]-pre_rng_dt) &\
               (stimulus_recorded['Time'].values <= rng[1]+post_rng_dt)
    
    # Get start index of range.
    rng_idx_start = np.argwhere(stimulus_recorded['Time'].values >= rng[0])[0][0]

    # Get indexes directly before range.
    pre_rng_idxs = (stimulus_recorded['Time'].values >= rng[0]-pre_rng_mean_dt) &\
                   (stimulus_recorded['Time'].values <= rng[0])
    
    # Compute mean and a measure for standard deviation before step range.
    mean = stimulus_recorded['Stim'].iloc[pre_rng_idxs].mean()
    stdm = 3*stimulus_recorded['Stim'].iloc[pre_rng_idxs].std()
    
    # Find step according to pre step  values.
    step_idx = np.argwhere((np.abs(stimulus_recorded['Stim'].values-mean) > stdm) &\
                           (stimulus_recorded['Time'].values >= rng[0]))[0][0]
    step_time = stimulus_recorded['Time'].values[step_idx]
    
    plt.figure(figsize=(15,6))
    plt.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'])
    
    plt.xlim(rng[0]-pre_rng_dt, rng[1]+post_rng_dt)
    plt.axvline(rng[0], c='r')
    plt.axvline(rng[1], c='darkred')
    plt.axvline(rng[0]-pre_rng_mean_dt, c='b')
    
    plt.axhline(mean, c='k')
    plt.axhline(mean-stdm, c='k', ls=':')
    plt.axhline(mean+stdm, c='k', ls=':')
    
    plt.axvline(step_time, c='orange', ls='--')
    
    plt.ylim(stimulus_recorded['Stim'].iloc[rng_idxs].min()-0.01, stimulus_recorded['Stim'].iloc[rng_idxs].max()+0.01)
    plt.show()
    
    return step_time

## Find steps

Align the steps of the digital stimulus with the recorded stimulus.

### 1st ON step

In [None]:
rng = (2.0, 2.01)
t_1st_step_on = plot_rng_and_find_step(rng=rng)
t_1st_step_on

### 1st OFF step

In [None]:
rng = (4.935, 4.98)
t_1st_step_off = plot_rng_and_find_step(rng=rng)
t_1st_step_off

### 2nd ON step

In [None]:
rng = (7.895, 7.92)
t_2nd_step_on = plot_rng_and_find_step(rng=rng)
t_2nd_step_on

### 2nd OFF step

In [None]:
rng = (29.1, 29.13)
t_2nd_step_off = plot_rng_and_find_step(rng=rng)
t_2nd_step_off

### Show all steps

In [None]:
plt.figure(figsize=(15,5))

for i, t in enumerate([t_1st_step_on, t_1st_step_off, t_2nd_step_on, t_2nd_step_off]):
    plt.subplot(1,4,i+1)
    plt.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'])
    plt.axvline(t, c='r')
    plt.xlim(t-0.1,t+0.1)
plt.show()

## Define time windows for both chirps

In [None]:
t_1st_chirp_on = 10 # Should be roughly aligned
t_1st_chirp_off = 18 # Does not have to be well aligned

t_2nd_chirp_on = 19 # Does not have to be well aligned
t_2nd_chirp_off = 28 # Does not have to be well aligned

t_between_chirps_on  = 17.6 # Does not have to be well aligned
t_between_chirps_off = 19.5 # Does not have to be well aligned

assert t_between_chirps_on < t_1st_chirp_off
assert t_between_chirps_off > t_2nd_chirp_on

In [None]:
plt.figure(1,(15,3))
ax = plt.subplot(111)
ax.plot(
    stimulus_digital['Time'],
    correct_amp_sigmoid(stimulus_digital['Stim'], sigmoid_params=[0.03,190,0.9,0.01]),
    '-', alpha=0.5, label='example'
)
ax.legend(loc='upper left')
ax.set_ylim(-0.1, 1.2)

for t in [
    t_1st_step_on,
    t_1st_step_off,
    t_2nd_step_on,
    t_2nd_step_off,
]:
    ax.axvline(t, c='r')

for t in [
    t_1st_chirp_on,
    t_1st_chirp_off,
    t_2nd_chirp_on,
    t_2nd_chirp_off,
]:
    ax.axvline(t, c='b')

ax.axvline(t_between_chirps_on, c='g')
ax.axvline(t_between_chirps_off, c='g')
    
ax2 = ax.twinx()
ax2.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'], 'k--', label='recorded')
ax2.legend(loc='upper right')

plt.show()

# Correct stimulus time

In [None]:
p_err = 2 # Exponent of error estimate. 2 --> MSE

In [None]:
# Set output time zu recorded time.
tout = stimulus_recorded['Time'].values.copy()
# Use digital stimulus as input.
tin  = stimulus_digital['Time'].values.copy()

In [None]:
# Get recorded first chirp.
tout_idx_1st_chirp = (tout >= t_1st_chirp_on) & (tout < t_1st_chirp_off)
rec_time_1st_chirp = stimulus_recorded['Time'][tout_idx_1st_chirp].values
rec_amp_1st_chirp  = stimulus_recorded['Stim'][tout_idx_1st_chirp].values

In [None]:
# Get digital 1st chirp, just roughly cut out, before and after it's constant.
digital_1st_chirp_idxs = np.arange(9500,19000)
digital_time_1st_chirp = tin[digital_1st_chirp_idxs]

plt.figure(1,(15,3))
plt.plot(stimulus_digital['Stim'])
plt.axvline(digital_1st_chirp_idxs[0], c='r')
plt.axvline(digital_1st_chirp_idxs[-1], c='r')
plt.show()

In [None]:
# Function to fit first chirp.
def loss_1st_chirp(t0_dt, digital_amp_1st_chirp, plot=False, return_array=False):
    
    t0 = t0_dt[0] # Time offset.
    dt = t0_dt[1] # Time per frame.
    
    # Create proposal time.
    time_proposal = t0+digital_time_1st_chirp[0]+np.arange(digital_amp_1st_chirp.size)*dt
    # Get proposal amplitude by interpolation.
    amp_proposal  = interpolation_utils.in_ex_polate(time_proposal, digital_amp_1st_chirp, rec_time_1st_chirp)
    
    # Compute loss.
    loss = np.mean(np.abs((amp_proposal - rec_amp_1st_chirp))**p_err)
        
    # Either return loss or amp proposal
    if return_array: return amp_proposal
        
    return loss

In [None]:
# Get recorded 2nd chirp.
tout_idx_2nd_chirp = (tout >= t_2nd_chirp_on) & (tout < t_2nd_chirp_off)
rec_time_2nd_chirp = stimulus_recorded['Time'][tout_idx_2nd_chirp].values
rec_amp_2nd_chirp  = stimulus_recorded['Stim'][tout_idx_2nd_chirp].values

In [None]:
# Get digital 2nd chirp, just roughly cut out, before and after it's constant.
digital_2nd_chirp_idxs = np.arange(19000,29000)
digital_time_2nd_chirp = tin[digital_2nd_chirp_idxs]

plt.figure(1,(15,3))
plt.plot(stimulus_digital['Stim'])
plt.axvline(digital_2nd_chirp_idxs[0], c='r')
plt.axvline(digital_2nd_chirp_idxs[-1], c='r')
plt.show()

In [None]:
# Function to fit second chirp.
def loss_2nd_chirp(t0_dt, digital_amp_2nd_chirp, plot=False, return_array=False):
    
    t0 = t0_dt[0] # Time offset.
    dt = t0_dt[1] # Time per frame.
    
    # Create proposal time.
    time_proposal = t0+digital_time_2nd_chirp[0]+np.arange(digital_amp_2nd_chirp.size)*dt
    # Get proposal amplitude by interpolation.
    amp_proposal  = interpolation_utils.in_ex_polate(time_proposal, digital_amp_2nd_chirp, rec_time_2nd_chirp)
    
    # Compute loss.
    loss = np.mean(np.abs((amp_proposal - rec_amp_2nd_chirp))**p_err)
        
    # Either return loss or amp proposal
    if return_array: return amp_proposal
        
    return loss

In [None]:
# Find high frequncy area. It will not be used to fit the amplitudes, because it's not correct in amplitude.
t_high_f_on = 12.5
t_high_f_off = 17.7

idxs_high_f = (tout >= t_high_f_on) & (tout <= t_high_f_off)

# Plot.
plt.figure(1,(15,3))
plt.subplot(121)
plt.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'])
plt.axvline(t_high_f_on, c='r')
plt.axvline(t_high_f_off, c='r')

plt.subplot(122)
plt.plot(stimulus_recorded['Time'].values[~idxs_high_f], stimulus_recorded['Stim'].values[~idxs_high_f])
plt.show()

In [None]:
# Corrects both time and ampltiude and returns the loss (or the fitted array).
# Final loss will not be computed on the high frequency area.
# High frequency range will be used to fit the timing though.
def transform_stim(sigmoid_params, return_array=False, plot_steps=False):

    sigmoid_params = np.asarray(sigmoid_params)
    
    yout = np.full(tout.size, np.nan)
    yin = correct_amp_sigmoid(stimulus_digital['Stim'], sigmoid_params=sigmoid_params).values.copy()
    
    half_step_amp = correct_amp_sigmoid(np.array([127]), sigmoid_params=sigmoid_params)
    full_step_amp = correct_amp_sigmoid(np.array([255]), sigmoid_params=sigmoid_params)
    base_amp      = correct_amp_sigmoid(np.array([0]),   sigmoid_params=sigmoid_params)

    # Correct steps before first chirp.
    yout[tout < t_1st_step_on] = base_amp
    yout[(tout >= t_1st_step_on) & (tout < t_1st_step_off)] = full_step_amp
    yout[(tout >= t_1st_step_off) & (tout < t_2nd_step_on)] = base_amp
    yout[(tout >= t_2nd_step_on) & (tout < t_1st_chirp_on)] = half_step_amp
    
    if plot_steps: plot(yout, title='Step 1')
    
    # Correct first chirp.
    digital_amp_1st_chirp = yin[9500:19000]
    best_params_1 = minimize(loss_1st_chirp, args=digital_amp_1st_chirp, x0=inital_params_1, bounds=bounds_params_1).x
    best_fit_1st_chirp = loss_1st_chirp(best_params_1, digital_amp_1st_chirp, return_array=True)
    yout[tout_idx_1st_chirp] = best_fit_1st_chirp
    
    # Make transition smooth.
    idx0_1st_chirp = np.argwhere(tout_idx_1st_chirp)[0][0]
    
    idx_1st_chirp_smooth_in = np.concatenate([np.arange(idx0_1st_chirp-300, idx0_1st_chirp-200),
                                              np.arange(idx0_1st_chirp, idx0_1st_chirp+100)])

    idx_1st_chirp_smooth_out = np.arange(idx_1st_chirp_smooth_in[0], idx_1st_chirp_smooth_in[-1])

    if plot_steps: plot(yout, title='Step 2', xlims=[(0,32), (tout[idx_1st_chirp_smooth_out][0],
                                                              tout[idx_1st_chirp_smooth_out][-1])])
    # Make transition smooth.
    yout[idx_1st_chirp_smooth_out] = interpolation_utils.in_ex_polate(
        x_old=tout[idx_1st_chirp_smooth_in], y_old=yout[idx_1st_chirp_smooth_in],
        x_new=tout[idx_1st_chirp_smooth_out], kind='cubic'
    )

    if plot_steps: plot(yout, title='Step 3', xlims=[(0,32), (tout[idx_1st_chirp_smooth_out][0],
                                                              tout[idx_1st_chirp_smooth_out][-1])])
    
    # Correct 2nd chirp.
    digital_amp_2nd_chirp = yin[19000:29000]
    best_params_2 = minimize(loss_2nd_chirp, args=digital_amp_2nd_chirp, x0=inital_params_2, bounds=bounds_params_2).x
    best_fit_2nd_chirp = loss_2nd_chirp(best_params_2, digital_amp_2nd_chirp, return_array=True)
    yout[tout_idx_2nd_chirp] = best_fit_2nd_chirp
    
    if plot_steps: plot(yout, title='Step 4')
    
    # Correct between chirps.
    between_chirps_on_idxs = (tout >= t_between_chirps_on) & (tout < t_between_chirps_off)
    yout[between_chirps_on_idxs] = half_step_amp
    
    if plot_steps: plot(yout, title='Step 5', xlims=[(0,32), (tout[between_chirps_on_idxs][0]-1,
                                                              tout[between_chirps_on_idxs][-1]+1)])
    
    # Correct last step.
    yout[(tout >= t_2nd_chirp_off) & (tout < t_2nd_step_off)] = half_step_amp
    yout[tout >= t_2nd_step_off] = base_amp    
    
    if plot_steps: plot(yout, title='Step 6', xlims=[(0,32), (tout[tout >= t_2nd_chirp_off][0]-1,
                                                              tout[tout >= t_2nd_chirp_off][-1])])
    
    # Do not include high frequency in loss.
    loss = np.mean(np.abs((stimulus_recorded['Stim'].values[~idxs_high_f] - yout[~idxs_high_f]))**p_err)
    
    if return_array:
        return loss, yout
    
    return loss

## Initial parameters

In [None]:
# Compute original step step as inital value.
dt_0 = np.mean(np.diff(stimulus_digital['Time']))

# For optimization of first chirp.
inital_params_1 = np.array([-0.1, 0.96*dt_0]) 
bounds_params_1 = [(-1,1), (0.8*dt_0,1.1*dt_0)]

# For optimization of second chirp.
inital_params_2 = np.array([-0.44, 0.96*dt_0])
bounds_params_2 = [(-1,1), (0.8*dt_0,1.1*dt_0)]

## Show example

Run the time correction, given a specific amplitude correction.

In [None]:
def plot(yout, xlims=[(0, 32)], title=None):
    
    # Plot.
    plt.figure(figsize=(12,3))
    
    for idx, xlim in enumerate(xlims):
        ax = plt.subplot(1,len(xlims),idx+1)
        if title is not None: ax.set_title(title)
        ax.set_xlim(xlim)
        ax.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'])
        ax.plot(tout, yout, alpha=0.8)
    
    plt.xlabel('Time [s]')
    plt.tight_layout()
    plt.show()

In [None]:
yout_test_loss, yout_test = transform_stim(sigmoid_params=[0.03, 180, 0.9, 0.01], return_array=True, plot_steps=True)
print('Loss = ', yout_test_loss)
plot(yout_test, title='Final')

In [None]:
yout_test_loss, yout_test = transform_stim(sigmoid_params=[0.1, 150, 0.9, 0.1], return_array=True)
print('Loss = ', yout_test_loss)
plot(yout_test)

# Optimize

## Run once for testing

In [None]:
initial_sigmoid_params0 = np.array([0.03, 180, 0.9, 0.01])
np.random.seed(1353)
solution = minimize(transform_stim, x0=initial_sigmoid_params0, method='Nelder-Mead')

In [None]:
assert solution.success

In [None]:
best_sigmoid_params = solution.x
best_sigmoid_params

In [None]:
loss_best, yout_best = transform_stim(sigmoid_params=best_sigmoid_params, return_array=True)
plot(yout_best, title='Loss='+str(loss_best))

assert solution.fun == loss_best

## Optimize with random initializations

In [None]:
def draw_params():
    return np.array([np.random.normal(0.03, 0.1),
                     np.random.normal(180, 30),
                     np.random.normal(0.9, 0.1),
                     np.random.normal(0.01, 0.1)])

In [None]:
np.random.seed(1353)
for i in range(10):
    print(i, end='\t')
    
    initial_sigmoid_params = draw_params()
    
    # Redraw if negativ first or third parameter.
    while (initial_sigmoid_params[0] <= 0) or (initial_sigmoid_params[2] <= 0):
        initial_sigmoid_params = draw_params()

    solution = minimize(transform_stim, x0=initial_sigmoid_params, method='Nelder-Mead')
    
    if not solution.success:
        print('Did not terminate!', end='\t')
        
    print('Loss = {:.8f}'.format(solution.fun), ' Params=', solution.x)
    if solution.fun < loss_best:
            best_sigmoid_params = solution.x
            loss_best = solution.fun

## Show best solution

In [None]:
loss_best, yout_best = transform_stim(sigmoid_params=best_sigmoid_params, return_array=True, plot_steps=True)
plot(yout_best)

print(loss_best)

### Plot details.

In [None]:
xlims = [(0, 32), (1.5,2.5),(3,6), (6, 9), (9, 11), (11, 13), (15, 17), (17, 18), (18.5, 20), (23, 25), (25, 28), (28, 31)]

# Plot.
plt.figure(figsize=(12,3*len(xlims)))

for idx, xlim in enumerate(xlims):
    ax = plt.subplot(len(xlims),1,idx+1)
    ax.set_xlim(xlim)
    ax.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'], label='recorded')
    ax.plot(tout, yout_best, label='fit')
    ax.legend()
    
    # Set limits.
    idx1 = np.where(tout >= xlim[0])[0][0]    
    idx2 = np.where(tout <= xlim[1])[0][-1]
    ax.set_ylim(np.min(yout_best[idx1:idx2])-0.05, np.max(yout_best[idx1:idx2])+0.05)

plt.xlabel('Time [s]')
plt.tight_layout()
plt.show()

### Compare peak times

In [None]:
trace_list = [
    yout_best,
    stimulus_recorded['Stim'],
]

time_list = [
    tout,
    stimulus_recorded['Time'],
]
color_list = ['r', 'b']
label_list = ['fit', 'target']

params_dict_list = [
    {'height_pos': 0.02, 'height_neg': 0.00, 'prom': 0.05},
    {'height_pos': 0.03, 'prom': 0.08}
] 

xlims = [(0, 32), (1.5,2.5),(3,6), (6, 9), (9, 11), (11, 13), (15, 17), (17, 18), (18.5, 20), (23, 25), (25, 28), (28, 31)]

In [None]:
trace_peaks = plot_peaks.compare_peaks_in_traces(
    trace_list=trace_list,
    time_list=time_list,
    plot_single=False,
    plot_hist=True,
    plot=True,
    params_dict_list=params_dict_list,
    color_list=color_list,
    label_list=label_list,
    xlims=xlims
)

# Save to file

In [None]:
plt.figure(figsize=(15,6))
ax = plt.subplot(111)
ax.plot(tout, yout_best)
ax.plot(stimulus_recorded['Time'], stimulus_recorded['Stim'])
plt.show()

In [None]:
stimulus_franke2017_amp_and_time_corrected = pd.DataFrame({
    'Time': tout,
    'Stim': yout_best,
})

In [None]:
# Normalize
stimulus_franke2017_amp_and_time_corrected['Stim'] -= stimulus_franke2017_amp_and_time_corrected['Stim'].iloc[0]
stimulus_franke2017_amp_and_time_corrected['Stim'] /= stimulus_franke2017_amp_and_time_corrected['Stim'].max()

print(stimulus_franke2017_amp_and_time_corrected['Stim'][0])
print(stimulus_franke2017_amp_and_time_corrected['Stim'].max())

In [None]:
stimulus_franke2017_amp_and_time_corrected.plot(x='Time')

In [None]:
# Save.
stim_file = os.path.join('data_preprocessed', 'Franke2017_stimulus_time_and_amp_corrected.csv')
stimulus_franke2017_amp_and_time_corrected.to_csv(stim_file, index=False)