In [None]:
import copy
import os
import numpy as np
import sys
import time
import types
from scipy import signal
from scipy.interpolate import CubicSpline
from scipy.stats import norm  # for u(t) as gaussians
from scipy.integrate import solve_ivp

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

In [None]:
#%matplotlib widget
%matplotlib ipympl
#%matplotlib inline

In [None]:
plt.rcParams["font.family"] = "serif"
plt.rcParams["mathtext.fontset"] = "dejavuserif"

# Notebook setup (path trick) and local import

In [None]:
SRC_ROOT = os.path.dirname(os.path.abspath(''))
print('appending to path SRC_ROOT...', SRC_ROOT)
sys.path.append(SRC_ROOT)

PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('')))
print('appending to path PACKAGE_ROOT...', PACKAGE_ROOT)
sys.path.append(PACKAGE_ROOT)

NB_OUTPUT = SRC_ROOT + os.sep + 'output'

if not os.path.exists(NB_OUTPUT):
    os.makedirs(NB_OUTPUT)

In [None]:
from src.defined_ode_fn import *

# Recreate Fig. 1 of discrete model of Staddon + Higa (1996)

In [None]:
Tfast = 2
Tslow = 8

u_Tfast = [1 if ((a % Tfast == 0) and a < 100) else 0 for a in range(200)]
u_Tslow = [1 if ((a % Tslow == 0) and a < 100) else 0 for a in range(200)]

a1 = 0.8
a2 = 0.95
a3 = 0.99

b1, b2, b3 = 0.2, 0.2, 0.2

def simple_relu(z):
    return np.where(z > 0, z, 0)

def sim_staddon1996_response(u):
    """
    Staddon, Higa 1996 Fig. 1:  Feedforward, 3 units
        a_1=0.8, a_2=0.95, a_3=0.99; b_k=0.2 for all
        ISI=2 and ISI=8 mean in our language T=3 and T=9 ("units" of time)
    
    Here, the input u goes into the first unit only
    """
    nn = len(u)
    
    arr_x = np.zeros((nn, 3))
    arr_y_strength = np.zeros((nn, 3))
    arr_y_out = np.zeros((nn, 3))
    final_output = np.zeros(nn)
    
    # init cond 
    arr_x[0, :] = [0, 0, 0]
    arr_y_strength[0, :] = [u[0], u[0], u[0]]
    arr_y_out[0, :] = simple_relu(arr_y_strength[0, :])
    
    for i in range(1, nn):
        
        # BLOCK 1
        # =================================================
        arr_x[i, 0] = a1 * arr_x[i-1, 0] + b1 * u[i-1]
        
        # block 1 output ----> block 2 input
        arr_y_strength[i, 0] = u[i-1] - arr_x[i, 0]  # TODO note that this u[i-1], previous timestep, to recreate Staddon-like-plot
        arr_y_out[i, 0] = simple_relu(arr_y_strength[i, 0])
        
        # BLOCK 2
        # =================================================
        arr_x[i, 1] = a2 * arr_x[i-1, 1] + b2 * arr_y_out[i, 0]
        
        # block 2 output ----> block 3 input
        arr_y_strength[i, 1] = arr_y_out[i, 0] - arr_x[i, 1]
        arr_y_out[i, 1] = simple_relu(arr_y_strength[i, 1])

        # BLOCK 3
        # =================================================
        arr_x[i, 2] = a3 * arr_x[i-1, 2] + b3 * arr_y_out[i, 1]        
        
        # block 3 output ----> system output
        arr_y_strength[i, 2] = arr_y_out[i, 1] - arr_x[i, 2]
        arr_y_out[i, 2] = simple_relu(arr_y_strength[i, 2])
        
    final_output[:] = arr_y_out[:, 2]
    
    return final_output, arr_x, arr_y_strength, arr_y_out


def sim_staddon1996_guessing_response(u):
    """
    Staddon, Higa 1996 Fig. 1:  Feedforward, 3 units
        a_1=0.8, a_2=0.95, a_3=0.99; b_k=0.2 for all
        ISI=2 and ISI=8 mean in our language T=3 and T=9 ("units" of time)
        
    Here, the input u goes into all the units...
    """
    nn = len(u)
    
    arr_x = np.zeros((nn, 3))
    arr_y_strength = np.zeros((nn, 3))
    arr_y_out = np.zeros((nn, 3))
    final_output = np.zeros(nn)
    
    # init cond 
    arr_x[0, :] = [0, 0, 0]
    arr_y_strength[0, :] = [u[0], u[0], u[0]]
    arr_y_out[0, :] = simple_relu(arr_y_strength[0, :])
    
    for i in range(1, nn):
        
        # BLOCK 1
        # =================================================
        arr_x[i, 0] = a1 * arr_x[i-1, 0] + b1 * u[i-1]
        
        # block 1 output ----> block 2 input
        arr_y_strength[i, 0] = u[i-1] - arr_x[i, 0]
        arr_y_out[i, 0] = simple_relu(arr_y_strength[i, 0])
        
        # BLOCK 2
        # =================================================
        arr_x[i, 1] = a2 * arr_x[i-1, 1] + b2 * arr_y_out[i, 0]
        
        # block 2 output ----> block 3 input
        arr_y_strength[i, 1] = arr_y_out[i, 0] - arr_x[i, 1]
        arr_y_out[i, 1] = simple_relu(arr_y_strength[i, 1])

        # BLOCK 3
        # =================================================
        arr_x[i, 2] = a3 * arr_x[i-1, 2] + b3 * arr_y_out[i, 1]        
        
        # block 3 output ----> system output
        arr_y_strength[i, 2] = arr_y_out[i, 1] - arr_x[i, 2]
        arr_y_out[i, 2] = simple_relu(arr_y_strength[i, 2])
        
    final_output[:] = arr_y_out[:, 2]
    
    return final_output, arr_x, arr_y_strength, arr_y_out


