In [1]:
%reset -f
%matplotlib inline
import pickle
import scipy as sc
import scipy.io
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.framework import ops
import statsmodels.formula.api as smf
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D 
from sklearn.metrics import euclidean_distances

Instructions for updating:
non-resource variables are not supported in the long term


In [2]:
plt.rcParams['figure.figsize'] = [20, 5]

In [3]:
x0_val, W0_in_val, W0_rec_val, W0_out_val, C_da_val, b_rec_val = pickle.load( 
    open( "./files/RNN_0.p", "rb" ) )

In [4]:
#-----------------------------------------------------------------------------------------
# Probability of reward
#-----------------------------------------------------------------------------------------
# mdprl probability matrix
prob_mdprl                  = np.zeros((3,3,3))
prob_mdprl[:,:,0]           = ([0.92, 0.75, 0.43], [0.50, 0.50, 0.50], [0.57, 0.25, 0.08])
prob_mdprl[:,:,1]           = ([0.16, 0.75, 0.98], [0.50, 0.50, 0.50], [0.02, 0.25, 0.84])
prob_mdprl[:,:,2]           = ([0.92, 0.75, 0.43], [0.50, 0.50, 0.50], [0.57, 0.25, 0.08])

# random probability matrix
prob_rand                   = np.random.uniform(0, 1, size=(3, 3, 3))

# generalizable probability matrix
prob_gen                    = np.zeros((3, 3, 3))
prob_gen[:,:,0]             = 0.9
prob_gen[:,:,1]             = 0.5
prob_gen[:,:,2]             = 0.1

# 0.5 probability matrix
prob_noinf                  = 0.5*np.ones((3, 3, 3))

#-----------------------------------------------------------------------------------------
# trial timing parameters
#-----------------------------------------------------------------------------------------
s              = 1
ms             = 10**-3 
dt             = 20*ms
tauX           = 100*ms
alphaX         = dt/tauX
T              = np.linspace(-0.5*s, 1.5*s, 1+2*int(s/dt))
T_s            = (T>0.0*s) & (T<=1.2*s)    # when stimuli is present on the screen
T_da           = (T>1.0*s) & (T<=1.2*s)    # when dopamine is released
T_ch           = (T>0.7*s) & (T<=1.0*s)    # when choice is read (only used for making the target)
T_sch          = 1.5*(T<0.0*s) + T_ch      # when choice is read (used for training the network)

In [5]:
# ! /usr/bin/env python
"""
Generates input

Notes
-----
Generates indices of active input populations of size 27*Nrep (trials)

"""

