In [None]:
import sys
dir1 = '/Users/erihe/OneDrive - NTNU/'
if not dir1 in sys.path: sys.path.append(dir1)
from utils_new import *

In [None]:
data_dir = '/Users/erihe/OneDrive - NTNU/Prosjekt/Main'
# if >0, plots the sheet of activity during the simulation on every livePlot'th step
livePlot = 100

# if =0, just give constant velocity. if =1, load trajectory from disk
useRealTrajectory = 1
constantVelocity = 1*[.0005, 0*0.0005] # m/s

## Network/Weight matrix parameters
Nx = 10 # number of cells in x direction
Ny = 10 # number of cells in y direction
ncells = Nx*Ny # total number of cells in network
# grid spacing is approx 1.02 - 0.48*log2(alpha), pg 236
alpha = 30 # input gain, unitless
beta = 0 # input direction bias (i.e. grid orientation), rad
sigma = 0.24 # exponential weight std. deviation
I = 0.3 # peak synaptic strength
T = 0.05 # shift so tail of exponential weights turn inhibitory
tau = 0.8 # relative weight of normalized vs. full-strength synaptic inputs

## Simulation parameters
dt = 20 # time step, ms
simdur = 5*59000 # total simulation time, ms
stabilizationTime = 80 # no-velocity time for pattern to form, ms
tind = -1 # time step number for indexing
t = 0 # simulation time variable, ms
A = np.random.rand(ncells)/np.sqrt(ncells) # activation of each cell
R = np.linalg.inv(np.array([[np.cos(0), np.cos(np.pi/3)],[ np.sin(0), np.sin(np.pi/3)]]))
x = (np.arange(Nx) - 0.5)/Nx
y = (np.arange(Ny) - 0.5)/Ny
X,Y = np.meshgrid(x,y)
x = np.concatenate((X.ravel()[:,np.newaxis],  
                    Y.ravel()[:,np.newaxis]),1).T

[jx,ix] = np.meshgrid(x[0,:],x[0,:])
[jy,iy] = np.meshgrid(x[1,:],x[1,:])
jx = jx.ravel()#[np.newaxis,:]#reshape(jx,1,[])
ix = ix.ravel()#[np.newaxis,:]#reshape(ix,1,[])
jy = jy.ravel()#[np.newaxis,:]#reshape(jy,1,[])
iy = iy.ravel()#[np.newaxis,:]#reshape(iy,1,[])
W = np.ones(ncells)

if useRealTrajectory:
    pos = sio.loadmat(data_dir + '/HaftingTraj_centimeters_seconds.mat')
    pos = pos['pos']
    pos[2,:] *= 1e3;
    end = 100000
    pos = np.concatenate((np.interp(np.arange(0, pos[2,-1], dt), pos[2,:],pos[0,:])[np.newaxis,:],
           np.interp(np.arange(0, pos[2,-1], dt), pos[2,:],pos[1,:])[np.newaxis,:],
           np.interp(np.arange(0, pos[2,-1], dt), pos[2,:],pos[2,:])[np.newaxis,:]),0)
    pos[:2,:] /= 100
    vels = np.concatenate((np.diff(pos[0,:])[:,np.newaxis], np.diff(pos[1,:])[:,np.newaxis]),1)/dt

    
## Possibly load trajectory from disk
## Simulation
#simdur = 50000
spikes = np.zeros((len(pos[0,:]), ncells))

for tind in range(len(pos[0,:])-1):
    v = vels[tind,:] # m/s
    squaredPairwiseDists = np.zeros((9,len(ix)))                   
    squaredPairwiseDists[0,:] = np.square((ix-jx+0+alpha*v[0])) + np.square((iy-jy+0+alpha*v[1]))
    squaredPairwiseDists[1,:] = np.square((ix-jx-1+alpha*v[0])) + np.square((iy-jy-1+alpha*v[1]))
    squaredPairwiseDists[2,:] = np.square((ix-jx-1+alpha*v[0])) + np.square((iy-jy+0+alpha*v[1]))
    squaredPairwiseDists[3,:] = np.square((ix-jx-1+alpha*v[0])) + np.square((iy-jy+1+alpha*v[1]))
    squaredPairwiseDists[4,:] = np.square((ix-jx+0+alpha*v[0])) + np.square((iy-jy-1+alpha*v[1]))
    squaredPairwiseDists[5,:] = np.square((ix-jx+0+alpha*v[0])) + np.square((iy-jy+1+alpha*v[1]))
    squaredPairwiseDists[6,:] = np.square((ix-jx+1+alpha*v[0])) + np.square((iy-jy-1+alpha*v[1]))
    squaredPairwiseDists[7,:] = np.square((ix-jx+1+alpha*v[0])) + np.square((iy-jy+0+alpha*v[1]))
    squaredPairwiseDists[8,:] = np.square((ix-jx+1+alpha*v[0])) + np.square((iy-jy+1+alpha*v[1]))
    squaredPairwiseDists = np.min(squaredPairwiseDists,0)
    squaredPairwiseDists = squaredPairwiseDists.reshape(ncells,ncells).T

    W = I*np.exp(-squaredPairwiseDists/sigma**2) - T

    B = np.matmul(A, W.T)

    A = (1-tau)*B + tau*(B/np.sum(A))

    A[A<0] = 0
    # Save firing field information
    spikes[tind,:] = A