def plot_staddon_chain(u, x, y, yrelu):
    nn = len(u)
    t_discrete = np.arange(nn)
    
    n_chain = x.shape[1]
    
    fig, axarr = plt.subplots(5, n_chain+1, sharey=True, figsize=(10,5))
    axarr[0,0].plot(t_discrete, u, c='blue')
    for idx in range(n_chain):
        axarr[0, idx+1].set_title('Unit %d' % (idx+1))
        axarr[0, idx+1].plot(t_discrete, x[:, idx])
        axarr[0, idx+1].set_ylabel(r'state $x$')
        axarr[1, idx+1].plot(t_discrete, y[:, idx])
        axarr[1, idx+1].set_ylabel(r'$u - x$')
        axarr[2, idx+1].plot(t_discrete, yrelu[:, idx])
        axarr[2, idx+1].set_ylabel(r'ReLu$(u - x)$')
        axarr[3, idx+1].plot(t_discrete, 1 - x[:, idx])
        axarr[3, idx+1].set_ylabel(r'$1 - x_k$')
        axarr[4, idx+1].plot(t_discrete, 1 - np.sum(x[:, 0:idx+1], axis=1))
        axarr[4, idx+1].set_ylabel(r'RS hypo: $1 - \Sigma x_k$')
        
    plt.tight_layout()
    return 

def plot_output_variations(u, x, y, yrelu):
    
    nn = len(u)
    t_discrete = np.arange(nn)
    
    n_chain = x.shape[1]
    assert n_chain == 3
    
    plt.figure(figsize=(6,5))
    plt.axhline(0)
    
    hypothetical_RS = simple_relu(
            simple_relu( 
                simple_relu(1 - x[:, 0]) 
                - x[:, 1])
            - x[:, 2])

    plt.plot(t_discrete, yrelu[:, -1], label='output of last unit')
    
    plt.plot(t_discrete, 1 - np.sum(x, axis=1), label='RS from paper')
    plt.axhline(0, linestyle='--')
    plt.axhline(-0.5, linestyle=':', c='k', alpha=0.5)

    plt.plot(t_discrete, hypothetical_RS, label='Instantaneous RS of u=1 only @ t')
    plt.title('Recreate Staddon Higa 1996 Fig. 1 - Reflex Strength options')

    plt.legend()
    plt.tight_layout()
    return


In [None]:
out_Tfast, x_Tfast, y_Tfast, yrelu_Tfast = sim_staddon1996_response(u_Tfast)
out_Tslow, x_Tslow, y_Tslow, yrelu_Tslow = sim_staddon1996_response(u_Tslow)

plot_staddon_chain(u_Tfast, x_Tfast, y_Tfast, yrelu_Tfast)
plot_output_variations(u_Tfast, x_Tfast, y_Tfast, yrelu_Tfast)

plot_staddon_chain(u_Tslow, x_Tslow, y_Tslow, yrelu_Tslow)
plot_output_variations(u_Tslow, x_Tslow, y_Tslow, yrelu_Tslow)


In [None]:
t_discrete = np.arange(len(u_Tfast))

fig, axarr = plt.subplots(10, 2, squeeze=False, sharex=True,figsize=(8, 10))  #, 

axarr[0, 0].plot(t_discrete, u_Tfast)
axarr[0, 0].set_title('T=%.2f' % Tfast)
axarr[0, 0].set_ylabel(r'$u$')
axarr[0, 1].plot(t_discrete, u_Tslow)
axarr[0, 1].set_title('T=%.2f' % Tslow)
axarr[0, 1].set_ylabel(r'$u$')

for idx in range(3):
    
    loc = 1 + idx*3
    
    axarr[loc, 0].plot(t_discrete, x_Tfast[:, idx])
    axarr[loc, 1].plot(t_discrete, x_Tslow[:, idx])
    axarr[loc, 0].set_ylabel('x')
    
    axarr[loc+1, 0].plot(t_discrete, y_Tfast[:, idx])
    axarr[loc+1, 1].plot(t_discrete, y_Tslow[:, idx])
    axarr[loc+1, 0].set_ylabel('input - x')
    
    axarr[loc+2, 0].plot(t_discrete, yrelu_Tfast[:, idx])
    axarr[loc+2, 1].plot(t_discrete, yrelu_Tslow[:, idx])
    axarr[loc+2, 0].set_ylabel('ReLu(input - x)')

plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.tight_layout()
plt.show()

In [None]:

plt.close('all')
plt.plot(t_discrete, out_Tfast)
plt.plot(t_discrete, out_Tslow)
plt.show()


plt.close('all')
plt.figure(figsize=(4,4))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, (yrelu_Tfast[:, 1] - x_Tfast[:, 2]), label='ISI low (rapid)')
plt.plot(t_discrete, (yrelu_Tslow[:, 1] - x_Tslow[:, 2]), label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle='--', c='k')
plt.legend()
plt.ylabel('Reflex strength of block 3')
plt.xlabel('timestep')
plt.tight_layout()
plt.show()

'''plt.close()
plt.plot(u_T3)
plt.plot(u_T9)
plt.xlim(0,20)
plt.show()'''

In [None]:

plt.close('all')
plt.plot(t_discrete / Tfast, out_Tfast, '-ok')
plt.plot(t_discrete / Tslow, out_Tslow)
plt.xlim(0, 10)

plt.show()

plt.close('all')
plt.figure(figsize=(4,4))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, u_Tfast, alpha=0.5)
plt.plot(t_discrete / Tfast, (yrelu_Tfast[:, 1] - x_Tfast[:, 2]), label='ISI low (rapid)')
plt.plot(t_discrete / Tslow, (yrelu_Tslow[:, 1] - x_Tslow[:, 2]), label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle='--', c='k')
plt.legend()
plt.ylabel('Reflex strength of block 3')
plt.xlabel('timestep')
plt.tight_layout()

for idx in range(10):
    plt.axvline(idx*2)
    
plt.xlim(0, 10)
plt.show()


In [None]:
plt.close('all')
plt.figure(figsize=(8,6))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, 1 - (x_Tfast[:, 0] + x_Tfast[:, 1] + x_Tfast[:, 2]), label='ISI low (rapid)')
plt.plot(t_discrete, 1 - (x_Tslow[:, 0] + x_Tslow[:, 1] + x_Tslow[:, 2]), label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle=':', c='k', alpha=0.5)
plt.axhline(-0.5, linestyle=':', c='k', alpha=0.5)
plt.legend()
plt.ylabel('Reflex strength GUESS: 1 - (x1+x2+x3)')
plt.xlabel('timestep')
#plt.xlim(0, 20)
plt.tight_layout()
plt.show()

In [None]:
 plt.close('all')
plt.figure(figsize=(8,6))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, 1 - (x_Tfast[:, 2]), label='ISI low (rapid)')
plt.plot(t_discrete, 1 - (x_Tslow[:, 2]), label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle='--', c='k')
plt.legend()
plt.ylabel('Reflex strength GUESS: 1 - x3')
plt.xlabel('timestep')
#plt.xlim(0, 20)
plt.tight_layout()
plt.show()

In [None]:
plt.close('all')
plt.figure(figsize=(8,6))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, y_Tfast[:, 2], label='ISI low (rapid)')
plt.plot(t_discrete, y_Tslow[:, 2], label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle='--', c='k')
plt.legend()
plt.ylabel('Reflex strength GUESS: 1 - x3')
plt.xlabel('timestep')
#plt.xlim(0, 20)
plt.tight_layout()
plt.show()

# Try assuming different dynamics for the chain to recreate plots...

In [None]:
out_Tfast, x_Tfast, y_Tfast, yrelu_Tfast = sim_staddon1996_guessing_response(u_Tfast)
out_Tslow, x_Tslow, y_Tslow, yrelu_Tslow = sim_staddon1996_guessing_response(u_Tslow)

plot_staddon_chain(u_Tfast, x_Tfast, y_Tfast, yrelu_Tfast)
plot_output_variations(u_Tfast, x_Tfast, y_Tfast, yrelu_Tfast)

plot_staddon_chain(u_Tslow, x_Tslow, y_Tslow, yrelu_Tslow)
plot_output_variations(u_Tslow, x_Tslow, y_Tslow, yrelu_Tslow)

In [None]:
plt.close('all')
plt.figure(figsize=(8,6))
plt.suptitle('StaddonHiga 1996 Fig 1 recreate')
plt.plot(t_discrete, 1 - (x_Tfast[:, 0] + x_Tfast[:, 1] + x_Tfast[:, 2]), label='ISI low (rapid)')
plt.plot(t_discrete, 1 - (x_Tslow[:, 0] + x_Tslow[:, 1] + x_Tslow[:, 2]), label='ISI high (slow)')
plt.axhline(0, linestyle='--', c='k')
plt.axhline(0.5, linestyle='--', c='k')
plt.legend()
plt.ylabel('Reflex strength GUESS: 1 - (x1+x2+x3)')
plt.xlabel('timestep')
#plt.xlim(0, 20)
plt.tight_layout()
plt.show()