In [None]:
import copy
import os
import numpy as np
import sys
import types
from scipy import signal
from scipy.interpolate import CubicSpline
from scipy.stats import norm  # for u(t) as gaussians

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec

In [None]:
%matplotlib ipympl
#%matplotlib inline

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# Notebook setup (path trick) and local import

In [None]:
SRC_ROOT = os.path.dirname(os.path.abspath(''))
print('appending to path SRC_ROOT...', SRC_ROOT)
sys.path.append(SRC_ROOT)

PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('')))
print('appending to path PACKAGE_ROOT...', PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

NB_OUTPUT = SRC_ROOT + os.sep + 'output'

if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

In [None]:
from src.class_ode_model import ODEModel
from src.class_ode_stimulus import ODEStimulus
from src.defined_stimulus import get_npulses_from_tspan, delta_fn_amp, stimulus_pulsewave, stimulus_pulsewave_of_pulsewaves, stimulus_constant, stimulus_rectangle
from src.preset_ode_model import PRESETS_ODE_MODEL, ode_model_preset_factory

from src.analyze_freq_amp_sensitivity import analyze_freq_amp_presets

# Plotting functions

In [None]:
color_input = '#00AEEF'
#color_memory = '#DEB87A'  
color_memory = '#C1A16B'  # dark: #DEB87A || darker: #D0AD73 ||| darkest: #C1A16B
color_response = '#662D91'

linewidth = 1.0
linestyle = '-'

line_kwargs_input = dict(color=color_input, linewidth=linewidth, linestyle=linestyle)
line_kwargs_memory = dict(color=color_memory, linewidth=linewidth, linestyle=linestyle)
#line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle='dotted')
line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle=linestyle)
line_kwargs_response = dict(color=color_response, linewidth=linewidth, linestyle=linestyle)

In [None]:
def plot_uy_stack(t, u, y, title, fpath, u_shadow=True, simplify=False, figsize=(3, 2), xlims=None):
    """
    Args:
        fmod: 'rectangle', 'gaussian'
    """
    # plot params
    grid_alpha = 0.6

    plt.close(); 
    ###nrows = 2
    ###fig, axarr = plt.subplots(nrows, 1, sharex=True, squeeze=False, figsize=figsize)
    fig = plt.figure(figsize=figsize)
    fig.suptitle(title)
    
    gs = GridSpec(2, 1, height_ratios=[1, 2], hspace=0.2)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0)
    axarr = [ax0, ax1]
    
    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0].plot(t, u, label=r'$u(t)$', **line_kwargs_input)

    #axarr[-1,0].set_title(r'foo 4')
    axarr[1].plot(t, y, label=r'$y(t)=u(t) r(t)$', **line_kwargs_response)
    
    for idx, ax in enumerate(axarr):
        if idx > 0 and u_shadow:
            ax.plot(t, u, '-k', alpha=0.1)
        if xlims is not None:
            ax.set_xlim(xlims)
    
    if simplify:
        ylabels = ['input', 'response']
        fprop = {'family': 'arial', 'size': 12}
        for idx, ax in enumerate(axarr):
            ax.set_ylabel(ylabels[idx], **fprop)
            ax.tick_params(labelbottom=False, labelleft=False)
            ax.spines[['right', 'top']].set_visible(False)
            ax.tick_params(which='both', size=0, labelsize=0)
        axarr[-1].set_xlabel('t', loc='right', style='italic', family='arial', size=12)
    else:
        fprop = {'family': 'arial', 'size': 12}
        ax0.set_ylabel(r'input $u(t)$')
        ax1.set_ylabel(r'output $y(t)$')
        for idx, ax in enumerate(axarr):
            ax.grid(alpha=grid_alpha)
            ax.legend()
        axarr[-1].set_xlabel(r'$t$', loc='right', **fprop)
    
    print(fpath)
    plt.savefig(fpath + '.pdf')
    plt.savefig(fpath + '.svg')
    plt.show()
    
    '''
    def annotate_axes(fig):
    for i, ax in enumerate(fig.axes):
        ax.text(0.5, 0.5, "ax%d" % (i+1), va="center", ha="center")
        ax.tick_params(labelbottom=False, labelleft=False)
    '''
    return axarr

In [None]:
def plot_uxry_stack(t, u, rate_decay, rate_grow, title, fpath):
    """
    Args:
        fmod: 'rectangle', 'gaussian'
    """
    
    x_of_t = W_of_t_convolve(t, u, rate_decay, rate_grow)
    r_of_t = filter_pointwise_expmemory_hill(t, u, rate_decay, rate_grow, N=2, plot=False)
    y_of_t = u * r_of_t 

    # plot params
    grid_alpha = 0.6

    plt.close(); 
    fig, axarr = plt.subplots(4, 1, sharex=True, squeeze=False, figsize=(7, 5))

    axarr[0, 0].set_title(title)
    axarr[0,  0].set_ylabel(r'input $u(t)$')
    axarr[1,  0].set_ylabel(r'memory $x(t)$')
    axarr[2,  0].set_ylabel(r'filter $r(t)$')
    axarr[3,  0].set_ylabel(r'output $y(t)$')
    axarr[-1, 0].set_xlabel(r'$t$')

    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0,0].plot(t, u, label=r'$u(t)$')

    #axarr[1,0].set_title(r'$foo 2$')
    axarr[1,0].plot(t, x_of_t, label=r"$x(t)=\int_{0}^t \,\beta e^{-\alpha (t-t')} u(t') dt'$")
    axarr[1,0].axhline(1, linestyle='--', alpha=0.5)
    axarr[1,0].set_ylim(-0.05, np.max(x_of_t)*1.05)

    #axarr[2,0].set_title(r'$foo 3$')
    axarr[2,0].plot(t, r_of_t, label=r'$r(t)=\sigma(x(t))$')
    axarr[2,0].axhline(0.5, linestyle='--', alpha=0.5)
    axarr[2,0].set_ylim(-0.05, 1.05)

    #axarr[3,0].set_title(r'foo 4')
    axarr[3,0].plot(t, y_of_t, label=r'$y(t)=u(t) r(t)$')
    axarr[3,0].axhline(0.5 * np.max(u), linestyle='--', alpha=0.5)

    for idx in range(0, 4):
        if idx > 0:
            axarr[idx,0].plot(t, u, '-k', alpha=0.1)
        axarr[idx,0].grid(alpha=grid_alpha)
        axarr[idx,0].legend()
    
    plt.savefig(fpath)
    plt.show()
    return axarr

In [None]:
#%matplotlib widget
%matplotlib ipympl

## 0. Construct some signals u(t)

In [None]:
def signal_rectangle(n_pulse, n_rest, n_cycles, times=None, period_S1=1.0, duty=0.01, pulse_area=1.0):
    
    def ubase_rect(t, t_mid, duty, period_S1):
        scale_for_unit_area = 1 / (duty * period_S1)
        val_at_t = scale_for_unit_area * (
            np.heaviside(t - (t_mid - duty * period_S1), 0) - 
            np.heaviside(t - t_mid, 0))
        return val_at_t
    
    L_cycle = n_pulse + n_rest
    
    if times is None:
        dt = 0.001
        tmax = L_cycle * n_cycles * period_S1
        times = np.arange(-2, tmax + dt, dt)
    else:
        dt = times[1] - times[0]
        
    tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]
    
    # build u: wave of rectangles
    assert duty * period_S1 > 5 * dt  # want it to be sampled nicely
    u_rect = pulse_area  * np.sum([ubase_rect(times, t_mid, duty, period_S1) for t_mid in tmid_list], axis=0)   
    
    return times, u_rect


def signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=1.0, times=None, sigma=0.05, pulse_area=1.0):
    
    def ubase_gaussian(t, t_mid, sigma):
        # direct calc has exp overflow
        '''sigma_sqr = sigma ** 2 
        prefactor = 1 / np.sqrt(2 * np.pi * sigma_sqr)
        expval = (t - t_mid) ** 2 / (2 * sigma_sqr)
        val_at_t = prefactor * np.exp(expval)'''
        val_at_t = norm.pdf(t, t_mid, sigma)
        return val_at_t
    
    L_cycle = n_pulse + n_rest
    
    if times is None:
        dt = 0.001
        tmax = L_cycle * n_cycles * period_S1
        times = np.arange(-2, tmax + dt, dt)
    else:
        dt = times[1] - times[0]
    
    tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]
    
    # build u: wave of rectangles
    u_gaussian = pulse_area * np.sum([ubase_gaussian(times, t_mid, sigma) for t_mid in tmid_list], axis=0)
    
    return times, u_gaussian

In [None]:
def W_of_t_manual(t, u, rate_decay, rate_grow):
    "As opposed to 1/dt * prate * np.convolve(u, exp_of_t, 'same')"    
    
    def do_integral(t_specific, t_idx):

        # could drop the heaviside factor if we can restrict limits of integration?
        #integrand = u * np.exp(alpha * (t - t_specific)) *  np.heaviside(t_specific - t, 1)
        
        ur = u[:t_idx]
        tr = t[:t_idx]
        integrand = ur * np.exp(rate_decay * (tr - t_specific))
        
        return np.trapz(integrand, dx=dt)
    
    W_of_t = np.zeros_like(t)
    print('Working on W_of_t_manual integrals...')
    for idx, tau in enumerate(t):
        W_of_t[idx] = do_integral(tau, idx)
    
    W_of_t = rate_grow * W_of_t
    
    return W_of_t


def W_of_t_convolve(t, u, rate_decay, rate_grow, nmult=5):
    "Simple (fast) Alternative to ODE or direct integration: use convolution"  
    dt = t[1] - t[0]
    nn = len(t)
    
    # size of window depends on alpha, slower (smaller) alpha means longer window
    tsample_exp_window = np.arange(0, (nmult * nn) * dt + dt, dt)
    exp_of_t_window = np.exp(-rate_decay * tsample_exp_window)
    assert len(tsample_exp_window) > nn

    W_of_t = dt * rate_grow * signal.convolve(u, exp_of_t_window, 'full')  
    W_of_t = W_of_t[0:nn]  # truncate to early values matching length of u
    return W_of_t

def filter_pointwise_expsimple(t, u, rate_decay):
    return np.exp(-rate_decay * t)


def filter_pointwise_expmemory_tanh(t, u, rate_decay, rate_grow, plot=False):

    #x_of_t = W_of_t_manual(t, u, alpha)
    x_of_t = W_of_t_convolve(t, u, rate_decay, rate_grow)
    
    print('u', len(u))
    print('x_of_t', len(x_of_t))
    
    r_of_t_tanh = 1 - np.tanh(x_of_t)
    y_of_t_tanh = u * r_of_t_tanh
    
    if plot:
        plt.figure()
        plt.plot(t, u, linewidth=0.25, label=r'$u(t)$')
        plt.plot(t, x_of_t, label=r'$x(t)$: weighted sum')
        plt.plot(t, r_of_t_tanh, label=r'$r(t) = 1 - \mathrm{tanh}(x)$')
        plt.plot(t, y_of_t_tanh, label=r'$y(t) = u(t) (1 - \mathrm{tanh}(x))$')
        plt.grid(alpha=0.5)
        plt.title('plot inside filter_pointwise_expmemory')
        plt.xlabel('t')
        plt.legend()   

    return r_of_t_tanh


def filter_pointwise_expmemory_hill(t, u, rate_decay, rate_grow, N=2, plot=False):
    
    x_of_t = W_of_t_convolve(t, u, rate_decay, rate_grow)
    
    print('u', len(u))
    print('W_of_t', len(x_of_t))
    
    r_of_t = 1 / (1 + x_of_t ** N)
    
    print(len(t))
    print(len(u))
    print(len(x_of_t))
    print(len(r_of_t))
    
    y_of_t = u * r_of_t
    
    if plot:
        plt.figure()
        plt.plot(t, u, linewidth=0.25, label=r'$u(t)$')
        plt.plot(t, x_of_t, label=r'$x(t)$: weighted sum')
        plt.plot(t, r_of_t, label=r'$r(t) = 1 / (1 + x^N)$')
        plt.plot(t, y_of_t, label=r'$y(t) = u(t) / (1 + x^N)$')
        plt.grid(alpha=0.5)
        plt.title('plot inside filter_pointwise_expmemory')
        plt.xlabel('t')
        plt.legend()   
    
    return r_of_t

In [None]:
period_S1 = 1.0 # period between consecutive pulses
n_pulse = 10
n_rest = 1
L_cycle = n_pulse + n_rest
n_cycles = 1  # 3
tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = L_cycle * n_cycles * period_S1
t = np.arange(-0.01, tmax + dt, dt)

# choose pulse_area
pulse_area = 1.0

# build u: heaviside step
u_step = np.heaviside(t, 0)

# build u: heaviside staircasec
u_stairs = np.sum([np.heaviside(t - t_mid, 0) for t_mid in tmid_list], axis=0)   

# build u: wave of gaussians
_, u_gaussian = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.05, pulse_area=pulse_area)
_, u_gaussian_wide = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.1, pulse_area=pulse_area)

# build u: wave of rectangles
_, u_rect = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.01, pulse_area=pulse_area)    
_, u_rect_wide = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.25, pulse_area=pulse_area)

plt.figure()
ax = plt.gca()
ax.plot(t, u_step, label=r'$u(t)$ step')
ax.plot(t, u_stairs, label=r'$u(t)$ staircase')
ax.plot(t, u_gaussian_wide, label=r'$u(t)$ gaussian pulses')
ax.plot(t, u_rect_wide, label=r'$u(t)$ rectangle pulses')
ax.axhline(0, linestyle='-', alpha=0.2, color='k')
ax.grid(alpha=0.5)
ax.set_xlabel(r'$t$')
ax.set_ylabel(r'signal')
ax.legend()
#plt.xlim(0.9, 1.1)
plt.title(r'Example input signals $u(t)$')
plt.show()

## 2b. Stacked plots u, y

In [None]:
figsize_main = (2.2, 1.6)

### H.1 - Hallmark 1 (Habituation)

In [None]:
period_S1 = 1.0 # period between consecutive pulses
n_pulse = 10
n_rest = 1
L_cycle = n_pulse + n_rest
n_cycles = 1  # 3
tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = L_cycle * n_cycles * period_S1
t = np.arange(-0.01, tmax + dt, dt)

# choose pulse_area
pulse_area = 1.0

# build u: wave of gaussians
_, u_gaussian = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.05, pulse_area=pulse_area)
_, u_gaussian_wide = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.1, pulse_area=pulse_area)

# build u: wave of rectangles
_, u_rect = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.01, pulse_area=pulse_area)    
_, u_rect_wide = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.25, pulse_area=pulse_area)

In [None]:
u = u_rect_wide
y = u 

fpath = NB_OUTPUT + os.sep + 'uy_stack_comb_identity'

_ = plot_uy_stack(t, u, y, '', fpath, simplify=True, figsize=figsize_main)

In [None]:
rate_decay = 0.5
ymin = 0.2
r_of_t = (ymin + np.exp(- rate_decay * t))

y_rect = u_rect * r_of_t
y_rect_wide = u_rect_wide * r_of_t
y_gaussian = u_gaussian * r_of_t
y_gaussian_wide = u_gaussian_wide * r_of_t

In [None]:
u_to_plot = u_rect_wide
y_to_plot = y_gaussian

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_1'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=figsize_main)

In [None]:
u_to_plot = u_rect_wide
r_of_t = filter_pointwise_expmemory_hill(t, u_to_plot, 0.1, 0.5, N=2, plot=False)
y_to_plot = u_to_plot * r_of_t

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_1_linearfilter'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=figsize_main)

### H.2 - Hallmark 2 (Recovery)

In [None]:
period_S1 = 1.0 # period between consecutive pulses
n_pulse = 6
n_rest = 5
L_cycle = n_pulse + n_rest
n_cycles = 2  # 3
tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = L_cycle * n_cycles * period_S1
t = np.arange(-0.01, tmax + dt, dt)

# choose pulse_area
pulse_area = 1.0

# build u: wave of gaussians
_, u_gaussian = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.05, pulse_area=pulse_area)
_, u_gaussian_wide = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.1, pulse_area=pulse_area)
# build u: wave of rectangles
_, u_rect = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.01, pulse_area=pulse_area)    
_, u_rect_wide = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.25, pulse_area=pulse_area)

In [None]:
xlims = (-0.1, ((n_cycles - 1) * L_cycle  + n_pulse + 2) * period_S1)

u_to_plot = u_rect_wide

assert n_cycles == 2
t0 = 0.0
t1 = L_cycle * period_S1
r_of_t = (ymin + np.exp(- rate_decay * (t-t0))) * np.heaviside(t - t0, 0) * np.heaviside(t1 - t, 0) + (ymin + np.exp(- rate_decay * (t-t1))) * np.heaviside(t - t1, 0)
y_to_plot = u_gaussian * r_of_t

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
#title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_2_synthetic'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=(2.2, 1.6), xlims=xlims)

