In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pandas as pd
import pingouin as pg
from sklearn import decomposition, manifold
from matplotlib import cm
import scipy.io
from tensorflow.keras.layers import TimeDistributed,Dense,LSTM, SimpleRNN
from tensorflow.keras.models import Sequential
import os
from datetime import datetime
import pytz

from tensorflow.python.ops.rnn_cell_impl import RNNCell
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from scipy.stats import ortho_group

In [None]:
# !conda install matplotlib -y
# !conda install pandas -y
# !conda install scikit-learn -y
# !conda install pingouin -y

In [None]:
np.__version__

In [None]:
tf.__version__

In [None]:
def add_x_loc(x_loc, n_eachring = 32):
    """Input activity given location."""
    pref  = np.arange(0,2*np.pi,2*np.pi/n_eachring)
    dist = get_dist(x_loc-pref)  # periodic boundary
    dist /= np.pi/8
    return 0.8*np.exp(-dist**2/2)

def add_x_noise(x, seed=0,sigma_x = 0.01):
    rng = np.random.RandomState(seed)
    x += rng.randn(len(x1))*sigma_x
    return x

def get_dist(original_dist):
    '''Get the distance in periodic boundary conditions'''
    return np.minimum(abs(original_dist),2*np.pi-abs(original_dist))

n_loc = 128
n_stim_loc1, n_stim_loc2, repeat = stim_loc_shape = n_loc, n_loc, 1
stim_loc_size = np.prod(stim_loc_shape)
ind_stim_loc1, ind_stim_loc2, ind_repeat = np.unravel_index(range(stim_loc_size),stim_loc_shape)
stim1_locs = 2*np.pi*ind_stim_loc1/n_stim_loc1
stim2_locs = 2*np.pi*ind_stim_loc2/n_stim_loc2

In [None]:
X_n = []
for i in range(stim_loc_size):
    for seed in range(10):
        x1 = add_x_loc(stim1_locs[i])
        x2 = add_x_loc(stim2_locs[i])
        x1 = add_x_noise(x1, seed=seed)
        x2 = add_x_noise(x2, seed=seed)
        X_n.append(np.append(x1,x2))
X_n = np.array(X_n)

X = []
for i in range(stim_loc_size):
    x1 = add_x_loc(stim1_locs[i])
    x2 = add_x_loc(stim2_locs[i])
    X.append(np.append(x1,x2))
X = np.array(X)



seed = 0
rng1 = np.random.RandomState(seed)
rng1.shuffle(stim1_locs)
rng1.shuffle(stim2_locs)

X_train = []
for i in range(stim_loc_size):
    x1 = add_x_loc(stim1_locs[i])
    x2 = add_x_loc(stim2_locs[i])
    X_train.append(np.append(x1,x2))
X_train = np.array(X_train)

rng1.shuffle(stim1_locs)
rng1.shuffle(stim2_locs)

X_test = []
for i in range(stim_loc_size):
    x1 = add_x_loc(stim1_locs[i])
    x2 = add_x_loc(stim2_locs[i])
    X_test.append(np.append(x1,x2))
X_test = np.array(X_test)

In [None]:
X1 = X_n[:,:32]
X2 = X_n[:,32:]
Zeros = np.zeros(X1.shape)
inputs_n = np.array([X1,Zeros,X2,Zeros,Zeros,Zeros])
outputs_n = np.array([Zeros,Zeros,Zeros,Zeros,X1,X2])
inputs_n = inputs_n.transpose((1, 0, 2))
outputs_n = outputs_n.transpose((1, 0, 2))

X1 = X[:,:32]
X2 = X[:,32:]
Zeros = np.zeros(X1.shape)
inputs = np.array([X1,Zeros,X2,Zeros,Zeros,Zeros])
outputs = np.array([Zeros,Zeros,Zeros,Zeros,X1,X2])
inputs = inputs.transpose((1, 0, 2))
outputs = outputs.transpose((1, 0, 2))

X1 = X_train[:,:32]
X2 = X_train[:,32:]
Zeros = np.zeros(X1.shape)
inputs_train = np.array([X1,Zeros,X2,Zeros,Zeros,Zeros])
outputs_train = np.array([Zeros,Zeros,Zeros,Zeros,X1,X2])
inputs_train = inputs_train.transpose((1, 0, 2))
outputs_train = outputs_train.transpose((1, 0, 2))

X1 = X_test[:,:32]
X2 = X_test[:,32:]
Zeros = np.zeros(X1.shape)
inputs_test = np.array([X1,Zeros,X2,Zeros,Zeros,Zeros])
outputs_test = np.array([Zeros,Zeros,Zeros,Zeros,X1,X2])
inputs_test = inputs_test.transpose((1, 0, 2))
outputs_test = outputs_test.transpose((1, 0, 2))

In [None]:
palette1 = cm.get_cmap('autumn',n_loc+15)
palette1 = [palette1(i)[:3] for i in range(n_loc)]
color1=np.array(palette1)[ind_stim_loc1]

palette2 = cm.get_cmap('summer',n_loc+15)
palette2 = [palette2(i)[:3] for i in range(n_loc)]
color2=np.array(palette2)[ind_stim_loc2]


    
def fit_isomap(data_to_use, n_neighbors = 15, target_dim = 3):
    iso_instance = manifold.Isomap(n_neighbors = n_neighbors, n_components = target_dim)
    proj = iso_instance.fit_transform(data_to_use)
    return proj