In [None]:
for j in range(3):
    plt.figure()
    ax = plt.axes()
    ax.scatter(pos[0, :tind],pos[1,:tind],c = spikes[:tind,j], s = 10)
    ax.set_aspect(1/ax.get_data_ratio())

In [None]:
dim = 8         # number of principal components
k = 1000          # number of neighbours for downsampling
maxdim = 1        # dimension of homology - often just do 1 as it could be expensive (depends on number of points and neighbours)
metric = 'cosine' # what metric to use for persistence
eps = 0.2        # radial distance downsampling
n_points = 1000   # number of downsampled points for persistence analysis 


spk1 = preprocessing.scale(spikes,axis = 0)
dim_red_spikes_move_scaled, e1, e2, var_exp = pca(spk1, dim = dim)
fig, ax = plt.subplots(1,1)
ax.plot(var_exp[:15])
ax.set_aspect(1/ax.get_data_ratio())
fig, axs = plt.subplots(1,dim, figsize= (10,5), dpi = 120)

for c in range(dim):
    mtot, __, __, circ  = binned_statistic_2d(pos[0, :],
                                              pos[1, :],
                                              dim_red_spikes_move_scaled[:,c], 
                                              statistic = 'mean', 
                                              bins = 30,
                                              expand_binnumbers = True)

    nans = np.isnan(mtot)
    mtot[nans] = np.mean(mtot[~nans])
    mtot = gaussian_filter(mtot, 1)
    plt.viridis()
    vals = np.unique(mtot)
    mtot[nans] = np.nan
    axs[c].imshow(mtot,vmin = vals[int(0.05*len(vals))], vmax = vals[int(0.95*len(vals))])
    axs[c].axis('off')
    axs[c].set_aspect(1/axs[c].get_data_ratio())



In [None]:
startindex = np.argmax(np.sum(np.abs(dim_red_spikes_move_scaled),1))
movetimes1 = radial_downsampling(dim_red_spikes_move_scaled,  epsilon = eps, 
    startindex = startindex)

indstemp  = sample_denoising(dim_red_spikes_move_scaled[movetimes1,:],  k, 
                                   n_points, 1, metric)[0]
indstemp = movetimes1[indstemp]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[indstemp,:]

indstemp = indstemp[:n_points]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[:n_points,:]

d = squareform(pdist(dim_red_spikes_move_scaled[:,:], metric))
thresh = np.max(d[~np.isinf(d)])
persistence = ripser(d, maxdim=1, coeff=47, do_cocycles= True, distance_matrix = True, thresh = thresh)    
dgms = persistence['dgms'] 
plt.figure()
plot_diagrams(dgms, list(np.arange(maxdim+1)), lifetime = True)
plt.show()
plot_barcode(dgms)

coords_ds, coords_ds_consistent = get_coords_consistent(persistence, coeff = 47, ph_classes = [0,1,], bConsistent = True)
fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds[i,np.argsort(coords_ds[i,:])])
ax[2].scatter(*coords_ds[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())


fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds_consistent[i,np.argsort(coords_ds_consistent[i,:])])
ax[2].scatter(*coords_ds_consistent[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())

In [None]:
plt.viridis()

numbins = 12
coords1 = get_coords_all(spikes, coords_ds_consistent[:,:],   
         np.arange(len(spikes)),                             
         indstemp, bPCA = False)
mcstemp, mtot_all = get_ratemaps_center(coords1, spikes, numbins = numbins,)
num_neurons = len(spikes[0,:])
spk_sim_hex = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'hex', t = 0.2)
spk_sim_sqr = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'sqr', t = 0.2)
        
