In [30]:
import importlib

import numpy as np
from numpy.random import RandomState
from numpy.testing import assert_allclose, assert_array_equal
from scipy import sparse
from scipy.sparse import csr_matrix, csc_matrix, coo_matrix
from scipy.stats import norm

from matplotlib import pyplot as plt
%matplotlib tk

import network as N

rng_seed = 1234
rng = np.random.RandomState(rng_seed)

In [22]:
# new, useful code!

from itertools import combinations

def assert_allcombs_equal(iter_arr, ck_fun=assert_allclose):
    combs = combinations(iter_arr, 2)
    for comb in combs:
        ck_fun(*comb)

def imshow_cb(a, ax):
    
    i = ax.imshow(a, cmap='RdBu_r')
    cb = plt.colorbar(i)

    lims = cb.get_clim()
    maxabs = np.fabs(lims).max()
    cb.set_clim(-maxabs, maxabs)

In [2]:
importlib.reload(N)

<module 'network' from '/Users/michaelseay/Code/py-rrn/network.py'>

In [178]:
network_params = {'n_units': 800,
                 'p_plastic': 0.6,
                 'p_connect': 0.1,
                 'syn_strength': 1.5,
                 'tau_ms': 10,
                 'sigmoid': np.tanh,
                 'noise_amp': 0.001}

trial_params = {'length_ms': 1000,
                'spacing': 2,
                'time_step': 1,
                'start_train_ms': 250,
                'end_train_ms': 1400,}

input_params = {'n_units': 1,
                'value': 5,
                'start_ms': 200,
                'duration_ms': 50}

output_params = {'n_units': 1,
                 'value': 1,
                'center_ms': 1250,
                'width_ms': 30,
                'baseline_val': 0.2}

train_params = {'n_trials_recurrent': 20,
                'n_trials_readout': 10,
                'n_trials_test': 10}

Net = N.Network(**network_params)
Tryal = N.Trial(**trial_params)
In = N.Input(Tryal, **input_params)
Out = N.Output(Tryal, **output_params)
Train = N.Trainer(Net, In, Out, Tryal, **train_params)

In [175]:
## connectivity matrices

# "generator" network recurrent weight matrix (WXX)
# indices are define as WXX[postsyn, presyn]

# logical mask for non-zero connections
WXX_mask = np.random.rand(Net.n_units, Net.n_units)  # uniform distribution!
WXX_mask[WXX_mask <= Net.p_connect] = 1
WXX_mask[WXX_mask < 1] = 0

# connection weights
WXX_vals = np.random.normal(scale=Net.scale_recurr, size=(Net.n_units, Net.n_units))

# create non-sparse version of WXX and set self-connections (diagonal elements) to 0
WXX_nonsparse = WXX_vals * WXX_mask
np.fill_diagonal(WXX_nonsparse, 0)

# convert to be sparse
WXX = csr_matrix(WXX_nonsparse)
WXX_c = csc_matrix(WXX_nonsparse) # for testing
WXX_o = coo_matrix(WXX_nonsparse) # for testing

# make a copy
WXX_ini = WXX.copy()

# input => generator weights
WInputX = np.random.normal(scale=1, size=(Net.n_units, In.n_units))

# generator weights => output
WXOut = np.random.normal(scale=1/np.sqrt(Net.n_units), size=(Out.n_units, Net.n_units))

# make a copy
WXOut_ini = WXOut.copy()

In [180]:
# check type size and range of these mats

ck_mats = (WXX_nonsparse, WXX, WXX_c, WXX_o, WInputX, WXOut)

for cm in ck_mats:
    print(type(cm), cm.shape, np.min(cm), np.max(cm))

<class 'numpy.ndarray'> (800, 800) -0.647480618689 0.677844715749
<class 'scipy.sparse.csr.csr_matrix'> (800, 800) -0.647480618689 0.677844715749
<class 'scipy.sparse.csc.csc_matrix'> (800, 800) -0.647480618689 0.677844715749
<class 'scipy.sparse.coo.coo_matrix'> (800, 800) -0.647480618689 0.677844715749
<class 'numpy.ndarray'> (800, 1) -3.58742954076 2.9278030976
<class 'numpy.ndarray'> (1, 800) -0.132783321659 0.105244799403


