In [1]:
import numpy as np
import matplotlib.patches as patches
import matplotlib
import matplotlib.pyplot as plt
from numba import njit
# %matplotlib inline
plt.rcParams['font.sans-serif'] = ['Verdana']
plt.rcParams.update({'pdf.fonttype': 42,
                     'ps.fonttype': 42,
                     'xtick.labelsize':'x-large',
                     'ytick.labelsize':'x-large'})
plt.rcParams.update({'font.size': 10, 'font.family': 'Helvetica'})

In [2]:
t1 = 2
t3 = 2
t4 = 1
t5 = 1
dt = 1e-4

p_kill = 0.05
p_synapses = 0.025
delay = 0.02
delay_r = 0.02
noise_a = 0.1
noise_b = 1

pGG = 0.1
pZ = 1
gGG = 1.5
gGZ = 1
tau = 0.01
alpha = 1

In [3]:
seed = 0
NG = 1000

In [4]:
T1 = int(t1 / dt)
T3 = int(t3 / dt)
T4 = int(t4 / dt)
T5 = int(t5 / dt)

delay_dt = int(delay / dt)
delay_r_dt = int(delay_r / dt)

DT = 50 #number of dt's inside each weight update cycle

In [5]:
@njit
def run_simple(seed, NG, tau, gGG, JGG, gGZ, JGZ1, JGZ2, w1, w2, dt, T1, T2, T3, T4, T5, DT, alpha, ptype = 'normal', nper = np.array([1, 2, 3]), mg = 0, mg_delay = 0):
    np.random.seed(seed)
    Delay_r_dt = np.random.randint(1, delay_r_dt, (NG, NG))
    Delay_r_z1_dt = np.random.randint(1, delay_r_dt, NG)
    Delay_r_z2_dt = np.random.randint(1, delay_r_dt, NG)
    Delay_r_f1_dt = np.random.randint(1, delay_r_dt, NG)
    Delay_r_f2_dt = np.random.randint(1, delay_r_dt, NG)
    T1 = T1
    TA = T1 + T2
    TB = TA + T3
    TC = TB + T4
    TD = TC + T5
    noiter = TD
    NG = JGG.shape[0]
    x = np.zeros((noiter, NG))
    r = np.zeros((noiter, NG))
    z1 = np.zeros(noiter)
    z2 = np.zeros(noiter)
    wh1 = np.zeros((noiter, w1.shape[0]))
    wc1 = w1
    wh1[:T1, :] = wc1
    wh2 = np.zeros((noiter, w2.shape[0]))
    wc2 = w2
    wh2[:T1, :] = wc2
    x[0, :] = np.random.randn(NG)
    P1 = np.eye(NG) / alpha
    P2 = np.eye(NG) / alpha
    recurrency = np.zeros(NG)
    feedback1 = np.zeros(NG)
    feedback2 = np.zeros(NG)
    for i in range(1, T1):
        x[i, :] = x[i - 1, :] + dt * (-x[i - 1, :] + gGG * np.dot(JGG, r[i - 1, :]) + gGZ * JGZ1 * z1[i - 1] + gGZ * JGZ2 * z2[i - 1]) / tau
        r[i, :] = np.tanh(x[i, :])
        z1[i] = np.dot(wc1, r[i, :])
        z2[i] = np.dot(wc2, r[i, :])
    j = 1
    for i in range(T1, TA):
        x[i, :] = x[i - 1, :] + dt * (-x[i - 1, :] + gGG * np.dot(JGG, r[i - 1, :]) + gGZ * JGZ1 * z1[i - 1] + gGZ * JGZ2 * z2[i - 1]) / tau
        r[i, :] = np.tanh(x[i, :])
        z1[i] = np.dot(wc1, r[i, :])
        z2[i] = np.dot(wc2, r[i, :])
        if j == DT:
            eminus1 = z1[i] - f1[i]
            P1 = P1 - (np.dot(np.dot(P1, r[i, :].reshape((NG, 1))), np.dot(r[i, :].reshape((1, NG)), P1))) / (1 + np.dot(r[i, :], np.dot(P1, r[i, :])))
            wc1 = wc1 - eminus1 * np.dot(P1, r[i, :])
            eminus2 = z2[i] - f2[i]
            P2 = P2 - (np.dot(np.dot(P2, r[i, :].reshape((NG, 1))), np.dot(r[i, :].reshape((1, NG)), P2))) / (1 + np.dot(r[i, :], np.dot(P2, r[i, :])))
            wc2 = wc2 - eminus2 * np.dot(P2, r[i, :])
            j = 0
        j += 1
        wh1[i, :] = wc1
        wh2[i, :] = wc2
    wh1[TA:, :] = wc1
    wh2[TA:, :] = wc2
    for i in range(TA, TB):
        x[i, :] = x[i - 1, :] + dt * (-x[i - 1, :] + gGG * np.dot(JGG, r[i - 1, :]) + gGZ * JGZ1 * z1[i - 1] + gGZ * JGZ2 * z2[i - 1]) / tau
        r[i, :] = np.tanh(x[i, :])
        z1[i] = np.dot(wc1, r[i, :])
        z2[i] = np.dot(wc2, r[i, :])
    if ptype == 'syna':
        JGG_p = JGG * (1 - p_synapses) ** (2 * np.random.rand(NG, NG) - 1)
    else:
        JGG_p = JGG
    if ptype == 'noise':
        mg_noise_b = noise_b
    else:
        mg_noise_b = 0
    mg_noise_a = 0
    for i in range(TB, TD):
        x[i, :] = x[i - 1, :] + dt * (-x[i - 1, :] + gGG * np.dot(JGG_p, r[i - 1, :]) + gGZ * JGZ1 * z1[i - 1] + gGZ * JGZ2 * z2[i - 1]) / tau + np.sqrt(dt) * mg_noise_b * np.random.randn(NG)
        if ptype == 'kill':
            for nper_i in nper:
                x[i, nper_i] = 0
        r[i, :] = np.tanh(x[i, :]) + mg_noise_a * np.random.randn(NG)
        z1[i] = np.dot(wc1, r[i, :])
        z2[i] = np.dot(wc2, r[i, :])
    return x, r, z1, z2, wh1, wh2