def set_axes_equal(ax):
    '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
    cubes as cubes, etc..  This is one possible solution to Matplotlib's
    ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

    Input
      ax: a matplotlib axis, e.g., as output from plt.gca().
    '''

    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()

    x_range = abs(x_limits[1] - x_limits[0])
    x_middle = np.mean(x_limits)
    y_range = abs(y_limits[1] - y_limits[0])
    y_middle = np.mean(y_limits)
    z_range = abs(z_limits[1] - z_limits[0])
    z_middle = np.mean(z_limits)

    # The plot bounding box is a sphere in the sense of the infinity
    # norm, hence I call half the max range the plot radius.
    plot_radius = 0.5*max([x_range, y_range, z_range])

    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

def plot_isomap(data_plot, color, annotate=False):
    fig = plt.figure(figsize=(16,16),dpi=200)
    ax = fig.add_subplot(111, projection='3d')

    if annotate:
        ax.scatter(data_plot[:,0], data_plot[:,1], data_plot[:,2], 
            s=5, alpha=1, edgecolor='face',c=color)
        label = 0
        for xyz in zip(data_plot[:,0], data_plot[:,1], data_plot[:,2]):
            x, y, z = xyz
            ax.text(x, y, z, '%s' % (label), size=5, zorder=1, color='k')
            label += 1
    else:
        ax.scatter(data_plot[:,0], data_plot[:,1], data_plot[:,2], 
            s=20, alpha=1, edgecolor='face',c=color)
    ax.grid(False)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    return fig, ax

def plot_single_distractor_or_target(palette, xlim, ylim, zlim, label_plot, proj_plot, annotate=False, filename=''):

    color=np.array(palette)[label_plot]

    # h0_longest,h1_longest,h2_longest = run_ripser(proj_plot,figure_dir+'ripser'+figure_subscript)
    fig, ax = plot_isomap(data_plot=proj_plot, color=color, annotate=annotate)
    plt.setp(ax, xlim=xlim, ylim=ylim, zlim=zlim)
    fig.tight_layout()
    if filename is not None:
        fig.savefig(filename)
    plt.show()
    plt.close(fig) 


def plot_all_isomap_figures(proj,filename=''):
    fig,ax = plot_isomap(data_plot=proj, color=color1)
    set_axes_equal(ax)
    fig.tight_layout()
    if filename is not None:
        fig.savefig('target_isomap_'+filename)
    plt.show()
    plt.close(fig)

    fig,ax = plot_isomap(data_plot=proj, color=color2)
    set_axes_equal(ax)
    fig.tight_layout()
    if filename is not None:
        fig.savefig('distractor_isomap_'+filename)
    plt.show()
    plt.close(fig)

    xlim=fig.gca().get_xlim()
    ylim=fig.gca().get_ylim()
    zlim=fig.gca().get_zlim()

    num=0
    indices = ind_stim_loc1==num
    label_plot = ind_stim_loc2[indices]
    proj_plot = proj[indices,:]
    plot_single_distractor_or_target(palette = palette2, xlim = xlim, ylim = ylim, zlim = zlim, label_plot=label_plot, proj_plot = proj_plot, filename = 'single_target_'+filename)

    num=0
    indices = ind_stim_loc2==num
    label_plot = ind_stim_loc1[indices]
    proj_plot = proj[indices,:]
    plot_single_distractor_or_target(palette = palette1, xlim = xlim, ylim = ylim, zlim = zlim, label_plot=label_plot, proj_plot = proj_plot, filename = 'single_distractor_'+filename)

In [None]:
def popvec(y):
    """Population vector read out.

    Assuming the last dimension is the dimension to be collapsed

    Args:
        y: population output on a ring network. Numpy array (Batch, Units)

    Returns:
        Readout locations: Numpy array (Batch,)
    """
    pref = np.arange(0, 2*np.pi, 2*np.pi/y.shape[-1])  # preferences
    temp_sum = y.sum(axis=-1)
    temp_cos = np.sum(y*np.cos(pref), axis=-1)/temp_sum
    temp_sin = np.sum(y*np.sin(pref), axis=-1)/temp_sum
    loc = np.arctan2(temp_sin, temp_cos)
    return np.mod(loc, 2*np.pi)

def get_model_performance(model):
    y_hat = model.predict(inputs)

    y_hat_loc1 = popvec(y_hat[:,-2,:])
    outputs_loc1 = popvec(outputs[:,-2,:])
    original_dist1 = outputs_loc1 - y_hat_loc1
    dist1 = np.minimum(abs(original_dist1), 2*np.pi-abs(original_dist1))
    corr_loc1 = dist1 < 2*np.pi/128

    y_hat_loc2 = popvec(y_hat[:,-1,:])
    outputs_loc2 = popvec(outputs[:,-1,:])
    original_dist2 = outputs_loc2 - y_hat_loc2
    dist2 = np.minimum(abs(original_dist2), 2*np.pi-abs(original_dist2))
    corr_loc2 = dist2 < 2*np.pi/128

    return np.sum(corr_loc1*0.5+corr_loc2*0.5)/len(inputs)


def plot_loss_over_epochs(history, foldername=''):
    plt.figure(figsize=(10,8))
    plt.plot(history.history['loss'],label="Training set loss")
    plt.plot(history.history['val_loss'],label="Validation set loss")
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.legend()
    plt.savefig('%sloss_over_epochs.png'%foldername)

def get_anova_stats(foldername, rnn_layer):
    hidden = rnn_layer(inputs_n)

    delay1_hidden = hidden[:,1,:]
    delay2_hidden = hidden[:,3,:]

    neuron = 0
    df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'delay1_activity':delay1_hidden[:,neuron],'delay2_activity':delay2_hidden[:,neuron]})
    aov = pg.anova(dv='delay1_activity', between=['first_stim', 'second_stim'], data=df,
              detailed=True)
    print(aov)
    aov.to_csv(foldername+"delay1_anova.csv")
    aov = pg.anova(dv='delay2_activity', between=['first_stim', 'second_stim'], data=df,
              detailed=True)
    print(aov)
    aov.to_csv(foldername+"delay2_anova.csv")

In [None]:
def plot_tuning_curves(Z,foldername='',filename=''):    
    fig,ax = plt.subplots(min(Z.shape[1],10),1,figsize=(5,min(Z.shape[1],10)*3))
    for i in range(min(Z.shape[1],10)):
        if Z.shape[1]<10:
            neuron = i
        else:
            neuron = i*int(Z.shape[1]/10)
        for loc1 in [j*10 for j in range(int(128/10))]:
            df = pd.DataFrame({'first_stim':ind_stim_loc1[ind_stim_loc1==loc1],'second_stim':ind_stim_loc2[ind_stim_loc1==loc1],'activity':Z[:,neuron][ind_stim_loc1==loc1]})
            x = df['second_stim']
            y = df['activity']
            ax[i].scatter(x,y,s=1, color=palette1[loc1], label=loc1)
        ax[i].set_ylabel('%dth neuron'%neuron, fontsize=13)
    ax[0].legend(loc='upper left', bbox_to_anchor= (1.05, 1.05), title='Stim 1')
    ax[0].set_xlabel('Stim 2', fontsize=13)
    ax[0].xaxis.set_label_position('top') 
    if filename != '':
        fig.tight_layout()
        fig.savefig(foldername+filename+'_stim2')
    plt.show()
    plt.close()
        
    fig,ax = plt.subplots(min(Z.shape[1],10),1,figsize=(5,min(Z.shape[1],10)*3))
    for i in range(min(Z.shape[1],10)):
        if Z.shape[1]<10:
            neuron = i
        else:
            neuron = i*int(Z.shape[1]/10)
        for loc2 in [j*10 for j in range(int(128/10))]:
            df = pd.DataFrame({'first_stim':ind_stim_loc1[ind_stim_loc2==loc2],'second_stim':ind_stim_loc2[ind_stim_loc2==loc2],'activity':Z[:,neuron][ind_stim_loc2==loc2]})
            x = df['first_stim']
            y = df['activity']
            ax[i].scatter(x,y,s=1, color=palette2[loc2], label=loc2)
        ax[i].set_ylabel('%dth neuron'%neuron, fontsize=13)
    ax[0].legend(loc='upper left', bbox_to_anchor= (1.05, 1.05), title='Stim 2')
    ax[0].set_xlabel('Stim 1', fontsize=13)
    ax[0].xaxis.set_label_position('top') 
    if filename != '':
        fig.tight_layout()
        fig.savefig(foldername+filename+'_stim1')
    plt.show()
    plt.close()

In [None]:
# ## Create folders to save model and training info
# model_folder = 'simple_rnn_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
# checkpoint_path = model_folder+"/cp.ckpt"
# saved_model_folder = model_folder+'/saved_model'
# loss_curve_folder = model_folder+'/'
# anova_folder = model_folder+'/'

# os.makedirs(model_folder)

n_hidden = 8

### Create the model
model = Sequential()
simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True)
model.add(simplernn)
model.add(TimeDistributed(Dense(32)))
# model.summary()

### Check performance and selectivity before training
print("n_hidden %s untrained performance: %s"%(n_hidden, get_model_performance(model)))

hidden = simplernn(inputs)
delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,'untrained_delay1_hidden')
plot_tuning_curves(delay2_hidden,'untrained_delay2_hidden')

get_anova_stats(anova_folder+'untrained_', simplernn)

In [None]:
### Create callbacks to saves the model's weights and earlystopping
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only=True,
                                                verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

### Training the model
history = model.fit(inputs_train, outputs_train, validation_data=(inputs_test,outputs_test), batch_size=64, epochs=500, verbose=2,
          callbacks=[cp_callback, es_callback])
plot_loss_over_epochs(history, foldername=loss_curve_folder)
model.save(saved_model_folder)

In [None]:
model = tf.keras.models.load_model(saved_model_folder)

### Check performance and selectivity after training
print("n_hidden %s trained performance: %s"%(n_hidden, get_model_performance(model)))

simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True, weights=model.layers[0].get_weights())
hidden = simplernn(inputs)

delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,'trained_delay1_hidden')
plot_tuning_curves(delay2_hidden,'trained_delay2_hidden')

get_anova_stats(anova_folder, simplernn)

In [None]:
class MinimalRNNCell(tf.keras.layers.Layer):

    def __init__(self, units, **kwargs):
        self.units = units
        self.state_size = units
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = tf.keras.backend.dot(inputs, self.kernel)
        output = h + tf.keras.backend.dot(prev_output, self.recurrent_kernel)
        return output, [output]


In [None]:
## Create folders to save model and training info
n_hidden = 8

model_folder = 'minimal_rnn_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
checkpoint_path = model_folder+"/cp.ckpt"
saved_model_folder = model_folder+'/saved_model'
loss_curve_folder = model_folder+'/'
anova_folder = model_folder+'/'

os.makedirs(model_folder)

cell = MinimalRNNCell(n_hidden)
rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                return_sequences=True)

In [None]:
hidden = rnn_layer(inputs)
delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,'untrained_delay1_hidden')
plot_tuning_curves(delay2_hidden,'untrained_delay2_hidden')

get_anova_stats(anova_folder+'untrained_', rnn_layer)

model = Sequential()
model.add(rnn_layer)
model.add(TimeDistributed(Dense(32)))
model.summary()

### Check performance and selectivity before training
print("n_hidden %s untrained performance: %s"%(n_hidden, get_model_performance(model)))


In [None]:
### Create callbacks to saves the model's weights and earlystopping
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only=True,
                                                verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

### Training the model
history = model.fit(inputs_train, outputs_train, validation_data=(inputs_test,outputs_test), batch_size=64, epochs=500, verbose=2,
          callbacks=[cp_callback, es_callback])
plot_loss_over_epochs(history, foldername=loss_curve_folder)
model.save(saved_model_folder)

In [None]:
# model = tf.keras.models.load_model(saved_model_folder)

### Check performance and selectivity after training
print("n_hidden %s trained performance: %s"%(n_hidden, get_model_performance(model)))

# simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True, weights=model.layers[0].get_weights())
rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                return_sequences=True, weights=model.layers[0].get_weights())
hidden = rnn_layer(inputs)

delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,'trained_delay1_hidden')
plot_tuning_curves(delay2_hidden,'trained_delay2_hidden')

get_anova_stats(anova_folder, rnn_layer)

In [None]:
## Create folders to save model and training info
n_hidden = 8

model_folder = 'leaky2_rnn_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
checkpoint_path = model_folder+"/cp.ckpt"
saved_model_folder = model_folder+'/saved_model'
loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

os.makedirs(model_folder)

cell = LeakyRNNCell2(n_hidden)
rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                return_sequences=True)

In [None]:
hidden = rnn_layer(inputs)
delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,tuning_curve_folder,'untrained_delay1_hidden')
plot_tuning_curves(delay2_hidden,tuning_curve_folder,'untrained_delay2_hidden')

get_anova_stats(anova_folder+'untrained_', rnn_layer)

model = Sequential()
model.add(rnn_layer)
model.add(TimeDistributed(Dense(32)))
model.summary()

### Check performance and selectivity before training
print("n_hidden %s untrained performance: %s"%(n_hidden, get_model_performance(model)))


In [None]:
### Create callbacks to saves the model's weights and earlystopping
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only=True,
                                                verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

### Training the model
history = model.fit(inputs_train, outputs_train, validation_data=(inputs_test,outputs_test), batch_size=64, epochs=500, verbose=2,
          callbacks=[cp_callback, es_callback])
plot_loss_over_epochs(history, foldername=loss_curve_folder)
model.save(saved_model_folder)

In [None]:
# model = tf.keras.models.load_model(saved_model_folder)

### Check performance and selectivity after training
print("n_hidden %s trained performance: %s"%(n_hidden, get_model_performance(model)))

# simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True, weights=model.layers[0].get_weights())
rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                return_sequences=True, weights=model.layers[0].get_weights())
hidden = rnn_layer(inputs)

delay1_hidden = hidden[:,1,:]
delay2_hidden = hidden[:,3,:]
plot_tuning_curves(delay1_hidden,tuning_curve_folder,'trained_delay1_hidden')
plot_tuning_curves(delay2_hidden,tuning_curve_folder,'trained_delay2_hidden')

get_anova_stats(anova_folder, rnn_layer)

In [None]:
class LeakyRNNCell2(tf.keras.layers.Layer):

    def __init__(self, units, **kwargs):
        alpha = 0.2
        sigma_rec = 0
        self.units = units
        self.state_size = units
        self._activation = tf.nn.relu
        self._w_in_start = 1.0
        self._w_rec_start = 0.5
        self.rng = np.random.RandomState()
        self._alpha = alpha
        self._sigma = np.sqrt(2 / alpha) * sigma_rec
        super(LeakyRNNCell2, self).__init__(**kwargs)

    def build(self, input_shape):
        
        w_in0 = (self.rng.randn(input_shape[-1], self.units) /
                 np.sqrt(input_shape[-1]) * self._w_in_start)
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units), dtype=tf.float32,
                                      initializer=tf.constant_initializer(w_in0),
                                      name='kernel')
        w_rec0 = self._w_rec_start*ortho_group.rvs(dim=self.units, random_state=self.rng)
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units), dtype=tf.float32,
                                                initializer=tf.constant_initializer(w_rec0),
                                                name='recurrent_kernel')
        matrix0 = np.concatenate((w_in0, w_rec0), axis=0)
    
#         self.kernel = self.add_weight(
#                 name='kernel',
#                 shape=[input_shape[-1] + self.units, self.units], 
#                 dtype=tf.float32,
#                 initializer=tf.constant_initializer(matrix0))
        
        self._bias = self.add_weight(
                name='bias',
                shape=[self.units],
                dtype=tf.float32,
                initializer=tf.zeros_initializer())
        
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = tf.keras.backend.dot(inputs, self.kernel)
#         h = tf.keras.backend.dot(array_ops.concat([inputs, prev_output], 1), self.kernel)
        h = h + tf.keras.backend.dot(prev_output, self.recurrent_kernel)
        h = tf.nn.bias_add(h, self._bias)
        noise = tf.random.normal(tf.shape(prev_output), mean=0, stddev=self._sigma)
        h = h + noise
        output = self._activation(h)
        output = (1-self._alpha) * prev_output + self._alpha * output
        return output, [output]

In [None]:
class Model(object):
    """The model."""

    def __init__(self,
                 model_dir,
                 hp=None,
                 sigma_rec=None,
                 dt=None):
        """
        Initializing the model with information from hp

        Args:
            model_dir: string, directory of the model
            hp: a dictionary or None
            sigma_rec: if not None, overwrite the sigma_rec passed by hp
        """

        # Reset tensorflow graphs
        tf.reset_default_graph()  # must be in the beginning

        if hp is None:
            hp = tools.load_hp(model_dir)
            if hp is None:
                raise ValueError(
                    'No hp found for model_dir {:s}'.format(model_dir))

        tf.set_random_seed(hp['seed'])
        self.rng = np.random.RandomState(hp['seed'])

        if sigma_rec is not None:
            print('Overwrite sigma_rec with {:0.3f}'.format(sigma_rec))
            hp['sigma_rec'] = sigma_rec

        if dt is not None:
            print('Overwrite original dt with {:0.1f}'.format(dt))
            hp['dt'] = dt

        hp['alpha'] = 1.0*hp['dt']/hp['tau']

        # Input, target output, and cost mask
        # Shape: [Time, Batch, Num_units]
        if hp['in_type'] != 'normal':
            raise ValueError('Only support in_type ' + hp['in_type'])

        self._build(hp)

        self.model_dir = model_dir
        self.hp = hp

    def _build(self, hp):
        if 'use_separate_input' in hp and hp['use_separate_input']:
            self._build_seperate(hp)
        else:
            self._build_fused(hp)

        self.var_list = tf.trainable_variables()
        self.weight_list = [v for v in self.var_list if is_weight(v)]

        if 'use_separate_input' in hp and hp['use_separate_input']:
            self._set_weights_separate(hp)
        else:
            self._set_weights_fused(hp)

        # Regularization terms
        self.cost_reg = tf.constant(0.)
        if hp['l1_h'] > 0:
            self.cost_reg += tf.reduce_mean(tf.abs(self.h)) * hp['l1_h']
        if hp['l2_h'] > 0:
            self.cost_reg += tf.nn.l2_loss(self.h) * hp['l2_h']

        if hp['l1_weight'] > 0:
            self.cost_reg += hp['l1_weight'] * tf.add_n(
                [tf.reduce_mean(tf.abs(v)) for v in self.weight_list])
        if hp['l2_weight'] > 0:
            self.cost_reg += hp['l2_weight'] * tf.add_n(
                [tf.nn.l2_loss(v) for v in self.weight_list])

        # Create an optimizer.
        if 'optimizer' not in hp or hp['optimizer'] == 'adam':
            self.opt = tf.train.AdamOptimizer(
                learning_rate=hp['learning_rate'])
        elif hp['optimizer'] == 'sgd':
            self.opt = tf.train.GradientDescentOptimizer(
                learning_rate=hp['learning_rate'])
        # Set cost
        self.set_optimizer()

        # Variable saver
        # self.saver = tf.train.Saver(self.var_list)
        self.saver = tf.train.Saver()

    def _build_fused(self, hp):
        n_input = hp['n_input']
        n_rnn = hp['n_rnn']
        n_output = hp['n_output']

        self.x = tf.placeholder("float", [None, None, n_input])
        self.y = tf.placeholder("float", [None, None, n_output])
        if hp['loss_type'] == 'lsq':
            self.c_mask = tf.placeholder("float", [None, n_output])
        else:
            # Mask on time
            self.c_mask = tf.placeholder("float", [None])

        # Activation functions
        if hp['activation'] == 'power':
            f_act = lambda x: tf.square(tf.nn.relu(x))
        elif hp['activation'] == 'retanh':
            f_act = lambda x: tf.tanh(tf.nn.relu(x))
        elif hp['activation'] == 'relu+':
            f_act = lambda x: tf.nn.relu(x + tf.constant(1.))
        else:
            f_act = getattr(tf.nn, hp['activation'])

        # Recurrent activity
        if hp['rnn_type'] == 'LeakyRNN':
            n_in_rnn = self.x.get_shape().as_list()[-1]
            cell = LeakyRNNCell(n_rnn, 
                                n_in_rnn,
                                hp['alpha'],
                                sigma_rec=hp['sigma_rec'],
                                activation=hp['activation'],
                                w_rec_init=hp['w_rec_init'],
                                rng=self.rng)
        elif hp['rnn_type'] == 'LeakyGRU':
            cell = LeakyGRUCell(
                n_rnn, hp['alpha'],
                sigma_rec=hp['sigma_rec'], activation=f_act)
        elif hp['rnn_type'] == 'LSTM':
            cell = tf.contrib.rnn.LSTMCell(n_rnn, activation=f_act)

        elif hp['rnn_type'] == 'GRU':
            cell = tf.contrib.rnn.GRUCell(n_rnn, activation=f_act)
        else:
            raise NotImplementedError("""rnn_type must be one of LeakyRNN,
                    LeakyGRU, EILeakyGRU, LSTM, GRU
                    """)

        # Dynamic rnn with time major
        self.h, states = rnn.dynamic_rnn(
            cell, self.x, dtype=tf.float32, time_major=True)

        # Output
        with tf.variable_scope("output"):
            # Using default initialization `glorot_uniform_initializer`
            w_out = tf.get_variable(
                'weights',
                [n_rnn, n_output],
                dtype=tf.float32
            )
            b_out = tf.get_variable(
                'biases',
                [n_output],
                dtype=tf.float32,
                initializer=tf.constant_initializer(0.0, dtype=tf.float32)
            )

        h_shaped = tf.reshape(self.h, (-1, n_rnn))
        y_shaped = tf.reshape(self.y, (-1, n_output))
        # y_hat_ shape (n_time*n_batch, n_unit)
        y_hat_ = tf.matmul(h_shaped, w_out) + b_out
        if hp['loss_type'] == 'lsq':
            # Least-square loss
            y_hat = tf.sigmoid(y_hat_)
            self.cost_lsq = tf.reduce_mean(
                tf.square((y_shaped - y_hat) * self.c_mask))
        else:
            y_hat = tf.nn.softmax(y_hat_)
            # Cross-entropy loss
            self.cost_lsq = tf.reduce_mean(
                self.c_mask * tf.nn.softmax_cross_entropy_with_logits(
                    labels=y_shaped, logits=y_hat_))

        self.y_hat = tf.reshape(y_hat,
                                (-1, tf.shape(self.h)[1], n_output))
        y_hat_fix, y_hat_ring = tf.split(
            self.y_hat, [1, n_output - 1], axis=-1)
        self.y_hat_loc = tf_popvec(y_hat_ring)

    def _set_weights_fused(self, hp):
        """Set model attributes for several weight variables."""
        n_input = hp['n_input']
        n_rnn = hp['n_rnn']
        n_output = hp['n_output']

        for v in self.var_list:
            if 'rnn' in v.name:
                if 'kernel' in v.name or 'weight' in v.name:
                    # TODO(gryang): For GRU, fix
                    self.w_rec = v[n_input:, :]
                    self.w_in = v[:n_input, :]
                else:
                    self.b_rec = v
            else:
                assert 'output' in v.name
                if 'kernel' in v.name or 'weight' in v.name:
                    self.w_out = v
                else:
                    self.b_out = v

        # check if the recurrent and output connection has the correct shape
        if self.w_out.shape != (n_rnn, n_output):
            raise ValueError('Shape of w_out should be ' +
                             str((n_rnn, n_output)) + ', but found ' +
                             str(self.w_out.shape))
        if self.w_rec.shape != (n_rnn, n_rnn):
            raise ValueError('Shape of w_rec should be ' +
                             str((n_rnn, n_rnn)) + ', but found ' +
                             str(self.w_rec.shape))
        if self.w_in.shape != (n_input, n_rnn):
            raise ValueError('Shape of w_in should be ' +
                             str((n_input, n_rnn)) + ', but found ' +
                             str(self.w_in.shape))

    def _build_seperate(self, hp):
        # Input, target output, and cost mask
        # Shape: [Time, Batch, Num_units]
        n_input = hp['n_input']
        n_rnn = hp['n_rnn']
        n_output = hp['n_output']

        self.x = tf.placeholder("float", [None, None, n_input])
        self.y = tf.placeholder("float", [None, None, n_output])
        self.c_mask = tf.placeholder("float", [None, n_output])

        sensory_inputs, rule_inputs = tf.split(
            self.x, [hp['rule_start'], hp['n_rule']], axis=-1)

        sensory_rnn_inputs = tf.layers.dense(sensory_inputs, n_rnn, name='sen_input')

        if 'mix_rule' in hp and hp['mix_rule'] is True:
            # rotate rule matrix
            kernel_initializer = tf.orthogonal_initializer()
            rule_inputs = tf.layers.dense(
                rule_inputs, hp['n_rule'], name='mix_rule',
                use_bias=False, trainable=False,
                kernel_initializer=kernel_initializer)

        rule_rnn_inputs = tf.layers.dense(rule_inputs, n_rnn, name='rule_input', use_bias=False)

        rnn_inputs = sensory_rnn_inputs + rule_rnn_inputs

        # Recurrent activity
        cell = LeakyRNNCellSeparateInput(
            n_rnn, hp['alpha'],
            sigma_rec=hp['sigma_rec'],
            activation=hp['activation'],
            w_rec_init=hp['w_rec_init'],
            rng=self.rng)

        # Dynamic rnn with time major
        self.h, states = rnn.dynamic_rnn(
            cell, rnn_inputs, dtype=tf.float32, time_major=True)

        # Output
        h_shaped = tf.reshape(self.h, (-1, n_rnn))
        y_shaped = tf.reshape(self.y, (-1, n_output))
        # y_hat shape (n_time*n_batch, n_unit)
        y_hat = tf.layers.dense(
            h_shaped, n_output, activation=tf.nn.sigmoid, name='output')
        # Least-square loss
        self.cost_lsq = tf.reduce_mean(
            tf.square((y_shaped - y_hat) * self.c_mask))

        self.y_hat = tf.reshape(y_hat,
                                (-1, tf.shape(self.h)[1], n_output))
        y_hat_fix, y_hat_ring = tf.split(
            self.y_hat, [1, n_output - 1], axis=-1)
        self.y_hat_loc = tf_popvec(y_hat_ring)

    def _set_weights_separate(self, hp):
        """Set model attributes for several weight variables."""
        n_input = hp['n_input']
        n_rnn = hp['n_rnn']
        n_output = hp['n_output']

        for v in self.var_list:
            if 'rnn' in v.name:
                if 'kernel' in v.name or 'weight' in v.name:
                    self.w_rec = v
                else:
                    self.b_rec = v
            elif 'sen_input' in v.name:
                if 'kernel' in v.name or 'weight' in v.name:
                    self.w_sen_in = v
                else:
                    self.b_in = v
            elif 'rule_input' in v.name:
                self.w_rule = v
            else:
                assert 'output' in v.name
                if 'kernel' in v.name or 'weight' in v.name:
                    self.w_out = v
                else:
                    self.b_out = v

        # check if the recurrent and output connection has the correct shape
        if self.w_out.shape != (n_rnn, n_output):
            raise ValueError('Shape of w_out should be ' +
                             str((n_rnn, n_output)) + ', but found ' +
                             str(self.w_out.shape))
        if self.w_rec.shape != (n_rnn, n_rnn):
            raise ValueError('Shape of w_rec should be ' +
                             str((n_rnn, n_rnn)) + ', but found ' +
                             str(self.w_rec.shape))
        if self.w_sen_in.shape != (hp['rule_start'], n_rnn):
            raise ValueError('Shape of w_sen_in should be ' +
                             str((hp['rule_start'], n_rnn)) + ', but found ' +
                             str(self.w_sen_in.shape))
        if self.w_rule.shape != (hp['n_rule'], n_rnn):
            raise ValueError('Shape of w_in should be ' +
                             str((hp['n_rule'], n_rnn)) + ', but found ' +
                             str(self.w_rule.shape))

    def initialize(self):
        """Initialize the model for training."""
        sess = tf.get_default_session()
        sess.run(tf.global_variables_initializer())

    def restore(self, load_dir=None):
        """restore the model"""
        sess = tf.get_default_session()
        if load_dir is None:
            load_dir = self.model_dir
        save_path = os.path.join(load_dir, 'model.ckpt')
        try:
            self.saver.restore(sess, save_path)
        except:
            # Some earlier checkpoints only stored trainable variables
            self.saver = tf.train.Saver(self.var_list)
            self.saver.restore(sess, save_path)
        print("Model restored from file: %s" % save_path)

    def save(self):
        """Save the model."""
        sess = tf.get_default_session()
        save_path = os.path.join(self.model_dir, 'model.ckpt')
        self.saver.save(sess, save_path)
        print("Model saved in file: %s" % save_path)

    def set_optimizer(self, extra_cost=None, var_list=None):
        """Recompute the optimizer to reflect the latest cost function.

        This is useful when the cost function is modified throughout training

        Args:
            extra_cost : tensorflow variable,
            added to the lsq and regularization cost
        """
        cost = self.cost_lsq + self.cost_reg
        if extra_cost is not None:
            cost += extra_cost

        if var_list is None:
            var_list = self.var_list

        print('Variables being optimized:')
        for v in var_list:
            print(v)

        self.grads_and_vars = self.opt.compute_gradients(cost, var_list)
        # gradient clipping
        capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var)
                      for grad, var in self.grads_and_vars]
        self.train_step = self.opt.apply_gradients(capped_gvs)

    def lesion_units(self, sess, units, verbose=False):
        """Lesion units given by units

        Args:
            sess: tensorflow session
            units : can be None, an integer index, or a list of integer indices
        """

        # Convert to numpy array
        if units is None:
            return
        elif not hasattr(units, '__iter__'):
            units = np.array([units])
        else:
            units = np.array(units)

        # This lesioning will work for both RNN and GRU
        n_input = self.hp['n_input']
        for v in self.var_list:
            if 'kernel' in v.name or 'weight' in v.name:
                # Connection weights
                v_val = sess.run(v)
                if 'output' in v.name:
                    # output weights
                    v_val[units, :] = 0
                elif 'rnn' in v.name:
                    # recurrent weights
                    v_val[n_input + units, :] = 0
                sess.run(v.assign(v_val))

        if verbose:
            print('Lesioned units:')
            print(units)

In [None]:
model = Model(model_dir, hp=hp)

with tf.Session() as sess:
#     if load_model:
#         model.restore(model_dir)  # complete restore
#     else:
#         # Assume everything is restored
    sess.run(tf.global_variables_initializer())

    # Set trainable parameters
#     if trainables is None or trainables == 'all':
    var_list = model.var_list  # train everything
#     elif trainables == 'input':
#         # train all nputs
#         var_list = [v for v in model.var_list
#                     if ('input' in v.name) and ('rnn' not in v.name)]
#     elif trainables == 'rule':
#         # train rule inputs only
#         var_list = [v for v in model.var_list if 'rule_input' in v.name]
#     else:
#         raise ValueError('Unknown trainables')
#     model.set_optimizer(var_list=var_list)

#     # penalty on deviation from initial weight
#     if hp['l2_weight_init'] > 0:
#         anchor_ws = sess.run(model.weight_list)
#         for w, w_val in zip(model.weight_list, anchor_ws):
#             model.cost_reg += (hp['l2_weight_init'] *
#                                tf.nn.l2_loss(w - w_val))

#         model.set_optimizer(var_list=var_list)

#     # partial weight training
#     if ('p_weight_train' in hp and
#         (hp['p_weight_train'] is not None) and
#         hp['p_weight_train'] < 1.0):
#         for w in model.weight_list:
#             w_val = sess.run(w)
#             w_size = sess.run(tf.size(w))
#             w_mask_tmp = np.linspace(0, 1, w_size)
#             hp['rng'].shuffle(w_mask_tmp)
#             ind_fix = w_mask_tmp > hp['p_weight_train']
#             w_mask = np.zeros(w_size, dtype=np.float32)
#             w_mask[ind_fix] = 1e-1  # will be squared in l2_loss
#             w_mask = tf.constant(w_mask)
#             w_mask = tf.reshape(w_mask, w.shape)
#             model.cost_reg += tf.nn.l2_loss((w - w_val) * w_mask)
#         model.set_optimizer(var_list=var_list)

    step = 0
    while step * hp['batch_size_train'] < max_steps:
        try:
            # Validation
            if step % display_step == 0:
                log['trials'].append(step * hp['batch_size_train'])
                log['times'].append(time.time()-t_start)
                log = do_eval(sess, model, log, hp['rule_trains'],n_loc=n_loc,step=step)
                #if log['perf_avg'][-1] > model.hp['target_perf']:
                #check if minimum performance is above target    
                # if log['perf_min'][-1] >= model.hp['target_perf']:
                if all(elem >= model.hp['target_perf'] for elem in log['perf_min'][-5:]):
                    print('Perf reached the target: {:0.2f}'.format(
                        hp['target_perf']))
                    break

                if rich_output:
                    display_rich_output(model, sess, step, log, model_dir)

            # Training
            rule_train_now = hp['rng'].choice(hp['rule_trains'],
                                              p=hp['rule_probs'])
            # Generate a random batch of trials.
            # Each batch has the same trial length
            trial = generate_trials(
                    rule_train_now, hp, 'random',
                    batch_size=hp['batch_size_train'],step=step,log=log,display_step=display_step,n_loc=n_loc)

            # Generating feed_dict.
            feed_dict = tools.gen_feed_dict(model, trial, hp)
            sess.run(model.train_step, feed_dict=feed_dict)

            step += 1

        except KeyboardInterrupt:
            print("Optimization interrupted by user")
            break

    print("Optimization finished!")

In [None]:
for n_hidden in [8, 16, 32, 64, 128, 256]:
    ## Create folders to save model and training info
    model_folder = 'leaky2_rnn_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
    checkpoint_path = model_folder+"/cp.ckpt"
    saved_model_folder = model_folder+'/saved_model'
    loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

    os.makedirs(model_folder)

    cell = LeakyRNNCell2(n_hidden)
    rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                    return_sequences=True)


    hidden = rnn_layer(inputs)
    delay1_hidden = hidden[:,1,:]
    delay2_hidden = hidden[:,3,:]
    plot_tuning_curves(delay1_hidden,tuning_curve_folder,'untrained_delay1_hidden')
    plot_tuning_curves(delay2_hidden,tuning_curve_folder,'untrained_delay2_hidden')

    get_anova_stats(anova_folder+'untrained_', rnn_layer)

    model = Sequential()
    model.add(rnn_layer)
    model.add(TimeDistributed(Dense(32)))
    model.summary()

    ### Check performance and selectivity before training
    print("n_hidden %s untrained performance: %s"%(n_hidden, get_model_performance(model)))




    ### Create callbacks to saves the model's weights and earlystopping
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                    save_weights_only=True,
                                                    verbose=1)
    es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

    ### Training the model
    history = model.fit(inputs_train, outputs_train, validation_data=(inputs_test,outputs_test), batch_size=64, epochs=500, verbose=2,
              callbacks=[cp_callback, es_callback])
    plot_loss_over_epochs(history, foldername=loss_curve_folder)
    model.save(saved_model_folder)



    # model = tf.keras.models.load_model(saved_model_folder)

    ### Check performance and selectivity after training
    print("n_hidden %s trained performance: %s"%(n_hidden, get_model_performance(model)))

    # simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True, weights=model.layers[0].get_weights())
    rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                    return_sequences=True, weights=model.layers[0].get_weights())
    hidden = rnn_layer(inputs)

    delay1_hidden = hidden[:,1,:]
    delay2_hidden = hidden[:,3,:]
    plot_tuning_curves(delay1_hidden,tuning_curve_folder,'trained_delay1_hidden')
    plot_tuning_curves(delay2_hidden,tuning_curve_folder,'trained_delay2_hidden')

    get_anova_stats(anova_folder, rnn_layer)

In [None]:
class LeakyRNNCell_softplus(tf.keras.layers.Layer):

    def __init__(self, units, **kwargs):
        alpha = 0.2
        sigma_rec = 0
        self.units = units
        self.state_size = units
        self._activation = tf.nn.softplus
        self._w_in_start = 1.0
        self._w_rec_start = 0.5
        self.rng = np.random.RandomState()
        self._alpha = alpha
        self._sigma = np.sqrt(2 / alpha) * sigma_rec
        super(LeakyRNNCell_softplus, self).__init__(**kwargs)

    def build(self, input_shape):
        
        w_in0 = (self.rng.randn(input_shape[-1], self.units) /
                 np.sqrt(input_shape[-1]) * self._w_in_start)
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units), dtype=tf.float32,
                                      initializer=tf.constant_initializer(w_in0),
                                      name='kernel')
        w_rec0 = self._w_rec_start*ortho_group.rvs(dim=self.units, random_state=self.rng)
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units), dtype=tf.float32,
                                                initializer=tf.constant_initializer(w_rec0),
                                                name='recurrent_kernel')
        matrix0 = np.concatenate((w_in0, w_rec0), axis=0)
    
#         self.kernel = self.add_weight(
#                 name='kernel',
#                 shape=[input_shape[-1] + self.units, self.units], 
#                 dtype=tf.float32,
#                 initializer=tf.constant_initializer(matrix0))
        
        self._bias = self.add_weight(
                name='bias',
                shape=[self.units],
                dtype=tf.float32,
                initializer=tf.zeros_initializer())
        
        self.built = True

    def call(self, inputs, states):
        prev_output = states[0]
        h = tf.keras.backend.dot(inputs, self.kernel)
#         h = tf.keras.backend.dot(array_ops.concat([inputs, prev_output], 1), self.kernel)
        h = h + tf.keras.backend.dot(prev_output, self.recurrent_kernel)
        h = tf.nn.bias_add(h, self._bias)
        noise = tf.random.normal(tf.shape(prev_output), mean=0, stddev=self._sigma)
        h = h + noise
        output = self._activation(h)
        output = (1-self._alpha) * prev_output + self._alpha * output
        return output, [output]

In [None]:
for n_hidden in [8, 16, 32, 64, 128, 256]:
    ## Create folders to save model and training info
    model_folder = 'leaky_softplus_rnn_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
    checkpoint_path = model_folder+"/cp.ckpt"
    saved_model_folder = model_folder+'/saved_model'
    loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

    os.makedirs(model_folder)

    cell = LeakyRNNCell_softplus(n_hidden)
    rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                    return_sequences=True)


    hidden = rnn_layer(inputs)
    delay1_hidden = hidden[:,1,:]
    delay2_hidden = hidden[:,3,:]
    plot_tuning_curves(delay1_hidden,tuning_curve_folder,'untrained_delay1_hidden')
    plot_tuning_curves(delay2_hidden,tuning_curve_folder,'untrained_delay2_hidden')

    get_anova_stats(anova_folder+'untrained_', rnn_layer)

    model = Sequential()
    model.add(rnn_layer)
    model.add(TimeDistributed(Dense(32)))
    model.summary()

    ### Check performance and selectivity before training
    print("n_hidden %s untrained performance: %s"%(n_hidden, get_model_performance(model)))




    ### Create callbacks to saves the model's weights and earlystopping
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                    save_weights_only=True,
                                                    verbose=1)
    es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

    ### Training the model
    history = model.fit(inputs_train, outputs_train, validation_data=(inputs_test,outputs_test), batch_size=64, epochs=500, verbose=2,
              callbacks=[cp_callback, es_callback])
    plot_loss_over_epochs(history, foldername=loss_curve_folder)
    model.save(saved_model_folder)



    # model = tf.keras.models.load_model(saved_model_folder)

    ### Check performance and selectivity after training
    print("n_hidden %s trained performance: %s"%(n_hidden, get_model_performance(model)))

    # simplernn = SimpleRNN(n_hidden, input_shape=(inputs_train.shape[1:]), return_sequences=True, weights=model.layers[0].get_weights())
    rnn_layer = tf.keras.layers.RNN(cell,input_shape=(inputs_train.shape[1:]),
                                    return_sequences=True, weights=model.layers[0].get_weights())
    hidden = rnn_layer(inputs)

    delay1_hidden = hidden[:,1,:]
    delay2_hidden = hidden[:,3,:]
    plot_tuning_curves(delay1_hidden,tuning_curve_folder,'trained_delay1_hidden')
    plot_tuning_curves(delay2_hidden,tuning_curve_folder,'trained_delay2_hidden')

    get_anova_stats(anova_folder, rnn_layer)

### creating mixed selectivity problem

In [None]:
stim3_locs = (stim1_locs*stim2_locs)%(2*np.pi)

Y = []
for i in range(stim_loc_size):
    y = add_x_loc(stim3_locs[i])
    Y.append(y)
Y = np.array(Y)

In [None]:
Y.shape

In [None]:
for n_hidden in [8, 16, 32, 64, 128, 256]:

    print("n_hidden: %s"%n_hidden)
    ## Create folders to save model and training info
    model_folder = 'linear_%s_n_hidden_%s'%(datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
    checkpoint_path = model_folder+"/cp.ckpt"
    saved_model_folder = model_folder+'/saved_model'
    loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

    os.makedirs(model_folder)

    model = Sequential()
    model.add(tf.keras.Input(shape=(64,)))
    model.add(tf.keras.layers.Dense(n_hidden, activation='relu'))
    model.add(tf.keras.layers.Dense(32, activation='relu'))
    model.summary()

    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation='relu', weights=model.layers[0].get_weights()))
    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'untrained_hidden')
    get_anova_stats_linear(anova_folder+'untrained_', model2)
    
    
    ### Create callbacks to saves the model's weights and earlystopping
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                    save_weights_only=True,
                                                    verbose=1)
    es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

    ### Training the model
    history = model.fit(X, Y, validation_data=(X,Y), batch_size=64, epochs=500, verbose=2,
              callbacks=[cp_callback, es_callback])
    plot_loss_over_epochs(history, foldername=loss_curve_folder)
    model.save(saved_model_folder)
    
    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation='relu', weights=model.layers[0].get_weights()))

    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'trained_hidden')
    get_anova_stats_linear(anova_folder+'trained_', model2)

In [None]:
for n_hidden in [8, 16, 32, 64, 128, 256]:
    
    activation = 'relu'
    kernel_initializer=tf.keras.initializers.RandomNormal

    print("n_hidden: %s"%n_hidden)
    ## Create folders to save model and training info
    model_folder = 'linear_%s_%s_n_hidden_%s'%(activation, datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
    checkpoint_path = model_folder+"/cp.ckpt"
    saved_model_folder = model_folder+'/saved_model'
    loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

    os.makedirs(model_folder)

    model = Sequential()
    model.add(tf.keras.Input(shape=(64,)))
    model.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer))
    model.add(tf.keras.layers.Dense(32, activation=activation, kernel_initializer=kernel_initializer))
    model.summary()

    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))
    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'untrained_hidden')
    get_anova_stats_linear(anova_folder+'untrained_', model2)
    
    
    ### Create callbacks to saves the model's weights and earlystopping
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                    save_weights_only=True,
                                                    verbose=1)
    es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

    ### Training the model
    history = model.fit(X, Y, validation_data=(X,Y), batch_size=64, epochs=500, verbose=2,
              callbacks=[cp_callback, es_callback])
    plot_loss_over_epochs(history, foldername=loss_curve_folder)
#     model.save(saved_model_folder)
    
    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))

    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'trained_hidden')
    get_anova_stats_linear(anova_folder+'trained_', model2)

In [None]:
for n_hidden in [8, 16, 32, 64, 128, 256]:
    
    activation = 'softplus'
    kernel_initializer=tf.keras.initializers.RandomNormal

    print("n_hidden: %s"%n_hidden)
    ## Create folders to save model and training info
    model_folder = 'linear_%s_%s_n_hidden_%s'%(activation, datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
    checkpoint_path = model_folder+"/cp.ckpt"
    saved_model_folder = model_folder+'/saved_model'
    loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

    os.makedirs(model_folder)

    model = Sequential()
    model.add(tf.keras.Input(shape=(64,)))
    model.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer))
    model.add(tf.keras.layers.Dense(32, activation=activation, kernel_initializer=kernel_initializer))
    model.summary()

    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))
    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'untrained_hidden')
    get_anova_stats_linear(anova_folder+'untrained_', model2)
    
    
    ### Create callbacks to saves the model's weights and earlystopping
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                    save_weights_only=True,
                                                    verbose=1)
    es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

    ### Training the model
    history = model.fit(X, Y, validation_data=(X,Y), batch_size=64, epochs=500, verbose=2,
              callbacks=[cp_callback, es_callback])
    plot_loss_over_epochs(history, foldername=loss_curve_folder)
#     model.save(saved_model_folder)
    
    model2 = Sequential()
    model2.add(tf.keras.Input(shape=(64,)))
    model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))

    hidden = model2(X)
    plot_tuning_curves(hidden,tuning_curve_folder,'trained_hidden')
    get_anova_stats_linear(anova_folder+'trained_', model2)

In [None]:
hidden = model2(X_n)

for i in range(min(n_hidden,10)):
    if n_hidden<10:
        neuron = i
    else:
        neuron = i*int(n_hidden/10)
        
    df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':hidden[:,neuron]})
    aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df,
              detailed=True)
    print('neuron %s:'%neuron, aov['p-unc'][0],aov['p-unc'][1],aov['p-unc'][2])

In [None]:
n_hidden

In [None]:
n_hidden = 256

activation = 'softplus'
kernel_initializer=tf.keras.initializers.RandomNormal

print("n_hidden: %s"%n_hidden)
## Create folders to save model and training info
model_folder = 'linear_%s_%s_n_hidden_%s'%(activation, datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
checkpoint_path = model_folder+"/cp.ckpt"
saved_model_folder = model_folder+'/saved_model'
loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

os.makedirs(model_folder)

model = Sequential()
model.add(tf.keras.Input(shape=(64,)))
model.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer))
model.add(tf.keras.layers.Dense(32, activation=activation, kernel_initializer=kernel_initializer))
model.summary()

model2 = Sequential()
model2.add(tf.keras.Input(shape=(64,)))
model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))
hidden = model2(X)
plot_tuning_curves(hidden,'')
# get_anova_stats_linear(anova_folder+'untrained_', model2)


hidden = model2(X_n)

for i in range(min(n_hidden,10)):
    if n_hidden<10:
        neuron = i
    else:
        neuron = i*int(n_hidden/10)
        
    df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':hidden[:,neuron]})
    aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df,
              detailed=True)
    print('neuron %s:'%neuron, aov['p-unc'][0],aov['p-unc'][1],aov['p-unc'][2])
    
    
    
    
    
### Create callbacks to saves the model's weights and earlystopping
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only=True,
                                                verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

### Training the model
history = model.fit(X, Y, validation_data=(X,Y), batch_size=64, epochs=500, verbose=2,
          callbacks=[cp_callback, es_callback])
plot_loss_over_epochs(history, foldername=loss_curve_folder)
#     model.save(saved_model_folder)

model2 = Sequential()
model2.add(tf.keras.Input(shape=(64,)))
model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))


hidden = model2(X_n)

for i in range(min(n_hidden,10)):
    if n_hidden<10:
        neuron = i
    else:
        neuron = i*int(n_hidden/10)
        
    df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':hidden[:,neuron]})
    aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df,
              detailed=True)
    print('neuron %s:'%neuron, aov['p-unc'][0],aov['p-unc'][1],aov['p-unc'][2])
    


In [None]:
def get_anova_stats_linear(foldername, model):
#     hidden = model(X_n)

#     neuron = 0
#     df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':hidden[:,neuron]})
#     aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df,
#               detailed=True)
#     print(aov)
#     aov.to_csv(foldername+"anova.csv")
    
    hidden = model(X_n)
    for i in range(min(n_hidden,10)):
        if n_hidden<10:
            neuron = i
        else:
            neuron = i*int(n_hidden/10)

        df = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':hidden[:,neuron]})
        aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df,
                  detailed=True)
        print('neuron %s:'%neuron, aov['p-unc'][0],aov['p-unc'][1],aov['p-unc'][2])
        
def plot_tuning_curves(model,foldername='',filename=''):
    Z = model(X)
    Z_n = model(X_n)
    fig,ax = plt.subplots(min(Z.shape[1],10),1,figsize=(5,min(Z.shape[1],10)*3))
    for i in range(min(Z.shape[1],10)):
        if Z.shape[1]<10:
            neuron = i
        else:
            neuron = i*int(Z.shape[1]/10)
        for loc1 in [j*10 for j in range(int(128/10))]:
            df = pd.DataFrame({'first_stim':ind_stim_loc1[ind_stim_loc1==loc1],'second_stim':ind_stim_loc2[ind_stim_loc1==loc1],'activity':Z[:,neuron][ind_stim_loc1==loc1]})
            x = df['second_stim']
            y = df['activity']
            ax[i].scatter(x,y,s=1, color=palette1[loc1], label=loc1)
            
        df_n = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':Z_n[:,neuron]})
        aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df_n, detailed=True)
        ax[i].text(1, 0, 'anova p: (%s,%s,%s)'%(round(aov['p-unc'][0],2),round(aov['p-unc'][1],2),round(aov['p-unc'][2],2)), fontsize = 8)
        
        ax[i].set_ylabel('%dth neuron'%neuron, fontsize=13)
    ax[0].legend(loc='upper left', bbox_to_anchor= (1.05, 1.05), title='Stim 1')
    ax[0].set_xlabel('Stim 2', fontsize=13)
    ax[0].xaxis.set_label_position('top') 
    if filename != '':
        fig.tight_layout()
        fig.savefig(foldername+filename+'_stim2')
    plt.show()
    plt.close()
        
    fig,ax = plt.subplots(min(Z.shape[1],10),1,figsize=(5,min(Z.shape[1],10)*3))
    for i in range(min(Z.shape[1],10)):
        if Z.shape[1]<10:
            neuron = i
        else:
            neuron = i*int(Z.shape[1]/10)
        for loc2 in [j*10 for j in range(int(128/10))]:
            df = pd.DataFrame({'first_stim':ind_stim_loc1[ind_stim_loc2==loc2],'second_stim':ind_stim_loc2[ind_stim_loc2==loc2],'activity':Z[:,neuron][ind_stim_loc2==loc2]})
            x = df['first_stim']
            y = df['activity']
            ax[i].scatter(x,y,s=1, color=palette2[loc2], label=loc2)
            
        df_n = pd.DataFrame({'first_stim':np.array([[i]*10 for i in ind_stim_loc1]).flatten(),'second_stim':np.array([[i]*10 for i in ind_stim_loc2]).flatten(),'activity':Z_n[:,neuron]})
        aov = pg.anova(dv='activity', between=['first_stim', 'second_stim'], data=df_n, detailed=True)
        ax[i].text(1, 0, 'anova p: (%s,%s,%s)'%(round(aov['p-unc'][0],2),round(aov['p-unc'][1],2),round(aov['p-unc'][2],2)), fontsize = 8)
            
        ax[i].set_ylabel('%dth neuron'%neuron, fontsize=13)
    ax[0].legend(loc='upper left', bbox_to_anchor= (1.05, 1.05), title='Stim 2')
    ax[0].set_xlabel('Stim 1', fontsize=13)
    ax[0].xaxis.set_label_position('top') 
    if filename != '':
        fig.tight_layout()
        fig.savefig(foldername+filename+'_stim1')
    plt.show()
    plt.close()

In [None]:
n_hidden = 256

activation = 'relu'
kernel_initializer=tf.keras.initializers.RandomNormal

print("n_hidden: %s"%n_hidden)
# ## Create folders to save model and training info
# model_folder = 'linear_%s_%s_n_hidden_%s'%(activation, datetime.now(pytz.timezone('Asia/Singapore')).strftime("%d_%m_%Y_%H_%M_%S"),n_hidden)
# checkpoint_path = model_folder+"/cp.ckpt"
# saved_model_folder = model_folder+'/saved_model'
# loss_curve_folder = anova_folder = tuning_curve_folder = model_folder+'/'

# os.makedirs(model_folder)

model = Sequential()
model.add(tf.keras.Input(shape=(64,)))
model.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer))
model.add(tf.keras.layers.Dense(32, activation=activation, kernel_initializer=kernel_initializer))
model.summary()

model2 = Sequential()
model2.add(tf.keras.Input(shape=(64,)))
model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))

print("untrained hidden:")
plot_tuning_curves(model2,tuning_curve_folder,'untrained_hidden')
# get_anova_stats_linear(anova_folder+'untrained_hidden_', model2)

print("untrained output:")
plot_tuning_curves(model,tuning_curve_folder,'untrained_output')
# get_anova_stats_linear(anova_folder+'untrained_output_', model)

### Create callbacks to saves the model's weights and earlystopping
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only=True,
                                                verbose=1)
es_callback = tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=10)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss=tf.keras.losses.mse)

### Training the model
history = model.fit(X, Y, validation_data=(X,Y), batch_size=64, epochs=500, verbose=2,
          callbacks=[cp_callback, es_callback])
plot_loss_over_epochs(history, foldername=loss_curve_folder)
model.save(saved_model_folder)

model2 = Sequential()
model2.add(tf.keras.Input(shape=(64,)))
model2.add(tf.keras.layers.Dense(n_hidden, activation=activation, kernel_initializer=kernel_initializer, weights=model.layers[0].get_weights()))

print("trained hidden:")
plot_tuning_curves(model2,tuning_curve_folder,'trained_hidden')
# get_anova_stats_linear(anova_folder+'trained_hidden_', model2)

print("trained output:")
plot_tuning_curves(model,tuning_curve_folder,'trained_output')
# get_anova_stats_linear(anova_folder+'trained_output_', model)