In [None]:
u_to_plot = u_gaussian_wide
r_of_t = filter_pointwise_expmemory_hill(t, u_to_plot, 0.1, 0.5, N=2, plot=False)
y_to_plot = u_to_plot * r_of_t

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
#title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_2_linearfilter'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=(2.2, 1.6), xlims=xlims)

### H.2 - Hallmark 2 (Recovery) - one late pulse

In [None]:
period_S1 = 1.0 # period between consecutive pulses
n_pulse = 10
n_rest = 10
L_cycle = n_pulse + n_rest
n_cycles = 2  # 3
tmid_list = [period_S1*(n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = L_cycle * n_cycles * period_S1
t = np.arange(-0.01, tmax + dt, dt)

# choose pulse_area
pulse_area = 1.0

# build u: wave of gaussians
_, u_gaussian = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.05, pulse_area=pulse_area)
_, u_gaussian_wide = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, sigma=0.1, pulse_area=pulse_area)
# build u: wave of rectangles
_, u_rect = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.01, pulse_area=pulse_area)    
_, u_rect_wide = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1, times=t, duty=0.25, pulse_area=pulse_area)

In [None]:
xlims = (-1, ((n_cycles - 1) * L_cycle  + 6) * period_S1)

u_to_plot = np.copy(u_rect_wide)
u_to_use_for_y = np.copy(u_gaussian_wide)

tcutoff = tmid_list[n_pulse] + 0.5 * period_S1
u_to_plot = np.where(t < tcutoff, u_to_plot, 0)
u_to_use_for_y = np.where(t < tcutoff, u_to_use_for_y, 0)

assert n_cycles == 2
t0 = 0.0
t1 = L_cycle * period_S1
r_of_t = (ymin + np.exp(- rate_decay * (t-t0))) * np.heaviside(t - t0, 0) * np.heaviside(t1 - t, 0) + (ymin + np.exp(- rate_decay * (t-t1))) * np.heaviside(t - t1, 0)
y_to_plot = u_to_use_for_y * r_of_t

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
#title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_2_waitpulse_synthetic'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=(2.2, 1.6), xlims=xlims)

In [None]:
xlims = (-1, (n_pulse + 2) * period_S1)

u_to_plot = np.copy(u_rect_wide)
u_to_use_for_y = np.copy(u_gaussian_wide)

tcutoff = tmid_list[n_pulse] + 0.5 * period_S1
u_to_plot = np.where(t < tcutoff, u_to_plot, 0)
u_to_use_for_y = np.where(t < tcutoff, u_to_use_for_y, 0)

assert n_cycles == 2
t0 = 0.0
t1 = L_cycle * period_S1
r_of_t = (ymin + np.exp(- rate_decay * (t-t0))) * np.heaviside(t - t0, 0) * np.heaviside(t1 - t, 0) + (ymin + np.exp(- rate_decay * (t-t1))) * np.heaviside(t - t1, 0)
y_to_plot = u_to_use_for_y * r_of_t

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
#title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uy_stack_hallmark_1_alt_synthetic'

_ = plot_uy_stack(t, u_to_plot, y_to_plot, '', fpath, u_shadow=False, simplify=True, figsize=(2.2, 1.6), xlims=xlims)

### Plots for: (3) Freq Sens, (4) Amp Sens

In [None]:
FLAG_PULSE_AREA_FIXED = True

In [None]:
def plot_tuytuy_stack(t1, u1, y1, t2, u2, y2, title, fpath, u_shadow=True, simplify=False, figsize=(3, 2), xlims=None):
    """
    Args:
        fmod: 'rectangle', 'gaussian'
    """
    # plot params
    grid_alpha = 0.6

    plt.close(); 
    ###nrows = 2
    ###fig, axarr = plt.subplots(nrows, 1, sharex=True, squeeze=False, figsize=figsize)
    fig = plt.figure(figsize=figsize)
    fig.suptitle(title)
    
    gs = GridSpec(2, 2, height_ratios=[1, 2], hspace=0.2)
    
    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
    ax2 = fig.add_subplot(gs[0, 1], sharey=ax0)
    ax3 = fig.add_subplot(gs[1, 1], sharex=ax2, sharey=ax1)
    
    axarr_A = [ax0, ax1]
    axarr_B = [ax2, ax3]
    
    titles = [r'$u A$', r'$u B$']
    plottuples = [(t1, u1, y1), (t2, u2, y2)]
    
    for col_idx, axarr in enumerate([axarr_A, axarr_B]):
        
        t, u, y = plottuples[col_idx]
        
        axarr[0].set_title(titles[col_idx])
        axarr[0].plot(t, u, label=r'$u(t)$', **line_kwargs_input)
        axarr[1].plot(t, y, label=r'$y(t)=u(t) r(t)$', **line_kwargs_response)
        
        for idx, ax in enumerate(axarr):
            if idx > 0 and u_shadow:
                ax.plot(t, u, '-k', alpha=0.1)
            if xlims is not None:
                ax.set_xlim(xlims)
        
        if simplify:
            ylabels = ['input', 'response']
            fprop = {'family': 'arial', 'size': 12}
            for idx, ax in enumerate(axarr):
                if col_idx == 0:
                    ax.set_ylabel(ylabels[idx], **fprop)
                ax.tick_params(labelbottom=False, labelleft=False)
                ax.spines[['right', 'top']].set_visible(False)
                ax.tick_params(which='both', size=0, labelsize=0)
            axarr[-1].set_xlabel('t', loc='right', style='italic', family='arial', size=12)
        else:
            fprop = {'family': 'arial', 'size': 12}
            if col_idx == 0:
                ax0.set_ylabel(r'input $u(t)$')
                ax1.set_ylabel(r'output $y(t)$')
            for idx, ax in enumerate(axarr):
                ax.grid(alpha=grid_alpha)
                ax.legend()
            axarr[-1].set_xlabel(r'$t$', loc='right', **fprop)
        
    print(fpath)
    plt.savefig(fpath + '.pdf')
    plt.savefig(fpath + '.svg')
    plt.show()
    
    return axarr

In [None]:
n_pulse = 30
n_rest = 10
L_cycle = n_pulse + n_rest
n_cycles = 1  # 3

In [None]:
period_S1_A = 1.0 # period between consecutive pulses

tmid_list = [period_S1_A * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001

duty_A = 0.1
amp_A = 1/(period_S1_A * duty_A)

# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_A = 1.0  
else:
    pulse_area_A = duty_A * period_S1_A * amp_A

#tmax = L_cycle * period_S1 * 4
tmax_A = L_cycle * n_cycles * period_S1_A
t_A = np.arange(-0.01, tmax_A + dt, dt)

# build u: wave of gaussians
_, u_gaussian_A = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_A, times=t_A, sigma=0.05, pulse_area=pulse_area_A)
_, u_gaussian_wide_A = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_A, times=t_A, sigma=0.1, pulse_area=pulse_area_A)
# build u: wave of rectangles
_, u_rect_A = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_A, times=t_A, duty=duty_A, pulse_area=pulse_area_A)    
#_, u_rect_wide_A = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_A, times=t_A, duty=0.25, pulse_area=pulse_area_A)

In [None]:
period_S1_B = 2.0 # period between consecutive pulses

tmid_list = [period_S1_B * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_B = 1.0
    duty_B = duty_A * period_S1_A / period_S1_B
else:
    duty_B = duty_A
    pulse_area_B = duty_B * period_S1_B * amp_A  # make this A d T for fixed Ad, T is w.e it is above
    
#tmax = L_cycle * period_S1 * 4
tmax_B = L_cycle * n_cycles * period_S1_B
t_B = np.arange(-0.01, tmax_B + dt, dt)

# build u: wave of gaussians
_, u_gaussian_B = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_B, times=t_B, sigma=0.05, pulse_area=pulse_area_B)
_, u_gaussian_wide_B = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_B, times=t_B, sigma=0.1, pulse_area=pulse_area_B)
# build u: wave of rectangles
_, u_rect_B = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_B, times=t_B, duty=duty_B, pulse_area=pulse_area_B)    
#_, u_rect_wide_B = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_B, times=t_B, duty=0.25, pulse_area=pulse_area_B)

In [None]:
period_S1_C = 5.0 # period between consecutive pulses

tmid_list = [period_S1_C * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse+1)]

# prep t to sample signals from
dt = 0.001
# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_C = 1.0
    duty_C = duty_A * period_S1_A / period_S1_C
else:
    duty_C = duty_A
    pulse_area_C = duty_C * period_S1_C * amp_A  # make this A d T for fixed Ad, T is w.e it is above
    
#tmax = L_cycle * period_S1 * 4
tmax_C = L_cycle * n_cycles * period_S1_C
t_C = np.arange(-0.01, tmax_C + dt, dt)

# build u: wave of gaussians
_, u_gaussian_C = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_C, times=t_C, sigma=0.05, pulse_area=pulse_area_C)
_, u_gaussian_wide_C = signal_gaussian(n_pulse, n_rest, n_cycles, period_S1=period_S1_C, times=t_C, sigma=0.1, pulse_area=pulse_area_C)
# build u: wave of rectangles
_, u_rect_C = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_C, times=t_C, duty=duty_C, pulse_area=pulse_area_C)    

In [None]:
fig, axarr = plt.subplots(3, 1, figsize=(4,4), sharex=True, squeeze=False)
axarr[0,0].plot(t_A/period_S1_A, u_rect_A)
axarr[1,0].plot(t_B/period_S1_B, u_rect_B)
axarr[2,0].plot(t_C/period_S1_C, u_rect_C)
axarr[2,0].set_xlim(0,4.1)
axarr[2,0].set_xlabel(r'$t/T$')
plt.tight_layout()

print(duty_A, duty_B, duty_C)

In [None]:
u_to_plot_A = u_gaussian_wide_A
u_to_plot_B = u_gaussian_wide_B
u_to_plot_C = u_gaussian_wide_C

y_to_plot_A = u_to_plot_A * filter_pointwise_expmemory_hill(t_A, u_to_plot_A, 0.1, 0.5, N=2, plot=False)
y_to_plot_B = u_to_plot_B * filter_pointwise_expmemory_hill(t_B, u_to_plot_B, 0.1, 0.5, N=2, plot=False)
y_to_plot_C = u_to_plot_C * filter_pointwise_expmemory_hill(t_C, u_to_plot_C, 0.1, 0.5, N=2, plot=False)

xlims = None

#fmod1 = 'rectangle'
#fmod2 = 'shifted exponential'
#title = r'$H_1$'
fpath = NB_OUTPUT + os.sep + 'uyuy_stack_hallmark_H4-FreqSens_linearfilter'

_ = plot_tuytuy_stack(t_A, u_to_plot_A, y_to_plot_A, 
                      t_B, u_to_plot_B, y_to_plot_B, 
                      '', fpath, u_shadow=False, simplify=True, figsize=(3.2, 2.6), xlims=xlims)
_ = plot_tuytuy_stack(t_A, u_to_plot_A, y_to_plot_A, 
                      t_B, u_to_plot_B, y_to_plot_B, 
                      '', fpath, u_shadow=False, simplify=False, figsize=(3.2, 2.6), xlims=xlims)

In [None]:
plt.close('all')
plt.plot(t_A / period_S1_A, y_to_plot_A); plt.xlabel('t/T')
plt.show()

In [None]:
plt.close('all')
plt.plot(t_B/period_S1_B, y_to_plot_B); plt.xlabel('t/T')
plt.show()

In [None]:
# Check freq sens curves
from scipy.signal import find_peaks

peaks_A, _ = find_peaks(y_to_plot_A)
peaks_B, _ = find_peaks(y_to_plot_B)
peaks_C, _ = find_peaks(y_to_plot_C)

plt.close('all')
plt.plot(np.arange(len(peaks_A)), y_to_plot_A[peaks_A], '--ok', label=r'$T=%.2f$' % period_S1_A)
plt.plot(np.arange(len(peaks_B)), y_to_plot_B[peaks_B], '--or', label=r'$T=%.2f$' % period_S1_B)
plt.plot(np.arange(len(peaks_C)), y_to_plot_C[peaks_C], '--og', label=r'$T=%.2f$' % period_S1_C)
plt.title(r'Expect higher $f$ (lower $T$) implies faster TTH_discrete')
plt.legend()
plt.show()

### Check freq. sens.

In [None]:
def sigma_of_x_main(x, N=2):
    return 1 / (1 + x ** N)

def h_inverter_H5(u, N=2):
    return 2 * u / (1 + u ** N)

def local_relu(z, theta=0):
    zshift = z - theta
    return np.where(np.asarray(zshift) > 0, zshift, 0)


# Current HW fns for Hallmark 4
# ===========================================
def h_of_u_H4(u):
     return u    
def g_of_x_u_H4(x, u, N=2):
     return u * sigma_of_x_main(x, N=N)


# Current HW fns for Hallmark 5
# ===========================================
def h_of_u_H5(u, N=2):
     return 2 * u / (1 + u ** N)
def g_of_x_u_H5(x, u, gamma=100, N=2):
     return np.tanh(gamma * u) * sigma_of_x_main(x, N=N)


# Current HW fns for Hallmark 3
# ===========================================
def h_of_u_H3(u):
     return u
def g_of_x_u_H3(x, u, gamma=2, p_steep=1, p_threshold=5):
     output_factor = 0.5 * (
             1 + np.tanh(p_steep * (p_threshold - x))
     )
     #return np.tanh(gamma * u) * output_factor
     return u * output_factor

# Current HW fns for Hallmark 6
# ===========================================
def h_of_u_H6(u):
     return u
def g_of_x_u_H6(x, u, p_threshold=0):
     return local_relu(u-x, theta=p_threshold)


def series_form_of_unit(t, u, nchain, alphas, betas, 
                        fn_h_of_u, fn_g_of_x_u, 
                        params_h, params_g
                        ):
    """
    ... 
    """
    arr_unit_x = np.zeros((len(t), nchain))
    arr_unit_output = np.zeros((len(t), nchain))
    
    unit_input = np.copy(u)
    
    for idx in range(nchain):
        
        alpha, beta = alphas[idx], betas[idx]
        
        unit_input_prime = fn_h_of_u(unit_input, **params_h)
        unit_x = W_of_t_convolve(t, unit_input_prime, alpha, beta)
        unit_output = fn_g_of_x_u(unit_x, unit_input, **params_g)
        
        arr_unit_x[:, idx] = unit_x
        arr_unit_output[:, idx] = unit_output
        
        unit_input = np.copy(unit_output)  # input to next unit is the current unit's output
        
    return arr_unit_output, arr_unit_x


"""
def construct_HW_args(foo_h, foo_g, p_h, p_g):
    return foo_h, foo_g, p_h, p_g
"""


In [None]:
# Check freq sens curves
from scipy.signal import find_peaks

color_1 = '#663290'           # regular y(t) purp OR 'darkblue'
color_2 = '#8965AC'           # medium purp       OR 'royalblue'
color_3 = '#B19CC9'           # light purp        OR 'cornflowerblue'


def plot_nchain_response_filter_tuptupab(t1, u1, label1, t2, u2, label2, t3, u3, label3, 
                                         alphas, betas, model_unit_tuple, suffix=''):
    """
    model_tuple has form (h_of_u, g_of_x_u, params_h, params_g) 
    """
    #h_of_u, g_of_x_u, params_h, params_g = model_tuple
    
    arr_unit_output_1, arr_unit_x_1 = series_form_of_unit(t1, u1, len(alphas), alphas, betas, *model_unit_tuple)
    arr_unit_output_2, arr_unit_x_2 = series_form_of_unit(t2, u2, len(alphas), alphas, betas, *model_unit_tuple)
    arr_unit_output_3, arr_unit_x_3 = series_form_of_unit(t3, u3, len(alphas), alphas, betas, *model_unit_tuple)
        
    y_to_plot_1 = arr_unit_output_1[:, -1]
    y_to_plot_2 = arr_unit_output_2[:, -1]
    y_to_plot_3 = arr_unit_output_3[:, -1]
    
    peaks_1, _ = find_peaks(y_to_plot_1, height=1e-8)
    peaks_2, _ = find_peaks(y_to_plot_2, height=1e-8)
    peaks_3, _ = find_peaks(y_to_plot_3, height=1e-8)
    
    plt.close('all')
    #plt.figure(figsize=(6.5,4.5))
    #plt.figure(figsize=(4,3))
    fig = plt.figure(figsize=(2.8,2))
    print('TODO for amplitude sensitivity... how to rescale y-values to 1.0? Divide by initial peak?')
    
    ls_kwargs=dict(
        markersize=3,
        marker='o',
        linestyle='-',
        linewidth=1,
    )
    
    plt.plot(np.arange(len(peaks_3)), y_to_plot_3[peaks_3], c=color_3, label=label3, **ls_kwargs)
    plt.plot(np.arange(len(peaks_2)), y_to_plot_2[peaks_2], c=color_2, label=label2, **ls_kwargs)
    plt.plot(np.arange(len(peaks_1)), y_to_plot_1[peaks_1], c=color_1, label=label1, **ls_kwargs)
    
    #plt.suptitle(r'Expect higher $f$ (lower $T$) implies faster TTH_discrete')
    plt.title(r'nchain=%d, alphas=%s, betas=%s' % (len(alphas), alphas, betas), fontsize=8)
    plt.xlabel(r'period index $k$')
    plt.ylabel(r'peak values $y[k]$')
    plt.legend()
    plt.grid(alpha=0.4)
    #plt.tight_layout()
    
    fpath = NB_OUTPUT + os.sep + 'fig_tuptupab%s' % suffix
    plt.savefig(fpath + '.png')
    plt.savefig(fpath + '.svg')
    plt.show()
    
    plt.close('all')
    _, ax = plt.subplots(3,1, figsize=(6,2), sharex=True)
    plt.suptitle('response')
    ax[0].plot(t1, y_to_plot_1, label=label1, c=color_1)
    ax[0].legend()
    ax[1].plot(t2, y_to_plot_2, label=label2, c=color_2)
    ax[1].legend()
    ax[2].plot(t3, y_to_plot_3, label=label3, c=color_3)
    ax[2].legend()
    #ax[0].set_xlim(0, 4*max(period_S1_A, period_S1_B))
    
    plt.show()
    
    plt.close('all')
    _, ax = plt.subplots(3,1, figsize=(6,2),sharex=True)
    plt.suptitle('input')
    ax[0].plot(t1, u1, label=label1, c=color_1)
    ax[0].legend()
    ax[1].plot(t2, u2, label=label2, c=color_2)
    ax[1].legend()
    ax[2].plot(t3, u3, label=label3, c=color_3)
    ax[2].legend()
    #ax[0].set_xlim(0, 4*max(period1, period2, period3))
    
    print('Habituation stats compare: (TODO)')
    #print('\t (A) TTH_discrete (eps=%.1e) = %d' % (tth_epsilon, tth_disctete_A))
    #print('\t (B) TTH_discrete (eps=%.1e) = %d' % (tth_epsilon, tth_disctete_B))
    
    plt.show()
    return arr_unit_output_1, arr_unit_x_1, peaks_1, y_to_plot_1[peaks_1], arr_unit_output_2, arr_unit_x_2, peaks_2, y_to_plot_2[peaks_2], arr_unit_output_3, arr_unit_x_3, peaks_3, y_to_plot_3[peaks_3]

