### Imports

In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook
%matplotlib inline

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch

# custom packages
from ratsimulator import trajectory_generator
from ratsimulator.Environment import Rectangle

import datetime
import sys
# avoid adding multiple relave paths to sys.path
sys.path.append("../src") if "../src" not in sys.path else None 

from PlaceCells import PlaceCells
from Models import SorscherRNN
from methods import *

ModuleNotFoundError: No module named 'torch'

### Set parameters and initialise

In [None]:
"""
# Sorscher params
options.save_dir = '/mnt/fs2/bsorsch/grid_cells/models/'
options.n_steps = 100000      # number of training steps
options.batch_size = 200      # number of trajectories per batch
options.sequence_length = 20  # number of steps in trajectory
options.learning_rate = 1e-4  # gradient descent learning rate
options.Np = 512              # number of place cells
options.Ng = 4096             # number of grid cells
options.place_cell_rf = 0.12  # width of place cell center tuning curve (m)
options.surround_scale = 2    # if DoG, ratio of sigma2^2 to sigma1^2
options.RNN_type = 'RNN'      # RNN or LSTM
options.activation = 'relu'   # recurrent nonlinearity
options.weight_decay = 1e-4   # strength of weight decay on recurrent weights
options.DoG = True            # use difference of gaussians tuning curves
options.periodic = False      # trajectories with periodic boundary conditions
options.box_width = 2.2       # width of training environment
options.box_height = 2.2      # height of training environment
"""

In [None]:
params = {}
# Environment params
params['boxsize'] = (2.2, 2.2)
params['origo'] = (0,0)
params['soft_boundary'] = 0.03 # Sorscher uses 0.03, I used to have 0.2
# Place Cells params
params['npcs'] = 512 # as used in Sorscher model
params['pc_width'] = 0.12
params['DoG'] = True
params['seed'] = 0 # place-cell center seed
# Training data (Agent) params
params['batch_size'] = 200
params['seq_len'] = 20
params['angle0'] = None # random
params['p0'] = None     # random
# Agent/random walk parameters
params['dt'] = 0.02
params['turn_angle'] = 5.76 * 2
params['b'] = 0.13 * 2 * np.pi
params['mu'] = 0
# Model params
params['Ng'] = 4096
params['Np'] = params['npcs'] # defined for Brain already
params['weight_decay'] = 1e-4
params['lr'] = 1e-4# 1e-3 is default for Adam()
params['nsteps'] = 100 # number of mini batches in an epoch
params['nepochs'] = 2000 # number of epochs

# stuff
params['tag'] = 'default'
params['save_model'] = True 
params['save_freq'] = 1
params['date'] = datetime.datetime.now()

num_workers = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"{device=}")

In [None]:
# Init Environment
environment = Rectangle(boxsize=params['boxsize'], soft_boundary=params['soft_boundary'])
params['environment_name'] = type(environment).__name__

# Init brain
place_cells = PlaceCells(environment=environment, **params)

# Init training data
dataset = Dataset(environment=environment, place_cells=place_cells, **params)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=params['batch_size'], num_workers=num_workers)

# Init model
model = SorscherRNN(Ng=params['Ng'], Np=params['Np']) 
model.to(device)
print(model)

# Init optimizer (use custom weight decay, rather than torch optim decay)
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], betas=(0.9, 0.999), \
                             eps=1e-08, weight_decay=0.0, amsgrad=False)

### Train Model

In [None]:
checkpoint_path = '../checkpoints/'
loss_history = []

if loaded_model:=True:
    model_name = type(model).__name__
    checkpoints = torch.load(f'{checkpoint_path}{model_name}_{params["tag"]}')
    model.load_state_dict(checkpoints['model_state_dict'])
    optimizer.load_state_dict(checkpoints['optimizer_state_dict'])
    loss_history = checkpoints['loss_history']
    training_metrics = checkpoints['training_metrics']
    print("Loaded weights")
    
