In [1]:
import importlib

import numpy as np
from scipy.sparse import csr_matrix
from scipy.stats import norm

from matplotlib import pyplot as plt
%matplotlib tk

import network as N

In [2]:
importlib.reload(N)

<module 'network' from '/Users/michaelseay/Code/py-rrn/network.py'>

In [2]:
network_params = {'n_units': 800,
                 'p_plastic': 0.6,
                 'p_connect': 0.1,
                 'syn_strength': 1.5,
                 'tau_ms': 10,
                 'sigmoid': np.tanh,
                 'noise_amp': 0.001}

trial_params = {'length_ms': 1000,
                'spacing': 2,
                'time_step': 1,
                'start_train_ms': 250,
                'end_train_ms': 1400,}

input_params = {'n_units': 1,
                'value': 5,
                'start_ms': 200,
                'duration_ms': 50}

output_params = {'n_units': 1,
                 'value': 1,
                'center_ms': 1250,
                'width_ms': 30,
                'baseline_val': 0.2}

train_params = {'n_trials_recurrent': 20,
                'n_trials_readout': 10,
                'n_trials_test': 10}

In [3]:
Net = N.Network(**network_params)
Tryal = N.Trial(**trial_params)
In = N.Input(Tryal, **input_params)
Out = N.Output(Tryal, **output_params)
Train = N.Trainer(Net, In, Out, Tryal, **train_params)

In [4]:
# check: plot input pattern

print(In.series.shape)

f, ax = plt.subplots()
ax.plot(Tryal.time_ms, In.series.T)

(1, 1600)


[<matplotlib.lines.Line2D at 0x10fc0b8d0>]

In [5]:
# check: plot output pattern

print(Out.series.shape)

f, ax = plt.subplots()
ax.plot(Tryal.time_ms, Out.series.T)

(1600,)


[<matplotlib.lines.Line2D at 0x1161e6978>]

In [5]:
## connectivity matrices

# "generator" network recurrent weight matrix (WXX)
# indices are define as WXX[postsyn, presyn]

# logical mask for non-zero connections
WXX_mask = np.random.rand(Net.n_units, Net.n_units)
WXX_mask[WXX_mask <= Net.p_connect] = 1
WXX_mask[WXX_mask < 1] = 0

# connection weights
WXX_vals = np.random.normal(scale=Net.scale_recurr, size=(Net.n_units, Net.n_units))

# create non-sparse version of WXX and set self-connections (diagonal elements) to 0
WXX_nonsparse = WXX_vals * WXX_mask
WXX_nonsparse[np.diag_indices_from(WXX_nonsparse)] = 0

# convert to be sparse
WXX = csr_matrix(WXX_nonsparse)

# make a copy
WXX_ini = WXX.copy()

# input => generator weights
WInputX = np.random.normal(scale=1, size=(Net.n_units, In.n_units))

# generator weights => output
WXOut = np.random.normal(scale=1/np.sqrt(Net.n_units), size=(Out.n_units, Net.n_units))

# make a copy
WXOut_ini = WXOut.copy()

In [6]:
# checks

ck_mats = (WXX, WInputX, WXOut)

for cm in ck_mats:
    print(type(cm), cm.shape, np.min(cm), np.max(cm))

<class 'scipy.sparse.csr.csr_matrix'> (800, 800) -0.684088636613 0.699169817345
<class 'numpy.ndarray'> (800, 1) -3.31463468482 2.7316477414
<class 'numpy.ndarray'> (1, 800) -0.100152488589 0.0990310082078


In [None]:
start_train_n = np.round(start_train/Tryal.time_step)
end_train_n = np.round(end_train/Tryal.time_step)