In [176]:
# trial training time indices

start_train_n = np.round(Tryal.start_train_ms/Tryal.time_step)
end_train_n = np.round(Tryal.end_train_ms/Tryal.time_step)

In [177]:
# main loop

# vectors representing the activity of the RRN units and ouputs over time
X_history = np.zeros((Net.n_units, Tryal.n_steps))
Out_history = np.zeros((Out.n_units, Tryal.n_steps))

# ???
WXOut_len = np.zeros((1, Tryal.n_steps))
WXX_len = np.zeros((1, Tryal.n_steps))
dW_readout_len = np.zeros((1, Tryal.n_steps))
dW_recurr_len = np.zeros((1, Tryal.n_steps))
train_window = 0

# initial conditions

# initial Xv is random uniform distribution from -1 to +1
# this represents an analog firing rate
Xv = 2 * np.random.rand(Net.n_units, 1) - 1

# X is the sigmoid (tanh) of Xv, which will be bound from -0.76 to +0.76
# which represents a membrane potential
# as firing rate increases,
# membrane potential increases less quickly than linearly
X = Net.sigmoid(Xv)

O = np.zeros((Out.n_units,1))

In [194]:
# what does the sigmoid do?

s = np.linspace(-1, 1, 100)
h = Net.sigmoid(s)

f, ax = plt.subplots()

ax.plot(s, h)
# ax.plot(s, s, 'k--')
ax.set_xlabel('Firing Rate')
ax.set_ylabel('Membrane Potential (uV)')

ax.legend()



In [189]:
# check type size and range of these mats

ck_mats = (X_history, Out_history, WXOut_len, WXX_len,
           dW_readout_len, dW_recurr_len, Xv, X, O)

for cm in ck_mats:
    print(type(cm), cm.shape, np.min(cm), np.max(cm))

<class 'numpy.ndarray'> (800, 1600) 0.0 0.0
<class 'numpy.ndarray'> (1, 1600) 0.0 0.0
<class 'numpy.ndarray'> (1, 1600) 0.0 0.0
<class 'numpy.ndarray'> (1, 1600) 0.0 0.0
<class 'numpy.ndarray'> (1, 1600) 0.0 0.0
<class 'numpy.ndarray'> (1, 1600) 0.0 0.0
<class 'numpy.ndarray'> (800, 1) -0.999716542914 0.998805284375
<class 'numpy.ndarray'> (800, 1) -0.761475085551 0.761091949335
<class 'numpy.ndarray'> (1, 1) 0.0 0.0


In [None]:
# integration loop

# constant value by which the update to Xv based on the summation
# of recurrent generator network inputs AND external input inputs
# AND noise, are divided...
# this simulates a neural time constant?
time_div = Net.tau_ms / Tryal.time_step

for i in range(Tryal.n_steps):

    # update units
    
    # (In.n_units, 1)
    in_vec = input_pattern[:, i]
    
    # (Net.n_units, 1)
    noise = use_noiseamp * np.random.normal(scale=np.sqrt(Tryal.time_step), size=(Net.n_units,1))
    
    
