In [1]:
%reload_ext autoreload
%autoreload 2

# %matplotlib widget

%matplotlib qt

# %gui qt

import time
import random
from copy import copy
from tqdm import tqdm
from pathlib import Path
import numpy as np
import numba as nb
import scipy as sp
from sklearn.decomposition import PCA
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
pd.options.display.width = 1000

import faststats as fs

from vrAnalysis import session
from vrAnalysis import registration
from vrAnalysis import functions
from vrAnalysis import analysis
from vrAnalysis import helpers
from vrAnalysis import fileManagement as fm
from vrAnalysis import database
from vrAnalysis import tracking

from vrAnalysis.uiDatabase import addEntryGUI
from vrAnalysis.redgui import redCellGUI as rgui

sessiondb = database.vrDatabase('vrSessions')
mousedb = database.vrDatabase('vrMice')

pd.set_option('display.max_rows', 100)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [2]:
from vrAnalysis.simulations.simulator import Simulator

In [75]:
box_length = 80
sim = Simulator(
    box_length,
    spacing=1,
    dt=0.1,
    speed_mean=0.12,
    num_place_cells=50,
    num_grid_cells=0,
    place_width_mean=20,
    place_width_std=10.0,
    num_grid_modules=4,
    grid_expansion=1.3,
    base_grid_spacing=39.8,
    base_grid_width=27.4,
    g_noise_amp=1 / 100000
)

In [33]:
# Show Examples
numExamplesEach = 6
figdim = 2
f, ax = plt.subplots(2,numExamplesEach,figsize=(figdim*numExamplesEach,figdim*2), layout='constrained')
for n in range(numExamplesEach):
    ax[0,n].imshow(sim.place_library[np.random.randint(sim.num_place_cells)])
    ax[1,n].imshow(sim.grid_library[np.random.randint(sim.num_grid_cells)])
    ax[0,n].axis('off')
    ax[1,n].axis('off')

In [81]:
plt.hist(np.reshape(sim.grid_library, -1))
plt.show()

In [78]:
# Do a trajectory
t, pos, posidx = sim.run_simulation(10000)

# Return place and grid cell activity along trajectory, with noise if requested
place_signal = sim.place_library[:,posidx[:,0], posidx[:,1]].T
grid_signal = sim.grid_library[:,posidx[:,0], posidx[:,1]].T
hippo_signal = np.concatenate((place_signal,grid_signal),axis=1)

# Add Noise
gNoiseAmp = 1/5
gNoisePlace = gNoiseAmp * np.std(place_signal,axis=0)
gNoiseHippo = gNoiseAmp * np.std(hippo_signal,axis=0)
place_activity = place_signal + np.random.normal(0,gNoisePlace,place_signal.shape)
hippo_activity = hippo_signal + np.random.normal(0,gNoiseHippo,hippo_signal.shape)

In [79]:
# Plot Trajectory and plot examples of place cell and grid cell activity overlaid on trajectory
numExamplesEach = 2
pexidx = np.random.randint(0,sim.num_place_cells,numExamplesEach)
gexidx = np.random.randint(0,sim.num_grid_cells,numExamplesEach)

figdim = 3
f, ax = plt.subplots(1,numExamplesEach*2+1,figsize=((numExamplesEach*2+1)*figdim,figdim))
# Ag.plot_trajectory(ax=ax[0])
plt.set_cmap('cool')
ax[0].scatter(pos[:,0],pos[:,1],s=15,c=range(pos.shape[0]),alpha=0.3)
plt.set_cmap('jet')
for n in range(numExamplesEach):
    ax[n+1].scatter(pos[:,0],pos[:,1],s=5,c=place_activity[:,pexidx[n]])
    ax[n+1+numExamplesEach].scatter(pos[:,0],pos[:,1],s=5,c=hippo_activity[:,sim.num_place_cells+gexidx[n]])