In [None]:
model_unit_tuple_H4 = (h_of_u_H4, g_of_x_u_H4, dict(), dict(N=2))

In [None]:
alphas = [0.5, 0.25, 0.1]
betas = [0.2 for i in alphas]
assert len(alphas) == len(betas)

suffix = '_K3_checkH4'

outs_nchain3 = plot_nchain_response_filter_tuptupab(t_A, u_rect_A, r'$T=%.2f$' % period_S1_A, 
                                                    t_B, u_rect_B, r'$T=%.2f$' % period_S1_B,
                                                    t_C, u_rect_C, r'$T=%.2f$' % period_S1_C, 
                                                    alphas, betas, model_unit_tuple_H4, suffix=suffix)

nMulti_arr_unit_output_1, nMulti_arr_unit_x_1, nMulti_peaks_1, nMulti_y_peaks_1, nMulti_arr_unit_output_2, nMulti_arr_unit_x_2, nMulti_peaks_2, nMulti_y_peaks_2, nMulti_arr_unit_output_3, nMulti_arr_unit_x_3, nMulti_peaks_3, nMulti_y_peaks_3 = outs_nchain3

In [None]:
alphas = [0.1]
betas = [0.2]
assert len(alphas) == len(betas)

suffix = '_K1_checkH4'

outs_nchain1 = plot_nchain_response_filter_tuptupab(t_A, u_rect_A, r'$T=%.2f$' % period_S1_A, 
                                                    t_B, u_rect_B, r'$T=%.2f$' % period_S1_B,
                                                    t_C, u_rect_C, r'$T=%.2f$' % period_S1_C, 
                                                    alphas, betas, model_unit_tuple_H4, suffix=suffix)
n1_arr_unit_output_1, n1_arr_unit_x_1, n1_peaks_1, n1_y_peaks_1, n1_arr_unit_output_2, n1_arr_unit_x_2, n1_peaks_2, n1_y_peaks_2, n1_arr_unit_output_3, n1_arr_unit_x_3, n1_peaks_3, n1_y_peaks_3 = outs_nchain1

# Hallmark 5 - amplitude sensitivity

Idea: Hammerstein version of the Wiener-Hammerstein generalization of LTI dynamics. 
- need to include a static nonlinear preprocessing to the input signal that acts like an ``inverter"
- $u'=h(u)$ where $h$ is some biologically implementable inverter 
- within some threshold of input $u$: high $u$ get mapped to low $u'$ and vice-versa

In [None]:
def sigma_inverter_digital(u_in):
    
    p_u_prime_max = 1.0
    p_u_in_max = 10.0  # i.e. signals beyond 10.0 are mapped to zero
    slope_mid_segment = p_u_prime_max / p_u_in_max
    
    
    u_out = np.zeros_like(u_in)
    
    u_out = np.where(u_in < p_u_in_max, p_u_prime_max - slope_mid_segment * u_in, u_out)
    u_out = np.where(u_in < 0, 0.0, u_out)
    #u_out = h_cases_1 * h_cases_2
    #u_out = h2#h1 * h2
    
    return u_out

def sigma_inverter_A2(u_in):
    
    sigma_num = u_in
    sigma_den = 1 + u_in ** 2
    
    u_out = 2 * sigma_num / sigma_den
    return u_out
    
def sigma_inverter_A3(u_in):
    
    sigma_num = u_in
    sigma_den = 1 + u_in ** 3
    
    u_out = 2 * sigma_num / sigma_den
    return u_out

def sigma_inverter_B(u_in):
    
    sigma_num = 1
    sigma_den = 1 + u_in ** 1
    
    u_out = sigma_num / sigma_den
    return u_out


In [None]:
u_in = np.linspace(-0.1, 12.0, 1000)
u_in_pos = np.linspace(0.001, 12.0, 1000)

u_out_digital = sigma_inverter_digital(u_in)
u_out_A = sigma_inverter_A2(u_in_pos)
u_out_B = sigma_inverter_A3(u_in_pos)
u_out_C = sigma_inverter_B(u_in_pos)

plt.close('all')
plt.figure(figsize=(1.4,1.8))
#plt.plot(u_in_pos, u_out_digital, label='digital', linewidth=1, zorder=10, color='k')
plt.plot(u_in_pos, u_out_A, label=r'$h_A(u)$ $(N=2)$', c=color_input, zorder=20)
plt.plot(u_in_pos, u_out_B, label=r'$h_A(u)$ $(N=3)$', c=color_input, linestyle=':', zorder=16)
#plt.plot(u_in, u_out_C, label=r'$h_B(u)$', zorder=10)

#plt.legend()
plt.xlabel(r'$u$')
plt.ylabel(r'$h(u)$')

# annotations
#plt.title(r'Pre-processing filter to apply to $u(t)$ for Hallmark 5')
plt.axvline(10, linestyle='--', color='k', alpha=0.4)
#plt.axvline(0, linestyle='--', color='k', alpha=0.8)
plt.axhline(0, linestyle='-', color='k', alpha=0.4)

fpath = NB_OUTPUT + os.sep + 'fig_h(u)_inverter'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.show()

### H5: first show that orig. system does not clearly satisfy H5 

In [None]:
#alphas = [0.5, 0.25, 0.1]
#betas = [0.2 for i in alphas]
alphas = [0.5]
betas = [0.2]
assert len(alphas) == len(betas)

period_select = period_S1_A
t_select = t_A
u_select = u_rect_A

suffix = '_K1_checkH5_noInverter'




outs_nchain1 = plot_nchain_response_filter_tuptupab(t_select, u_select, r'$A=%.2f$' % 1.0, 
                                                    t_select, 2.0 * u_select, r'$A=%.2f$' % 2.0,
                                                    t_select, 5.0 * u_select, r'$A=%.2f$' % 5.0, 
                                                    alphas, betas, model_unit_tuple_H4, suffix=suffix)

nMulti_arr_unit_output_1, nMulti_arr_unit_x_1, nMulti_peaks_1, nMulti_y_peaks_1, nMulti_arr_unit_output_2, nMulti_arr_unit_x_2, nMulti_peaks_2, nMulti_y_peaks_2, nMulti_arr_unit_output_3, nMulti_arr_unit_x_3, nMulti_peaks_3, nMulti_y_peaks_3 = outs_nchain1

In [None]:
plt.close('all')
plt.plot(u_rect_A)
plt.show()

### H5: (`use_inverter=True`) now show that adding the input nonlinear mod makes it satisfy H5 

In [None]:
model_unit_tuple_H5 = (h_of_u_H5, g_of_x_u_H5, dict(N=2), dict(N=2))

In [None]:
#alphas = [0.5, 0.25, 0.1]
#betas = [0.2 for i in alphas]
alphas = [0.2]
betas = [4] #[0.2]
assert len(alphas) == len(betas)

period_select = period_S1_A
t_select = t_A
u_select = u_rect_A

suffix = '_K1_checkH5_useInverter'

c1, c2, c3 = 0.1, 0.5, 1.0  # mults for amplitude

outs_nchain1 = plot_nchain_response_filter_tuptupab(t_select, c1 * u_select, r'$A=%.2f$' % c1, 
                                                    t_select, c2 * u_select, r'$A=%.2f$' % c2,
                                                    t_select, c3 * u_select, r'$A=%.2f$' % c3, 
                                                    alphas, betas, model_unit_tuple_H5, suffix=suffix)

nMulti_arr_unit_output_1, nMulti_arr_unit_x_1, nMulti_peaks_1, nMulti_y_peaks_1, nMulti_arr_unit_output_2, nMulti_arr_unit_x_2, nMulti_peaks_2, nMulti_y_peaks_2, nMulti_arr_unit_output_3, nMulti_arr_unit_x_3, nMulti_peaks_3, nMulti_y_peaks_3 = outs_nchain1

plt.close('all')
fig, axarr = plt.subplots(2,1, figsize=(6,2))
axarr[0].plot(nMulti_arr_unit_output_1, c=color_1, zorder=3)
axarr[1].plot(nMulti_arr_unit_x_1, c=color_1)
axarr[0].plot(nMulti_arr_unit_output_2, c=color_2, zorder=2)
axarr[1].plot(nMulti_arr_unit_x_2, c=color_2)
axarr[0].plot(nMulti_arr_unit_output_3, c=color_3, zorder=1)
axarr[1].plot(nMulti_arr_unit_x_3, c=color_3)

## Hallmark 4: Heuristic plot for frequency sensitivity including 'recovery'

In [None]:
alphas = [0.1]
betas = [0.2]
assert len(alphas) == len(betas)

suffix = '_K1_checkH4'


outs_freq_sens = plot_nchain_response_filter_tuptupab(t_A, u_rect_A, r'$T=%.2f$' % period_S1_A, 
                                                      t_B, u_rect_B, r'$T=%.2f$' % period_S1_B,
                                                      t_C, u_rect_C, r'$T=%.2f$' % period_S1_C, 
                                                      alphas, betas, model_unit_tuple_H4, suffix=suffix)
arr_output_1, arr_x_1, peaks_1, y_peaks_1, arr_output_2, arr_x_2, peaks_2, y_peaks_2, arr_output_3, arr_x_3, peaks_3, y_peaks_3 = outs_freq_sens

In [None]:
plt.close('all')
fig = plt.figure(figsize=(5.6,2))
plt.plot(t_A, arr_x_1, c=color_1)
plt.plot(t_B, arr_x_2, c=color_2)
plt.plot(t_C, arr_x_3, c=color_3)
plt.title('x values (numeric)')
plt.show()

In [None]:
alpha, beta = 0.1, 0.2

#dirac_area = 1.0  # numeric plot uses rectangle pulses with duty d=0.1 and amplitude A=10
amplitude = 10    # numeric plot uses rectangle pulses with duty d=0.1 and amplitude A=10
duty = 0.1


# heursitc for x_high of limit cycle - delta functions of area = 1.0
def get_limitcycle_xhigh_xlow(alpha, beta, period_S1, pulse_area):
    xhigh = beta * 1 / (1 - np.exp(-alpha * period_S1))
    decay_factor = np.exp(- alpha * period_S1)
    xlow = xhigh * decay_factor
    return xhigh, xlow


# heursitc for x_high of limit cycle - delta functions of area = 1.0
def get_traj_xhigh_xlow(alpha, beta, period_S1, duty, amplitude, nperiods):
    pulse_area = amplitude * period_S1 * duty
    
    q = np.exp(- alpha * period_S1)
    Q = beta * pulse_area * (1 - q ** duty) / (alpha * period_S1 * duty)
    
    arr_n = np.arange(0, nperiods)
    
    arr_xhigh = Q * (1- q ** (arr_n+1)) / (1 - q)
    #arr_xhigh = Q * (1- q ** arr_n) / (1 - q)
    arr_xlow = np.zeros_like(arr_xhigh)
    arr_xlow[1:] = arr_xhigh[:-1] * q ** (1 - duty)
    '''
    arr_xhigh = Q * (1- q ** (arr_n)) / (1 - q)
    arr_xlow = arr_xhigh * q ** (1 - duty)'''
    
    return arr_n, arr_xhigh, arr_xlow

def sigma_of_x(x, N=2):
    return 1 / (1 + x ** N)

def get_curves_recovery_using_x_decay(x0, period_S1, amplitude, nperiods):
    """
    k indexes "how many pulses are skipped before applying another pulse" - in order to measure recovery
    """
    recovery_arr_k = np.arange(0, nperiods)
    recovery_arr_x_of_k = np.zeros(nperiods)
    recovery_arr_y_of_k = np.zeros(nperiods)
    
    decay_factor = np.exp(- alpha * period_S1)  
    
    for k in range(1, nperiods + 1):
        decay_factor_kT = decay_factor ** (k - duty)
        x_of_k = x0 * decay_factor_kT
        y_of_k = amplitude * sigma_of_x(x_of_k) 
        
        recovery_arr_x_of_k[k-1] = x_of_k
        recovery_arr_y_of_k[k-1] = y_of_k
    
    return recovery_arr_k, recovery_arr_x_of_k, recovery_arr_y_of_k

In [None]:
ls_kwargs=dict(
    markersize=3,
    marker='o',
    linestyle='-',
    linewidth=1,
)

recovery_ls_kwargs=dict(
    markersize=3,
    marker='o',
    linestyle='-',
    linewidth=1,
    markerfacecolor='w',
)

label1 = r'$T=%.1f$' % period_S1_A
label2 = r'$T=%.1f$' % period_S1_B
label3 = r'$T=%.1f$' % period_S1_C

In [None]:
# get let half
nperiods = 30

z1, z2, z3 = 13, 12, 11

arr_n_1, arr_xhigh_1, arr_xlow_1 = get_traj_xhigh_xlow(alpha, beta, period_S1_A, duty_A, amplitude, nperiods)
arr_n_2, arr_xhigh_2, arr_xlow_2 = get_traj_xhigh_xlow(alpha, beta, period_S1_B, duty_B, amplitude, nperiods)
arr_n_3, arr_xhigh_3, arr_xlow_3 = get_traj_xhigh_xlow(alpha, beta, period_S1_C, duty_C, amplitude, nperiods)

# LEFT: heuristic for response to first nn pulses
arr_y1 = amplitude * sigma_of_x(arr_xlow_1)  
arr_y2 = amplitude * sigma_of_x(arr_xlow_2)  
arr_y3 = amplitude * sigma_of_x(arr_xlow_3) 

# RIGHT: heuristic for response to nn + k^th pulse (recovery)
recovery_arr_k_1, recovery_arr_x_of_k_1, recovery_arr_y_of_k_1 = get_curves_recovery_using_x_decay(arr_xhigh_1[-1], period_S1_A, amplitude, nperiods)
recovery_arr_k_2, recovery_arr_x_of_k_2, recovery_arr_y_of_k_2 = get_curves_recovery_using_x_decay(arr_xhigh_2[-1], period_S1_B, amplitude, nperiods)
recovery_arr_k_3, recovery_arr_x_of_k_3, recovery_arr_y_of_k_3 = get_curves_recovery_using_x_decay(arr_xhigh_3[-1], period_S1_C, amplitude, nperiods)

plt.close('all')
#plt.figure(figsize=(6.5,4.5))
#plt.figure(figsize=(4,3))
fig = plt.figure(figsize=(5.6,2))

plt.plot(np.arange(len(peaks_3)), y_peaks_3, zorder=z3, c='k', **ls_kwargs)
plt.plot(np.arange(len(peaks_2)), y_peaks_2, zorder=z2, c='k', **ls_kwargs)
plt.plot(np.arange(len(peaks_1)), y_peaks_1, zorder=z1, c='k', **ls_kwargs)

plt.plot(arr_n_1, arr_y3, c=color_3, label=label3, zorder=z3, **ls_kwargs)
plt.plot(arr_n_2, arr_y2, c=color_2, label=label2, zorder=z2, **ls_kwargs)
plt.plot(arr_n_3, arr_y1, c=color_1, label=label1, zorder=z1, **ls_kwargs)