In [6]:
def plotFORCE(z1, z2, r, wh1, wh2, T1, T2, T3, T4, T5, dt, type_p, NG, seed, xlim='Partial'):
    if xlim == 'Partial':
        p_xlim = 0.4
    else:
        p_xlim = 0
    a = T1 + T2
    b = T1 + T2 + T3
    c = T1 + T2 + T3 + T4
    d = T1 + T2 + T3 + T4 + T5
    ind = np.random.randint(0, NG, 10)
    rp = r[:, ind] + np.arange(0, 20, 2) + 4    
    
    no_plots = 1
    axis_lw = 3
    signal_lw = 4
    fig, ax = plt.subplots(no_plots, 1, figsize=(3 * (1.25 * (2 - p_xlim)) , 3 * no_plots))

    ax_signal = ax

    ax_signal.axvline(b, color='black', alpha=0.5, linewidth=5)

    ax_signal.plot(z1, alpha=0.7, linewidth=signal_lw)
    ax_signal.plot(z2, alpha=0.7, linewidth=signal_lw)
    ax_signal.set_yticks([-1, 0, 1])
    ax_signal.set_ylim([-1.5, 1.5])
    
    ax_signal.spines['bottom'].set_linewidth(axis_lw)
    ax_signal.spines['left'].set_linewidth(axis_lw)

    if xlim == 'Full':
        ax_signal.set_xlim(a, d)
        ax_signal.set_xticks([a, b])
        ax_signal.set_xticklabels([0, 2])
    elif xlim == 'Partial':
        ax_signal.set_xlim(a * (1-p_xlim) + b * p_xlim, d)
        ax_signal.set_xticks([a * (1-p_xlim) + b * p_xlim, b])
        ax_signal.set_xticklabels([2 * p_xlim, 2])
    ax_signal.spines['right'].set_visible(False)
    ax_signal.spines['top'].set_visible(False)
        

    fig.savefig('FIG1C_{}.pdf'.format(type_p), transparent=True)
    plt.close(fig)

In [7]:
t2 = 2
T2 = int(t2 / dt)
np.random.seed(seed)

JGG = ((np.random.rand(NG, NG) < pGG) * 1) * np.random.randn(NG, NG) / np.sqrt(NG * pGG)
JGZ1 = np.random.rand(NG) * 2 - 1
JGZ2 = np.random.rand(NG) * 2 - 1
w1 = ((np.random.rand(NG) < pZ) * 1) * np.random.randn(NG) / np.sqrt(NG * pZ)
w2 = ((np.random.rand(NG) < pZ) * 1) * np.random.randn(NG) / np.sqrt(NG * pZ)

time = np.arange(0, T1 + T2 + T3 + T4 + T5) * dt
f1 = np.sin(time * 5)
f2 = 0.8 * np.cos((time + 0.785) * 5)

nper = np.nonzero(np.random.rand(NG) < 0.1)[0]
x, r, z1, z2, wh1, wh2 = run_simple(seed, NG, tau, gGG, JGG, gGZ, JGZ1, JGZ2, w1, w2, dt, T1, T2, T3, T4, T5, DT, alpha, ptype = 'kill', nper = nper)
plotFORCE(z1, z2, r, wh1, wh2, T1, T2, T3, T4, T5, dt, 'Kill', NG, seed, xlim='Full')

x, r, z1, z2, wh1, wh2 = run_simple(seed, NG, tau, gGG, JGG, gGZ, JGZ1, JGZ2, w1, w2, dt, T1, T2, T3, T4, T5, DT, alpha, ptype = 'syna')
plotFORCE(z1, z2, r, wh1, wh2, T1, T2, T3, T4, T5, dt, 'Synapses', NG, seed)

x, r, z1, z2, wh1, wh2 = run_simple(seed, NG, tau, gGG, JGG, gGZ, JGZ1, JGZ2, w1, w2, dt, T1, T2, T3, T4, T5, DT, alpha, ptype = 'noise')
plotFORCE(z1, z2, r, wh1, wh2, T1, T2, T3, T4, T5, dt, 'Noise', NG, seed)