In [None]:
import copy
import os
import numpy as np
import sys
import types
from scipy.interpolate import CubicSpline

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [None]:
%matplotlib ipympl
#%matplotlib inline

# Notebook setup (path trick) and local import

In [None]:
PACKAGE_ROOT = os.path.dirname(os.path.abspath(''))
print(PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

In [None]:
from src.class_ode_model import ODEModel
from src.class_ode_stimulus import ODEStimulus
from src.defined_stimulus import get_npulses_from_tspan, delta_fn_amp, stimulus_pulsewave
from src.preset_ode_model import PRESETS_ODE_MODEL, ode_model_preset_factory
from src.utils_timeseries import time_to_habituate, preprocess_signal_habituation
from src.settings import DIR_OUTPUT

from src.analyze_freq_amp_sensitivity import analyze_freq_amp_presets

# 1. Choose model and simulate one trajectory

In [None]:
use_preset = False

integration_style_stitching = True
stitch_kwargs = {
    'forcetol': (1e-8, 1e-4),  # ATOL, RTOL = 1e-8, 1e-4   --or--   1e-12, 1e-6
    'max_step': np.Inf,        # np.Inf default
}
stitch_dynamic_max_step = True

if use_preset:
    ode_preset_base = PRESETS_ODE_MODEL['ode_custom_1_S1']
    ode_model = ODEModel(*ode_preset_base['args'], **ode_preset_base['kwargs'])
    
    if integration_style_stitching:
        if stitch_dynamic_max_step:
            print('ode_model.ode_base.max_step_augment', ode_model.ode_base.max_step_augment)
            stitch_kwargs['max_step'] = ode_model.ode_base.max_step_augment
        r, times = ode_model.trajectory_pulsewave(update_history=True, **stitch_kwargs)
        print('trajectory style: stitching')
    else:
        r, times = ode_model.propagate(update_history=True, 
                                       params_solver=None, verbose=False)
        print('trajectory style: propagate')
else:
    traj_label = 'not_preset'
    ode_name = 'ode_custom_4_innerout'
    scale_tspan = True  # if True, stretch tspan based on pulse period (larger period => longer tspan) 
    
    # MANUALLY REDEFINE DEPENDING ON ODE_NAME
    selected_params_ode = np.array([
        0.1,  # k_x1+ = K_1 / T_1
        0.1,  # k_x1- =   1 / T_1
        0.5,  # k_x2+ = K_2 / T_2
        5.0,  # k_x2- =   1 / T_2
        0,    # x1_high; set as a function of others param below
        1.1,  # x1_high_mult (alpha)
        1e6,  # 1/epsilon for 'singular perturbation' (last state eqn becomes output)
    ])
    ode_name = 'ode_custom_4_innerout'
    
    
    # ODE settings
    # ============================================================
    x0, tspan, base_params_ode, ode_pvary_dict = analyze_freq_amp_presets(ode_name)
    tspan = [0, 100.5]  # overide tspan above; for default period of 1.0: 100 pulses
    
    # use default params or not?
    base_params_ode = selected_params_ode
    
    # STIMULUS settings
    # ============================================================
    stim_fn = stimulus_pulsewave
    S1_duty = 0.01
    S1_period = 0.265 #1.0, 0.01, 0.1
    base_params_stim = [
        delta_fn_amp(S1_duty, S1_period),  # amplitude
        S1_duty,    # duty in [0,1]
        S1_period,  # period of stimulus pulses (level 1)
    ]
    stim_instance = ODEStimulus(
        stim_fn,
        base_params_stim,
        label='S1_manual'
    )
    control_objects = [stim_instance]
    
    if scale_tspan:
        assert tspan[0] == 0.0
        tspan[1] = tspan[1] * S1_period  # e.g., if t0, t1 = 0, 100.5 and period = 2.0: t1 = 201.0
    
    model_argprep = ode_model_preset_factory(traj_label, ode_name, control_objects,
                                             tspan[0], tspan[1],
                                             ode_params=base_params_ode, init_cond=x0)
    ode_model = ODEModel(*model_argprep['args'], **model_argprep['kwargs'])
        
    if integration_style_stitching:
        if stitch_dynamic_max_step:
            print('ode_model.ode_base.max_step_augment', ode_model.ode_base.max_step_augment)
            stitch_kwargs['max_step'] = ode_model.ode_base.max_step_augment
        r, times = ode_model.trajectory_pulsewave(update_history=True, init_cond=x0,
                                                  t0=tspan[0], t1=tspan[1],
                                                  **stitch_kwargs)
        print('trajectory style: stitching')
    else:
        r, times = ode_model.propagate(update_history=True, init_cond=x0,
                                       t0=tspan[0], t1=tspan[1],
                                       params_solver=None, verbose=False)
        print('trajectory style: propagate')

In [None]:
%matplotlib inline
#%matplotlib ipympl

# plotting
# - simple trajectory - stacked subplots, x_i(t) vs t
ode_model.plot_simple_trajectory()

### 1.1 Testing scipy.signal.find_peaks for edge cases, prominence, etc.

In [None]:
plt.clf()
%matplotlib ipympl

In [None]:
control_obj = ode_model.ode_base.control_objects[0]

stim_of_t = control_obj.fn_prepped(times)

output = ode_model.ode_base.output_fn(r, stim_of_t)
times = ode_model.history_times

plt.figure()
plt.plot(times, output)
plt.title('Signal we want to find particular peaks of')
plt.show()

In [None]:
from scipy.signal import find_peaks

peaks, properties = find_peaks(output)

In [None]:
stim_object = ode_model.ode_base.control_objects[0]
assert stim_object.stim_fn.__name__ == 'stimulus_pulsewave'
amp, duty, period = stim_object.params
###period = stim_object.params[2]
num_pulses_applied = get_npulses_from_tspan([times[0], times[-1]], duty, period)

In [None]:
does_habituate, tth_continuous, tth_discrete, reldiff, peaks, troughs = time_to_habituate(times, output, num_pulses_applied, tth_threshold=0.01)
print(len(peaks))
print(num_pulses_applied)

In [None]:
scale_plot_by_period = True

times_local = times
xlabel = 'wall time'
tth_local = tth_continuous
if scale_plot_by_period:
    times_local = times_local / S1_period
    xlabel = 'wall time / period'
    tth_local = tth_continuous / S1_period

plt.figure()
plt.plot(times_local, output, '-x', linewidth=1, markersize=3, color='skyblue')
plt.plot(times_local[peaks], output[peaks], 'o', markersize=6, color='orange', label='peak')
plt.plot(times_local[troughs], output[troughs], 'o', markersize=6, color='red', label='trough')
plt.fill_between(times, 0, output * control_obj.fn_prepped(times) / control_obj.params[0],
                 facecolor='gainsboro', alpha=1.0)
plt.title('Annotated peaks/troughs of signal')
plt.xlabel(xlabel)
plt.ylabel('output')
plt.axvline(tth_local, linestyle='--', label=r'TTH for $\epsilon=0.01$')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(reldiff, '--x')
plt.axhline(0)

In [None]:
tth_threshold = 0.01
reldiff_jitter_chopped = np.where(
    (-tth_threshold < reldiff) * (reldiff < tth_threshold),
    0, reldiff)

plt.figure()
plt.plot(reldiff_jitter_chopped, '--x')
plt.axhline(0)

In [None]:
print(tth_continuous)
print(tth_discrete)

In [None]:
print(reldiff)
first_positive_reldiff = np.argwhere(reldiff > 0)[0, 0]
print(first_positive_reldiff)
peaks_below_threshold = np.argwhere(reldiff <= tth_threshold)
print(peaks_below_threshold)
print(peaks_below_threshold.shape)


if len(peaks_below_threshold) > 0:  # i.e. TTH_THRESHOLD is met: the curve has habituated, pick the first appropriate timepoint
    npulse_idx = peaks_below_threshold[first_positive_reldiff, 0]  # note: the accessing index > 0 if peaks initially increasing
    tth_discrete = npulse_idx + 1  # note: reldiff is shifted to the right by one, need to add one here
    tth_time_idx = peaks[tth_discrete]
    tth_continuous = times[tth_time_idx]



print(tth_continuous)
print(tth_discrete)


In [None]:
plt.figure()
plt.plot(np.arange(10), 4*np.arange(10))
plt.title(r'$\arg \min_{k} \: P(k)$')
plt.show()

In [None]:
does_habituate, tth_continuous, tth_discrete, reldiff, peaks, troughs = time_to_habituate(
            times, output, num_pulses_applied, tth_threshold=tth_threshold)

In [None]:
does_habituate, reldiff, peaks, troughs = preprocess_signal_habituation(times, output, num_pulses_applied, tth_threshold=tth_threshold)

tth_continuous = None
tth_discrete = None

if does_habituate:
    """
    Method: 
    - find K the index where reldiff first becomes positive
    - assert that TTH index is >= k
    """
    first_positive_reldiff = np.argwhere(reldiff > 0)[0, 0]
    peaks_below_threshold = np.argwhere(reldiff <= tth_threshold)
    
    if len(peaks_below_threshold) > 0:  # i.e. TTH_THRESHOLD is met: the curve has habituated, pick the first appropriate timepoint
        npulse_idx = peaks_below_threshold[first_positive_reldiff, 0]  # note: the accessing index > 0 if peaks initially increasing
        tth_discrete = npulse_idx + 1  # note: reldiff is shifted to the right by one, need to add one here
        tth_time_idx = peaks[tth_discrete]
        tth_continuous = times[tth_time_idx]

print(reldiff)
print(peaks_below_threshold)
print(first_positive_reldiff)
print(peaks_below_threshold[first_positive_reldiff, 0])