<a href="https://colab.research.google.com/github/divyanshgupt/travelling-wave-mec/blob/main/Raster_Plot_Theta_Cycle_1D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

Assumptions:
1. Each module contains the same number of cells
2. Self similarity (the grid fields become smaller
3. All cells within the same module have the same spatial period (the definition of a module)
 

In [7]:
args = {'lam_0': 10, # spatial period of the largest module
        'K': 50,
        'n_max': 20,
        'nb_cells': 10, #no. of cells per module
        'nb_modules': 4,
        's': 1.5, # scaling ratio
        'b_theta': 0,
        'm_theta': 2*np.pi/10,   # rate of theta phase precession
        'k_theta': 1,
        'x_0': 0, # origin for phase precession
        'animal_velocity': 0.2, # m/sec
        'step_size': 0.01, # sec 
        'nb_steps': 1000,
        'theta_offset': 40, # degrees # total theta phase offset from the dorsal to the ventral end, defining the travelling wave
        'theta_freq': 8 , # Hz # Frequency of theta oscillations
        }

### 1-D Von Mises Function (for Encoding)

$$ \Omega_j = n_{max} e^{\kappa(cos(\omega (x - c_j)) - 1)}$$ 

In [8]:
def von_mises(position, lam, phase, args):
  """
  inputs:
    position - x-coordinate
    lam - the spatial period of the grid cell
    phase - the preferred relative phase of the grid cell
    args: ['K', 'n_max']
  returns:
    von-mises valuea the given position coordinate
  """
  K = args['K']
  n_max = args['n_max']

  w = 2*np.pi / lam

  value = n_max * np.exp(K*(np.cos(w*(position - phase)) - 1))

  return value

In [28]:
def poisson_cells(rate, args):
  """
  Returns a spike or no spike in the given time-bin for cells with given parameters
  inputs:
    rate - average rate of firing, shape: (nb_cells)
    args['step_size']
  returns:

  """
  nb_cells = args['nb_cells']
  dt = args['step_size']
  
  random = np.random.uniform(size=(nb_cells))
  activity = np.zeros(nb_cells)
  determinant = random - (rate*dt) # shape: nb_cells
  activity[determinant <= 0] = 1

  return activity 

In [11]:
def grid_module(lam, args):
  """
  distributes position relative phase among cells in a linear orderly fashion
  inputs:
    args:['nb_cells']
  returns:
    array of x-coordinates denoting relative phasese of subsquent cells
  """
  
  nb_cells = args['nb_cells']
  step = lam/nb_cells
  phases = np.arange(0, lam, lam/nb_cells)

  return phases

In [12]:
def multiple_modules(args):
  """
  generate phasese for multiple modules based on given parameters
  inputs:
    args - ['nb_modules', 'lam_0', 's']
  returns:
    phases - list of numpy arrays containing phases for individual modules
  """

  nb_modules = args['nb_modules']
  lam_0 = args['lam_0']
  s = args['s']

  phases = []

  for i in range(nb_modules):
    lam = lam_0/(s**(nb_modules-i-1))
    phases.append(grid_module(lam, args))

  return phases

### Positon Theta Phase Model
(Adapted from McClain et al., 2019)
**Spatial Input Equation**: (Gaussian to model place cell field)
$$ f = e^{A_x}e^{\frac{-(x-x_0)^2}{2\sigma^2_x}}  $$

**Phase modulation equation**:
$$ g(\theta,x) = e^{k_\theta(cos(\theta - \theta_0(x)) - 1)} $$

**Phase precession equation**:
$$ \theta_0(x) = b_\theta + m_\theta(x - x_0) $$

$ b_\theta = 0 $: preferred phase at the center of the place field

$ m_\theta = 1 $: rate of phase precession

**Rate** (modelled as a product of the spatial input and phase modulation equations):
$$ r(x, \theta) = f(x)\bullet g(\theta, x) $$

_______________________

**Spatial Input Equation** (von Mises tuning for grid fields):
$$ \Omega_j(\vec{x}) =  n_{max}e^{\frac{\kappa}{3} \Sigma_{l=1}^3(cos(\omega \vec{k_l}\bullet (\vec{x} - \vec{c_j})) - 1)}$$

**Phase modultion equation**:    
$$ g(\theta, \vec{x}) = e^{k_\theta(cos(\theta - \theta_0(\vec{x}))-1)} $$

**Phase precession (2-D)**:
$$ \theta_0(\vec{x}) = b_\theta + m_\theta(|\vec{x} - \vec{x_0}|)  $$


**Rate**:
$$ r(\vec{x}, \theta) = f(x)\bullet g(\theta, \vec{x})$$


------------------------------------



In [13]:
def phase_precession(position, args):
  """
  inputs:
    position - x-coordinate (in absolute frame)
    ref - origin x-coordinate
    args: ['b_theta', 'm_theta', 'x_0']
  returns:
    value of theta_0 (scalar)
  """
  b_theta = args['b_theta']
  m_theta = args['m_theta']
  ref = args['x_0']

  theta_0 = b_theta + m_theta*(position - ref)

  return theta_0

In [14]:
def phase_modulation(theta, theta_0, position, args):
  """
  inputs:
    theta - 
    theta_0 - 
    position - 
    args:['']
  returns:
    phase modulation factor value at given postion
  """
  k_theta = args['k_theta']
  theta_0 = phase_precession(position, args)
  modulation_factor = np.exp(k_theta * (np.cos(theta - theta_0)))

  return modulation_factor

In [20]:
def theta_phase(args):
  """
  inputs:
    args[]
  returns:
    theta phase value across all cells
  """
  offset = (np.pi/180)*args['theta_offset']
  nb_cells = args['nb_cells']
  nb_modules = args['nb_modules']

  offset_step = offset/(nb_cells*nb_modules)
  theta = np.arange(0, offset, offset_step)

  return theta

In [37]:
def simulate(args):
  """
  Runs the whole simulation and returns grid cell activity over time as the animal moves in space

  inputs:

  returns:
    
  """

  nb_steps = args['nb_steps']
  dt = args['step_size']
  v = args['animal_velocity'] # assuming constant velocity over the duration
  nb_cells = args['nb_cells']
  nb_modules = args['nb_modules']
  freq = args['theta_freq']
  lam_0 = args['lam_0'] # spatial period of the largest module
  s = args['s']
  omega = 2*np.pi*freq
  
  animal_position = 0 # starts at origin

  module_spatial_phases = multiple_modules(args) # initializes the preferred phases 
  theta_phases = theta_phase(args) # initializes the theta phases linearly for all cells

  activity = np.zeros((nb_cells, nb_steps, nb_modules))
  

  for t in tqdm(range(nb_steps)):

    animal_position += dt*v
    
    theta_phases += omega*dt

    for i in range(nb_modules): # simulate for modules going from dorsal to ventral
      
      lam = lam_0 / (s**(nb_modules-i+1))  
      
      theta = theta_phases[i*nb_cells:(i+1)*nb_cells] # shape: (nb_cells)
      #print(theta.shape) 
      theta_0 = phase_precession(animal_position, args) # scalar

      modulation_factors = phase_modulation(theta, theta_0, animal_position, args) # modulation factors for all cells in the module, shape: (nb_cells)
      #print("Modulation factors shape:", modulation_factors.shape)
      spatial_input = von_mises(animal_position, lam, module_spatial_phases[i], args) # shape:(nb_cells)
      #print("Spatial_input shape:", spatial_input.shape)
      rates = phase_modulation(theta, theta_0, animal_position, args)*von_mises(animal_position, lam, module_spatial_phases[i], args) #shape: (nb_cells,)
      #print(rates.shape)
      activity[:, t, i] = poisson_cells(rates, args)

  return activity

In [38]:
activity = simulate(args)

100%|██████████| 1000/1000 [00:00<00:00, 5557.61it/s]


In [39]:
print(activity.shape)

(10, 1000, 4)


In [40]:
print(activity[2 ,:, 1])

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

In [5]:
a = np.array([3, 3,3, 3,2123,12 ,1243, 23132542,3452,34523,4523,4523,4523,45234,5])
b = a.reshape((3, 1, 5))
print(b)
c = b[:, 0, 1]
print(c.shape)

[[[       3        3        3        3     2123]]

 [[      12     1243 23132542     3452    34523]]

 [[    4523     4523     4523    45234        5]]]
(3,)


In [None]:
def raster_plot_module(activity, args):
  """
  plots a raster plot for cells in a module
  inputs:
    activity
  returns:

  """

  plt.eventplot()
  plt.show()