def generateinput(N_s, prob_index, T_s, T_da, T_ch):
    #-----------------------------------------------------------------------------------------
    # initialization
    #-----------------------------------------------------------------------------------------
    index_pttrn  = np.zeros((3,3,3))
    index_shp    = np.zeros((3,3,3))
    index_clr    = np.zeros((3,3,3))
    
    index_shppttrn  = np.zeros((3,3,3))
    index_pttrnclr  = np.zeros((3,3,3))
    index_shpclr    = np.zeros((3,3,3))

    filter_s     = T_s.astype(int)
    filter_da    = T_da.astype(int).reshape((-1,1))
    filter_ch    = T_ch.astype(int).reshape((-1,1))

    #-----------------------------------------------------------------------------------------
    # indexing features
    #-----------------------------------------------------------------------------------------
    for d in range(3):
        index_shp[:,:,d]     = np.matrix([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) 
        index_pttrn[:,:,d]   = np.matrix([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 
        index_clr[:,:,d]     = np.matrix([[1, 1, 1], [1, 1, 1], [1, 1, 1]])*d 
        
        index_shppttrn[:,:,d]= index_shp[:,:,d]*3   + index_pttrn[:,:,d]
        index_pttrnclr[:,:,d]= index_pttrn[:,:,d]*3 + index_clr[:,:,d]
        index_shpclr[:,:,d]  = index_shp[:,:,d]*3   + index_clr[:,:,d]

    index_shp       = index_shp.flatten().astype(int)
    index_pttrn     = index_pttrn.flatten().astype(int)
    index_clr       = index_clr.flatten().astype(int)
    index_shppttrn  = index_shppttrn.flatten().astype(int)
    index_pttrnclr  = index_pttrnclr.flatten().astype(int)
    index_shpclr    = index_shpclr.flatten().astype(int)
    prob_index      = prob_index.flatten()

    #-----------------------------------------------------------------------------------------
    # generate population activity
    #-----------------------------------------------------------------------------------------
    index_s      = np.repeat(np.arange(0,27,1),N_s)
    pop_o        = np.zeros((len(T), len(index_s), 27))
    pop_s        = np.zeros((len(T), len(index_s), 63))
    ch_s         = np.zeros((len(T), 1))
    rw_s         = np.zeros((len(T), len(index_s)))
    for n in range(len(index_s)):
        pop_s[:, n, index_shp[index_s[n]]]          = filter_s*1
        pop_s[:, n, 3+index_pttrn[index_s[n]]]      = filter_s*1
        pop_s[:, n, 6+index_clr[index_s[n]]]        = filter_s*1
        
        pop_s[:, n, 9+index_shppttrn[index_s[n]]]   = filter_s*1
        pop_s[:, n, 18+index_pttrnclr[index_s[n]]]  = filter_s*1
        pop_s[:, n, 27+index_shpclr[index_s[n]]]    = filter_s*1
        
        pop_s[:, n, 36+index_s[n]]                  = filter_s*1
        pop_o[:, n, index_s[n]]                     = filter_s*1

    R            = np.random.binomial(1, prob_index[index_s]) 
    ch_s         = filter_ch*prob_index[index_s]
    DA_s         = filter_da*R - filter_da*(1-R)
    prob_s       = prob_index[index_s]
    
    # output
    return DA_s, ch_s, pop_s, pop_o

In [6]:
#-----------------------------------------------------------------------------------------
# Network parameters
#-----------------------------------------------------------------------------------------
N_s            = 10
idxNin         = np.arange(0,63,1) # np.arange(0,9,1), np.arange(9,36,1), np.arange(36,63,1)
Nin            = len(idxNin)
Nrec           = 120
Nout           = 1
batch_size     = 27*N_s
init_stddev    = 0.01

#-----------------------------------------------------------------------------------------
# Inhibitory/Excitatory
#-----------------------------------------------------------------------------------------
pE = 0.8 
Nexc = int(pE*Nrec)
Ninh = Nrec - Nexc
idx = range(Nrec)
EXC = idx[:Nexc]
INH = idx[Nexc:]

F_rec       = np.ones((1,Nrec))
extinh      = np.ones((1,Nrec))
F_rec[:,INH] *= -1
extinh[:,INH] *= 0

F_out       = np.zeros((Nrec,Nout))
F_out[EXC]  = 1

#-----------------------------------------------------------------------------------------
# Dopamine modulation
#-----------------------------------------------------------------------------------------
mda_in0  = np.zeros((Nin, Nrec))
mda_in0[:, idx[:int(Nexc/2)]] = 1
mda_in0[:, idx[-int(Ninh/2):]] = 1
mda_rec0 = np.zeros((Nrec, Nrec))
mda_rec0[:, idx[int(Nexc/4):int(3*Nexc/4)]] = 1
mda_rec0[:, idx[-int(3*Ninh/4):-int(Ninh/4)]] = 1

mda_lsn0 = np.ones((Nrec,Nrec))                      # for lesioning the RNN

#-----------------------------------------------------------------------------------------
# Network noise
#-----------------------------------------------------------------------------------------
var_in         = (0.001**2)
var_rec        = (0.015**2)

#-----------------------------------------------------------------------------------------
# Input noise
#-----------------------------------------------------------------------------------------
var_in         = 2*tauX/dt*var_in
if np.any(var_in > 0):
    noise_in           = np.sqrt(var_in)*np.random.normal(size=(len(T)*batch_size, Nin))
else:
    noise_in           = np.zeros((len(T)*batch_size, Nin))

#-----------------------------------------------------------------------------------------
# Recurrent noise
#-----------------------------------------------------------------------------------------
var_rec = 2/dt*var_rec
if np.any(var_rec > 0):
    noise_rec          = np.sqrt(var_rec)*np.random.normal(size=(len(T)*batch_size, Nrec))
    noise_rectest      = np.sqrt(var_rec)*np.random.normal(size=(len(T)*27, Nrec))
else:
    noise_rec          = np.zeros((len(T)*batch_size, Nrec))
    noise_rectest      = np.zeros((len(T)*27, Nrec))

#---------------------------------------------------------------------------------
# Weight initialization 
#---------------------------------------------------------------------------------
x_init         = tf.constant_initializer(x0_val)
Win_init       = tf.constant_initializer(W0_in_val)
Wrec_init      = tf.constant_initializer(W0_rec_val)
Wout_init      = tf.constant_initializer(W0_out_val)
brec_init      = tf.constant_initializer(b_rec_val)
bout_init      = tf.constant_initializer(0)
C_dainit       = tf.constant_initializer(C_da_val)
mda_ininit     = tf.constant_initializer(mda_in0)
mda_recinit    = tf.constant_initializer(mda_rec0)
mda_lsninit    = tf.constant_initializer(mda_lsn0)

#---------------------------------------------------------------------------------
# Activation functions
#---------------------------------------------------------------------------------
f_hidden       = tf.nn.relu 
f_output       = tf.nn.relu
f_weight       = tf.nn.relu     # to impose positivity on weights

#-----------------------------------------------------------------------------------------
# Optimization parameters
#-----------------------------------------------------------------------------------------
optimizer      = tf.train.AdamOptimizer
lr             = 0.001

#-----------------------------------------------------------------------------------------
# Etc.
#-----------------------------------------------------------------------------------------



In [7]:
def f_relu(x):
    return (abs(x) + x) / 2

In [8]:
#---------------------------------------------------------------------------------
# Network structure
#---------------------------------------------------------------------------------
ops.reset_default_graph()
# len(T)*batch_size
T_schRNN       = tf.placeholder(tf.float32, (None,1))         # to calculate the output loss
T_daRNN        = tf.placeholder(tf.float32, (None,1))         # to implement reward learning
T_sRNN         = tf.placeholder(tf.float32, (None,1))         # to calculate loss of middle layer acitivity
T0_RNN         = tf.placeholder(tf.float32, (None,1))         # to implement reseting of x to x0

u_RNN          = tf.placeholder(tf.float32, (None,Nin))
nrec_RNN       = tf.placeholder(tf.float32, (None,Nrec))
x_RNN          = tf.placeholder(tf.float32, (None,Nrec))
z_RNN          = tf.placeholder(tf.float32, (None,Nout))

W_in           = tf.placeholder(tf.float32, (None,Nin, Nrec))
W_rec          = tf.placeholder(tf.float32, (None,Nrec,Nrec))

#---------------------------------------------------------------------------------
# Initial values
#---------------------------------------------------------------------------------
x0_RNN         = tf.get_variable('x0_RNN', shape=(1,   Nrec), initializer=x_init,    trainable=True)
W0_in          = tf.get_variable('W0_in',  shape=(Nin, Nrec), initializer=Win_init,  trainable=True)
W0_rec         = tf.get_variable('W0_rec', shape=(Nrec,Nrec), initializer=Wrec_init, trainable=True)
W0_out         = tf.get_variable('W0_out', shape=(Nrec,Nout), initializer=Wout_init, trainable=True)         

b_rec          = tf.get_variable('b_rec',  shape=(1,   Nrec), initializer=brec_init, trainable=True)
b_out          = tf.get_variable('b_out',  shape=(1,   Nout), initializer=bout_init, trainable=False)

C_da           = tf.get_variable('C_da',   shape=(6,1), initializer=C_dainit, trainable=True)
mda_in         = tf.get_variable('mda_in', shape=(Nin, Nrec), initializer=mda_ininit,  trainable=False)
mda_rec        = tf.get_variable('mda_rec',shape=(Nrec,Nrec), initializer=mda_recinit, trainable=False)
mda_lsn        = tf.get_variable('mda_lsn',shape=(Nrec,Nrec), initializer=mda_lsninit, trainable=False)

In [9]:
def rnn(x, u):
    x_RNN, W_in, W_rec = tf.split(x,[1, Nin, Nrec],0)
    u_in, n_rec, T0, Tda = u
    W_in       = f_weight(W_in)
    W_rec      = f_weight(W_rec)
    W_rec      = tf.multiply(W_rec, mda_lsn)
    W_rec      = tf.multiply(W_rec, tf.ones(Nrec)-tf.eye(Nrec))
    
    r_RNN      = f_hidden(x_RNN)
    rF_RNN     = tf.multiply(r_RNN, F_rec)
    
    x_RNN      = (1-T0)*((1 - alphaX)*x_RNN                                # Leak
                          + alphaX*(tf.matmul(rF_RNN, W_rec)               # Recurrent
                          + b_rec                                          # Bias
                          + tf.matmul(tf.reshape(u_in, (1, -1)), W_in)     # Input
                          + tf.reshape(n_rec, (1,-1)))                     # Recurrent noise
                          ) + T0*x0_RNN
    
    deltaW_in  = dt*( (Tda*C_da[0])*(tf.matmul(tf.reshape(u_in, (-1, 1)), tf.ones((1,Nrec))))
                    + (Tda*C_da[1])*(tf.matmul(tf.ones((Nin,1)), tf.reshape(r_RNN, (1, -1))))
                    + (Tda*C_da[2])*(tf.matmul(tf.reshape(u_in, (-1, 1)), tf.reshape(r_RNN, (1, -1)))))
    W_in       = W_in + tf.multiply(deltaW_in, mda_in)
    
    deltaW_rec = dt*( (Tda*C_da[3])*(tf.matmul(tf.reshape(r_RNN, (-1, 1)), tf.ones((1,Nrec))))
                    + (Tda*C_da[4])*(tf.matmul(tf.ones((Nrec,1)), tf.reshape(r_RNN, (1, -1))))
                    + (Tda*C_da[5])*(tf.matmul(tf.reshape(r_RNN, (-1, 1)), tf.reshape(r_RNN, (1, -1)))))
    W_rec      = W_rec + tf.multiply(deltaW_rec, mda_rec)
    
    W_in       = f_weight(W_in)
    W_rec      = f_weight(W_rec)
    W_rec      = tf.multiply(W_rec, tf.ones(Nrec)-tf.eye(Nrec))
    return tf.concat([x_RNN, W_in, W_rec],0)

x              = tf.scan(rnn, (u_RNN, nrec_RNN, T0_RNN, T_daRNN), 
                                    initializer=tf.concat([x0_RNN, W0_in, W0_rec], 0))
x_RNN, W_in, W_rec = tf.split(x,[1, Nin, Nrec],1)
x_RNN          = tf.reshape(x_RNN , [-1, Nrec])
r              = f_hidden(x_RNN)
WF_out         = tf.multiply(f_weight(W0_out), F_out)
z              = f_output(tf.matmul(r, WF_out) + b_out)
lossZ          = tf.reduce_mean((z_RNN - tf.multiply(z,T_schRNN))**2)
lossR          = tf.reduce_mean(r**2)
loss           = (lossZ) + 0.001*lossR

train_op       = optimizer(learning_rate=lr).minimize(loss)

In [10]:
T_schbatch     = np.tile(T_sch,batch_size).reshape((-1,1))
T_chbatch      = np.tile(T_ch,batch_size).reshape((-1,1))
T_dabatch      = np.tile(T_da,batch_size).reshape((-1,1))

T0             = (T==-0.5*s).astype(int)
T0_batch       = np.tile(T0,batch_size).reshape((-1,1))

In [11]:
nrep            = 100
nbouts          = N_s
sdshffle = np.array(range(27*N_s))
sdshffle = sdshffle.reshape((27,N_s)).T.flatten()

bouts = np.linspace(0, N_s*len(T)*27, nbouts+1).astype(int)
r_vals = np.empty((N_s*len(T)*27,Nrec,nrep))
for rep in range(nrep):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        prob_index     = prob_mdprl
        DA_s, ch_s, pop_s, pop_o    = generateinput(N_s, prob_index, T_s, T_da, T_ch)
        pop_s          = pop_s[:,:,idxNin]
        pop_oRNN       = np.concatenate([pop_o[:,t,:] for t in sdshffle], axis=0) 
        pop_sRNN       = np.concatenate([pop_s[:,t,:] for t in sdshffle], axis=0)
        DA_sRNN        = np.concatenate([DA_s[:,t] for t in sdshffle], axis=0).reshape((-1,1))
        ch_sRNN        = np.concatenate([ch_s[:,t] for t in sdshffle], axis=0).reshape((-1,1))

        z_val, r_val, x_val, W_in_val, W_rec_val, W0_out_val, C_da_val = sess.run([
                z, r, x_RNN, W_in, W_rec, W0_out, C_da], 
                                            feed_dict={ u_RNN: pop_sRNN, nrec_RNN: noise_rec,  
                                                        z_RNN: ch_sRNN,    T0_RNN: T0_batch,  
                                                     T_schRNN: T_schbatch,T_daRNN: DA_sRNN})
        
        r_vals[:,:,rep] = r_val.reshape((-1,Nrec))

In [12]:
sc.io.savemat('./files/r_vals_PCAtime.mat', {'r_vals': r_vals})
sc.io.savemat('./files/r_vals_PCA.mat', {'r_vals': r_vals[np.where(T_chbatch)[0],:,:]})