In [None]:
import copy
import os
import numpy as np
import sys
import types
from scipy import signal
from scipy.interpolate import CubicSpline
from scipy.stats import norm  # for u(t) as gaussians

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec

In [None]:
%matplotlib ipympl
#%matplotlib inline

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# Notebook setup (path trick) and local import

In [None]:
SRC_ROOT = os.path.dirname(os.path.abspath(''))
print('appending to path SRC_ROOT...', SRC_ROOT)
sys.path.append(SRC_ROOT)

PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('')))
print('appending to path PACKAGE_ROOT...', PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

NB_OUTPUT = SRC_ROOT + os.sep + 'output'

if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

In [None]:
from src.class_ode_model import ODEModel
from src.class_ode_stimulus import ODEStimulus
from src.defined_stimulus import get_npulses_from_tspan, delta_fn_amp, stimulus_pulsewave, stimulus_pulsewave_of_pulsewaves, stimulus_constant
from src.preset_ode_model import PRESETS_ODE_MODEL, ode_model_preset_factory
from src.utils_timeseries import time_to_habituate, preprocess_signal_habituation
from src.settings import DIR_OUTPUT

from src.analyze_freq_amp_sensitivity import analyze_freq_amp_presets

# Plotting functions

In [None]:
color_input = '#00AEEF'
#color_memory = '#DEB87A'  
color_memory = '#C1A16B'  # dark: #DEB87A || darker: #D0AD73 ||| darkest: #C1A16B
color_response = '#662D91'

linewidth = 1.0
linestyle = '-'

line_kwargs_input = dict(color=color_input, linewidth=linewidth, linestyle=linestyle)
line_kwargs_memory = dict(color=color_memory, linewidth=linewidth, linestyle=linestyle)
#line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle='dotted')
line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle=linestyle)
line_kwargs_response = dict(color=color_response, linewidth=linewidth, linestyle=linestyle)

# Below: plan is to run k ODEs for two types of inpouts and make a chart of behavior

