In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# custom packages
from ratsimulator import Agent, trajectory_generator, batch_trajectory_generator
from ratsimulator.Environment import Rectangle
from ctimeit import ctimeit # for timing

import sys
sys.path.append("../src") if "../src" not in sys.path else None # avoid adding multiple relave paths to sys.path

from Brain import Brain
from Models import SorscherRNN

### Set parameters and initialise

In [3]:
# Environment params
boxsize = (2.2, 2.2)
origo = (0,0)
soft_boundary = 0.2
# Brain params
npcs = 512 # as used in Sorscher model
sigma = 0.12
# Training data (Agent) params
batch_size = 64
seq_len = 20
angle0 = None # random
p0 = None     # random
# Agent/random walk parameters
dt = 0.02
sigma = 5.76 * 2
b = 0.13 * 2 * np.pi
mu = 0
# Model params
Ng=4096
Np=npcs # defined for Brain already
weight_decay=1e-4
activation="relu"
lr=1e-4 # 1e-3 is default for Adam()

In [4]:
# Init Environment
env = Rectangle(boxsize=boxsize, soft_boundary=soft_boundary)
# Init brain
brain = Brain(env, npcs, sigma)
# Init training data
btg = batch_trajectory_generator(batch_size, env, seq_len, angle0, p0, dt=dt, sigma=sigma, b=b, mu=mu)
# model init
model = SorscherRNN(Ng,Np,weight_decay,activation)

Singular matrix
Singular matrix


In [5]:
"""
# Sorscher params
options.save_dir = '/mnt/fs2/bsorsch/grid_cells/models/'
options.n_steps = 100000      # number of training steps
options.batch_size = 200      # number of trajectories per batch
options.sequence_length = 20  # number of steps in trajectory
options.learning_rate = 1e-4  # gradient descent learning rate
options.Np = 512              # number of place cells
options.Ng = 4096             # number of grid cells
options.place_cell_rf = 0.12  # width of place cell center tuning curve (m)
options.surround_scale = 2    # if DoG, ratio of sigma2^2 to sigma1^2
options.RNN_type = 'RNN'      # RNN or LSTM
options.activation = 'relu'   # recurrent nonlinearity
options.weight_decay = 1e-4   # strength of weight decay on recurrent weights
options.DoG = True            # use difference of gaussians tuning curves
options.periodic = False      # trajectories with periodic boundary conditions
options.box_width = 2.2       # width of training environment
options.box_height = 2.2      # height of training environment
"""

"\n# Sorscher params\noptions.save_dir = '/mnt/fs2/bsorsch/grid_cells/models/'\noptions.n_steps = 100000      # number of training steps\noptions.batch_size = 200      # number of trajectories per batch\noptions.sequence_length = 20  # number of steps in trajectory\noptions.learning_rate = 1e-4  # gradient descent learning rate\noptions.Np = 512              # number of place cells\noptions.Ng = 4096             # number of grid cells\noptions.place_cell_rf = 0.12  # width of place cell center tuning curve (m)\noptions.surround_scale = 2    # if DoG, ratio of sigma2^2 to sigma1^2\noptions.RNN_type = 'RNN'      # RNN or LSTM\noptions.activation = 'relu'   # recurrent nonlinearity\noptions.weight_decay = 1e-4   # strength of weight decay on recurrent weights\noptions.DoG = True            # use difference of gaussians tuning curves\noptions.periodic = False      # trajectories with periodic boundary conditions\noptions.box_width = 2.2       # width of training environment\noptions.box_he

In [6]:
pos, vel = next(btg)

In [10]:
def to_one_hot(x):
    """OBS! if max(x,axis=-1) is not unique, then return result is not one-hot, but k-hot"""
    return np.where(x == np.max(x,axis=-1,keepdims=True),1,0)
    
def data_generator(btg, brain):
    
    while True:
        pos, vel = next(btg)
        
        #labels = brain(pos) # from euclidean positions to place-cell positions
        labels = brain.softmax_response(pos)
        init_pos, labels = labels[:,0], labels[:,1:] # "next" pos is label
        
        vel = vel[:,1:] # discard first velocity -> always 0 (initial vel)
        # labels = to_one_hot(labels) 
        
        yield (vel,init_pos), labels 

# Initialise data generator
dg = data_generator(btg,brain)

# Specify and instantiate Model

In [11]:


model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=lr),
    loss=tf.nn.softmax_cross_entropy_with_logits,
    #metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], # requires dims: (batch_size,N)
)

# Build (by specifying input_shape) and summarize model
input_shape = [(batch_size, seq_len, 2), (batch_size, Np)] # velocity-input UNION initial-state
model.build(input_shape)
model.summary()

Model: "sorscher_rnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder (Dense)              multiple                  2097152   
_________________________________________________________________
RNN (SimpleRNN)              multiple                  16785408  
_________________________________________________________________
decoder (Dense)              multiple                  2097152   
Total params: 20,979,712
Trainable params: 20,979,712
Non-trainable params: 0
_________________________________________________________________


# Train model

In [12]:
# train model
epochs = 10
steps_per_epoch = 200
model.fit(x=dg, epochs=epochs, steps_per_epoch=steps_per_epoch)

Epoch 1/10
 27/200 [===>..........................] - ETA: 3:33 - loss: 6.6287

KeyboardInterrupt: 

In [13]:
model.RNN.weights

[<tf.Variable 'RNN/simple_rnn_cell/kernel:0' shape=(2, 4096) dtype=float32, numpy=
 array([[-0.00920061,  0.02811981, -0.00011324, ...,  0.0294923 ,
         -0.03277944,  0.01796926],
        [-0.01600602, -0.01291346,  0.03635537, ..., -0.03057673,
          0.02890574,  0.00356031]], dtype=float32)>,
 <tf.Variable 'RNN/simple_rnn_cell/recurrent_kernel:0' shape=(4096, 4096) dtype=float32, numpy=
 array([[-0.02158454, -0.01334286, -0.01302407, ..., -0.01205708,
          0.00833452, -0.01296994],
        [ 0.00657571,  0.02461227, -0.02314649, ..., -0.01977941,
          0.01084158, -0.00099443],
        [ 0.00673047,  0.01362112,  0.0139323 , ..., -0.02046178,
         -0.00570212, -0.0063231 ],
        ...,
        [-0.01000874, -0.00921379, -0.00883043, ...,  0.02316147,
          0.01265423, -0.02066893],
        [-0.01533048,  0.01487139, -0.02406915, ...,  0.0194156 ,
          0.00078582, -0.01597998],
        [-0.0034672 , -0.00508838,  0.01132517, ..., -0.0002491 ,
         -