In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

# custom packages
from ratsimulator import Agent, trajectory_generator, batch_trajectory_generator
from ratsimulator.Environment import Rectangle

import sys
sys.path.append("../src") if "../src" not in sys.path else None # avoid adding multiple relave paths to sys.path

from Brain import Brain
from Models import UnitPathIntegrator
from methods import *

### Set parameters and initialise

In [3]:
"""
# Sorscher params
options.save_dir = '/mnt/fs2/bsorsch/grid_cells/models/'
options.n_steps = 100000      # number of training steps
options.batch_size = 200      # number of trajectories per batch
options.sequence_length = 20  # number of steps in trajectory
options.learning_rate = 1e-4  # gradient descent learning rate
options.Np = 512              # number of place cells
options.Ng = 4096             # number of grid cells
options.place_cell_rf = 0.12  # width of place cell center tuning curve (m)
options.surround_scale = 2    # if DoG, ratio of sigma2^2 to sigma1^2
options.RNN_type = 'RNN'      # RNN or LSTM
options.activation = 'relu'   # recurrent nonlinearity
options.weight_decay = 1e-4   # strength of weight decay on recurrent weights
options.DoG = True            # use difference of gaussians tuning curves
options.periodic = False      # trajectories with periodic boundary conditions
options.box_width = 2.2       # width of training environment
options.box_height = 2.2      # height of training environment
"""

# Environment params
boxsize = (2.2, 2.2)
origo = (0,0)
soft_boundary = 0.2
# Brain params
npcs = 512 # as used in Sorscher model
sigma = 0.12
# Training data (Agent) params
batch_size = 64
seq_len = 1
angle0 = None # random
p0 = None     # random
# Agent/random walk parameters
dt = 0.02
sigma = 5.76 * 2
b = 0.13 * 2 * np.pi
mu = 0
# Model params
Ng=4096
Np=npcs # defined for Brain already
weight_decay=1e-4
lr=1e-4 # 1e-3 is default for Adam()

# can also select cuda:0 or cuda:1 f.eks when multiple gpus
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
nsteps = 100 # number of mini batches in an epoch
num_workers = 16 # number of workers to load data (in parallel)

In [4]:
# Init Environment
env = Rectangle(boxsize=boxsize, soft_boundary=soft_boundary)
# Init brain
brain = Brain(env, npcs, sigma)
# Init training data
dataset = Dataset(brain=brain, batch_size=batch_size, nsteps=nsteps, environment=env, seq_len=seq_len, \
                  angle0=angle0, p0=p0, dt=dt, sigma=sigma, b=b, mu=mu)
dataloader = torch.utils.data.DataLoader(dataset,  batch_size=batch_size, num_workers=num_workers)
# Init model
model = UnitPathIntegrator(Ng,Np)
model.to(device)

Singular matrix
Singular matrix


UnitPathIntegrator(
  (velocity_encoder): Linear(in_features=2, out_features=4096, bias=False)
  (init_position_encoder): Linear(in_features=512, out_features=4096, bias=False)
  (recurrence): Linear(in_features=4096, out_features=4096, bias=False)
  (decoder): Linear(in_features=4096, out_features=512, bias=False)
)

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
model.train(trainloader=dataloader, optimizer = optimizer, weight_decay=weight_decay, nepochs=10, device=device)

Epoch=10, loss=6.238444905281067: 100%|█████████| 10/10 [00:18<00:00,  1.82s/it]


In [11]:
model

UnitPathIntegrator(
  (velocity_encoder): Linear(in_features=2, out_features=4096, bias=False)
  (init_position_encoder): Linear(in_features=512, out_features=4096, bias=False)
  (recurrence): Linear(in_features=4096, out_features=4096, bias=False)
  (decoder): Linear(in_features=4096, out_features=512, bias=False)
)

In [35]:
Wg = model.recurrence.weight

In [36]:
Wg = Wg.to('cpu').detach().numpy()

In [37]:
np.max(Wg), np.min(Wg), np.mean(Wg), np.std(Wg)

(3.7126825e-05, -3.6143654e-05, -2.574247e-10, 1.1633853e-06)

In [41]:
W = torch.nn.Linear(4096,4096)
W = W.weight.detach().numpy()
np.max(W), np.min(W), np.mean(W), np.std(W)

(0.015624998, -0.015624998, -5.161618e-07, 0.0090188775)