In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

In [2]:
input_dim = 50
t_res = 0.01
input_locs = np.linspace(-5., 5., input_dim)
t = np.arange(-10., 10., t_res)
pre_activity = np.zeros((input_dim, len(t)))
for i in range(input_dim):
    index = np.where(t >= input_locs[i])[0][0]
    pre_activity[i, index] = 1.
plt.figure()
plt.imshow(pre_activity, aspect='auto', extent=(-10., 10., 50, 0))
plt.xlim(-5., 5.)

<IPython.core.display.Javascript object>

(-5.0, 5.0)

In [3]:
test = pre_activity[0,:]
plateau = np.zeros_like(t)
plateau_indexes = np.where((t >= 0.) & (t < 0.3))[0]
plateau[plateau_indexes] = 1.
tau_ET = 1.7
tau_IS = 0.9

filter_ET = np.exp(-t/tau_ET)
filter_ET /= np.sum(filter_ET)
this_ET = np.convolve(test, filter_ET)[:len(t)]
ET_norm_factor = np.max(this_ET)
this_ET /= ET_norm_factor

filter_IS = np.exp(-t/tau_IS)
filter_IS /= np.sum(filter_IS)
IS = np.convolve(plateau, filter_IS)[:len(t)]
IS /= np.max(IS)

Ca_tau_ET = 0.05
Ca_tau_IS = 0.05

Ca_filter_ET = np.exp(-t/Ca_tau_ET)
Ca_filter_ET /= np.sum(Ca_filter_ET)
this_Ca_ET = np.convolve(test, Ca_filter_ET)[:len(t)]
Ca_ET_norm_factor = np.max(this_Ca_ET)
this_Ca_ET /= Ca_ET_norm_factor

Ca_filter_IS = np.exp(-t/Ca_tau_IS)
Ca_filter_IS /= np.sum(Ca_filter_IS)
Ca_IS = np.convolve(plateau, Ca_filter_IS)[:len(t)]
Ca_IS /= np.max(Ca_IS)

plt.figure()
plt.plot(t, test, t, this_ET)
plt.plot(t, this_Ca_ET)
plt.figure()
plt.plot(t, plateau, t, IS)
plt.plot(t, Ca_IS)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fdd982d3650>]

In [4]:
ET = np.empty_like(pre_activity)
for i in range(input_dim):
    ET[i,:] = np.convolve(pre_activity[i,:], filter_ET)[:len(t)] / ET_norm_factor
plt.figure()
plt.imshow(ET, aspect='auto', extent=(-10., 10., 50, 0))
plt.colorbar()
plt.xlim(-5., 5.)

ET_IS = ET * IS
plt.figure()
plt.imshow(ET_IS, aspect='auto', extent=(-10., 10., 50, 0))
plt.colorbar()
plt.xlim(-5., 5.)

Ca_ET = np.empty_like(pre_activity)
for i in range(input_dim):
    Ca_ET[i,:] = np.convolve(pre_activity[i,:], Ca_filter_ET)[:len(t)] / Ca_ET_norm_factor
plt.figure()
plt.imshow(Ca_ET, aspect='auto', extent=(-10., 10., 50, 0))
plt.colorbar()
plt.xlim(-5., 5.)

Ca_ET_IS = Ca_ET * Ca_IS
plt.figure()
plt.imshow(Ca_ET_IS, aspect='auto', extent=(-10., 10., 50, 0))
plt.colorbar()
plt.xlim(-5., 5.)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

(-5.0, 5.0)

In [5]:
dt = t[1] - t[0]
area_ET_IS = np.trapz(ET_IS, dx=dt)
Ca_area_ET_IS = np.trapz(Ca_ET_IS, dx=dt)
plt.figure()
plt.plot(input_locs, area_ET_IS)
plt.plot(input_locs, Ca_area_ET_IS)

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fdd580749d0>]

In [6]:
def scaled_single_sigmoid(th, peak, x=None, ylim=None):
    """
    Transform a sigmoid to intersect x and y range limits.
    :param th: float
    :param peak: float
    :param x: array
    :param ylim: pair of float
    :return: callable
    """
    if x is None:
        x = (0., 1.)
    if ylim is None:
        ylim = (0., 1.)
    if th < x[0] or th > x[-1]:
        raise ValueError('scaled_single_sigmoid: th: %.2E is out of range for xlim: [%.2E, %.2E]' % (th, x[0], x[-1]))
    if peak == th:
        raise ValueError('scaled_single_sigmoid: peak and th: %.2E cannot be equal' % th)
    slope = 2. / (peak - th)
    y = lambda x: 1. / (1. + np.exp(-slope * (x - th)))
    start_val = y(x[0])
    end_val = y(x[-1])
    amp = end_val - start_val
    target_amp = ylim[1] - ylim[0]
    return lambda xi: (target_amp / amp) * (1. / (1. + np.exp(-slope * (xi - th))) - start_val) + ylim[0]