plt.axvline(nperiods, linestyle='--', color='k')

plt.plot(recovery_arr_k_1 + nperiods, recovery_arr_y_of_k_1, c=color_1, zorder=z1, **recovery_ls_kwargs)
plt.plot(recovery_arr_k_2 + nperiods, recovery_arr_y_of_k_2, c=color_2, zorder=z2, **recovery_ls_kwargs)
plt.plot(recovery_arr_k_3 + nperiods, recovery_arr_y_of_k_3, c=color_3, zorder=z3, **recovery_ls_kwargs)

#plt.suptitle(r'Expect higher $f$ (lower $T$) implies faster TTH_discrete')
plt.title(r'HEURISTIC (all) nchain=%d, alphas=%s, betas=%s' % (len(alphas), alphas, betas), fontsize=8)
plt.xlabel(r'period index $k$')
plt.ylabel(r'peak values $y[k]$')
plt.legend()
plt.grid(alpha=0.4)
#plt.tight_layout()

plt.gca().set_xticks([0, 10, 20, 30, 40, 50, 60])
plt.gca().set_xticklabels([-30, -20, -10, 0, 10, 20, 30])
plt.xlabel(r'num. pulses previously applied | num. pulses skipped before next pulse')

fpath = NB_OUTPUT + os.sep + 'fig_H4_add_recovery'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.show()


plt.close('all')
fig = plt.figure(figsize=(5.6,2))
plt.plot(t_A, arr_x_1, c=color_1)
plt.plot(t_B, arr_x_2, c=color_2)
plt.plot(t_C, arr_x_3, c=color_3)

buffer = 1e-6
tlow_A = np.arange(0 + (1-duty_A)*period_S1_A, period_S1_A * nperiods + buffer, period_S1_A)
tlow_B = np.arange(0 + (1-duty_B)*period_S1_B, period_S1_B * nperiods + buffer, period_S1_B)
tlow_C = np.arange(0 + (1-duty_C)*period_S1_C, period_S1_C * nperiods + buffer, period_S1_C)
plt.plot(tlow_A, arr_xlow_1, 'o', color='red', markersize=2)
plt.plot(tlow_B, arr_xlow_2, 'o', color='red', markersize=2)
plt.plot(tlow_C, arr_xlow_3, 'o', color='red', markersize=2)

thigh_A = np.arange(period_S1_A-buffer, period_S1_A * nperiods + buffer, period_S1_A)
thigh_B = np.arange(period_S1_B-buffer, period_S1_B * nperiods + buffer, period_S1_B)
thigh_C = np.arange(period_S1_C-buffer, period_S1_C * nperiods + buffer, period_S1_C)
plt.plot(thigh_A, arr_xhigh_1, 'o', color='green', markersize=2, zorder=z1)
plt.plot(thigh_B, arr_xhigh_2, 'o', color='green', markersize=2, zorder=z2)
plt.plot(thigh_C, arr_xhigh_3, 'o', color='green', markersize=2, zorder=z3)

fpath = NB_OUTPUT + os.sep + 'fig_H4_add_recovery_Check-x'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.title('x values (numeric vs heuristic)')
plt.show()


plt.close('all')
fig = plt.figure(figsize=(5.6,2))
plt.plot(t_A, arr_output_1, c=color_1)
plt.plot(t_B, arr_output_2, c=color_2)
plt.plot(t_C, arr_output_3, c=color_3)


fpath = NB_OUTPUT + os.sep + 'fig_H4_add_recovery_Check-y'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.title('y values (numeric vs heuristic)')
plt.show()

## Heatmaps for frequency and amplitude sensitivity
- show that original system does not satisfy both, but using amplitude gate does both

In [None]:
"""
These functions are to return y[infty] / y[0] for each model
(1) Wiener model
(2) Hammerstein-Wiener model
"""

FIX_AREA_GIVEN_AMP_BY_VARYING_DUTY = False

# Guideline for Fig. 3 - A=10, T=1, and d=0.1 (three curves that conserve area = 1.0, with T=1, 2, 5)
#   - note d varies from d=0.1 (T=1) to d=0.05 (T=2) to d=0.02 (T=5)
#   - alpha=0.1, beta=0.2
#   - here rho = y[inf]/y[0] = sigma(x[inf])
#       - setting this to 0.5 can get the 0.5 contour...
#       - if sigma = 1/1+x^N, then sigma(z)=0.5 at z=1 (indep. of N)
#       - x[inf]=1 means... 
# Guideline for Fig. 4 - T=1 and d=0.1 is fixed (three curves that do NOT conserve area, with A=10, 5, 1)
#   - alpha=0.2, beta=4


BASE_DUTY = 0.1            # originally 0.1
BASE_PERIOD = 1.0          # used for computing varying duty
BASE_AMPLITUDE = 10.0
BASE_AREA = BASE_AMPLITUDE * BASE_DUTY * BASE_PERIOD  # 1.0 


def heatmap_sigma_of_x(x, N=2):
    return 1 / (1 + x ** N)

# hammerstein portion - modifies input amplitude using function u'=h(u)
def heatmap_h_of_u(u, N=2):
    return 2 * u / (1 + u ** N)

def get_delta_y_model1_W(val_amp, val_T, val_duty=BASE_DUTY):
    
    # fixed model parameters  
    # - current Fig. 3: alpha = 0.1, beta = 0.2
    alpha = 0.1
    beta = 0.2
    
    pulse_area = val_amp * val_T * val_duty
    q = np.exp(- alpha * val_T)
    Q = beta * pulse_area * (1 - q ** val_duty) / (alpha * val_T * val_duty)
    
    # part 1
    y0 = val_amp  # by definition... this is A * sigma(0) with sigma(0)=1
    
    # part 2
    xhigh_LC = Q / (1 - q) 
    xlow_LC = xhigh_LC * q ** (1 - val_duty)
    
    y_infty = val_amp * heatmap_sigma_of_x(xlow_LC)
    
    return y_infty / y0


def get_delta_y_model2_HW(val_amp, val_T, val_duty=BASE_DUTY):
    # fixed model parameters
    # - current fig, 4: alpha = 0.2, beta = 4, gamma = 100  # HW model only
    alpha = 0.2
    beta = 4
    gamma = 100  # HW model only
    
    val_amp_prime = heatmap_h_of_u(val_amp)
    
    pulse_area_prime = val_amp_prime * val_T * val_duty
    q = np.exp(- alpha * val_T)
    Q = beta * pulse_area_prime * (1 - q ** val_duty) / (alpha * val_T * val_duty)

    # part 1
    y0 = np.tanh(gamma * val_amp)  # by definition... this is tanh(gamma * A) * sigma(0) with sigma(0)=1
    
    # part 2
    xhigh_LC = Q / (1 - q) 
    xlow_LC = xhigh_LC * q ** (1 - val_duty)
    
    y_infty = np.tanh(gamma * val_amp) * heatmap_sigma_of_x(xlow_LC)
    
    return y_infty / y0

In [None]:
# generate the data
A_range = np.linspace(0.1, 10, 40)
T_range = np.linspace(0.1, 10, 40)

T_vs_A_data_1 = np.zeros((len(A_range), len(T_range))) + np.nan   # M_ij -> A_i, T_j (A is 20x rows, T is 30x columns)
T_vs_A_data_2 = np.zeros((len(A_range), len(T_range))) + np.nan


flag_vary_duty=FIX_AREA_GIVEN_AMP_BY_VARYING_DUTY
for j, val_T in enumerate(T_range):
    if flag_vary_duty:
        val_duty = BASE_DUTY * BASE_PERIOD / (val_T)  # need this to be less than 1.0 
        #if val_duty <= 1.0:
        #    break
        if val_duty > 1.0:
            #val_duty = np.nan
            #val_area = np.nan
            valid_region = False
        else:
            valid_region = True
    else:
        val_duty = BASE_DUTY
        valid_region = True

    for i, val_amp in enumerate(A_range):
        if valid_region:
            T_vs_A_data_1[i,j] = get_delta_y_model1_W(val_amp, val_T, val_duty=val_duty)
            T_vs_A_data_2[i,j] = get_delta_y_model2_HW(val_amp, val_T, val_duty=val_duty)
            
            #duty_grid[i,j] = val_duty
            #area_grid[i,j] = val_duty * val_amp * val_T
#T_vs_A_data_1[5, 15] = 20  # test point for imshow x-vs-y games

# plot the data
fig, (ax1, ax2) = plt.subplots(figsize=(9, 3), ncols=2)

imshow_kwargs=dict(
    interpolation='none', 
    origin='lower',
    cmap='Spectral',
    #extent=[min(T_range), max(T_range), max(A_range), min(A_range)]
    extent=[min(T_range), max(T_range), min(A_range), max(A_range)],
    aspect='auto',
    vmin=0.0, vmax=1.0
)

T_vs_A_data_1_masknans = np.ma.masked_where(np.isnan(T_vs_A_data_1), T_vs_A_data_1)

#im1 = ax1.imshow(T_vs_A_data_1, **imshow_kwargs)
im1 = ax1.imshow(T_vs_A_data_1_masknans, **imshow_kwargs)
fig.colorbar(im1, ax=ax1, shrink=0.8)

ax1.set_title(r'Model 1 - $\Delta y = 1$ means no Hab')
ax1.set_xlabel(r'period $T$')
ax1.set_ylabel(r'amplitude $A$')

# can manually set vmin=-1.2, vmax=1.2, for example
print('try force vmin, vmax 0, 1')
im2 = ax2.imshow(T_vs_A_data_2, **imshow_kwargs)

cbar2 = fig.colorbar(im2, ax=ax2, shrink=0.8)
#cbar2.set_label(r'$\Delta y = y_{\infty} \,/\, y_0$')
cbar2.set_label(r'$\Delta y$')

ax2.set_title(r'Model 2 - $\Delta y = 1$ means no Hab')
ax2.set_xlabel(r'period $T$')
#ax2.set_ylabel(r'amplitude $A$')

# extra options for each cbar
#fig.colorbar(im2, ax=ax2, location='right', anchor=(0, 0.3), shrink=0.7)
#cbar = fig.colorbar(im2, ax=ax2, extend='both')
#cbar.minorticks_on()

fpath = NB_OUTPUT + os.sep + 'fig_hab_heatmap_v1_dVary-areaFixed-%s' % FIX_AREA_GIVEN_AMP_BY_VARYING_DUTY
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.show()

### Repeat plot above but now logscale

In [None]:
import matplotlib.colors as colors
from matplotlib.colors import ListedColormap, LinearSegmentedColormap


# cmap code from here https://stackoverflow.com/questions/74731282/how-to-create-a-linear-colormap-with-color-defined-at-specific-values-with-matpl
def get_colormap(values, colors, name="custom"):
    if values is not None:
        values = np.sort(np.array(values))
        values = np.interp(values, (values.min(), values.max()), (0., 1.))
        cmap = LinearSegmentedColormap.from_list(name, list(zip(values, colors)), N=512)
    else:
        cmap = LinearSegmentedColormap.from_list(name, colors, N=512)  # will assume equidistant values
    #newcmp = LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)
    return cmap

#list_of_values = np.array([0., 0.1, 0.5, 0.9, 0.999, 1.])
#list_of_colors = ["#66348F", "#8966AC", "#B19CC8", "#f0e8f7", "#f7f3fb", "#fcfafd"]

list_of_values = np.array([0., 0.01, 0.2, 0.5, 0.9, 1.])
#list_of_values = None
list_of_colors = ["#572c7a", "#66348F", "#8966AC", "#B19CC8", "#f0e8f7", "#f7f3fb"]  # darkest at front, lightest at end (1.0)


cmap_lingseg_purples = get_colormap(list_of_values, list_of_colors)


def make_loglog_AT_heatmaps(flag_vary_duty=FIX_AREA_GIVEN_AMP_BY_VARYING_DUTY):
    # generate the data
    #A_range = np.linspace(1, 20, 20)
    #T_range = np.linspace(0.1, 10, 30)
    
    #A_range = np.logspace(-1, 2, 30)
    #T_range = np.logspace(-1, 2, 30)
    A_range = np.logspace(-1.1, 2, 60)
    T_range = np.logspace(-1.3, 2, 60)  # was -1.5 in v1
    
    T_vs_A_data_1 = np.zeros((len(A_range), len(T_range))) + np.nan   # M_ij -> A_i, T_j (A is 20x rows, T is 30x columns)
    T_vs_A_data_2 = np.zeros((len(A_range), len(T_range))) + np.nan
    
    duty_grid = np.zeros_like(T_vs_A_data_1) + np.nan
    area_grid = np.zeros_like(T_vs_A_data_1) + np.nan
    
    for j, val_T in enumerate(T_range):
        if flag_vary_duty:
            val_duty = BASE_DUTY * BASE_PERIOD / (val_T)  # need this to be less than 1.0 ... TODO this should not depend on val_amp
            #if val_duty <= 1.0:
            #    break
            if val_duty > 1.0:
                #val_duty = np.nan
                #val_area = np.nan
                valid_region = False
            else:
                valid_region = True
        else:
            val_duty = BASE_DUTY
            valid_region = True
    
        for i, val_amp in enumerate(A_range):
            if valid_region:
                T_vs_A_data_1[i,j] = get_delta_y_model1_W(val_amp, val_T, val_duty=val_duty)
                T_vs_A_data_2[i,j] = get_delta_y_model2_HW(val_amp, val_T, val_duty=val_duty)
                
                duty_grid[i,j] = val_duty
                area_grid[i,j] = val_duty * val_amp * val_T
    print('Done data loop')
    
    # plot the data
    #fig, (ax1, ax2, ax3, ax4) = plt.subplots(figsize=(18, 3), ncols=4)
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(figsize=(16, 2.5), ncols=4)
    
    cmap_lognorm_01 = colors.LogNorm(vmin=0.0, vmax=1.0)
    # see here 
    # https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.pcolormesh.html
    grey_for_nan = '#D3D3D3'
        
    cmap_rho = plt.cm.get_cmap("Purples_r").copy()  # Blues | Spectral | Purples_r | GnBu_r
    cmap_rho = cmap_lingseg_purples
    # if lognorm ----> use Spectral
    cmap_rho.set_bad(grey_for_nan, 1.)
    
    cmap_duty = plt.cm.get_cmap("Blues").copy()  # Blues | Spectral | Purples_r | ...
    cmap_duty.set_bad(grey_for_nan, 1.)
    
    cmap_area = plt.cm.get_cmap("Reds").copy()  # Blues | Spectral | Purples_r | ...
    cmap_area.set_bad(grey_for_nan, 1.)
    
    pcolormesh_kwargs=dict(
        cmap=cmap_rho, # Blues | Spectral | Purples_r | ...
        shading='auto',
        vmin=0.0, vmax=1.0,
        #norm=cmap_lognorm_01
        #norm=colors.LogNorm(),
    )
    pcolormesh_duty_kwargs=dict(
        cmap=cmap_duty,
        shading='auto',
        #vmin=0.0, vmax=1.0
        #norm=colors.LogNorm(vmin=0.0, vmax=1.0, clip=True),
        norm=colors.LogNorm(),
    )
    pcolormesh_area_kwargs=dict(
        cmap=cmap_area,
        shading='auto',
        #vmin=0.0, vmax=1.0
        #norm=colors.LogNorm(),
    )
    
    #im1 = ax1.imshow(T_vs_A_data_1, **imshow_kwargs)
    im1 = ax1.pcolormesh(T_range, A_range, T_vs_A_data_1,
                         #norm=colors.LogNorm(vmin=0.0, vmax=1.0),
                         #norm=colors.LogNorm(),
                         **pcolormesh_kwargs)
    cbar1 = fig.colorbar(im1, ax=ax1, shrink=0.8)
    ax1.set_title(r'W - $\rho = 1$ -- no Hab')
    ax1.set_ylabel(r'amplitude $A$')
    
    # can manually set vmin=-1.2, vmax=1.2, for example
    #im2 = ax2.imshow(T_vs_A_data_2, **imshow_kwargs)
    im2 = ax2.pcolormesh(T_range, A_range, T_vs_A_data_2, 
                         **pcolormesh_kwargs)
    cbar2 = fig.colorbar(im2, ax=ax2, shrink=0.8)
    #cbar2.set_label(r'$\Delta y = y_{\infty} \,/\, y_0$')
    #cbar2.set_label(r'$\rho=y_{\infty}/y_{0}$')
    cbar2.set_label(r'$\rho \,=\, \frac{y[\infty]}{y[0]}$')
    ax2.set_title(r'WH - $\rho = 1$ -- no Hab')
    
    # now plot duty and area
    im3 = ax3.pcolormesh(T_range, A_range, duty_grid, 
                         **pcolormesh_duty_kwargs)
    cbar3 = fig.colorbar(im3, ax=ax3, shrink=0.8)
    cbar3.set_label(r'$d$')
    ax3.set_title(r'Base duty $d_{0}=%.1e$' % BASE_DUTY)
    
    # can manually set vmin=-1.2, vmax=1.2, for example
    #im2 = ax2.imshow(T_vs_A_data_2, **imshow_kwargs)
    im4 = ax4.pcolormesh(T_range, A_range, area_grid, **pcolormesh_area_kwargs)
    cbar4 = fig.colorbar(im4, ax=ax4, shrink=0.8)
    #cbar2.set_label(r'$\Delta y = y_{\infty} \,/\, y_0$')
    cbar4.set_label(r'$\Lambda$')
    ax4.set_title(r'Area $\Lambda$')
    
    for ax in (ax1, ax2, ax3, ax4):
        ax.set_xlabel(r'period $T$')
        ax.set_xscale('log')
        ax.set_yscale('log')
    
    # extra options for each cbar
    #fig.colorbar(im2, ax=ax2, location='right', anchor=(0, 0.3), shrink=0.7)
    #cbar = fig.colorbar(im2, ax=ax2, extend='both')
    #cbar.minorticks_on()
    
    # =========================
    # Add contours
    # =========================
    x = T_range
    y = A_range
    Tmesh, Amesh = np.meshgrid(x, y)
    #Z = np.exp(-X**2 - Y**2)
    
    # =========================
    # Left plot (Model 1) contours
    # =========================
    alpha1 = 0.1 
    beta1 = 0.2
    
    # additional contour for rho=0.5
    # ===========================================================================
    q_mesh = np.exp(- alpha1 * Tmesh)
    Z1rho = beta1/alpha1 * Amesh * (1 - q_mesh ** duty_grid) / (1 - q_mesh) * q_mesh ** (1 - duty_grid)
    CS1rho = ax1.contour(Tmesh, Amesh, Z1rho, levels=[1], linestyles=['--'], colors=['k'])  # cyan or #2BACE2
    fmt = {}; strs = [r'$\rho = 0.5$']
    for l, s in zip(CS1rho.levels, strs):
        fmt[l] = s
    # Label every other level using strings
    ax1.clabel(CS1rho, CS1rho.levels, inline=True, fmt=fmt, fontsize=10)
    
    # =========================
    # Right plot (Model 2) contours
    # =========================
    alpha2 = 0.2 
    beta2 = 4.0
    
    # additional contour for rho=0.5
    # ===========================================================================
    q_mesh = np.exp(- alpha2 * Tmesh)
    Amesh_prime = heatmap_h_of_u(Amesh)
    Z2rho = beta2/alpha2 * Amesh_prime * (1 - q_mesh ** duty_grid) / (1 - q_mesh) * q_mesh ** (1 - duty_grid)
    CS2rho = ax2.contour(Tmesh, Amesh, Z2rho, levels=[1], linestyles=['--'], colors=['k'])  # cyan
    fmt = {}; strs = [r'$\rho = 0.5$']
    for l, s in zip(CS2rho.levels, strs):
        fmt[l] = s
    # Label every other level using strings
    ax2.clabel(CS2rho, CS2rho.levels, inline=True, fmt=fmt, fontsize=10)
    
    # end and output
    plt.suptitle('fix $\Lambda$, vary $d$ [%s]' % flag_vary_duty)
    
    #ax2.yaxis.set_major_formatter(plt.NullFormatter())
    #ax1.xaxis.set_tick_params(labelsize=12)

    fpath = NB_OUTPUT + os.sep + 'fig_hab_heatmap_v2_log_dVary-areaFixed-%s' % flag_vary_duty
    plt.savefig(fpath + '.png')
    plt.savefig(fpath + '.svg')
    plt.show()
    
