<a href="https://colab.research.google.com/github/cianadeveau/NeuroRNN/blob/main/RNN_seqModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')import sys
sys.path.insert(0,'/content/drive/My Drive/src')

In [None]:
"""
Requires PyTorch -> conda install pytorch::pytorch torchvision torchaudio -c pytorch
"""
# setup
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# for system
import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt
import torch
import seaborn as sns
import matplotlib
from pathlib import Path

import modules_simpleRNN as mm

In [None]:
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['figure.dpi'] = 300
plt.rcParams['axes.labelsize'] = 7
plt.rcParams['xtick.labelsize'] = 7
plt.rcParams['ytick.labelsize'] = 7
plt.rcParams['legend.fontsize'] = 7
plt.rcParams['axes.titlesize'] = 7
!apt-get -qq install fonts-noto-cjk

import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DejaVu Sans'

In [None]:
datadir = Path('/content/drive/My Drive/Colab Notebooks/').expanduser() # For saving plots etc.

In [None]:
#%%
# =============================================================================
# Parameters
# =============================================================================
mm.set_plot()
N = 500 #number of neurons
hidden_size = N #same as number of neurons
Tau = 60 #membrane time constant (ms)
dta = 1 # integration timestep (ms)
alpha = dta/Tau #this ratio is an important parameter for pytorch

Tmax = 1000 #length of trial (ms)
ts = np.arange(0, Tmax, dta) #time axis
Nt = len(ts) #number of timepoints

trials = 100 #number of trials

s_Inp = 0.1 #0.1 #scale of input weights

seq_length = 5 # number of patterns in the sequence
n_stim = 20

input_size = seq_length # n input per sequence
output_size = N # num cells

# Training parameters

lr = 0.01 # 0.01 #learning rate
n_ep = 100 #number of training epochs

# Initial connectivity
g0 = 0.8 #initial recurrent gain

g_post = 2 #gain of responses after adaptation

num_seq = 32
num_inputs = 18

In [None]:
#%%
# =============================================================================
#   Generate connectivity, input and output patterns
# =============================================================================

## create the 50 input patterns, there are 500 cells, each pattern is going to have 30 cells in it so there will be overlap
wo_init = np.identity(hidden_size) # identity matrix of cells
wi_init = np.zeros((num_inputs, hidden_size))

n_overlap=0

cell0 = 0
celln = n_stim
for inp in range(num_inputs):
    wi_init[inp, cell0:celln] = np.sort(np.random.exponential(scale=0.05, size=n_stim))[::-1]
    # print(inp, cell0, celln)
    cell0 = celln-n_overlap
    celln = cell0+n_stim

# Recurrent weights
wrec_init = g0*(1/np.sqrt(N))*np.random.randn(N,N)

In [None]:
def create_sequences(base_array, num_inputs, window_size=5):
     # Calculate how many complete sequences we can make
    num_sequences = num_inputs - window_size + 1

    # Create list of sequences using array slicing
    sequences = [base_array[i:i+window_size] for i in range(num_sequences)]

    return sequences

In [None]:
# Sequence Training Protocol
natural= np.arange(0, num_inputs)
seqs = create_sequences(natural, num_inputs)
natural_seq = len(seqs)

for i in range(num_seq-natural_seq):
    seqs.append(np.random.choice(np.arange(0,num_inputs), seq_length, replace=False))

In [None]:
amp_mask = np.concatenate(([2]*7,[0.5]*13)) #np.random.choice((0,3), 30)
wi_init_masks = np.zeros((num_inputs, hidden_size))

for pat in range(num_inputs):
    wi_init_masks[pat,np.where(wi_init[pat]>0)[0]] = wi_init[pat,np.where(wi_init[pat]>0)[0]]*amp_mask