# whether to train
if train:=False:
    loss_history = model.train(trainloader = dataloader, optimizer = optimizer, weight_decay=params['weight_decay'], \
                nepochs=params['nepochs'], loaded_model = loaded_model, \
                save_model = params['save_model'], save_freq = params['save_freq'], \
                loss_history = loss_history, training_metrics = training_metrics, tag = params['tag'], params = params)

### Analyse Model

In [None]:
fig, ax = plt.subplots(ncols=4,figsize=(18,5))
for i,(key,value) in enumerate(checkpoints['training_metrics'].items()):
    if key == 'KL':
        ax[i].axhline(0,ls=":")
    ax[i].plot(value)
    ax[i].set_title(key)

# maximum labelled distribution entropy (uniform labelled distribution)
n = 512
px = np.ones(n) / n # uniform
entropy = lambda x: -np.sum(x * np.log(x))
entropy(px)

In [None]:
plt.plot(loss_history[5:])
plt.title('Total-loss history')
plt.xlabel('epoch')
plt.ylabel('loss')

### Grid cells

In [None]:
save_name = type(model).__name__ + '_' + params['tag']
if save_ratemaps:=False:
    idxs=slice(0, params['Ng'], 1)
    res=np.array([32, 32])
    board, ratemaps, response_maps, count_maps = rate_map(model=model.g, environment=environment, dataset=dataset, seq_len=params['seq_len'], \
                  res=res, idxs=idxs, num_samples=1)
    # pickle rate_maps
    save_obj(ratemaps, save_name + "/ratemaps")

elif load_ratemaps:=True:
    ratemaps = load_obj(save_name + "/ratemaps")

In [None]:
#fig, ax = multicontourf(*board.T, rate_maps)
start_idx = 512
num_ratemaps = 256
fig, ax = multiimshow(ratemaps[start_idx:start_idx+num_ratemaps])

### Predicted place Cells

In [None]:
"""
board, rate_maps, _, _ = rate_map(model=model, environment=environment, dataset=dataset, seq_len=params['seq_len'], \
              res=res, idxs=idxs, num_samples=1)
#fig, ax = multicontourf(*board.T, rate_maps)
fig, ax = multiimshow(rate_maps)
"""

### Predicted Place cells - SOFTMAX

In [None]:
"""
forward_with_softmax = lambda x: torch.exp(model(x, log_softmax=True))
board, rate_maps, _, _ = rate_map(model=forward_with_softmax, environment=environment, dataset=dataset, seq_len=params['seq_len'], \
              res=res, idxs=idxs, num_samples=1)
#fig, ax = multicontourf(*board.T, rate_maps)
fig, ax = multiimshow(rate_maps)
"""

### Labelled place cells

In [None]:
"""
board, rate_maps, _, _ = rate_map(model='labels', environment=environment, dataset=dataset, seq_len=params['seq_len'], \
              res=res, idxs=idxs, num_samples=1)
#fig, ax = multicontourf(*board.T, rate_maps)
fig, ax = multiimshow(rate_maps)
"""

In [None]:
model.prune_mask = list(range(int(4096/2),4096))

In [None]:
model.prune_mask = []

In [None]:
model.prune_mask

In [None]:
idxs=slice(0, 16**2, 1)
res=np.array([32, 32])
board, rate_maps, response_maps, count_maps = rate_map(model=model.g, environment=environment, dataset=dataset, seq_len=params['seq_len'], \
              res=res, idxs=idxs, num_samples=1)
#fig, ax = multicontourf(*board.T, rate_maps)
fig, ax = multiimshow(rate_maps)

### Decoding labels and predictions to cartesian

In [None]:
dataset.return_cartesian = True
[[vel, init_pos], labels, true_cartesian_pos] = dataset[0]
dataset.return_cartesian = False
true_decoded_pos = place_cells.to_euclid(torch.cat([init_pos[None], labels]))
pc_preds = model([vel, init_pos]).detach().cpu()[0]
predicted_decoded_pos = place_cells.to_euclid(torch.cat([init_pos[None], pc_preds]))

