In [None]:
class SimpleRNN_pos_loss(keras.layers.Layer):
    def __init__(self, units, conn_prob=1, activation="tanh", kernel_regularizer=None,activation_weight=0, **kwargs):
        super().__init__(**kwargs)
        self.state_size = units
        self.output_size = units
        self.kernel_regularizer = kernel_regularizer
        self.activation_weight = activation_weight
        self.simple_rnn_cell = keras.layers.SimpleRNNCell(units,
                                                          activation=None,kernel_initializer='glorot_uniform',
                                                          kernel_regularizer=self.kernel_regularizer,
                                                          recurrent_initializer=tf.keras.initializers.glorot_normal,
                                                          )
        
        self.activation = keras.activations.get(activation)
        self.pos = locmap(args)
        self.lateral_effect = lateral_effect(args)
        self.lateral_effect = tf.convert_to_tensor(self.lateral_effect,dtype="float32")
        self.counter = -1
        
        self.conn_prob = conn_prob
        self.conn_mat = self.get_connection_mask(units)
        
    def call(self, inputs, states):
        # apply connection mask to simple_rnn_weights and then set weights
        #weight_mat = self.simple_rnn_cell.get_weights()
        #if(len(weight_mat)>0): # apply connection
        #    weight_mat[1] = tf.multiply(self.conn_mat,weight_mat[1])
        #    self.simple_rnn_cell.set_weights(weight_mat)
            
        # run rnn cell
        outputs, self.states = self.simple_rnn_cell(inputs, states)
        
        # activation
        outputs = self.activation(outputs + np.random.random(outputs.shape)*args.noise_val-args.noise_val/2)
        
        # lateral loss
        A = tf.matmul(tf.convert_to_tensor(outputs,dtype="float32"),self.lateral_effect)
        # loss is -1*mean of diagonal elements
        lateral_loss = -1*tf.reduce_mean(tf.linalg.diag_part(A))
        self.add_loss(args.lambda_l*lateral_loss)
        
        # loss based on activations
        self.add_loss(tf.reduce_mean(outputs)*self.activation_weight)
        
        # adjust outputs based on stim
        if(stim_params.is_stim and stim_params.stim_dist_tau > 0 and self.counter >= stim_params.stim_time[0] and self.counter <=stim_params.stim_time[1]):
            # set activations around electrode to 1 with an exponential decay prob based on distance
            # get distance from stim_pos
            stim_dists = pairwise_distances(cell.pos,np.expand_dims(stim_params.stim_pos,axis=0))
            
            # turn distance into prob_act based on stim_dist_decay
            prob_act = np.exp(-1*stim_dists/stim_params.stim_dist_tau)
            # determine which cells were activated
            cell_activated = (np.argwhere(np.random.random(prob_act.shape) < prob_act))
            cell_activated = np.expand_dims(cell_activated[:,0],axis=1)
            
            # set output
            outputs=tf.transpose(outputs)
            outputs=tf.tensor_scatter_nd_update(outputs,tf.convert_to_tensor(cell_activated),tf.ones((cell_activated.shape[0],outputs.shape[1])))
            outputs=tf.transpose(outputs)
            
        self.counter = self.counter + 1
        return outputs, [outputs]

    def reset_counter(self):
        self.counter = -1
        
    def get_connection_mask(self,units):
        # produces units x units matrix where each entry is a 1 or 0 determining if cells are connected
        # diagonal must be 1
        conn_mat = np.random.random((units,units)) < self.conn_prob
        conn_mat = conn_mat.astype(np.float32)
        np.fill_diagonal(conn_mat,1.0)
        conn_mat = tf.convert_to_tensor(conn_mat)
        return conn_mat

    
def locmap(args):
    '''
    :return: location of each neuron
    '''
    x = np.arange(0, args.latent_shape[0], dtype=np.float32)
    y = np.arange(0, args.latent_shape[1], dtype=np.float32)
    xv, yv = np.meshgrid(x, y)
    xv = np.reshape(xv, (xv.size, 1))
    yv = np.reshape(yv, (yv.size, 1))
    return np.hstack((xv, yv))

def lateral_effect(args):
    '''
    :return: functions of lateral effect
    '''
    locations = locmap(args)
    weighted_distance_matrix = euclidean_distances(locations, locations)/args.sigma

    if args.lateral is 'mexican':
        S = (1.0-0.5*np.square(weighted_distance_matrix))*np.exp(-0.5*np.square(weighted_distance_matrix))
        return S-np.eye(len(locations))

    if args.lateral is 'rbf':
        S = np.exp(-0.5*np.square(weighted_distance_matrix))
        return S-np.eye(len(locations))
    
    print('no lateral effect is chosen')
    return np.zeros(weighted_distance_matrix.shape, dtype=np.float32)