In [None]:
#%%
# =============================================================================
#  Generate temporal profile of input and output
# =============================================================================
input_train = np.zeros((num_seq, Nt, num_inputs)) #shape of inputs: trials x time x number of input patterns
output_train = np.zeros((num_seq, Nt, output_size)) #shape of outputs
mask_train = np.ones_like(output_train) # this selects which time points are included in the loss


I_length = 120 #input duration (100)
I_length_int = int(I_length/dta) #input duration in timesteps
min_sil = 50 # minimum silence before/after each input
max_sil = 400 #maximum silence before/after each input

tSt = []
iT0s = []
for tr in range(num_seq):
    pat_list = seqs[tr]
    # Input squares are preceded and followed by a silent window (from 50 to 400 ms)
    T0 = 200 # np.random.randint(min_sil, max_sil)#first input timepoint
    tSt.append(T0)
    iT0 = int(T0/dta) #first input time index
    if tr < natural_seq:
        for pat in pat_list:
            iT0s.append(iT0)
            input_train[tr,iT0:(iT0+I_length_int),pat] = 1.
            iT0 = iT0+I_length_int

    else:
        for pat in pat_list:
            input_train[tr,iT0:(iT0+I_length_int),pat] = 0
            iT0 = iT0+I_length_int


In [None]:
input_target = np.zeros((num_seq, Nt, num_inputs)) #shape of inputs: trials x time x number of input patterns

I_length = 120 #input duration (100)
I_length_int = int(I_length/dta) #input duration in timesteps
min_sil = 50 # minimum silence before/after each input
max_sil = 400 #maximum silence before/after each input

tSt = []
for tr in range(num_seq):
    pat_list = seqs[tr]
    # Input squares are preceded and followed by a silent window (from 50 to 400 ms)
    T0 = 200 # np.random.randint(min_sil, max_sil)#first input timepoint
    tSt.append(T0)
    iT0 = int(T0/dta) #iT0s[0] #first input time index
    if tr < natural_seq:
        for pat in pat_list:
            input_target[tr,iT0:(iT0+I_length_int),pat] = 1
            # iT0 = iT0s[pat+1]
            iT0 = iT0+I_length_int
    else:
        for pat in pat_list:
            input_target[tr,iT0:(iT0+I_length_int),pat] = 0
            iT0 = iT0+I_length_int
        # print(iT0, Nt)

In [None]:
output_train[:,0,:] = np.matmul(np.matmul(input_target[:,0,:], wi_init_masks), wo_init)
for i in range(Nt-1):
    output_train[:,i+1,:] = np.matmul(np.matmul(input_target[:,i,:], wi_init_masks), wo_init)

In [None]:
#%%
# =============================================================================
#   Convert numpy variables into PyTorch tensors
# =============================================================================
dtype = torch.FloatTensor
Output_train = torch.from_numpy(output_train).type(dtype).to('cuda')

Input_train = torch.from_numpy(input_train).type(dtype).to('cuda')
Mask_train = torch.from_numpy(mask_train).type(dtype).to('cuda')

Wi_init = torch.from_numpy(wi_init).type(dtype).to('cuda')
Wo_init = torch.from_numpy(wo_init).type(dtype).to('cuda')

Wrec_init = torch.from_numpy(wrec_init).type(dtype).to('cuda')

In [None]:
#%%
# =============================================================================
#   Initialize and train networks
# =============================================================================
num_epochs = 500
#Initialize naive network
Net_temp = mm.RNN(num_inputs, hidden_size, output_size, Wi_init, Wo_init, Wrec_init, alpha=alpha) #initialize
​
​
loss_temp, wrec_temp = mm.train(Net_temp, Input_train, Output_train, Mask_train, n_epochs=num_epochs, plot_learning_curve=True, plot_gradient=True,
                              lr=lr, clip_gradient = 2.,  cuda=True, save_loss=True, save_params=True, adam=True) # added the checkpoint_dir and checkpoint_frequency
​
​