In [None]:
plt.plot(*true_decoded_pos.T, label='true_decoded_pos')
plt.plot(*true_cartesian_pos.T, label='true_cartesian_pos')
plt.plot(*predicted_decoded_pos.T, label='predicted_decoded_pos')
plt.xlim(environment.origo[0],environment.boxsize[0])
plt.ylim(environment.origo[1],environment.boxsize[1])
plt.legend()

### Plot all place cell centers and some with tuning curves

In [None]:
fig, ax = plt.subplots()
x,y = place_cells.pcs.T

ax.plot(x, y, "+")
# add standard deviation circles to locations
for i in range(5):
    ax.plot(x[i], y[i], "r+")
    a_circle = plt.Circle((x[i], y[i]), params['pc_width'], fill=False, color=(1, 0, 0, 0.5))
    ax.add_artist(a_circle)

plt.title("Spatial plot of place cell locations")
plt.xlabel("X")
plt.ylabel("Y")
plt.show()

### Calculate grid scores using different implementations of the metric

In [None]:
# get grid cells
idxs=slice(0, 128, 1)
board, rate_maps, response_maps, count_maps = rate_map(model=model.g, environment=environment, dataset=dataset, seq_len=params['seq_len'], \
              res=res, idxs=idxs, num_samples=1)

In [None]:
# Custom grid score
print("CUSTOM:", grid_score(rate_maps[1])) 

# CINPLA grid score
import spatial_maps as sm
print("CINPLA:", sm.gridness(rate_maps[1])) 

# BANINO (and Sorscher) grid scoring
from scores import GridScorer
"""
One difference from custom and CINPLA grid scores: 
1. Uses average difference between phase60 and phase30 correlations
"""
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
coord_range=((0, environment.boxsize[0]), (0, environment.boxsize[1]))
box_width, box_height = 2.2, 2.2
coords_range=((-box_width/2, box_width/2), (-box_height/2, box_height/2))
mask_parameters = zip(starts, ends.tolist())
scorer = GridScorer(nbins=res[0], coords_range=coords_range, mask_parameters=mask_parameters)

#score_60, score_90, max_60_mask, max_90_mask, sac, max_60_ind = zip(
#      *[scorer.get_scores(rm.reshape(res, res)) for rm in tqdm(rate_map_lores)])
score_60, score_90, max_60_mask, max_90_mask, sac = scorer.get_scores(rate_maps[1])
print("BANINO/SORSCHER:", score_60)

In [None]:
# choose grid scoring function to use, e.g: grid_score, sm.gridness or scorer.get_scores
# for scorer.get_scores use: < (lambda rm: scorer.get_scores(rm)[0])(rate_map) >
grid_scoring_fn = lambda rate_map: sm.gridness(rate_map)

#map(grid_scoring_fn, *rate_maps)
grid_scoring_fn(rate_maps[1])

### Small analysis / checks / tests etc

In [None]:
# Wr = model.recurrence.weight.detach().cpu().numpy()
Wr = model.RNN.weight_hh_l0.detach().cpu().numpy()
stats = lambda W : print(f"{np.min(W)=}, {np.max(W)=}, {np.min(abs(W))=}, {np.mean(W)=}, {np.std(W)=}, {np.sum(W**2)=}")
stats(Wr)
plt.imshow(Wr[:25,:25])
plt.colorbar()

In [None]:
Wp = model.init_position_encoder.weight.detach().cpu().numpy()
stats = lambda W : print(f"{np.min(W)=}, {np.max(W)=}, {np.min(abs(W))=}, {np.mean(W)=}, {np.std(W)=}, {np.sum(W**2)=}")
stats(Wp)
plt.imshow(Wp)
plt.colorbar()