make_loglog_AT_heatmaps(flag_vary_duty=True)
make_loglog_AT_heatmaps(flag_vary_duty=False)

# Hallmark 6 implement using $y=ReLU(u-x)$

In [None]:
model_unit_tuple_H6 = (h_of_u_H6, g_of_x_u_H6, dict(), dict(p_threshold=0))

In [None]:
from matplotlib.patches import Rectangle

WIDTH_ONECOL = 3.54331
GLOBAL_FS = 10  # fontsize
GLOBAL_TICKLABELS = 9  # fontsize
GLOBAL_FPROP = {'family': 'arial', 'size': GLOBAL_FS}

grid_alpha = 0.4

'''
color_input = '#00AEEF'
#color_memory = '#DEB87A'  
color_memory = '#C1A16B'  # dark: #DEB87A || darker: #D0AD73 ||| darkest: #C1A16B
color_response = '#662D91'

# these colors taken from first version of pdf plots
color_martin_input_u = '#000000'
color_martin_memory_green = '#2BB34B'
color_martin_response_blue = '#1870B8'     # 3953A4 darker, 1870B8 lighter
color_martin_sensitize_orange = '#D48149'  # D48149 darker, FBA919 lighter?
'''

color_input_choice = color_input
color_memory_choice = color_memory
color_response_choice = color_response
color_aux_stimulus_sensitize = '#B3446C'

zorder = 5
linewidth = 1.0
linestyle = '-'

line_kwargs_input = dict(color=color_input_choice, linewidth=linewidth, linestyle=linestyle, zorder=zorder)
line_kwargs_memory = dict(color=color_memory_choice, linewidth=linewidth, linestyle=linestyle, zorder=zorder)
#line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle='dotted')
#line_kwargs_memory_sigma = dict(color=color_memory, linewidth=linewidth, linestyle=linestyle)
line_kwargs_response = dict(color=color_response_choice, linewidth=linewidth, linestyle=linestyle, zorder=zorder)
line_kwargs_response_approx = dict(color=color_response_choice, linewidth=linewidth, zorder=zorder) 
                                   #linestyle='densely dashed', )
line_kwargs_stim2 = dict(color=color_aux_stimulus_sensitize, linewidth=linewidth, linestyle=linestyle, zorder=zorder)

In [None]:
def figure_stack_H6(t1, u1, x1, y1, t2, u2, x2, y2, size_small=True):
    """
    1) drawing a vertical line at t= 15 
        (and explaining that both first combs are aligned to end there) 
    2) maybe even shift t = 0 to that point to support this further
    
    3) and a set of two vertical lines at t= 30 and 35 with a shading in between them 
        (to point out that we have to wait for 5T longer in the bottom case)
    """
    SHIFT_T_AXIS = -20
    
    #DASH_STYLE = [2, 1]
    DASH_STYLE_U = [1.5, 0.75]
    DASH_STYLE_XY = [0.7, 0.7]

    plt.close('all')
    ###nrows = 2
    ###fig, axarr = plt.subplots(nrows, 1, sharex=True, squeeze=False, figsize=figsize)
    # Figure init notes
    # - use layout constrained to avoid axis labels cut off (else adjust manually...)
    if size_small:
        figsize = (WIDTH_ONECOL/2*1.2, 3)
        height_ratios = [0.3, 0.3, 0.5, 0.35] # v1: [0.3, 0.3, 0.66, 0.35]
        hspace = 0.1
        fname_size = '_small'
    else:
        figsize = (WIDTH_ONECOL, 3.6)
        height_ratios = [0.3, 0.3, 0.5, 0.35]  # v1: [0.3, 0.3, 0.66, 0.35]
        hspace = 0.09
        fname_size = ''
    
    fig = plt.figure(figsize=figsize, layout="constrained")
    #fig.suptitle('Figure for H6 (vB)')
    
    nrows = 4
    gs = GridSpec(nrows, 1, figure=fig, height_ratios=height_ratios, hspace=hspace)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0)
    ax3 = fig.add_subplot(gs[3], sharex=ax0)
    axarr = [ax0, ax1, ax2, ax3]
    
    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0].plot(t1 + SHIFT_T_AXIS, u1, 
                  label=r'$u_1$', **line_kwargs_input)
    ymax = np.max(u1)
    axarr[0].set_ylim(-0.2*ymax, ymax*1.2)
    
    cu2, = axarr[1].plot(t2 + SHIFT_T_AXIS, u2, 
                  label=r'$u_2$', **line_kwargs_input)    
    ymax = np.max(u2)
    axarr[1].set_ylim(-0.2*ymax, ymax*1.2)
    cu2.set_dashes(DASH_STYLE_U)  # 2pt line, 1pt break 
    
    axarr[2].plot(t1 + SHIFT_T_AXIS, x1, 
                  label=r'$x_1$', **line_kwargs_memory)
     
    cx2, = axarr[2].plot(t2 + SHIFT_T_AXIS, x2, 
                  label=r'$x_2$', **line_kwargs_memory)
    cx2.set_dashes(DASH_STYLE_XY)  # 2pt line, 1pt break
    
    assert np.max(u1) == np.max(u2)
    line_kwargs_horiz_0 = dict(color='k', linewidth=0.5, linestyle='--', zorder=10, alpha=0.5)
    axarr[2].axhline(np.max(u1), **line_kwargs_horiz_0)
    
    axarr[3].plot(t1 + SHIFT_T_AXIS, y1, 
                  label=r'$y_1$', **line_kwargs_response)
    cy2, = axarr[3].plot(t2 + SHIFT_T_AXIS, y2, 
                  label=r'$y_2$', **line_kwargs_response)
    axarr[3].set_ylim(-0.1*ymax, ymax*1.1)
    cy2.set_dashes(DASH_STYLE_XY)  # 2pt line, 1pt break
    
    axarr[3].set_yticks([0, 5, 10])
    axarr[3].set_yticklabels([0, 5, 10])
    
    axarr[0].tick_params(labelbottom=False)
    axarr[1].tick_params(labelbottom=False)
    axarr[2].tick_params(labelbottom=False)
    #axarr[3].tick_params(labelbottom=False) 
    
    
    # add vertical lines
    buffer = 0.5
    
    SHIFT_T_SECOND = 20
    vertical_0 = SHIFT_T_SECOND + SHIFT_T_AXIS + buffer
    line_kwargs_vert_0 = dict(color='k', linewidth=0.75, linestyle='--', zorder=10, alpha=0.8)
    
    vertical_1 = 30 + SHIFT_T_AXIS - buffer
    line_kwargs_vert_1 = dict(color='k', linewidth=0.75, linestyle='--', zorder=10, alpha=0.5)

    vertical_2 = 35 + SHIFT_T_AXIS - buffer
    line_kwargs_vert_2 = dict(color='k', linewidth=0.75, linestyle='--', zorder=10, alpha=0.5)

    for idx in range(nrows):
        if idx > 1:
            #axarr[idx].axvline(vertical_0, **line_kwargs_vert_0)
            continue  # no vert line for now...

        # plot the rectangle
        #rect_botleft = (SHIFT_T_SECOND + SHIFT_T_AXIS + 1.5*buffer, 0)
        #rect_botleft = (SHIFT_T_AXIS + 1.5*buffer, 0)
        rect_botleft = (0 + 1.5*buffer, 0)
        if idx == 0:
            umax = np.max(u1)
            rect_width = 13 - 3*buffer
            axarr[idx].add_patch(
                Rectangle(rect_botleft, rect_width, umax, color='#EEEFEF', alpha=0.8, zorder=0))
        elif idx == 1:
            umax = np.max(u2)
            rect_width = 13 - 3*buffer
            axarr[idx].add_patch(
                Rectangle(rect_botleft, rect_width, umax, color='#EEEFEF', alpha=0.8, zorder=0))

    
    for idx, ax in enumerate(axarr):
        if idx > 1:
            ax.grid(alpha=0.25)
        ax.legend(fontsize=GLOBAL_FS - 1, loc='upper right')
        ax.tick_params(axis='x', labelsize=GLOBAL_TICKLABELS)
        ax.tick_params(axis='y', labelsize=GLOBAL_TICKLABELS)
    axarr[-1].set_xlabel(r'$t/T$', **GLOBAL_FPROP)

    #plt.tight_layout()  # can't use this
    
    fpath = NB_OUTPUT + os.sep + 'figure_stack_H6%s' % fname_size
    print('Saving figure to', fpath)
    plt.savefig(fpath + '.pdf', dpi=450)
    plt.savefig(fpath + '.svg', dpi=450)
    plt.show()

    return axarr

In [None]:
flag_ylabels = False
flag_legends = False

flag_grid = False
flag_right_and_top_ax = False
flag_axhlines = False

# plot params
grid_alpha = 0.4


def MOD_2_plot_uxry_stack(t, u, rate_decay, rate_grow, model_unit_tuple, title, fpath, 
                          xlims=(-5, 60), x_thresh=None):
    """
    Args:
        fmod: 'rectangle', 'gaussian'
    """
    
    #def sigma_of_x(x, N=2):
    #    return 1 / (1 + x ** N)
    
    alphas = [rate_decay]
    betas = [rate_grow]
    arr_unit_output, arr_unit_x = series_form_of_unit(t, u, len(alphas), alphas, betas, *model_unit_tuple)
    
    x_of_t = arr_unit_x[:, -1]  #W_of_t_convolve(t, u, rate_decay, rate_grow)
    #filter_pointwise_expmemory_hill(t, u, rate_decay, rate_grow, N=2, plot=False)
    r_of_t = u - x_of_t
    print('warning r_of_t is not defined for specified model | g = relu only')
    y_of_t = arr_unit_output[:, -1]   # local_relu(r_of_t) 
    amp = np.max(u)

    # plot params
    grid_alpha = 0.6

    plt.close(); 
    fig, axarr = plt.subplots(4, 1, sharex=True, squeeze=False, figsize=(7, 5))

    axarr[0, 0].set_title(title)
    if flag_ylabels:
        axarr[0,  0].set_ylabel(r'input $u(t)$')
        axarr[1,  0].set_ylabel(r'memory $x(t)$')
        axarr[2,  0].set_ylabel(r'filter $r(t)$')
        axarr[3,  0].set_ylabel(r'output $y(t)$')
    axarr[-1, 0].set_xlabel(r'$t$')

    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0,0].plot(t, u, label=r'$u(t)$', **line_kwargs_input)
    axarr[0,0].set_yticks([0, amp])
    axarr[0,0].set_yticklabels(['0', '%d' % int(amp)])
    
    #axarr[1,0].set_title(r'$foo 2$')
    axarr[1,0].plot(t, x_of_t, label=r"$x(t)=\int_{0}^t \,\beta e^{-\alpha (t-t')} u(t') dt'$", **line_kwargs_memory)
    if flag_axhlines:
        axarr[1,0].axhline(1, linestyle='--', alpha=0.5, color='k')
    axarr[1,0].set_ylim(-0.05, np.max(x_of_t)*1.05)
    if x_thresh is not None:
        axarr[1, 0].axhline(x_thresh, linestyle='--', c='k', linewidth=linewidth, zorder=2)
    
    #axarr[2,0].set_title(r'$foo 3$')
    response_before_relu, = axarr[2,0].plot(t, r_of_t, label=r'$r(t) = u - x$', **line_kwargs_response)
    if flag_axhlines:
        axarr[2,0].axhline(0.5, linestyle='--', alpha=0.5, color='k')
    #axarr[2,0].set_ylim(-0.05, 1.05)
    #axarr[2,0].set_yticks([0, 0.5, 1.0])
    #axarr[2,0].set_yticklabels(['0', '', '1'])
    #**line_kwargs_response_approx)
    response_before_relu.set_dashes([2, 1])  # 2pt line, 1pt break
    axarr[2, 0].axhline(0, linestyle='-', c='k', linewidth=linewidth, zorder=2)

    #axarr[3,0].set_title(r'foo 4')
    #axarr[3,0].plot(t, amp * r_of_t, zorder=10, label=r'$r(t)=\sigma(x(t))$', **line_kwargs_memory_sigma)
    axarr[3,0].plot(t, y_of_t, label=r'$y(t)=\textrm{ReLU}(u - x)$', **line_kwargs_response)
    if flag_axhlines:
        axarr[3,0].axhline(0.5 * np.max(u), linestyle='--', alpha=0.5, color='k')
    
    for idx in range(0, 4):
        if idx > 0:
            height_scale = axarr[idx,0].get_ylim()[1] / np.max(u)
            axarr[idx,0].plot(t, u * height_scale, '-k', alpha=0.05, zorder=1)
        if flag_grid:
            axarr[idx,0].grid(alpha=grid_alpha)
        if not flag_right_and_top_ax:
            axarr[idx,0].spines[['right', 'top']].set_visible(False)
        
        if flag_legends:
            axarr[idx,0].legend()
    
    if xlims is not None:
        plt.xlim(xlims)
    plt.savefig(fpath + '.pdf')
    plt.savefig(fpath + '.svg')
    plt.show()
    return axarr

In [None]:
# settings for hallmark 6
H6_rate_decay = 0.1
H6_rate_grow = 4

H6_duty = 0.1
H6_pulse_area = 1  #0.5
H6_period = 1.0

In [None]:
n_pulse = 20
n_rest = 12

L_cycle = n_pulse + n_rest
n_cycles = 2  # 3
tmid_list = [H6_period * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse + 1)]

cutoff_pulse_end = L_cycle + 4

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = 55 #L_cycle * n_cycles * period_S1
t1 = np.arange(-0.5, tmax + dt, dt)