In [None]:
# Calculate readout for both networks
Output = Net_temp.forward(Input_train)
output = Output.cpu().detach().numpy()

## First Panel

In [None]:
sub_inps = [[1,2,3,4,5], [13,5,16,14,2]] # Ex Trained and Untrained sequences

In [None]:
stim_cells_amp = np.concatenate((np.where(wi_init[1]>0)[0], np.where(wi_init[2]>0)[0], np.where(wi_init[3]>0)[0], np.where(wi_init[4]>0)[0], np.where(wi_init[5]>0)[0]), axis=0)
stim_cells_supp = np.concatenate((np.where(wi_init[13]>0)[0], np.where(wi_init[5]>0)[0], np.where(wi_init[16]>0)[0], np.where(wi_init[14]>0)[0], np.where(wi_init[2]>0)[0]), axis=0)

In [None]:
## Make the test input
test_input = np.zeros((len(sub_inps), Nt, num_inputs))
for tr in range(len(sub_inps)):
    T0 = 200
    iT0 = int(T0/dta)
    for pat in sub_inps[tr]:
        test_input[tr,iT0:(iT0+I_length_int),pat] = 1.
        iT0 = iT0+I_length_int

In [None]:
Test_input = torch.from_numpy(test_input).type(dtype).to('cuda')

Output_test = Net_temp.forward(Test_input)
output_test = Output_test.cpu().detach().numpy()

In [None]:
# =============================================================================
#   Plot results
# =============================================================================

# Plot input stimulus structure
iTr = 1 # sample index
#fig = plt.figure()
fig = plt.figure(figsize=[2, 1.5])
ax1 = plt.gca()
for axis in ['bottom','left']:
    ax1.spines[axis].set_linewidth(0.25)
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)

x1 = np.linspace(0, 1000, num=1000)

plt.plot(x1[:], (output_test[0,:,stim_cells_amp]).T.mean(axis=1)[:], c='red', alpha=0.5, lw=0.5)
plt.plot(x1[:], (output_test[1,:,stim_cells_supp]).T.mean(axis=1)[:], c='blue', alpha=0.5, lw=0.5)


plt.xlabel('time (ms)', fontsize=7)
plt.ylabel('average rate', fontsize=7)
plt.xticks(fontsize=7)
plt.yticks(fontsize=7)

plt.show()

## Second Panel

In [None]:
sub_inps = [[5,6], [2,6]] # Matched vs Unmatched context

In [None]:
## Make the test input
test_input = np.zeros((len(sub_inps), Nt, num_inputs))
for tr in range(len(sub_inps)):
    T0 = 200
    iT0 = int(T0/dta)
    for pat in sub_inps[tr]:
        test_input[tr,iT0:(iT0+I_length_int),pat] = 1.
        iT0 = iT0+I_length_int

In [None]:
Test_input = torch.from_numpy(test_input).type(dtype).to('cuda')

Output_test = Net_temp.forward(Test_input)
output_test = Output_test.cpu().detach().numpy()

In [None]:
# =============================================================================
#   Plot results
# =============================================================================

# Plot input stimulus structure
iTr = 1 # sample index
#fig = plt.figure()
fig = plt.figure(figsize=[2, 1.5])
ax1 = plt.gca()
for axis in ['bottom','left']:
    ax1.spines[axis].set_linewidth(0.25)
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)

x1 = np.linspace(0, 1000, num=1000)


Colors = ['blue', 'green','yellow','green', 'blue', 'indigo']
for i in range(output_test.shape[0]):
  plt.plot(x1[:], (output_test[i,:,np.where(wi_init[6]>0)[0]]).T.mean(axis=1), color=Colors[i],lw=0.5)

plt.xlabel('time (ms)', fontsize=7)
plt.ylabel('average rate', fontsize=7)
plt.xticks(fontsize=7)
plt.yticks(fontsize=7)
# plt.savefig(datadir/f'matchedvunmatched_control.pdf')
plt.show()