for i in np.array(np.random.rand(4)*len(spikes[0,:]), dtype = int):
    print(i)
    mtot = binned_statistic_2d(pos[0,:][:],pos[1,:][:], spikes[:,i], bins = 25, statistic = 'mean')[0]
    mtot1 = binned_statistic_2d(pos[0,:][:],pos[1,:][:], spk_sim_sqr[:,i], bins = 25, statistic = 'mean')[0]
    mtot3 = binned_statistic_2d(pos[0,:][:],pos[1,:][:], spk_sim_hex[:,i], bins = 25, statistic = 'mean')[0]
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot,
              vmin = 0, 
              vmax = mtot[~np.isnan(mtot)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_simulated_square.png', transparent = True, pad_inches = 0.1)
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot1,
              vmin = 0, 
              vmax = mtot1[~np.isnan(mtot1)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_sqr_square.png', transparent = True, pad_inches = 0.1)

    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot3,
              vmin = 0, 
              vmax = mtot3[~np.isnan(mtot3)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_hex_square.png', transparent = True, pad_inches = 0.1)
    

In [None]:
    
mcstemp, mtot_all = get_ratemaps_center(coords1, spikes, numbins = numbins,)
fig,ax = plt.subplots(1,1)
plot_centered_ratemaps(coords1, spikes, mcstemp, numbins, ax)
fig.savefig('stacked_ratemap_square.png', transparent = True, pad_inches = 0.1)
fig.savefig('stacked_ratemap_square.pdf', transparent = True, pad_inches = 0.1)


In [None]:


im = plt.imread('image044.png')
im = np.array(im)
a1 = np.rot90(im, 1)
sp = -np.inf
sig = 1

cc = np.arctan2(gaussian_filter1d(np.sin(coords1),sigma = sig,axis = 0),
               gaussian_filter1d(np.cos(coords1),sigma = sig,axis = 0))%(2*np.pi)
bCos = True
if bCos:
    eps = 0.0001 
    digitized = np.concatenate((np.digitize(np.cos(cc[:, 0]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis], 
                        np.digitize(np.cos(cc[:, 1]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis]),1)
else:
    digitized = np.concatenate((np.digitize(cc[:, 0], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis], 
                               np.digitize(cc[:, 1], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis]),1)
cc1 = []
for i in range(len(digitized)):
    cc1.append(a1[digitized[i,1]-1, digitized[i,0]-1]) 
fig = plt.figure(figsize = (10,5), dpi = 200)
plt.axis('off')
plt.hsv()
ax1 = fig.add_subplot(111)
ax1.axis('off')
im = ax1.scatter(pos[0,:], pos[1,:], s = 50, c = cc1, alpha  =0.7)
ax1.set_aspect(1/ax1.get_data_ratio())

plt.savefig('square_OF_2dcoords', transparent = True, pad_inches = 0.1)
plt.show()




In [None]:
plt.viridis()
plt.imshow(spikes[150,:].reshape(10,10))
plt.axis('off')
plt.savefig('Pop_vect_square.png', transparent = True,pad_inches = 0.1)
plt.savefig('Pop_vect_square.pdf', transparent = True,pad_inches = 0.1)


### Guanella hexagonal

In [None]:
### Guanella 

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as signal
import scipy.ndimage.filters as filt
arena_size = 50

arenaX = [0,arena_size]
arenaY = [0,arena_size]

## Initial position
Txx = [arenaX[1]/2]
Tyy = [arenaY[1]/2]

def conv(ang):
    x = np.cos(np.radians(ang)) 
    y = np.sin(np.radians(ang)) 
    return x , y

def random_navigation(length):
    thetaList = []

    theta = 90
    counter = 0
    lenght_counter = 0
    for i in range(length):
        lenght_counter += 1

        prevTheta = np.copy(theta)

        if( Txx[-1]<2 ): theta = np.random.randint(-85,85)

        if( Txx[-1]>arena_size-2 ): theta = np.random.randint(95,260)

        if( Tyy[-1]<2 ): theta = np.random.randint(10,170)

        if( Tyy[-1]>arena_size-2 ): theta = np.random.randint(190,350)


        Txx.append( Txx[-1]+conv(theta)[0] + np.random.uniform(-0.5,0.5) )
        Tyy.append( Tyy[-1]+conv(theta)[1] + np.random.uniform(-0.5,0.5)  )

        cx = abs( Txx[-1] - Txx[-2]  )
        cy = abs( Tyy[-1] - Tyy[-2]  )
        h = np.sqrt( cx**2 + cy**2  )
        counter+=h

        if(theta != prevTheta or i == length-1):
            thetaList.append( [prevTheta, conv(prevTheta)[0], conv(prevTheta)[1], counter]  )
            counter = 0
    
    plt.plot(Txx,Tyy, '-')
    plt.show()


random_navigation(5000)

Txx = np.array(Txx)
Tyy = np.array(Tyy)

class Grid():
    def __init__(self, gain):
        
        self.mm = 20
        self.nn = 20
        self.TAO = 0.9
        self.II = 0.3
        self.SIGMA = 0.24
        self.SIGMA2 = self.SIGMA**2
        self.TT = 0.05
        self.grid_gain = gain
        self.grid_layers = len(self.grid_gain)  
        self.grid_activity = np.random.uniform(0,1,(self.mm,self.nn,self.grid_layers))  
        self.distTri = self.buildTopology(self.mm,self.nn)


    def update(self, speedVector):

        self.speedVector = speedVector
        
        grid_ActTemp = []
        for jj in range(0,self.grid_layers):
            rrr = self.grid_gain[jj]*np.exp(1j*0)
            matWeights = self.updateWeight(self.distTri,rrr)
            activityVect = np.ravel(self.grid_activity[:,:,jj])
            activityVect = self.Bfunc(activityVect, matWeights)
            activityTemp = activityVect.reshape(self.mm,self.nn)
            activityTemp += self.TAO *( activityTemp/np.mean(activityTemp) - activityTemp)
            activityTemp[activityTemp<0] = 0

            self.grid_activity[:,:,jj] = (activityTemp-np.min(activityTemp))/(  np.max(activityTemp)-np.min(activityTemp)) * 30  ##Eq 2
                        

    def buildTopology(self,mm,nn):  # Build connectivity matrix     ### Eq 4
        mmm = (np.arange(mm)+(0.5/mm))/mm
        nnn = ((np.arange(nn)+(0.5/nn))/nn)*np.sqrt(3)/2
        xx,yy = np.meshgrid(mmm, nnn)
        posv = xx+1j * yy
        Sdist = [ 0+1j*0, -0.5+1j*np.sqrt(3)/2, -0.5+1j*(-np.sqrt(3)/2), 0.5+1j*np.sqrt(3)/2, 0.5+1j*(-np.sqrt(3)/2), -1+1j*0, 1+1j*0]      
        xx,yy = np.meshgrid( np.ravel(posv) , np.ravel(posv) )
        distmat = xx-yy
        for ii in range(len(Sdist)):
            aaa1 = abs(distmat)
            rrr = xx-yy + Sdist[ii]
            aaa2 = abs(rrr)
            iii = np.where(aaa2<aaa1)
            distmat[iii] = rrr[iii]
        return distmat.transpose()

    def updateWeight(self,topology,rrr): # Slight update on weights based on speed vector.
        matWeights = self.II * np.exp((-abs(topology-rrr*self.speedVector)**2)/self.SIGMA2) - self.TT   ## Eq 3
        return matWeights

    def Bfunc(self,activity, matWeights):  ## Eq 1
        activity += np.dot(activity,matWeights)
        return activity
# this produces grid cells with different scales. Change the list to just one scale for one module
scale = [0.06,]
grid = Grid(scale)

log_grid_cells = []
for i in range(1,Txx.size):

    speedVector = (Txx[i]-Txx[i-1])+1j*(Tyy[i]-Tyy[i-1])

    grid.update(speedVector)
    log_grid_cells.append( grid.grid_activity.flatten()  )
    
log_grid_cells = np.array(log_grid_cells)
xx = np.copy(Txx[1:])
yy = np.copy(Tyy[1:])
dv_levels = [0,5,10]

dv_levels = 10
dv_start = 25
for cell_num in np.arange(dv_levels):

    celula = log_grid_cells[:,dv_start+cell_num]

    pos_spike_idx = np.where( celula > celula.max()*.9 )[0]

    plt.figure()
    plt.scatter(xx,yy)
    plt.scatter(   xx[pos_spike_idx] , yy[pos_spike_idx], c = 'r' )
    #ax.set_aspect('equal', 'box')

In [None]:
dim = 6         # number of principal components
k = 1000          # number of neighbours for downsampling
maxdim = 1        # dimension of homology - often just do 1 as it could be expensive (depends on number of points and neighbours)
metric = 'cosine' # what metric to use for persistence
eps = 0.2        # radial distance downsampling
n_points = 1000   # number of downsampled points for persistence analysis 


spk1 = preprocessing.scale(log_grid_cells,axis = 0)
dim_red_spikes_move_scaled, e1, e2, var_exp = pca(spk1, dim = dim)
fig, ax = plt.subplots(1,1)
ax.plot(var_exp[:15])
ax.set_aspect(1/ax.get_data_ratio())
fig, axs = plt.subplots(1,dim, figsize= (10,5), dpi = 120)

for c in range(dim):
    mtot, __, __, circ  = binned_statistic_2d(xx,
                                              yy,
                                              dim_red_spikes_move_scaled[:,c], 
                                              statistic = 'mean', 
                                              bins = 30,
                                              expand_binnumbers = True)

    nans = np.isnan(mtot)
    mtot[nans] = np.mean(mtot[~nans])
    mtot = gaussian_filter(mtot, 1)
    plt.viridis()
    vals = np.unique(mtot)
    mtot[nans] = np.nan
    axs[c].imshow(mtot,vmin = vals[int(0.05*len(vals))], vmax = vals[int(0.95*len(vals))])
    axs[c].axis('off')
    axs[c].set_aspect(1/axs[c].get_data_ratio())



In [None]:
startindex = np.argmax(np.sum(np.abs(dim_red_spikes_move_scaled),1))
movetimes1 = radial_downsampling(dim_red_spikes_move_scaled,  epsilon = eps, 
    startindex = startindex)

indstemp  = sample_denoising(dim_red_spikes_move_scaled[movetimes1,:],  k, 
                                   n_points, 1, metric)[0]
indstemp = movetimes1[indstemp]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[indstemp,:]

indstemp = indstemp[:n_points]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[:n_points,:]

d = squareform(pdist(dim_red_spikes_move_scaled[:,:], metric))
thresh = np.max(d[~np.isinf(d)])
persistence = ripser(d, maxdim=1, coeff=47, do_cocycles= True, distance_matrix = True, thresh = thresh)    
dgms = persistence['dgms'] 
plt.figure()
plot_diagrams(dgms, list(np.arange(maxdim+1)), lifetime = True)
plt.show()
plot_barcode(dgms)

coords_ds, coords_ds_consistent = get_coords_consistent(persistence, coeff = 47, ph_classes = [0,1,], bConsistent = True)
fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds[i,np.argsort(coords_ds[i,:])])
ax[2].scatter(*coords_ds[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())


fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds_consistent[i,np.argsort(coords_ds_consistent[i,:])])
ax[2].scatter(*coords_ds_consistent[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())

In [None]:
coords1 = get_coords_all(log_grid_cells, coords_ds_consistent[:,:],   
         np.arange(len(log_grid_cells)), indstemp, bPCA = False)


In [None]:
for c in [0,1]:
    fig, axs = plt.subplots(1,1)
    mtot, __, __, circ  = binned_statistic_2d(xx[:],
                                              yy[:],
                                              coords1[:,c], 
                                              statistic = circmean, 
                                              bins = 50,
                                              expand_binnumbers = True)

    nans = np.isnan(mtot)
    sintot = np.sin(mtot)
    costot = np.cos(mtot)
    sintot[nans] = np.mean(sintot[~nans])
    costot[nans] = np.mean(costot[~nans])
    sintot = gaussian_filter(sintot,1)
    costot = gaussian_filter(costot,1)
    mtot = np.cos(np.arctan2(sintot, costot))
    mtot = gaussian_filter(mtot,1)
    plt.viridis()
    mtot[nans] = np.nan
    axs.imshow(mtot)
    axs.axis('off')
    axs.set_aspect(1/axs.get_data_ratio())


In [None]:
plt.viridis()


mcstemp, mtot_all = get_ratemaps_center(coords1, log_grid_cells, numbins = numbins,)
num_neurons = len(log_grid_cells[0,:])
spk_sim_hex = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'hex', t = 0.2)
spk_sim_sqr = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'sqr', t = 0.2)
        
for i in np.array(np.random.rand(4)*num_neurons, dtype = int):
    print(i)
    mtot = binned_statistic_2d(xx,yy, log_grid_cells[:,i], bins = 25, statistic = 'mean')[0]
    mtot1 = binned_statistic_2d(xx,yy, spk_sim_sqr[:,i], bins = 25, statistic = 'mean')[0]
    mtot3 = binned_statistic_2d(xx,yy, spk_sim_hex[:,i], bins = 25, statistic = 'mean')[0]
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot,
              vmin = 0, 
              vmax = mtot[~np.isnan(mtot)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_simulated_square.png', transparent = True, pad_inches = 0.1)
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot1,
              vmin = 0, 
              vmax = mtot1[~np.isnan(mtot1)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_sqr_square.png', transparent = True, pad_inches = 0.1)

    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot3,
              vmin = 0, 
              vmax = mtot3[~np.isnan(mtot3)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_hex_square.png', transparent = True, pad_inches = 0.1)
    

In [None]:

mcstemp, mtot_all = get_ratemaps_center(coords1, log_grid_cells, numbins = numbins,)

fig,ax = plt.subplots(1,1)
plot_centered_ratemaps(coords1, log_grid_cells, mcstemp, numbins, ax)
fig.savefig('stacked_ratemap_guanella.png', transparent = True, pad_inches = 0.1)
fig.savefig('stacked_ratemap_guanella.pdf', transparent = True, pad_inches = 0.1)


In [None]:

im = plt.imread('image044.png')
im = np.array(im)
a1 = np.rot90(im, 1)
sp = -np.inf
sig = 1

cc = np.arctan2(gaussian_filter1d(np.sin(coords1),sigma = sig,axis = 0),
               gaussian_filter1d(np.cos(coords1),sigma = sig,axis = 0))%(2*np.pi)
bCos = True
if bCos:
    eps = 0.0001 
    digitized = np.concatenate((np.digitize(np.cos(cc[:, 0]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis], 
                        np.digitize(np.cos(cc[:, 1]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis]),1)
else:
    digitized = np.concatenate((np.digitize(cc[:, 0], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis], 
                               np.digitize(cc[:, 1], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis]),1)
cc1 = []
for i in range(len(digitized)):
    cc1.append(a1[digitized[i,1]-1, digitized[i,0]-1]) 
fig = plt.figure(figsize = (10,5), dpi = 200)
plt.axis('off')
plt.hsv()
ax1 = fig.add_subplot(111)
ax1.axis('off')
im = ax1.scatter(xx, yy, s = 50, c = cc1, alpha  =0.7)
ax1.set_aspect(1/ax1.get_data_ratio())

plt.savefig('Guanella_OF_2dcoords', transparent = True, pad_inches = 0.1)
plt.show()



In [None]:

plt.viridis()
plt.imshow(log_grid_cells[350,:].reshape(20,20))
plt.axis('off')
plt.savefig('Pop_vect_guanella.png', transparent = True,pad_inches = 0.1)
plt.savefig('Pop_vect_guanella.pdf', transparent = True,pad_inches = 0.1)


## Couey 

In [None]:
from utils_torus import *
import numpy as np
#f = np.load('couey_300random_10ds.npz', allow_pickle = True)
#sspikes1  = f['sspikes'].T
#f.close()

In [None]:
f = np.load('/Users/erihe/OneDrive - NTNU/Prosjekt/Toroidal_topology_grid_cell_data/rat_r_day1_grid_modules_1_2_3.npz', 
            allow_pickle = True)
t = f['t']
x = f['x']
y = f['y']
aa = f['azimuth']
f.close()



min_of_1, max_of_1 = 7457, 14778
tt, xx, yy, speed, aa = get_pos(x, y, t, aa,
                                   min_time = min_of_1, max_time = max_of_1, dt_orig = 0.01, res = 100000)



posx = xx[np.arange(0,180000,2)]
posy = yy[np.arange(0,180000,2)]
post = np.arange(0,180000,2)/100
post = post[np.isfinite(posx)]
posy = posy[np.isfinite(posx)]
posx = posx[np.isfinite(posx)]
post = post[np.isfinite(posy)]
posx = posx[np.isfinite(posy)]
posy = posy[np.isfinite(posy)]
post *= 1000

side = max(max(posx)-min(posx), max(posy)-min(posy))
posx *= 1./side
posy *= 1./side
posx -= min(posx)
posy -= min(posy)

tnew = np.arange(0, 599000, 1)
posx = np.interp(tnew, post, posx)
posy = np.interp(tnew, post, posy)
post = tnew

#Get angles and velocities
angs = np.zeros(len(post))
angs[:-1] = np.arctan2(posy[1:]-posy[:-1], posx[1:]-posx[:-1])
angs[-1] = angs[-2]
speeds = 1000.*np.sqrt((posx[1:]-posx[:-1])**2+(posy[1:]-posy[:-1])**2)
nums = len(speeds)
 
numbumps = 4 
# parameters of the model
extinp = 1.
alpha = 0.15
ell = 2.
inh = -0.02
R  = 15.
Nx = 28
Ny = 44

if(numbumps==4):
    Nx*=2
if(numbumps==8):
    Nx*=2
    Ny*=2
NG=Nx*Ny 

### MAKE CONNECTIVITY WITH AN OFFSET RELATIVE TO PREFERRED DIRECTION
theta = np.zeros([Nx,Ny])
theta[0:Nx:2,0:Ny:2] = 0

theta[1:Nx:2,0:Ny:2] = 1
theta[0:Nx:2,1:Ny:2] = 2
theta[1:Nx:2,1:Ny:2] = 3
theta = 0.5*np.pi*theta
theta = np.ravel(theta)
xes = np.zeros([Nx,Ny])
yes = np.zeros([Nx,Ny])
for x in range(Nx):
    for y in range(Ny):
        xes[x,y] = x
        yes[x,y] = y
xes = np.ravel(xes)
yes = np.ravel(yes)
Rsqrd = R**2
W = np.zeros([NG,NG])
for xi in range(Nx):
    xdiffA = abs(xes-xi-ell*np.cos(theta))
    xdiffB = Nx-xdiffA
    xdiffA = xdiffA**2
    xdiffB = xdiffB**2
    for y in range(Ny):
        n = xi*Ny+y
        ydiffA = abs(yes-y-ell*np.sin(theta))
        ydiffB = Ny-ydiffA
        ydiffA = ydiffA**2
        ydiffB = ydiffB**2
        d = xdiffA+ydiffA
        W[d<Rsqrd,n] += inh
        d = xdiffB+ydiffA
        W[d<Rsqrd,n] += inh
        d = xdiffA+ydiffB
        W[d<Rsqrd,n] += inh
        d = xdiffB+ydiffB
        W[d<Rsqrd,n] += inh
minx = min([min(posy),min(posx)])
maxx = max([max(posy),max(posx)])

t = 100000
posx = posx[np.arange(0, len(posx),10)]
posy = posy[np.arange(0, len(posy),10)]
xx, yy = posx, posy

In [None]:
xes=0
yes=0
N = NG

S = (np.random.rand(Nx*Ny) > 0.1)*1.0
for t in range(2000):
    S = S + 0.1*(-S + np.maximum((extinp+np.matmul(S,W)),0.))
    S[S<0.00001] = 0.

maxS = max(np.ravel(S))
minx = min([min(posy),min(posx)])
maxx = max([max(posy),max(posx)])
whichn = np.arange(len(S))#random.sample(np.arange(len(S)), 100)
nodes1 = np.zeros([len(whichn), int(nums/10)])

for t in range(0, nums):
    S = S + 0.1*(-S + np.maximum((extinp+np.matmul(S,W)+alpha*speeds[t]*np.cos(angs[t]-theta)),0.))
    if(np.mod(t,10)==0):
        S[S<0.0001] = 0.
        nodes1[:,int(t/10)] = S#[]#>0.3*maxS ##some fake spikes
    if(np.mod(t,5000)==0 and t>2):
        print('%2.2f percent done'%(float(t)*100/float(nums)))
        


In [None]:
plt.viridis()
from scipy.stats import binned_statistic_2d, pearsonr
num_neurons = len(nodes1[:,0])
inds = np.arange(num_neurons)
np.random.shuffle(inds)
#for i in inds[:10]:
mtot, x_edge, y_edge, circ = binned_statistic_2d(posx[:int(t/10)],
                                                 posy[:int(t/10)], 
                                                 nodes1[80,:], 
    statistic='mean', bins=50, range=None, expand_binnumbers=True)
plt.figure()
plt.imshow(mtot)

In [None]:
np.savez('Couey_simulation',  nodes1 = nodes1, posx = posx, posy = posy)

In [None]:
f = np.load('Couey_simulation.npz', allow_pickle = True)
nodes1 = f['nodes1']
posx = f['posx']
posy = f['posy']
f.close()

In [None]:
dim = 6         # number of principal components
k = 1000          # number of neighbours for downsampling
maxdim = 1        # dimension of homology - often just do 1 as it could be expensive (depends on number of points and neighbours)
metric = 'cosine' # what metric to use for persistence
eps = 0.2        # radial distance downsampling
n_points = 1000   # number of downsampled points for persistence analysis 


spk1 = preprocessing.scale(nodes1.T,axis = 0)
dim_red_spikes_move_scaled, e1, e2, var_exp = pca(spk1, dim = dim)

dim_red_spikes_move_scaled = np.real(dim_red_spikes_move_scaled)
dim_red_spikes_move_scaled /=np.sqrt(np.real(e2[:dim]))
fig, ax = plt.subplots(1,1)
ax.plot(var_exp[:15])
ax.set_aspect(1/ax.get_data_ratio())
fig, axs = plt.subplots(1,dim, figsize= (10,5), dpi = 120)

for c in range(dim):
    mtot, __, __, circ  = binned_statistic_2d(posx[:len(spk1)],
                                              posy[:len(spk1)],
                                              dim_red_spikes_move_scaled[:,c], 
                                              statistic = 'mean', 
                                              bins = 30,
                                              expand_binnumbers = True)

    nans = np.isnan(mtot)
    mtot[nans] = np.mean(mtot[~nans])
    mtot = gaussian_filter(mtot, 1)
    plt.viridis()
    vals = np.unique(mtot)
    mtot[nans] = np.nan
    axs[c].imshow(mtot,vmin = vals[int(0.05*len(vals))], vmax = vals[int(0.95*len(vals))])
    axs[c].axis('off')
    axs[c].set_aspect(1/axs[c].get_data_ratio())



In [None]:
startindex = np.argmax(np.sum(np.abs(dim_red_spikes_move_scaled),1))
movetimes1 = radial_downsampling(dim_red_spikes_move_scaled,  epsilon = eps, 
    startindex = startindex)

indstemp  = sample_denoising(dim_red_spikes_move_scaled[movetimes1,:],  k, 
                                   n_points, 1, metric)[0]
indstemp = movetimes1[indstemp]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[indstemp,:]

indstemp = indstemp[:n_points]
dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[:n_points,:]

d = squareform(pdist(dim_red_spikes_move_scaled[:,:], metric))
thresh = np.max(d[~np.isinf(d)])
persistence = ripser(d, maxdim=1, coeff=47, do_cocycles= True, distance_matrix = True, thresh = thresh)    
dgms = persistence['dgms'] 
plt.figure()
plot_diagrams(dgms, list(np.arange(maxdim+1)), lifetime = True)
plt.show()
plot_barcode(dgms)

coords_ds, coords_ds_consistent = get_coords_consistent(persistence, coeff = 47, ph_classes = [0,1,], bConsistent = True)
fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds[i,np.argsort(coords_ds[i,:])])
ax[2].scatter(*coords_ds[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())


fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
for i in range(len(coords_ds)):
    ax[i].plot(coords_ds_consistent[i,np.argsort(coords_ds_consistent[i,:])])
ax[2].scatter(*coords_ds_consistent[:2,:], s = 100)
for i in range(3):
    ax[i].set_aspect(1/ax[i].get_data_ratio())

In [None]:
coords1 = get_coords_all(nodes1.T, coords_ds_consistent[:,:],   
         np.arange(len(nodes1[0,:])), indstemp, bPCA = False)


In [None]:
for c in [0,1]:
    fig, axs = plt.subplots(1,1)
    mtot, __, __, circ  = binned_statistic_2d(posx[:len(spk1)],
                                              posy[:len(spk1)],
                                              coords1[:,c], 
                                              statistic = circmean, 
                                              bins = 50,
                                              expand_binnumbers = True)

    nans = np.isnan(mtot)
    sintot = np.sin(mtot)
    costot = np.cos(mtot)
    sintot[nans] = np.mean(sintot[~nans])
    costot[nans] = np.mean(costot[~nans])
    sintot = gaussian_filter(sintot,1)
    costot = gaussian_filter(costot,1)
    mtot = np.cos(np.arctan2(sintot, costot))
    mtot = gaussian_filter(mtot,1)
    plt.viridis()
    mtot[nans] = np.nan
    axs.imshow(mtot)
    axs.axis('off')
    axs.set_aspect(1/axs.get_data_ratio())


In [None]:
xx,yy = posx[:len(spk1)], posy[:len(spk1)]


In [None]:
plt.viridis()

numbins = 12
mcstemp, mtot_all = get_ratemaps_center(coords1, nodes1.T, numbins = numbins,)
num_neurons = len(nodes1)
spk_sim_hex = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'hex', t = 0.2)
spk_sim_sqr = get_sim(coords1, mcstemp, numbins = numbins, simtype = 'sqr', t = 0.2)
        
for i in np.array(np.random.rand(4)*num_neurons, dtype = int):
    print(i)
    mtot = binned_statistic_2d(xx,yy, nodes1[i,:], bins = 25, statistic = 'mean')[0]
    mtot1 = binned_statistic_2d(xx,yy, spk_sim_sqr[:,i], bins = 25, statistic = 'mean')[0]
    mtot3 = binned_statistic_2d(xx,yy, spk_sim_hex[:,i], bins = 25, statistic = 'mean')[0]
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot,
              vmin = 0, 
              vmax = mtot[~np.isnan(mtot)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_simulated_square.png', transparent = True, pad_inches = 0.1)
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot1,
              vmin = 0, 
              vmax = mtot1[~np.isnan(mtot1)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_sqr_square.png', transparent = True, pad_inches = 0.1)

    fig, ax = plt.subplots(1,1)
    ax.imshow(mtot3,
              vmin = 0, 
              vmax = mtot3[~np.isnan(mtot3)].max())
    ax.axis('off')
    fig.tight_layout()
    plt.savefig('ratemap_hex_square.png', transparent = True, pad_inches = 0.1)
    

In [None]:
fig,ax = plt.subplots(1,1)
plot_centered_ratemaps(coords1, nodes1.T, mcstemp, numbins, ax)
fig.savefig('stacked_ratemap_couey.png', transparent = True, pad_inches = 0.1)
fig.savefig('stacked_ratemap_couey.pdf', transparent = True, pad_inches = 0.1)


In [None]:

im = plt.imread('image044.png')
im = np.array(im)
a1 = np.rot90(im, 1)
sp = -np.inf
sig = 1

cc = np.arctan2(gaussian_filter1d(np.sin(coords1),sigma = sig,axis = 0),
               gaussian_filter1d(np.cos(coords1),sigma = sig,axis = 0))%(2*np.pi)
bCos = True
if bCos:
    eps = 0.0001 
    digitized = np.concatenate((np.digitize(np.cos(cc[:, 0]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis], 
                        np.digitize(np.cos(cc[:, 1]), np.linspace(-1-eps,1+eps, len(a1)+1))[:,np.newaxis]),1)
else:
    digitized = np.concatenate((np.digitize(cc[:, 0], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis], 
                               np.digitize(cc[:, 1], np.linspace(0,2*np.pi, len(a1)+1))[:,np.newaxis]),1)
cc1 = []
for i in range(len(digitized)):
    cc1.append(a1[digitized[i,1]-1, digitized[i,0]-1]) 
fig = plt.figure(figsize = (10,5), dpi = 200)
plt.axis('off')
plt.hsv()
ax1 = fig.add_subplot(111)
ax1.axis('off')
im = ax1.scatter(xx, yy, s = 50, c = cc1, alpha  =0.7)
ax1.set_aspect(1/ax1.get_data_ratio())

plt.savefig('Couey_OF_2dcoords', transparent = True, pad_inches = 0.1)
plt.show()



In [None]:
64**2

In [None]:

plt.viridis()
plt.imshow(nodes1[:,350].reshape(56,44))
plt.axis('off')
plt.savefig('Pop_vect_couey.png', transparent = True,pad_inches = 0.1)
plt.savefig('Pop_vect_couey.pdf', transparent = True,pad_inches = 0.1)