def get_linear_dW(dt, ET_IS, w, learning_rate, dep_ratio, max_weight):
    w0 = w / max_weight
    dWdt = learning_rate * ((1. - w0[:,None]) * ET_IS - dep_ratio * w0[:,None] * ET_IS)
    dW = np.trapz(dWdt, dx=dt)
    dW = np.maximum(-w0, np.minimum(1. - w0, dW))
    dW *= max_weight
    return dW


def get_linear_W_eq(dt, ET_IS, learning_rate, dep_ratio, max_weight):
    W_eq = np.ones(ET_IS.shape[0])
    W_eq *= 1. / (1. + dep_ratio) * max_weight
    return W_eq


def get_lin_pot_sig_dep_dW(dt, ET_IS, w, dep_f, learning_rate, dep_ratio, max_weight):
    w0 = w / max_weight
    dWdt = learning_rate * ((1. - w0[:,None]) * ET_IS - dep_ratio * w0[:,None] * dep_f(ET_IS))
    dW = np.trapz(dWdt, dx=dt)
    dW = np.maximum(-w0, np.minimum(1. - w0, dW))
    dW *= max_weight
    return dW


def get_lin_pot_sig_dep_W_eq(dt, ET_IS, dep_f, learning_rate, dep_ratio, max_weight):
    W_eq = np.empty(ET_IS.shape[0])
    W_eq[:] = np.nan
    numer = np.trapz(ET_IS, dx=dt)
    denom = np.trapz(ET_IS + dep_ratio * dep_f(ET_IS), dx=dt)
    indexes = np.where(denom >= 0.0005)
    W_eq[indexes] = numer[indexes] / denom[indexes] * max_weight
    return W_eq


def get_sig_dW(dt, ET_IS, w, pot_f, dep_f, learning_rate, dep_ratio, max_weight):
    w0 = w / max_weight
    dWdt = learning_rate * ((1. - w0[:,None]) * pot_f(ET_IS) - dep_ratio * w0[:,None] * dep_f(ET_IS))
    dW = np.trapz(dWdt, dx=dt)
    dW = np.maximum(-w0, np.minimum(1. - w0, dW))
    dW *= max_weight
    return dW


def get_sig_W_eq(dt, ET_IS, pot_f, dep_f, learning_rate, dep_ratio, max_weight):
    W_eq = np.empty(ET_IS.shape[0])
    W_eq[:] = np.nan
    numer = np.trapz(pot_f(ET_IS), dx=dt)
    denom = np.trapz(pot_f(ET_IS) + dep_ratio * dep_f(ET_IS), dx=dt)
    indexes = np.where(denom >= 0.0005)
    W_eq[indexes] = numer[indexes] / denom[indexes] * max_weight
    return W_eq

In [7]:
lin_learning_rate = 1.02
lin_dep_ratio = 0.5
lin_max_weight = 3.
linear_dW_params = [lin_learning_rate, lin_dep_ratio, lin_max_weight]

linear_W_eq = get_linear_W_eq(dt, ET_IS, *linear_dW_params)
linear_dW = np.empty((input_dim, input_dim))
for i, w in enumerate(np.linspace(1., 2.4, input_dim)):
    this_lin_dW = get_linear_dW(dt, ET_IS, np.ones_like(input_locs) * w, *linear_dW_params)
    linear_dW[i,:] = this_lin_dW
    
vmax = max(np.abs(np.max(linear_dW)), np.abs(np.min(linear_dW)))
plt.figure()
plt.imshow(linear_dW[::-1,:], extent=(-5., 5., 1., 2.4), aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
plt.colorbar()
plt.plot(input_locs, linear_W_eq, '--', c='grey')

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7fdd6a39d450>]

In [8]:
lin_pot_sig_dep_learning_rate = 1.01
lin_pot_sig_dep_dep_ratio = 0.15
lin_pot_sig_dep_max_weight = 3.
lin_pot_sig_dep_dep_th = 0.007
lin_pot_sig_dep_dep_width = 0.07
lin_pot_sig_dep_dW_params = [lin_pot_sig_dep_learning_rate, lin_pot_sig_dep_dep_ratio, lin_pot_sig_dep_max_weight]
lin_pot_sig_dep_dep_f = scaled_single_sigmoid(lin_pot_sig_dep_dep_th, lin_pot_sig_dep_dep_th + lin_pot_sig_dep_dep_width)