_, u1 = signal_rectangle(n_pulse, n_rest, n_cycles, times=t1, 
                         period_S1=H6_period, duty=H6_duty, pulse_area=H6_pulse_area)

u1 = np.where(t1 > cutoff_pulse_end * H6_period, 0, u1)

fmod = 'rectangle_1'
title = r'%s $u(t)$ with nonlinear filter: $\alpha=%.2f$, $\beta=%.2f$' % (fmod, H6_rate_decay, H6_rate_grow)
fpath = NB_OUTPUT + os.sep + 'mod_uxry_stack_%s' % fmod
_ = MOD_2_plot_uxry_stack(t1, u1, H6_rate_decay, H6_rate_grow, model_unit_tuple_H6, title, fpath)

In [None]:
n_pulse = 20  # dummy npulse for u2 (will truncate below to 4)
n_rest = 12

L_cycle = n_pulse + n_rest
n_cycles = 2  # 3
tmid_list = [H6_period * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse + 1)]

cutoff_pulse_front = 20 - 4  # "4" is the real npulse
cutoff_pulse_end = L_cycle + 4  # "4" is the real npulse

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = 55 #L_cycle * n_cycles * period_S1
t2 = np.arange(-0.5, tmax + dt, dt)

_, u2 = signal_rectangle(n_pulse, n_rest, n_cycles, times=t2, 
                         period_S1=H6_period, duty=H6_duty, pulse_area=H6_pulse_area)

u2 = np.where(t2 < cutoff_pulse_front * H6_period, 0, u2)
u2 = np.where(t2 > cutoff_pulse_end * H6_period, 0, u2)

fmod = 'rectangle_2'
title = r'%s $u(t)$ with nonlinear filter: $\alpha=%.2f$, $\beta=%.2f$' % (fmod, H6_rate_decay, H6_rate_grow)
fpath = NB_OUTPUT + os.sep + 'mod_uxry_stack_%s' % fmod
_ = MOD_2_plot_uxry_stack(t2, u2, H6_rate_decay, H6_rate_grow, model_unit_tuple_H6, title, fpath)

In [None]:
x1 = W_of_t_convolve(t1, u1, H6_rate_decay, H6_rate_grow)
y1 = local_relu(u1 - x1)
 
x2 = W_of_t_convolve(t2, u2, H6_rate_decay, H6_rate_grow)
y2 = local_relu(u2 - x2)

figure_stack_H6(t1, u1, x1, y1, t2, u2, x2, y2, size_small=True)
figure_stack_H6(t1, u1, x1, y1, t2, u2, x2, y2, size_small=False)


# Hallmark 8 - dishabituation

In [None]:
def figure_stack_H8(t, u1, u2, x, y, 
                    size_small=True, 
                    hide_fill_between_glitch=True, xlims=None, hide_grids=False, hide_legends=False):

    plt.close('all') 
    ###nrows = 2
    ###fig, axarr = plt.subplots(nrows, 1, sharex=True, squeeze=False, figsize=figsize)
    # Figure init notes
    # - use layout constrained to avoid axis labels cut off (else adjust manually...)
    #fig = plt.figure(figsize=(WIDTH_ONECOL, 3.6), layout="constrained")
    
    if size_small:
        figsize = (WIDTH_ONECOL/2*1.2, 3) #2.4)
        height_ratios = [0.3, 0.3, 0.5, 0.35]
        hspace = 0.1
        fname_size = '_small'
    else:
        figsize = (WIDTH_ONECOL, 2.9)
        height_ratios = [0.3, 0.3, 0.5, 0.35]
        hspace = 0.08
        fname_size = ''
        
    fig = plt.figure(figsize=figsize, layout="constrained")
    #fig.suptitle('Figure for H8 (dishabituation)')
    
    #gs = GridSpec(4, 1, figure=fig, height_ratios=[0.5,0.5,1,1], hspace=0.1)
    #gs = GridSpec(4, 1, figure=fig, height_ratios=[0.3, 0.3, 0.6, 0.4], hspace=0.08)
    nrows = 4
    gs = GridSpec(nrows, 1, figure=fig, height_ratios=height_ratios, hspace=hspace)

    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0)
    ax3 = fig.add_subplot(gs[3], sharex=ax0)
    axarr = [ax0, ax1, ax2, ax3]
    
    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0].plot(t, u1, label=r'$u$', **line_kwargs_input)
    axarr[0].tick_params(labelbottom=False)    
    ymax = np.max(u1)
    axarr[0].set_ylim(-0.2*ymax, ymax*1.2)
    
    axarr[1].plot(t, u2, label=r'$v$', **line_kwargs_stim2) #, linewidth=1.0, zorder=4, alpha=0.6)
    axarr[1].tick_params(labelbottom=False) 
    ymax = np.max(u2)
    axarr[1].set_ylim(-0.2*ymax, ymax*1.2)
    #line_response_approx.set_dashes([2, 1])  # 2pt line, 1pt break
    #axarr[1].axhline(0, linestyle='-', c='k', linewidth=linewidth, zorder=2)
    #axarr[1].set_ylim(-2.01,2.01)

    #axarr[-1,0].set_title(r'foo 4')
    axarr[2].plot(t, x, 
                  label=r'$x$', **line_kwargs_memory)
    axarr[2].tick_params(labelbottom=False)
    if not hide_fill_between_glitch:
        axarr[2].fill_between(
            t, 0, np.max(x) * u2,
            facecolor=color_aux_stimulus_sensitize,
            alpha=0.25,
            zorder=2)
        
    axarr[3].plot(t, y, 
                  label=r'$y$', **line_kwargs_response)
    if not hide_fill_between_glitch:
        axarr[3].fill_between(
            t, 0, np.max(x) * u2,
            facecolor=color_aux_stimulus_sensitize,
            alpha=0.25,
            zorder=2)
    axarr[3].set_ylim(-0.1*np.max(y), np.max(y)*1.1)

    '''
    ax0.set_ylabel(r'$u(t)$', fontsize=GLOBAL_FS)
    ax1.set_ylabel(r'$v(t)$', fontsize=GLOBAL_FS)
    ax2.set_ylabel(r'$x(t)$', fontsize=GLOBAL_FS)
    ax3.set_ylabel(r'$y(t)$', fontsize=GLOBAL_FS)'''
    for idx, ax in enumerate(axarr):
        if idx > 1 and not hide_grids:
            ax.grid(alpha=grid_alpha)
        #if idx != 3:
        #ax.legend(fontsize=9)
        ax.tick_params(axis='x', labelsize=GLOBAL_TICKLABELS)
        ax.tick_params(axis='y', labelsize=GLOBAL_TICKLABELS)
        
    axarr[-1].set_xlabel(r'$t/T$', **GLOBAL_FPROP)

    if xlims is not None:
        ##xlims = (-1, 31)
        axarr[-1].set_xlim(xlims)
    
    if not hide_legends:
        for idx, ax in enumerate(axarr):
            ax.legend(fontsize=9)
    #plt.tight_layout()  # can't use this
    
    fpath = NB_OUTPUT + os.sep + 'figure_stack_H8%s' % fname_size
    print('Saving figure to', fpath)
    plt.savefig(fpath + '.pdf', dpi=450)
    plt.savefig(fpath + '.svg', dpi=450)
    plt.show()

    return axarr

#### Hallmark 8: settings

In [None]:
# settings for hallmark 6
H8_rate_decay = 0.1
H8_rate_grow = 1

H8_duty = 0.1
H8_pulse_area = 1.0
H8_period = 1.0

In [None]:
n_pulse = 20
n_rest = 10

L_cycle = n_pulse + n_rest
n_cycles = 1  # 3
tmid_list = [H8_period * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse + 1)]

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = 40 #L_cycle * n_cycles * period_S1
t1 = np.arange(-0.5, tmax + dt, dt)

_, u1 = signal_rectangle(n_pulse, n_rest, n_cycles, times=t1, 
                         period_S1=H8_period, duty=H8_duty, pulse_area=H8_pulse_area)

In [None]:
fmod = 'rectangle_1'
title = r'%s $u(t)$ with nonlinear filter: $\alpha=%.2f$, $\beta=%.2f$' % (fmod, H8_rate_decay, H8_rate_grow)
fpath = NB_OUTPUT + os.sep + 'mod_uxry_stack_%s' % fmod
_ = MOD_2_plot_uxry_stack(t1, u1, H6_rate_decay, H6_rate_grow, model_unit_tuple_H6, title, fpath, xlims=(-5, 45))

#### Compute ODE traj
- option to hack around 2x control objects (jit issues...) is to include the v(t) args (height, loc, duration) as ODE params... ugly but quick

In [None]:
timedep_alpha = np.ones_like(t1) * H6_rate_decay

base_alpha = H6_rate_decay

apply_sensitizer_at_period_k = 15
period_integer_buffer = 0.05
kappa = 10

timedep_alpha = np.zeros_like(t1)
timedep_alpha = np.where(t1 < (apply_sensitizer_at_period_k + period_integer_buffer) * H6_period, base_alpha, kappa)
timedep_alpha = np.where(t1 > (apply_sensitizer_at_period_k + 1 - period_integer_buffer) * H6_period, base_alpha, timedep_alpha)