In [None]:
def get_ode_traj_nb(ode_name, ode_params, stimulus_suffix):
    
    integration_style_stitching = True
    stitch_kwargs = {
        'forcetol': (1e-8, 1e-4),  # ATOL, RTOL = 1e-8, 1e-4   --or--   1e-12, 1e-6
        'max_step': np.Inf,        # np.Inf default
    }
    stitch_dynamic_max_step = True
    
    traj_label = 'not_preset'
    scale_tspan = True  # if True, stretch tspan based on pulse period (larger period => longer tspan) 

    # ODE settings
    # ============================================================
    if ode_params is None:
        x0, tspan, base_params_ode, _ = analyze_freq_amp_presets(ode_name)
    else:
        x0, tspan, _, _ = analyze_freq_amp_presets(ode_name)
        base_params_ode = ode_params
    
    # STIMULUS settings
    # ============================================================
    assert stimulus_suffix in ['S2', 'S1', 'S0_one']  # todo extend
    if stimulus_suffix == 'S2':
        tspan = [0, 200.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_pulsewave_of_pulsewaves
        S1_duty = 0.01
        S1_period = 1.0  #1.0, 0.01, 0.1
        base_params_stim = [
            delta_fn_amp(S1_duty, S1_period),  # amplitude
            S1_duty,    # duty in [0,1]
            S1_period,  # period of stimulus pulses (level 1)
            5,  # npulse
            5,  # nrest
        ]
        stim_label = 'S2_manual'
    elif stimulus_suffix == 'S1':
        tspan = [0, 20.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_pulsewave
        S1_duty = 0.01
        S1_period = 1.0  #1.0, 0.01, 0.1
        base_params_stim = [
            delta_fn_amp(S1_duty, S1_period),  # amplitude
            S1_duty,    # duty in [0,1]
            S1_period,  # period of stimulus pulses (level 1)
        ]
        stim_label = 'S1_manual'
    else:
        assert stimulus_suffix == 'S0_one'
        tspan = [0, 20.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_constant
        step_amplitude = 1.0
        base_params_stim = [
            step_amplitude
        ]
        stim_label = 'S_step_manual'
        integration_style_stitching = False  # override the setting near tp of function
    
    stim_instance = ODEStimulus(
        stim_fn,
        base_params_stim,
        label=stim_label
    )
    control_objects = [stim_instance]
    
    #if scale_tspan:
    #    assert tspan[0] == 0.0
    #    tspan[1] = tspan[1] * S1_period  # e.g., if t0, t1 = 0, 100.5 and period = 2.0: t1 = 201.0
    
    # PREPARE a new instance of ODEModel
    # ============================================================
    model_argprep = ode_model_preset_factory(traj_label, ode_name, control_objects,
                                             tspan[0], tspan[1],
                                             ode_params=base_params_ode, init_cond=x0)
    ode_model = ODEModel(*model_argprep['args'], **model_argprep['kwargs'])
    
    
    traj_kwargs = dict(verbose=True, traj_pulsewave_stitch=integration_style_stitching)
    if stitch_dynamic_max_step:
        print('ode_model.ode_base.max_step_augment', ode_model.ode_base.max_step_augment)
        stitch_kwargs['max_step'] = ode_model.ode_base.max_step_augment
    _, _, runtime = ode_model.traj_and_runtime(update_history=True, **traj_kwargs, **stitch_kwargs)
    
    print('...done | runtime:', runtime, '(s)')
    
    return ode_model

In [None]:
"""
#title = r'%s $u(t)$ with filter: %s'  % (fmod1, fmod2)
rate_decay = ode_model.ode_base.params[0]
rate_grow = ode_model.ode_base.params[1]
title = r'plot_uxry_stack_for_ode(...) | rectangle $u(t)$; $\alpha=%.2f$, $\beta=%.2f$' % (rate_decay, rate_grow)

fpath = NB_OUTPUT + os.sep + 'plot_uxry_stack_for_ode.pdf'
plot_uxry_stack_for_ode(ode_model, title, fpath)
"""

In [None]:
ode_list_to_sim = [
    ('ode_linear_filter_1d', 'Habituating filter (direct)',
     np.array([
            0.1,  # alpha: timescale for x decay -- prop to x(t)
            1,  # beta: timescale for x growth -- prop to u(t)
        ])
     ),
    ('ode_linear_filter_1d_lifted', 'Habituating filter', 
     np.array([
            0.1,   # alpha: timescale for x decay -- prop to x(t)
            1,     # beta: timescale for x growth -- prop to u(t)
            2,     # N: hill function, filter saturation
            1e-2,  # epsilon: timescale for output target synchronization (should be fastest timescale)
        ])),
    ('ode_custom_2', 'Sniffer',  #  (Tyson) 
     np.array([
            1,    # k_x+
            0.1,  # k_x-
            1,    # k_z+
            10,   # k_z-
         ])
     ),
    #('ode_custom_1', 'IFF', None),
    ('ode_custom_5', 'Negative feedback', 
     np.array([
            1,  # k_x+
            1,  # k_x-
            1,  # k_y+
            0.1,  # k_y-
            1,  # k_z+
            10,  # k_z-  100
        ])
     ),
    ('ode_custom_7', 'AIC', 
     np.array([
            1,  # a1
            1,  # a2
            1,  # k_0
            10,  # a3  10
            1,   # b3
        ])
     ),
    ('ode_custom_7_ferrell', 'AIC (4d-Ferrell)', 
     np.array([
            1,  # a1    AIC-k_6
            1,  # a2    AIC-k_4
            10,  # k_0  AIC-k_5        def: 50
            10,  # a3   output prod.   def: 2
            1,  # b3    output decay
            1,  # x_+
            1,  # x_-           def: 3
        ])
     ),
    ('ode_hallmark5', 'ode_H5 (1d)', 
     np.array([
            0.2,   # alpha: timescale for x decay -- prop to x(t)
            4,     # beta: timescale for x growth -- prop to u(t)
        ])
     ),
    ('ode_hallmark5_lifted', 'ode_H5 (lift)', 
     np.array([
            0.2,   # alpha: timescale for x decay -- prop to x(t)
            4,     # beta: timescale for x growth -- prop to u(t)
            2,     # N: hill function, filter saturation
            1e-2,  # epsilon: timescale for output target synchronization (should be fastest timescale)
        ])
     )
]

ulist = ['S0_one', 'S1']

In [None]:
ode_solutions_LoD = {
    ode_props[0]: {
        k: {'t': None, 'u': None, 'x': None, 'y': None} for k in ulist
    }
    for ode_props in ode_list_to_sim}

for uidx, ustr in enumerate(ulist):

    for idx, ode_props in enumerate(ode_list_to_sim): 
        
        ode_name, ode_title, ode_params = ode_props
        print('\nworking on %s (u=%s)...' % (ode_name, ustr))
        ode_model = get_ode_traj_nb(ode_name, ode_params, ustr)
        
        if ode_model.ode_base.max_step_augment is None or np.isinf(ode_model.ode_base.max_step_augment):
            force_dt = (ode_model.history_times[-1] - ode_model.history_times[0]) / 1e4  # interpolate uniformly using N=1e4 points
        else:
            force_dt = ode_model.ode_base.max_step_augment
        ode_state, ode_t = ode_model.interpolate_trajectory(force_dt=force_dt)  # choose multiple of average timestep
        ode_u = ode_model.ode_base.control_objects[0].fn_prepped(ode_t)
        
        ode_solutions_LoD[ode_name][ustr]['t'] = ode_t
        ode_solutions_LoD[ode_name][ustr]['u'] = ode_u
        ode_solutions_LoD[ode_name][ustr]['x'] = ode_state
        #ode_solutions_LoD[ode_name][ustr]['y'] = ode_state[:, -1]  # TODO care, this assumes output is just the last
        ode_solutions_LoD[ode_name][ustr]['y'] = ode_model.ode_base.output_fn(ode_state, ode_u)
        
        # state variable
    
    #ax[idx, 0].plot(ode_t, ode_state[:, -1])

## Plot the data (vertical mode)

In [None]:
line_input_kwargs = dict(color=color_input, linewidth=1)
line_response_kwargs = dict(color=color_response, linewidth=1)

_, ax = plt.subplots(1 + len(ode_list_to_sim), 2, squeeze=False, figsize=(4, 10), sharex=True)

###ax[0, 0].plot(ode_list_to_sim, ode_u)

for uidx, ustr in enumerate(ulist):
    
    topplot_t = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['t']
    topplot_u = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['u']
    ax[0, uidx].plot(topplot_t, topplot_u, **line_input_kwargs)
    
    for j, ode_props in enumerate(ode_list_to_sim): 
        
        ode_name, ode_title, ode_params = ode_props
        ode_t = ode_solutions_LoD[ode_name][ustr]['t']
        ode_y = ode_solutions_LoD[ode_name][ustr]['y']
        ax[j+1, uidx].plot(ode_t, ode_y, **line_response_kwargs)
        ax[j+1, uidx].set_title(ode_title)

plt.tight_layout()
plt.show()

## Plot the data (horizontal mode)

In [None]:
ode_list_to_sim_subset = [ode_list_to_sim[i] for i in range(len(ode_list_to_sim)) if i in [1,2,3,4]]

#_, ax = plt.subplots(2, 1 + len(ode_list_to_sim_subset), squeeze=False, figsize=(12, 3), sharex=True)
_, ax = plt.subplots(2, 1 + len(ode_list_to_sim_subset), squeeze=False, figsize=(10, 3), sharex=True)

###ax[0, 0].plot(ode_list_to_sim, ode_u)

# for each row...
for uidx, ustr in enumerate(ulist):
    
    topplot_t = ode_solutions_LoD[ode_list_to_sim_subset[0][0]][ustr]['t']
    topplot_u = ode_solutions_LoD[ode_list_to_sim_subset[0][0]][ustr]['u']
    ax[uidx, 0].plot(topplot_t, topplot_u, **line_input_kwargs)
    
    # for each column...
    for j, ode_props in enumerate(ode_list_to_sim_subset): 
        
        ode_name, ode_title, ode_params = ode_props
        
        ode_t = ode_solutions_LoD[ode_name][ustr]['t']
        ode_y = ode_solutions_LoD[ode_name][ustr]['y']
        ax[uidx, j+1].plot(ode_t, ode_y, **line_response_kwargs)
        
        if uidx == 0:
            ax[uidx, j+1].set_title(ode_title)
    
    # change ylims for first column, i.e. u(t)
    ax[0, 0].set_ylim(-0.1, 1.1)
    ax[1, 0].set_ylim(-10, 110)
    

plt.tight_layout()
fpath = NB_OUTPUT + os.sep + 'adapt_vs_hab' 
plt.savefig(fpath + '.pdf')
plt.savefig(fpath + '.svg')
plt.show()

# Plot u x1... xN y for one of the ODEs

In [None]:
def plot_uxry_stack_for_ode(ode_name, ode_title, ode_sol, fpath):
    """
    Args:
        fmod: 'rectangle', 'gaussian'
    """
        
    nstates = ode_sol['x'].shape[1]
    
    _, ax = plt.subplots(1 + nstates, 1, squeeze=False, figsize=(12, 3), sharex=True)
    
    ax[0, 0].plot(ode_sol['t'], ode_sol['u'], **line_input_kwargs)
        
    for j in range(nstates): 
        
        if j == nstates-1:
            ax[1+j, 0].plot(ode_sol['t'], ode_sol['x'][:, j], **line_response_kwargs)
        else:
            ax[1+j, 0].plot(ode_sol['t'], ode_sol['x'][:, j], **line_kwargs_memory)
    
    ax[-1, 0].set_xlabel(r'$t$')
    ax[0, 0].set_title(ode_title)
    
    plt.savefig(fpath)
    plt.show()
    return ax

In [None]:
fpath = NB_OUTPUT + os.sep + 'nb_adapt_vs_hab_traj.pdf'

u_name = ulist[0]
ode_name, ode_title, ode_params = ode_list_to_sim[0]
selected_ode_sol = ode_solutions_LoD[ode_name][u_name]

plot_uxry_stack_for_ode(ode_name, ode_title, selected_ode_sol, fpath)


In [None]:
u_name = ulist[1]
ode_name, ode_title, ode_params = ode_list_to_sim[0]
selected_ode_sol = ode_solutions_LoD[ode_name][u_name]

plot_uxry_stack_for_ode(ode_name, ode_title, selected_ode_sol, fpath)