#     Xv_current = \
#         WXX * X \ # (Net.n_units, Net.n_units) * (Net.n_units, 1) => (Net.n_units, 1)
#         + \
#         WInputX * in_vec \ # (Net.n_units, In.n_units) * (In.n_units, 1) => (Net.n_units, 1)
#         + \
#         noise # (Net.n_units, 1)
    Xv_current = WXX * X + WInputX * in_vec + noise
    
    Xv += \
        (-Xv + Xv_current) \
        / \
        time_div
    Xv += (-Xv + Xv_current) / time_div
    X = Net.sigmoid(Xv)
    Out = WXOut * X

    # start-end training window
    if (i == start_train_n)
        train_window = 1
    end
    if (i == end_train_n)
        train_window = 0
    end

    # training
    if (train_window == 1 && rem(i,learn_every) == 0)

        if TRAIN_RECURR == 1
            # train recurrent
            error = X - Target_innate_X(:,i)
            for plas = 1:numplastic_Units
                X_pre_plastic = X(pre_plastic_units(plas).inds)
                P_recurr_old = P_recurr(plas).P
                P_recurr_old_X = P_recurr_old*X_pre_plastic
                den_recurr = 1 + X_pre_plastic'*P_recurr_old_X
                P_recurr(plas).P = P_recurr_old - (P_recurr_old_X*P_recurr_old_X')/den_recurr
                # update network matrix
                dW_recurr = -error(plas)*(P_recurr_old_X/den_recurr)'
                WXX(plas,pre_plastic_units(plas).inds) = WXX(plas,pre_plastic_units(plas).inds) + dW_recurr
                # store change in weights
                dW_recurr_len(i) = dW_recurr_len(i) + np.sqrt(dW_recurr*dW_recurr')
            end
        end

        if TRAIN_READOUT == 1
            # update inverse correlation matrix (using property P' = P)
            P_readout_old = P_readout
            P_readout_old_X = P_readout_old*X
            den_readout = 1 + X'*P_readout_old_X
            P_readout = P_readout_old - (P_readout_old_X*P_readout_old_X')/den_readout
            # update error
            error = Out - target_Out(i)
            # update output weights
            dW_readout = -error*(P_readout_old_X/den_readout)'
            WXOut = WXOut + dW_readout
            # store change in weights
            dW_readout_len(i) = np.sqrt(dW_readout*dW_readout')
        end

    end
    # store output
    Out_history(:,i) = Out
    X_history(:,i) = X
    WXOut_len(i) = np.sqrt(sum(reshape(WXOut.^2,numOut*numUnits,1)))
    WXX_len(i) = np.sqrt(sum(reshape(WXX.^2,numUnits^2,1)))
end

In [None]:
# raw matlab code

WXOut_len = np.zeros((1, Tryal.n_steps))
WXX_len = np.zeros((1, Tryal.n_steps))
dW_readout_len = np.zeros((1, Tryal.n_steps))
dW_recurr_len = np.zeros((1, Tryal.n_steps))
train_window = 0

# initial conditions
Xv = 1*(2*np.random.rand(numUnits,1)-1)
X = Net.sigmoid(Xv)
Out = np.zeros(numOut,1)


# integration loop
for i = 1:n_steps

    if rem(i,round(n_steps/10)) == 0 && (TRAIN_RECURR == 1 || TRAIN_READOUT == 1)
        fprintf('.')
    end

    in_vec= input_pattern(:,i)

    # update units
    noise = use_noiseamp*np.random.normal(numUnits,1)*np.sqrt(Tryal.time_step)
    Xv_current = WXX*X + WInputX*in_vec+ noise
    Xv = Xv + ((-Xv + Xv_current)./tau)*Tryal.time_step
    X = Net.sigmoid(Xv)
    Out = WXOut*X

    # start-end training window
    if (i == start_train_n)
        train_window = 1
    end
    if (i == end_train_n)
        train_window = 0
    end

    # training
    if (train_window == 1 && rem(i,learn_every) == 0)

        if TRAIN_RECURR == 1
            # train recurrent
            error = X - Target_innate_X(:,i)
            for plas = 1:numplastic_Units
                X_pre_plastic = X(pre_plastic_units(plas).inds)
                P_recurr_old = P_recurr(plas).P
                P_recurr_old_X = P_recurr_old*X_pre_plastic
                den_recurr = 1 + X_pre_plastic'*P_recurr_old_X
                P_recurr(plas).P = P_recurr_old - (P_recurr_old_X*P_recurr_old_X')/den_recurr
                # update network matrix
                dW_recurr = -error(plas)*(P_recurr_old_X/den_recurr)'
                WXX(plas,pre_plastic_units(plas).inds) = WXX(plas,pre_plastic_units(plas).inds) + dW_recurr
                # store change in weights
                dW_recurr_len(i) = dW_recurr_len(i) + np.sqrt(dW_recurr*dW_recurr')
            end
        end

        if TRAIN_READOUT == 1
            # update inverse correlation matrix (using property P' = P)
            P_readout_old = P_readout
            P_readout_old_X = P_readout_old*X
            den_readout = 1 + X'*P_readout_old_X
            P_readout = P_readout_old - (P_readout_old_X*P_readout_old_X')/den_readout
            # update error
            error = Out - target_Out(i)
            # update output weights
            dW_readout = -error*(P_readout_old_X/den_readout)'
            WXOut = WXOut + dW_readout
            # store change in weights
            dW_readout_len(i) = np.sqrt(dW_readout*dW_readout')
        end

    end
    # store output
    Out_history(:,i) = Out
    X_history(:,i) = X
    WXOut_len(i) = np.sqrt(sum(reshape(WXOut.^2,numOut*numUnits,1)))
    WXX_len(i) = np.sqrt(sum(reshape(WXX.^2,numUnits^2,1)))
end

# testing 1

In [4]:
# check: plot input and output patterns

print(Tryal.time_ms.shape)
print(In.series.shape)
print(Out.series.shape)

f, ax = plt.subplots()
ax.plot(Tryal.time_ms, In.series.T, label='input')
ax.plot(Tryal.time_ms, Out.series.T, label='output')

ax.legend()

(1600,)
(1, 1600)
(1, 1600)


<matplotlib.legend.Legend at 0x114726550>

In [28]:
# quick verification that all of the following ways of creating
# a matrix containing random normal numbers
# are equivalent

prng = RandomState(1234) # note one needs to re-call this to get same results every time
test1 = prng.normal(size=(Net.n_units, Net.n_units)) * Net.scale_recurr

prng = RandomState(1234)
test2 = prng.normal(scale=Net.scale_recurr, size=(Net.n_units, Net.n_units))

prng = RandomState(1234)
test3 = norm.rvs(scale=Net.scale_recurr, size=(Net.n_units, Net.n_units), random_state=prng)

In [29]:
assert_allcombs_equal((test1, test2, test3), ck_fun=assert_array_equal)

In [None]:
# testing equality / speed of generating random sparse matrices

# method 1

In [169]:
# %%timeit

# uniform distribution for mask
prng = RandomState(1234)
WXX_mask = prng.rand(Net.n_units, Net.n_units)
WXX_mask[WXX_mask <= Net.p_connect] = 1
WXX_mask[WXX_mask < 1] = 0

# normal distribution for vals
prng = RandomState(1234)
WXX_vals = prng.normal(scale=Net.scale_recurr, size=(Net.n_units, Net.n_units))

# create non-sparse version of WXX and set self-connections (diagonal elements) to 0
WXX_nonsparse = WXX_vals * WXX_mask
np.fill_diagonal(WXX_nonsparse, 0)
# WXX_nonsparse[np.diag_indices_from(WXX_nonsparse)] = 0

# convert to be sparse
# WXX = coo_matrix(WXX_nonsparse)
WXX = csr_matrix(WXX_nonsparse)

In [170]:
# method 2

In [171]:
rvs = norm(scale=Net.scale_recurr).rvs

In [172]:
# %%timeit

# random variable generator object
prng = RandomState(1234)
WXX_s = random(Net.n_units, Net.n_units, density=Net.p_connect, random_state=prng, data_rvs=rvs,
               format='csr',)
#                format='csr',)
# WXX_s[np.diag_indices_from(WXX_s)] = 0
WXX_s.setdiag(0)



In [173]:
cm = WXX_nonsparse
print(cm.shape, np.count_nonzero(cm), cm.min(), cm.max())

(800, 800) 63797 -0.775266142687 0.726642624784


In [174]:
ck_mats = (WXX, WXX_s)

for cm in ck_mats:
    print(cm.shape, cm.nnz, cm.count_nonzero(), cm.min(), cm.max())

(800, 800) 63797 63797 -0.775266142687 0.726642624784
(800, 800) 64723 63923 -0.742525396081 0.73124373277


In [96]:
f, ax = plt.subplots()
imshow_cb(WXX.toarray(), ax)

In [95]:
f, ax = plt.subplots()
imshow_cb(WXX_s.toarray(), ax)

In [None]:
# second method is slower :(

# testing 2

In [None]:
ck_mats = (WXX_nonsparse, WXX, WXX_c, WXX_o, WInputX, WXOut)

In [None]:
# testing speed of *

In [109]:
%%timeit
test1_ns = WXX_nonsparse*X # SLOW, WRONG SHAPE

1000 loops, best of 3: 700 µs per loop


In [110]:
%%timeit
test1 = WXX*X

The slowest run took 4.41 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 60.6 µs per loop


In [111]:
%%timeit
test1_c = WXX_c*X # NO DIFFERENCE WITH ROW / COLUMN SPARSE MATRICES

10000 loops, best of 3: 62.5 µs per loop


In [112]:
%%timeit
test1_o = WXX_o*X # COO sparse matrix is slower

10000 loops, best of 3: 159 µs per loop


In [133]:
test1_ns = WXX_nonsparse*X # SLOW, WRONG SHAPE
test1 = WXX*X
test1_c = WXX_c*X # NO DIFFERENCE WITH ROW / COLUMN SPARSE MATRICES
test1_o = WXX_o*X # COO sparse matrix is slower

print(test1_ns.shape)
print(test1.shape)
print(test1_c.shape)
print(test1_o.shape)

(800, 800)
(800, 1)
(800, 1)
(800, 1)


In [116]:
ck_mats = (test1, test1_c, test1_o)

In [118]:
assert_allcombs_equal(ck_mats, assert_array_equal)

In [119]:
# testing speed of dot

In [126]:
%%timeit
test2_ns = WXX_nonsparse.dot(X) # ONLY VERY SLIGHTLY SLOWER THAN
# SPARSE, CORRECT

The slowest run took 5.99 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 69.7 µs per loop


In [127]:
%%timeit
test2 = WXX.dot(X)

10000 loops, best of 3: 63.1 µs per loop


In [129]:
%%timeit
test2_c = WXX_c.dot(X)

The slowest run took 4.52 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 62.4 µs per loop


In [128]:
%%timeit
test2_o = WXX_o.dot(X)

10000 loops, best of 3: 160 µs per loop


In [134]:
test2_ns = WXX_nonsparse.dot(X) # ONLY VERY SLIGHTLY SLOWER THAN
# SPARSE, CORRECT BUT NUMERICALLY SLIGHTLY DIFFERENT?
test2 = WXX.dot(X)
test2_c = WXX_c.dot(X)
test2_o = WXX_o.dot(X)

In [135]:
ck_mats = (test1, test1_c, test1_o, test2, test2_c, test2_o, test2_ns)

In [140]:
ck_mats = (test2, test2_c, test2_o,)

In [136]:
assert_allcombs_equal(ck_mats)

In [141]:
assert_allcombs_equal(ck_mats, ck_fun=assert_array_equal)

In [None]:
# testing speed of matmul

In [145]:
%%timeit
test3_ns = np.matmul(WXX_nonsparse, X)

The slowest run took 7.55 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 69.5 µs per loop


In [148]:
%%timeit
test3 = np.matmul(WXX, X)

TypeError: Object arrays are not currently supported

In [147]:
test3_ns.shape

(800, 1)

In [None]:
# conclusion, for multiplying the sparse matrix by a vector, it's fastest
# for a csr_matrix and just use the * operator, it will do the matrix
# multiplication in about 60 us