lin_pot_sig_dep_W_eq = get_lin_pot_sig_dep_W_eq(dt, ET_IS, lin_pot_sig_dep_dep_f, *lin_pot_sig_dep_dW_params)
lin_pot_sig_dep_dW = np.empty((input_dim, input_dim))
for i, w in enumerate(np.linspace(1., 2.4, input_dim)):
    this_lin_pot_sig_dep_dW = get_lin_pot_sig_dep_dW(dt, ET_IS, np.ones_like(input_locs) * w, lin_pot_sig_dep_dep_f, *lin_pot_sig_dep_dW_params)
    lin_pot_sig_dep_dW[i,:] = this_lin_pot_sig_dep_dW
    
vmax = max(np.abs(np.max(lin_pot_sig_dep_dW)), np.abs(np.min(lin_pot_sig_dep_dW)))
plt.figure()
plt.imshow(lin_pot_sig_dep_dW[::-1,:], extent=(-5., 5., 1., 2.4), aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
plt.colorbar()
plt.plot(input_locs, lin_pot_sig_dep_W_eq, '--', c='grey')
plt.ylim(1., 2.4)

<IPython.core.display.Javascript object>

(1.0, 2.4)

In [9]:
sig_learning_rate = 0.58
sig_dep_ratio = 0.22
sig_max_weight = 3.
sig_dep_th = 0.006
sig_dep_width = 0.039
sig_pot_th = 0.01
sig_pot_width = 0.4
sig_dW_params = [sig_learning_rate, sig_dep_ratio, sig_max_weight]
sig_dep_f = scaled_single_sigmoid(sig_dep_th, sig_dep_th + sig_dep_width)
sig_pot_f = scaled_single_sigmoid(sig_pot_th, sig_pot_th + sig_pot_width)

sig_W_eq = get_sig_W_eq(dt, ET_IS, sig_pot_f, sig_dep_f, *sig_dW_params)
sig_dW = np.empty((input_dim, input_dim))
for i, w in enumerate(np.linspace(1., 2.4, input_dim)):
    this_sig_dW = get_sig_dW(dt, ET_IS, np.ones_like(input_locs) * w, sig_pot_f, sig_dep_f, *sig_dW_params)
    sig_dW[i,:] = this_sig_dW
    
vmax = max(np.abs(np.max(sig_dW)), np.abs(np.min(sig_dW)))
plt.figure()
plt.imshow(sig_dW[::-1,:], extent=(-5., 5., 1., 2.4), aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
plt.colorbar()
plt.plot(input_locs, sig_W_eq, '--', c='grey')
plt.ylim(1., 2.4)

<IPython.core.display.Javascript object>

(1.0, 2.4)

In [10]:
sig_learning_rate = 7.4
sig_dep_ratio = 0.22
sig_max_weight = 3.
sig_dep_th = 0.006
sig_dep_width = 0.039
sig_pot_th = 0.01
sig_pot_width = 0.4
sig_dW_params = [sig_learning_rate, sig_dep_ratio, sig_max_weight]
sig_dep_f = scaled_single_sigmoid(sig_dep_th, sig_dep_th + sig_dep_width)
sig_pot_f = scaled_single_sigmoid(sig_pot_th, sig_pot_th + sig_pot_width)

sig_W_eq = get_sig_W_eq(dt, Ca_ET_IS, sig_pot_f, sig_dep_f, *sig_dW_params)
sig_dW = np.empty((input_dim, input_dim))
for i, w in enumerate(np.linspace(1., 2.4, input_dim)):
    this_sig_dW = get_sig_dW(dt, Ca_ET_IS, np.ones_like(input_locs) * w, sig_pot_f, sig_dep_f, *sig_dW_params)
    sig_dW[i,:] = this_sig_dW
    
vmax = max(np.abs(np.max(sig_dW)), np.abs(np.min(sig_dW)))
plt.figure()
plt.imshow(sig_dW[::-1,:], extent=(-5., 5., 1., 2.4), aspect='auto', cmap='RdBu_r', vmin=-vmax, vmax=vmax)
plt.colorbar()
plt.plot(input_locs, sig_W_eq, '--', c='grey')
plt.ylim(1., 2.4)

<IPython.core.display.Javascript object>

(1.0, 2.4)