In [None]:
def get_ode_traj_nb(ode_name, ode_params, stimulus_suffix, additional_control_objects=[]):
    
    integration_style_stitching = True
    stitch_kwargs = {
        'forcetol': (1e-8, 1e-4),  # ATOL, RTOL = 1e-8, 1e-4   --or--   1e-12, 1e-6
        'max_step': np.Inf,        # np.Inf default
    }
    stitch_dynamic_max_step = True
    
    traj_label = 'not_preset'
    scale_tspan = True  # if True, stretch tspan based on pulse period (larger period => longer tspan) 

    # ODE settings
    # ============================================================
    if ode_params is None:
        x0, tspan, base_params_ode, _ = analyze_freq_amp_presets(ode_name)
    else:
        x0, tspan, _, _ = analyze_freq_amp_presets(ode_name)
        base_params_ode = ode_params
    
    # STIMULUS settings
    # ============================================================
    assert stimulus_suffix in ['S2', 'S1', 'S0_one']  # todo extend
    if stimulus_suffix == 'S2':
        tspan = [0, 200.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_pulsewave_of_pulsewaves
        S1_duty = 0.1
        S1_period = 1.0  #1.0, 0.01, 0.1
        base_params_stim = [
            delta_fn_amp(S1_duty, S1_period),  # amplitude
            S1_duty,    # duty in [0,1]
            S1_period,  # period of stimulus pulses (level 1)
            5,  # npulse
            5,  # nrest
        ]
        stim_label = 'S2_manual'
    elif stimulus_suffix == 'S1':
        tspan = [0, 20.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_pulsewave
        S1_duty = 0.1
        S1_period = 1.0  #1.0, 0.01, 0.1
        base_params_stim = [
            delta_fn_amp(S1_duty, S1_period),  # amplitude
            S1_duty,    # duty in [0,1]
            S1_period,  # period of stimulus pulses (level 1)
        ]
        stim_label = 'S1_manual'
    else:
        assert stimulus_suffix == 'S0_one'
        tspan = [0, 20.5]  # overide tspan above; for default period of 1.0: 100 pulses
        stim_fn = stimulus_constant
        step_amplitude = 1.0
        base_params_stim = [
            step_amplitude
        ]
        stim_label = 'S_step_manual'
        integration_style_stitching = False  # override the setting near tp of function
    
    stim_instance = ODEStimulus(
        stim_fn,
        base_params_stim,
        label=stim_label
    )
    control_objects = [stim_instance] + additional_control_objects  # second term is non-zero for e.g. ode_hallmark8 (req. > 1 stimuli)
    
    #if scale_tspan:
    #    assert tspan[0] == 0.0
    #    tspan[1] = tspan[1] * S1_period  # e.g., if t0, t1 = 0, 100.5 and period = 2.0: t1 = 201.0
    
    # PREPARE a new instance of ODEModel
    # ============================================================
    model_argprep = ode_model_preset_factory(traj_label, ode_name, control_objects,
                                             tspan[0], tspan[1],
                                             ode_params=base_params_ode, init_cond=x0)
    ode_model = ODEModel(*model_argprep['args'], **model_argprep['kwargs'])
    
    
    traj_kwargs = dict(verbose=True, traj_pulsewave_stitch=integration_style_stitching)
    if stitch_dynamic_max_step:
        print('ode_model.ode_base.max_step_augment', ode_model.ode_base.max_step_augment)
        stitch_kwargs['max_step'] = ode_model.ode_base.max_step_augment
    _, _, runtime = ode_model.traj_and_runtime(update_history=True, **traj_kwargs, **stitch_kwargs)
    
    print('...done | runtime:', runtime, '(s)')
    
    return ode_model

In [None]:
stim_params_rectangle = (1.0, 
                         apply_sensitizer_at_period_k + period_integer_buffer, 
                         1 - 2 * period_integer_buffer)  # height, start, duration (assumes period_S1=1)
auxilliary_control_H8 = ODEStimulus(
        stimulus_rectangle,
        stim_params_rectangle,
        label=r'$v(t)$'
    )

ode_list_to_sim = [
    ('ode_hallmark8', 'ode_hallmark8',
     np.array([
            H8_rate_decay,  # alpha: timescale for x decay -- prop to x(t)
            H8_rate_grow,  # beta: timescale for x growth -- prop to u(t)
            kappa,
        ])
     ),
    ]

ulist = ['S0_one', 'S1']

ode_solutions_LoD = {
    ode_props[0]: {
        k: {'t': None, 'u': None, 'x': None, 'y': None} for k in ulist
    }
    for ode_props in ode_list_to_sim}

for uidx, ustr in enumerate(ulist):

    for idx, ode_props in enumerate(ode_list_to_sim): 
        
        ode_name, ode_title, ode_params = ode_props
        print('\nworking on %s (u=%s)...' % (ode_name, ustr))
        ode_model = get_ode_traj_nb(ode_name, ode_params, ustr, additional_control_objects=[auxilliary_control_H8])
        
        if ode_model.ode_base.max_step_augment is None or np.isinf(ode_model.ode_base.max_step_augment):
            force_dt = (ode_model.history_times[-1] - ode_model.history_times[0]) / 1e4  # interpolate uniformly using N=1e4 points
        else:
            force_dt = ode_model.ode_base.max_step_augment
        ode_state, ode_t = ode_model.interpolate_trajectory(force_dt=force_dt)  # choose multiple of average timestep
        
        ode_u_main = ode_model.ode_base.control_objects[0].fn_prepped(ode_t)
        ode_u_aux = ode_model.ode_base.control_objects[1].fn_prepped(ode_t)
        
        uvals_arr = np.array([ode_model.ode_base.control_objects[0].fn_jit(ode_t),
                              ode_model.ode_base.control_objects[1].fn_jit(ode_t)]).T
        print(uvals_arr.shape)
        
        ode_solutions_LoD[ode_name][ustr]['t'] = ode_t
        ode_solutions_LoD[ode_name][ustr]['u'] = ode_u_main
        ode_solutions_LoD[ode_name][ustr]['v'] = ode_u_aux
        ode_solutions_LoD[ode_name][ustr]['x'] = ode_state
        #ode_solutions_LoD[ode_name][ustr]['y'] = ode_state[:, -1]  # TODO care, this assumes output is just the last
        ode_solutions_LoD[ode_name][ustr]['y'] = ode_model.ode_base.output_fn(ode_state, uvals_arr) #np.array(ode_u_main, ode_u_aux))
        
        # state variable
    
    #ax[idx, 0].plot(ode_t, ode_state[:, -1])

In [None]:
line_input_kwargs = dict(color=color_input, linewidth=1)
line_response_kwargs = dict(color=color_response, linewidth=1)

_, ax = plt.subplots(2 + len(ode_list_to_sim), 2, squeeze=False, figsize=(4, 5), sharex=True)

###ax[0, 0].plot(ode_list_to_sim, ode_u)

for uidx, ustr in enumerate(ulist):
    
    topplot_t = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['t']
    topplot_u = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['u']
    ax[0, uidx].plot(topplot_t, topplot_u, **line_input_kwargs)
    
    topplot_t = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['t']
    topplot_v = ode_solutions_LoD[ode_list_to_sim[0][0]][ustr]['v']
    ax[1, uidx].plot(topplot_t, topplot_v, **line_input_kwargs)
    
    for j, ode_props in enumerate(ode_list_to_sim): 
        
        ode_name, ode_title, ode_params = ode_props
        ode_t = ode_solutions_LoD[ode_name][ustr]['t']
        ode_y = ode_solutions_LoD[ode_name][ustr]['y']
        ax[j+2, uidx].plot(ode_t, ode_y, **line_response_kwargs)
        ax[j+2, uidx].set_title(ode_title)

plt.tight_layout()
plt.show()

In [None]:
t1 = ode_t
u1 = ode_u_main
u2 = ode_u_aux
x1 = ode_solutions_LoD['ode_hallmark8']['S1']['x']
y1 = ode_solutions_LoD['ode_hallmark8']['S1']['y']

t2 = t1
x2 = np.zeros_like(t1)
y2 = np.zeros_like(t1)

figure_stack_H8(t1, u1, u2, x1, y1, xlims=None, hide_grids=False, size_small=True)
figure_stack_H8(t1, u1, u2, x1, y1, xlims=None, hide_grids=False, size_small=False)

# Hallmark 3 - potentiation of habituation

Idea: Hammerstein version of the Wiener-Hammerstein generalization of LTI dynamics. 
- need to include a static nonlinear preprocessing to the input signal that acts like an ``inverter"
- $u'=h(u)$ where $h$ is some biologically implementable inverter 
- within some threshold of input $u$: high $u$ get mapped to low $u'$ and vice-versa

In [None]:
# This set works "V1"
p_threshold = 8
model_unit_tuple_H3 = (h_of_u_H3, g_of_x_u_H3, dict(), dict(gamma=10, p_steep=1, p_threshold=p_threshold))  # threshold at 5

# settings for hallmark 6
H3_rate_decay = 0.05
H3_rate_grow = 1

H3_duty = 0.1
H3_pulse_area = 1.0
H3_period = 1.0

In [None]:
n_pulse = 15
n_rest = 15

L_cycle = n_pulse + n_rest
n_cycles = 3  # 3
tmid_list = [H3_period * (n + L_cycle * k) for k in range(n_cycles) for n in range(1, n_pulse + 1)]

#cutoff_pulse_end = L_cycle

# prep t to sample signals from
dt = 0.001
#tmax = L_cycle * period_S1 * 4
tmax = 100 #L_cycle * n_cycles * period_S1
t1 = np.arange(-0.5, tmax + dt, dt)

_, u1 = signal_rectangle(n_pulse, n_rest, n_cycles, times=t1, 
                         period_S1=H3_period, duty=H3_duty, pulse_area=H3_pulse_area)

#u1 = np.where(t1 > cutoff_pulse_end * H6_period, 0, u1)

fmod = 'rectangle_1'
title = r'%s $u(t)$ with nonlinear filter: $\alpha=%.2f$, $\beta=%.2f$' % (fmod, H3_rate_decay, H3_rate_grow)
fpath = NB_OUTPUT + os.sep + 'mod_uxry_stack_%s' % fmod
_ = MOD_2_plot_uxry_stack(t1, u1, H3_rate_decay, H3_rate_grow, model_unit_tuple_H3, title, fpath, x_thresh=p_threshold, xlims=(-0.5, tmax))

In [None]:
def figure_stack_H3(t1, u1, x1, y1, x_thresh=None, xlims=None):
    """
    1) drawing a vertical line at t= 15 
        (and explaining that both first combs are aligned to end there) 
    2) maybe even shift t = 0 to that point to support this further
    
    3) and a set of two vertical lines at t= 30 and 35 with a shading in between them 
        (to point out that we have to wait for 5T longer in the bottom case)
    """
    
    #DASH_STYLE = [2, 1]
    
    DASH_STYLE_U = [1.5, 0.75]
    DASH_STYLE_XY = [0.7, 0.7]
    
    plt.close('all')
    ###nrows = 2
    ###fig, axarr = plt.subplots(nrows, 1, sharex=True, squeeze=False, figsize=figsize)
    # Figure init notes
    # - use layout constrained to avoid axis labels cut off (else adjust manually...)
    SIZE_SMALL = True
    if SIZE_SMALL:
        figsize = (WIDTH_ONECOL/2, 2.8)
        height_ratios = [0.3, 0.3, 0.66, 0.35]
        hspace = 0.1
    else:
        figsize = (WIDTH_ONECOL, 3.6)
        height_ratios = [0.3, 0.3, 0.66, 0.35]
        hspace = 0.09
    
    fig = plt.figure(figsize=figsize, layout="constrained") 
    #fig.suptitle('Figure for H6 (vB)') 
    
    nrows = 4
    gs = GridSpec(nrows, 1, figure=fig, height_ratios=height_ratios, hspace=hspace)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0)
    ax3 = fig.add_subplot(gs[3], sharex=ax0)
    axarr = [ax0, ax1, ax2, ax3]
    
    #axarr[0,0].set_title(r'$foo 1$')
    axarr[0].plot(t1, u1, 
                  label=r'$u$', **line_kwargs_input)
    ymax = np.max(u1)
    axarr[0].set_ylim(-0.2 * ymax, ymax * 1.2)
    
    axarr[1].plot(t1, x1, 
                  label=r'$x$', **line_kwargs_memory)
    if x_thresh is not None:
        axarr[1, 0].axhline(x_thresh, linestyle='--', c='k', linewidth=linewidth, zorder=2)
     
    assert np.max(u1) == np.max(u2)
    line_kwargs_horiz_0 = dict(color='k', linewidth=0.5, linestyle='--', zorder=10, alpha=0.5)
    axarr[2].axhline(np.max(u1), **line_kwargs_horiz_0)
    
    axarr[3].plot(t1, y1, 
                  label=r'$y$', **line_kwargs_response)
    ymax = np.max(y1)
    axarr[3].set_ylim(-0.1*ymax, ymax*1.1)
    
    axarr[3].set_yticks([0, 5, 10])
    axarr[3].set_yticklabels([0, 5, 10])
    
    axarr[0].tick_params(labelbottom=False)
    axarr[1].tick_params(labelbottom=False)
    axarr[2].tick_params(labelbottom=False)
    #axarr[3].tick_params(labelbottom=False) 
    
    # add vertical lines
    buffer = 0.5

    for idx in range(nrows):
        if idx > 1:
            #axarr[idx].axvline(vertical_0, **line_kwargs_vert_0)
            continue  # no vert line for now...
        '''
        if idx in [1,2,3]:
            axarr[idx].axvline(vertical_1, **line_kwargs_vert_1)
            rect_width = 15 - 3*buffer
        else:
            axarr[idx].axvline(vertical_2, **line_kwargs_vert_2)
            rect_width = 20 - 3*buffer
        '''
        # plot the rectangle
        #rect_botleft = (SHIFT_T_SECOND + SHIFT_T_AXIS + 1.5*buffer, 0)
        #rect_botleft = (SHIFT_T_AXIS + 1.5*buffer, 0)
        rect_botleft = (0 + 1.5 * buffer, 0)
        if idx == 0:
            umax = np.max(u1)
            rect_width = 13 - 3 * buffer
            axarr[idx].add_patch(
                Rectangle(rect_botleft, rect_width, umax, color='#EEEFEF', alpha=0.8, zorder=0))
        elif idx == 1:
            umax = np.max(u2)
            rect_width = 13 - 3 * buffer
            axarr[idx].add_patch(
                Rectangle(rect_botleft, rect_width, umax, color='#EEEFEF', alpha=0.8, zorder=0))

    '''
    ax0.set_ylabel('input', **GLOBAL_FPROP)
    ax1.set_ylabel('memory', **GLOBAL_FPROP)
    ax2.set_ylabel('response', **GLOBAL_FPROP)'''
    
    '''
    ax0.set_ylabel(r'$u(t)$', fontsize=GLOBAL_FS)
    ax1.set_ylabel(r'$u(t)$', fontsize=GLOBAL_FS)
    ax2.set_ylabel(r'$x(t)$', fontsize=GLOBAL_FS)
    ax3.set_ylabel(r'$y(t)$', fontsize=GLOBAL_FS)'''
    
    for idx, ax in enumerate(axarr):
        if idx > 1:
            ax.grid(alpha=0.25)
        ax.legend(fontsize=GLOBAL_FS - 1, loc='upper right')
        ax.tick_params(axis='x', labelsize=GLOBAL_TICKLABELS)
        ax.tick_params(axis='y', labelsize=GLOBAL_TICKLABELS)
    axarr[-1].set_xlabel(r'$t/T$', **GLOBAL_FPROP)

    #plt.tight_layout()  # can't use this
    
    fpath = NB_OUTPUT + os.sep + 'figure_stack_H6'
    print('Saving figure to', fpath)
    plt.savefig(fpath + '.pdf', dpi=450)
    plt.savefig(fpath + '.svg', dpi=450)
    plt.show()

    return axarr

In [None]:
# construct times_to_slice arg for plot_nchain_response_filter_Kslices(...)

cycle_tval_start_end = [
    (H3_period * (L_cycle * k),
     H3_period * (L_cycle * (1+k))) 
    for k in range(n_cycles)]
print(cycle_tval_start_end)

cycle_tidx_start_end = [
    (np.searchsorted(t1, a), np.searchsorted(t1, b)) for (a,b) in cycle_tval_start_end
]
print(cycle_tidx_start_end)

In [None]:
def plot_nchain_response_filter_Kslices(t1, u1, label1, alphas, betas, model_unit_tuple, times_to_slice, suffix=''):
    """
    model_tuple has form (h_of_u, g_of_x_u, params_h, params_g) 
    """
    h_of_u, g_of_x_u, params_h, params_g = model_unit_tuple
    
    arr_unit_output, arr_unit_x = series_form_of_unit(t1, u1, len(alphas), alphas, betas, *model_unit_tuple)
        
    list_of_y_to_plot = [0] * len(times_to_slice)
    list_of_peaks = [0] * len(times_to_slice)
    
    for idx, times_slice_pair in enumerate(times_to_slice):
        t_idx_start, t_idx_end = times_slice_pair
        list_of_y_to_plot[idx] = arr_unit_output[t_idx_start : t_idx_end, -1]
        peaks_slice, _ = find_peaks(list_of_y_to_plot[idx])
        list_of_peaks[idx] = peaks_slice

    plt.close('all')
    #plt.figure(figsize=(6.5,4.5))
    #plt.figure(figsize=(4,3))
    fig = plt.figure(figsize=(2.8, 2))
    
    ls_kwargs=dict(
        markersize=3,
        marker='o',
        linestyle='-',
        linewidth=1,
    )
    
    print()
    
    print(list_of_y_to_plot)
    vals_for_cmap = [(k)/ (1 + len(list_of_peaks)) for k in range(len(list_of_peaks))][::-1]
    
    
    for k in range(len(list_of_peaks)):
        peaks = list_of_peaks[k]
        y_to_plot = list_of_y_to_plot[k]
        
        #val_for_cmap = y_to_plot[peaks][-1]
        val_for_cmap = vals_for_cmap[k]
        
        plt.plot(np.arange(len(peaks)), y_to_plot[peaks], c=cmap_lingseg_purples(val_for_cmap), label=r'$K=%d$' % k, **ls_kwargs)
    
    #plt.suptitle(r'Expect higher $f$ (lower $T$) implies faster TTH_discrete')
    plt.title(r'nchain=%d, alphas=%s, betas=%s' % (len(alphas), alphas, betas), fontsize=8)
    plt.xlabel(r'pulse index $k$ (in $K^{th}$ cycle)')
    plt.ylabel(r'peak values $y[k]$')
    plt.legend()
    plt.grid(alpha=0.4)
    #plt.tight_layout()
    
    fpath = NB_OUTPUT + os.sep + 'fig_H3_Kslices%s' % suffix
    plt.savefig(fpath + '.png')
    plt.savefig(fpath + '.svg')
    plt.show()
    
    # Second plot -----------------------------------------------------------
    plt.close('all')
    nrows = 3
    fig = plt.figure(figsize=(5, 2.4))
    gs = GridSpec(nrows, 1, figure=fig, height_ratios=[1,2,2], hspace=0.1)
    ax0 = fig.add_subplot(gs[0])
    ax1 = fig.add_subplot(gs[1], sharex=ax0)
    ax2 = fig.add_subplot(gs[2], sharex=ax0)
    axarr = [ax0, ax1, ax2]
    
    plt.suptitle('Potentiation of Habituation')
    axarr[0].plot(t1, u1, label='$u$', c=color_input, linewidth=1)
    axarr[1].plot(t1, arr_unit_x[:, -1], label='$x$', c=color_memory, linewidth=1)
    axarr[2].plot(t1, arr_unit_output[:, -1], label='$y$', c=color_response, linewidth=1)
    
    
    umax = u1.max()
    axarr[0].set_ylim(-0.1 * umax, 1.2 * umax)
    
    xmax = arr_unit_x[:, -1].max()
    axarr[1].set_ylim(-0.01 * xmax, 1.1 * xmax)
    
    ymax = arr_unit_output[:, -1].max()
    axarr[2].set_ylim(-0.01 * ymax, 1.2 * ymax)
    
    axarr[2].set_xlabel(r'$t/T$')
    
    #ax[0].grid(alpha=0.4)
    #ax[0].set_xlim(0, 4*max(period_S1_A, period_S1_B))
    axarr[1].axhline(params_g['p_threshold'], linestyle='--', c='k', linewidth=linewidth, zorder=2)

    for idx in range(nrows):
        axarr[idx].legend()
        
    # plot rectangles
    buffer = 0.9
    for idx, times_slice_pair in enumerate(times_to_slice):
        t_idx_start, t_idx_end = times_slice_pair
        rect_botleft = (t1[t_idx_start] + 1.5*buffer + n_pulse, 0) 
        rect_width = t1[t_idx_end] - rect_botleft[0] - buffer
        
        print(idx, rect_botleft, rect_width)
        umax = np.max(u1)
        axarr[0].add_patch(
            Rectangle(rect_botleft, rect_width, umax, color='#EEEFEF', alpha=0.8, zorder=0)
        )
        
        axarr[1].add_patch(
            Rectangle(rect_botleft, rect_width, 5, color='#EEEFEF', alpha=0.8, zorder=0)
        )
        
        axarr[2].add_patch(
            Rectangle(rect_botleft, rect_width, 5, color='#EEEFEF', alpha=0.8, zorder=0)
        )
    
    fpath = NB_OUTPUT + os.sep + 'fig_H3_timeseries%s' % suffix
    plt.savefig(fpath + '.png')
    plt.savefig(fpath + '.svg')
    plt.show()    
    
    # Third plot --------------------------------------------
    plt.close('all')
    
    xx = np.linspace(-5, 30, 1000)
    amp = np.max(u1)
    
    plt.figure(figsize=(3, 2))
    p_steep = 0.2
    p_threshold = 8
    plt.plot(xx, 
             g_of_x_u(xx, amp, p_threshold=p_threshold, p_steep=p_steep)/amp, 
             '-g', label=r'$\theta = 8$')
    plt.plot(xx, 
             g_of_x_u(xx, amp, p_threshold=0, p_steep=p_steep)/amp, 
             '--g', label=r'$\theta = 4$')
    
    plt.xlabel(r'$x$')
    plt.ylabel(r'$g(x,A)$')
    plt.xlim(-1, xx.max())
    
    plt.axhline(0, linestyle='-', color='k', linewidth=0.5)
    plt.axvline(0, linestyle='-', color='k', linewidth=0.5)
    plt.axvline(p_threshold, linestyle='--', color='k', linewidth=0.5)
    
    fpath = NB_OUTPUT + os.sep + 'fig_H3_nonlin_tanh%s' % suffix
    plt.savefig(fpath + '.png')
    plt.savefig(fpath + '.svg')
    plt.show()
    
    return arr_unit_x, arr_unit_output


alphas = [H3_rate_decay]
betas = [H3_rate_grow] #[0.2]
assert len(alphas) == len(betas)

suffix = '_K1_checkH3'

arr_unit_x, arr_unit_output = plot_nchain_response_filter_Kslices(t1, u1, r'LABEL?', 
                                                                  alphas, betas, model_unit_tuple_H3, 
                                                                  cycle_tidx_start_end, suffix=suffix)


## [Separate] Constructing Staddon-inspired outputs like $\phi = 1-\sum_i x_i$
### Address Hallmark 4 part (iii) (recovery faster for higher freq.)

In [None]:
n_pulse = 60
n_rest = 40
L_cycle = n_pulse + n_rest
n_cycles = 1  # 3

In [None]:
period_S1_A = 1.0 # period between consecutive pulses

n_pulse_A = n_pulse
n_rest_A = n_rest
L_cycle_A = n_pulse_A + n_rest_A

tmid_list = [period_S1_A * (n + L_cycle_A * k) for k in range(n_cycles) for n in range(1, n_pulse_A+1)]

# prep t to sample signals from
dt = 0.001

duty_A = 0.1
amp_A = 1/(period_S1_A * duty_A)

# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_A = 1.0  
else:
    pulse_area_A = duty_A * period_S1_A * amp_A

#tmax = L_cycle * period_S1 * 4
tmax_A = L_cycle_A * n_cycles * period_S1_A
t_A = np.arange(-0.01, tmax_A + dt, dt)

# build u: wave of rectangles
_, u_rect_A = signal_rectangle(n_pulse_A, n_rest_A, n_cycles, period_S1=period_S1_A, times=t_A, duty=duty_A, pulse_area=pulse_area_A)    
#_, u_rect_wide_A = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_A, times=t_A, duty=0.25, pulse_area=pulse_area_A)

In [None]:
period_S1_B = 2.0 # period between consecutive pulses

n_pulse_B = int(n_pulse_A / period_S1_B)
n_rest_B = n_rest
L_cycle_B = n_pulse_B + n_rest_B

tmid_list = [period_S1_B * (n + L_cycle_B * k) for k in range(n_cycles) for n in range(1, n_pulse_B+1)]

# prep t to sample signals from
dt = 0.001
# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_B = 1.0
    duty_B = duty_A * period_S1_A / period_S1_B
else:
    duty_B = duty_A
    pulse_area_B = duty_B * period_S1_B * amp_A  # make this A d T for fixed Ad, T is w.e it is above
    
#tmax = L_cycle * period_S1 * 4
tmax_B = L_cycle_B * n_cycles * period_S1_B
t_B = np.arange(-0.01, tmax_B + dt, dt)

# build u: wave of rectangles
_, u_rect_B = signal_rectangle(n_pulse_B, n_rest_B, n_cycles, period_S1=period_S1_B, times=t_B, duty=duty_B, pulse_area=pulse_area_B)    
#_, u_rect_wide_B = signal_rectangle(n_pulse, n_rest, n_cycles, period_S1=period_S1_B, times=t_B, duty=0.25, pulse_area=pulse_area_B)

In [None]:
period_S1_C = 5.0 # period between consecutive pulses

n_pulse_C = int(n_pulse / period_S1_C)
n_rest_C = n_rest
L_cycle_C = n_pulse_C + n_rest_C

tmid_list = [period_S1_C * (n + L_cycle_C * k) for k in range(n_cycles) for n in range(1, n_pulse_C+1)]

# prep t to sample signals from
dt = 0.001
# choose pulse_area
if FLAG_PULSE_AREA_FIXED:
    pulse_area_C = 1.0
    duty_C = duty_A * period_S1_A / period_S1_C
else:
    duty_C = duty_A
    pulse_area_C = duty_C * period_S1_C * amp_A  # make this A d T for fixed Ad, T is w.e it is above
    
#tmax = L_cycle * period_S1 * 4
tmax_C = L_cycle_C * n_cycles * period_S1_C
t_C = np.arange(-0.01, tmax_C + dt, dt)

# build u: wave of rectangles
_, u_rect_C = signal_rectangle(n_pulse_C, n_rest_C, n_cycles, period_S1=period_S1_C, times=t_C, duty=duty_C, pulse_area=pulse_area_C)    

In [None]:
FLAG_STADDON_VIA_SIGMA = True

if FLAG_STADDON_VIA_SIGMA:

    model_unit_tuple_STADDON = (h_of_u_H4, g_of_x_u_H4, dict(), dict(N=2))
    # - input: h(u) = u
    # - output: g(x,u) = u * sigma(x, N)
    # For g = hill AKA sigma, use
    alphas = [0.2, 0.02]
    betas = [1, 0.5]

else:
    
    model_unit_tuple_STADDON = (h_of_u_H4, g_of_x_u_H6, dict(), dict(p_threshold=0))
    # - input: h(u) = u
    # - output: g(x,u) = ReLU(u-x)
    # For g = ReLU, use
    alphas = [0.25, 0.005]
    betas = [4, 1]

assert len(alphas) == len(betas)

suffix = '_K1_checkH4'
n1_outs_freq_sens = plot_nchain_response_filter_tuptupab(t_A, u_rect_A, r'$T=%.2f$' % period_S1_A, 
                                                         t_B, u_rect_B, r'$T=%.2f$' % period_S1_B,
                                                         t_C, u_rect_C, r'$T=%.2f$' % period_S1_C, 
                                                         [alphas[0]], [betas[0]], model_unit_tuple_STADDON, suffix=suffix)



suffix = '_K%d_checkH4' % len(alphas)
nMulti_outs_freq_sens = plot_nchain_response_filter_tuptupab(t_A, u_rect_A, r'$T=%.2f$' % period_S1_A, 
                                                         t_B, u_rect_B, r'$T=%.2f$' % period_S1_B,
                                                         t_C, u_rect_C, r'$T=%.2f$' % period_S1_C, 
                                                         alphas, betas, model_unit_tuple_STADDON, suffix=suffix)

#arr_output_1, arr_x_1, peaks_1, y_peaks_1, arr_output_2, arr_x_2, peaks_2, y_peaks_2, arr_output_3, arr_x_3, peaks_3, y_peaks_3 = outs_freq_sens
(n1_arr_output_1, n1_arr_x_1, n1_peaks_1, n1_y_peaks_1, 
 n1_arr_output_2, n1_arr_x_2, n1_peaks_2, n1_y_peaks_2, 
 n1_arr_output_3, n1_arr_x_3, n1_peaks_3, n1_y_peaks_3) = n1_outs_freq_sens

(nMulti_arr_output_1, nMulti_arr_x_1, nMulti_peaks_1, nMulti_y_peaks_1, 
 nMulti_arr_output_2, nMulti_arr_x_2, nMulti_peaks_2, nMulti_y_peaks_2, 
 nMulti_arr_output_3, nMulti_arr_x_3, nMulti_peaks_3, nMulti_y_peaks_3) = nMulti_outs_freq_sens

In [None]:
color_1 = '#663290'           # regular y(t) purp OR 'darkblue'
color_2 = '#8965AC'           # medium purp       OR 'royalblue'
color_3 = '#B19CC9'           # light purp        OR 'cornflowerblue'

colors_purp = [color_1, color_2, color_3]

#plt.plot(n1_arr_unit_x_1)
fig, axarr = plt.subplots(2, 1, sharex=True, figsize=(8, 5), squeeze=False)

ms = 0.5

# n=1 chain
n1_staddon_out_Ta = amp_A - n1_arr_x_1[:, 0]
n1_staddon_out_Tb = amp_A - n1_arr_x_2[:, 0]
n1_staddon_out_Tc = amp_A - n1_arr_x_3[:, 0]
axarr[0,0].plot(t_A, n1_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms,
                label=r'$T=%d$' % period_S1_A)
axarr[0,0].plot(t_B, n1_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % period_S1_B)
axarr[0,0].plot(t_C, n1_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % period_S1_C)

axarr[0,0].plot(t_A[n1_peaks_1], n1_staddon_out_Ta[n1_peaks_1], 'o', 
                linestyle='', c=colors_purp[0], markersize=4, marker='s')
axarr[0,0].plot(t_B[n1_peaks_2], n1_staddon_out_Tb[n1_peaks_2], 'o', 
                linestyle='', c=colors_purp[1], markersize=4, marker='^')
axarr[0,0].plot(t_C[n1_peaks_3], n1_staddon_out_Tc[n1_peaks_3], 'o', 
                linestyle='', c=colors_purp[2], markersize=4, marker='o')

axarr[0,0].set_title(r'$K = %d$' % n1_arr_x_1.shape[1])
axarr[0,0].set_ylabel(r'$A - x$')
axarr[0,0].legend()

# n=3 chain
# - now need to sum x_i 
nMulti_staddon_out_Ta = amp_A - np.sum(nMulti_arr_x_1, axis=1)
nMulti_staddon_out_Tb = amp_A - np.sum(nMulti_arr_x_2, axis=1)
nMulti_staddon_out_Tc = amp_A - np.sum(nMulti_arr_x_3, axis=1)
axarr[1,0].plot(t_A, nMulti_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%d$' % (period_S1_A))
axarr[1,0].plot(t_B, nMulti_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % (period_S1_B))
axarr[1,0].plot(t_C, nMulti_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % (period_S1_C))

axarr[1,0].plot(t_A[nMulti_peaks_1], nMulti_staddon_out_Ta[nMulti_peaks_1], 'o', 
                linestyle='', c=colors_purp[0], markersize=4, marker='s')
axarr[1,0].plot(t_B[nMulti_peaks_2], nMulti_staddon_out_Tb[nMulti_peaks_2], 'o', 
                linestyle='',c=colors_purp[1], markersize=4, marker='^')
axarr[1,0].plot(t_C[nMulti_peaks_3], nMulti_staddon_out_Tc[nMulti_peaks_3], 'o', 
                linestyle='',c=colors_purp[2], markersize=4, marker='o')

axarr[1,0].set_title(r'$K = %d$' % nMulti_arr_x_1.shape[1])
axarr[1,0].set_ylabel(r'$A - \sum x_i$')
axarr[1,0].legend()

for i in range(2):
    #axarr[0,0].axhline(0, linestyle='-', color='grey', zorder=1)
    #axarr[0,0].axhline(amp_A, linestyle='--', color='grey', zorder=1)
    axarr[i,0].axhline(0, linestyle='-', color='grey', linewidth=1, zorder=1)
    axarr[i,0].axhline(amp_A, linestyle='--', color='grey', linewidth=1, zorder=1)
    assert period_S1_A == 1.0
    axarr[i,0].axvline( (n_pulse_A+0.25)*period_S1_A, linestyle='--', color='k', zorder=10)

#axarr[-1,0].set_xlabel(r'$t$ (absolute)')
axarr[-1,0].set_xlabel(r'$t$')
plt.xlim(-1,100)

fpath = NB_OUTPUT + os.sep + 'staddon-like_H4iii_K1_vs_Kmulti_A-SumX'
plt.savefig(fpath+'.png')
plt.savefig(fpath+'.svg')
plt.show()

# =====================================================================================================

#plt.plot(n1_arr_unit_x_1)
fig, axarr = plt.subplots(2, 1, sharex=True, figsize=(10,5), squeeze=False)

# n=1 chain
axarr[0,0].plot(t_A / period_S1_A, n1_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_A)
axarr[0,0].plot(t_B / period_S1_B, n1_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_B)
axarr[0,0].plot(t_C / period_S1_C, n1_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_C)
axarr[0,0].set_title(r'$K = %d$' % n1_arr_x_1.shape[1])
axarr[0,0].set_ylabel(r'$A - x$')
axarr[0,0].axhline(0, linestyle='--', color='grey')
axarr[0,0].legend()

# n=3 chain
# - now need to sum x_i 
axarr[1,0].plot(t_A / period_S1_A, nMulti_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%d$' % (period_S1_A))
axarr[1,0].plot(t_B / period_S1_B, nMulti_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % (period_S1_B))
axarr[1,0].plot(t_C / period_S1_C, nMulti_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % (period_S1_C))

axarr[1,0].set_title(r'$K = %d$' % nMulti_arr_x_1.shape[1])
axarr[1,0].set_ylabel(r'$A - \sum x_i$')
axarr[1,0].axhline(0, linestyle='--', color='grey')
axarr[1,0].legend()

axarr[-1,0].set_xlabel(r'$t/T$ (# pulses)')

fpath = NB_OUTPUT + os.sep + 'staddon-like_H4iii_K1_vs_Kmulti_A-SumX_tDiscrete'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.show()

### Reflex V2 $y_{pre} = \prod \sigma(x_i)$ - Test again with slightly modified "Reflex function"

In [None]:
#plt.plot(n1_arr_unit_x_1)
fig, axarr = plt.subplots(2, 1, sharex=True, figsize=(8, 5), squeeze=False)

ms = 0.5

# n=1 chain
n1_staddon_out_Ta = amp_A * np.prod(sigma_of_x_main(n1_arr_x_1, N=2), axis=1)
n1_staddon_out_Tb = amp_A * np.prod(sigma_of_x_main(n1_arr_x_2, N=2), axis=1) #sigma_of_x_main(n1_arr_x_2[:, -1], N=2)
n1_staddon_out_Tc = amp_A * np.prod(sigma_of_x_main(n1_arr_x_3, N=2), axis=1) #sigma_of_x_main(n1_arr_x_3[:, -1], N=2)
axarr[0,0].plot(t_A, n1_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms,
                label=r'$T=%d$' % period_S1_A)
axarr[0,0].plot(t_B, n1_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % period_S1_B)
axarr[0,0].plot(t_C, n1_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % period_S1_C)

axarr[0,0].plot(t_A[n1_peaks_1], n1_staddon_out_Ta[n1_peaks_1], 'o', 
                linestyle='', c=colors_purp[0], markersize=4, marker='s')
axarr[0,0].plot(t_B[n1_peaks_2], n1_staddon_out_Tb[n1_peaks_2], 'o', 
                linestyle='', c=colors_purp[1], markersize=4, marker='^')
axarr[0,0].plot(t_C[n1_peaks_3], n1_staddon_out_Tc[n1_peaks_3], 'o', 
                linestyle='', c=colors_purp[2], markersize=4, marker='o')

axarr[0,0].set_title(r'$K = %d$' % n1_arr_x_1.shape[1])
axarr[0,0].set_ylabel(r'$A \sigma(x)$')
axarr[0,0].legend()

# n=3 chain
nMulti_staddon_out_Ta = amp_A * np.prod(sigma_of_x_main(nMulti_arr_x_1, N=2), axis=1) # sigma_of_x_main(nMulti_arr_x_1[:, -1], N=2)
nMulti_staddon_out_Tb = amp_A * np.prod(sigma_of_x_main(nMulti_arr_x_2, N=2), axis=1) # sigma_of_x_main(nMulti_arr_x_2[:, -1], N=2)
nMulti_staddon_out_Tc = amp_A * np.prod(sigma_of_x_main(nMulti_arr_x_3, N=2), axis=1) # sigma_of_x_main(nMulti_arr_x_3[:, -1], N=2)
axarr[1,0].plot(t_A, nMulti_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%d$' % (period_S1_A))
axarr[1,0].plot(t_B, nMulti_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % (period_S1_B))
axarr[1,0].plot(t_C, nMulti_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % (period_S1_C))

axarr[1,0].plot(t_A[nMulti_peaks_1], nMulti_staddon_out_Ta[nMulti_peaks_1], 'o', 
                linestyle='', c=colors_purp[0], markersize=4, marker='s')
axarr[1,0].plot(t_B[nMulti_peaks_2], nMulti_staddon_out_Tb[nMulti_peaks_2], 'o', 
                linestyle='',c=colors_purp[1], markersize=4, marker='^')
axarr[1,0].plot(t_C[nMulti_peaks_3], nMulti_staddon_out_Tc[nMulti_peaks_3], 'o', 
                linestyle='',c=colors_purp[2], markersize=4, marker='o')

axarr[1,0].set_title(r'$K = %d$' % nMulti_arr_x_1.shape[1])
axarr[1,0].set_ylabel(r'$A \prod \sigma(x_i)$')
axarr[1,0].legend()

for i in range(2):
    #axarr[0,0].axhline(0, linestyle='-', color='grey', zorder=1)
    #axarr[0,0].axhline(amp_A, linestyle='--', color='grey', zorder=1)
    axarr[i,0].axhline(0, linestyle='-', color='grey', linewidth=1, zorder=1)
    axarr[i,0].axhline(amp_A, linestyle='--', color='grey', linewidth=1, zorder=1)
    assert period_S1_A == 1.0
    axarr[i,0].axvline( (n_pulse_A+0.25)*period_S1_A, linestyle='--', color='k', zorder=10)

#axarr[-1,0].set_xlabel(r'$t$ (absolute)')
axarr[-1,0].set_xlabel(r'$t$')
plt.xlim(-1,100)

fpath = NB_OUTPUT + os.sep + 'staddon-like_H4iii_K1_vs_Kmulti_ProdSigma'
plt.savefig(fpath+'.png')
plt.savefig(fpath+'.svg')
plt.show()

# =====================================================================================================

#plt.plot(n1_arr_unit_x_1)
fig, axarr = plt.subplots(2, 1, sharex=True, figsize=(10,5), squeeze=False)

# n=1 chain
axarr[0,0].plot(t_A / period_S1_A, n1_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_A)
axarr[0,0].plot(t_B / period_S1_B, n1_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_B)
axarr[0,0].plot(t_C / period_S1_C, n1_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%.2f$' % period_S1_C)
axarr[0,0].set_title(r'$K = %d$' % n1_arr_x_1.shape[1])
axarr[0,0].set_ylabel(r'$A \sigma(x)$')
axarr[0,0].axhline(0, linestyle='--', color='grey')
axarr[0,0].legend()

# n=3 chain
# - now need to sum x_i 
axarr[1,0].plot(t_A / period_S1_A, nMulti_staddon_out_Ta, '-', c=colors_purp[0], markersize=ms, 
                label=r'$T=%d$' % (period_S1_A))
axarr[1,0].plot(t_B / period_S1_B, nMulti_staddon_out_Tb, '-', c=colors_purp[1], markersize=ms, 
                label=r'$T=%d$' % (period_S1_B))
axarr[1,0].plot(t_C / period_S1_C, nMulti_staddon_out_Tc, '-', c=colors_purp[2], markersize=ms, 
                label=r'$T=%d$' % (period_S1_C))

axarr[1,0].set_title(r'$K = %d$' % nMulti_arr_x_1.shape[1])
axarr[1,0].set_ylabel(r'$A \prod \sigma(x_i)$')
axarr[1,0].axhline(0, linestyle='--', color='grey')
axarr[1,0].legend()

axarr[-1,0].set_xlabel(r'$t/T$ (# pulses)')

fpath = NB_OUTPUT + os.sep + 'staddon-like_H4iii_K1_vs_Kmulti_ProdSigma_tDiscrete'
plt.savefig(fpath + '.png')
plt.savefig(fpath + '.svg')
plt.show()