Define functions

In [None]:
from utils_new import *
    

Get data from Waaga et al (2022) and load

In [None]:
data_dir = '/Users/erihe/OneDrive - NTNU/Prosjekt/Main/waaga'
fnames = glob.glob(data_dir + '/*')
for ff in fnames:
    if ff.find('readme')>-1:
        continue
    f = np.load(ff, allow_pickle = True)[()]
    print(f['description'])
    print(np.bincount(f['module_id'].astype(int)))
    print('')


Extract data from mouse and session

In [None]:
taskid = -1
f = np.load(fnames[taskid], allow_pickle = True)[()]
sess_names = [f['task'][task]['name'] for task in f['task']]
print(f['description'])
print(np.bincount(f['module_id'].astype(int)))
print('')

spike_timestamp_dark = f['task'][0]['spike_timestamp']
spikes_cluster_id_dark = f['task'][0]['spike_cluster_id']
t_dark = f['task'][0]['tracking']['t']
x_dark = f['task'][0]['tracking']['x']
y_dark = f['task'][0]['tracking']['y']
z_dark = f['task'][0]['tracking']['z']
hd_dark = f['task'][0]['tracking']['hd']
unit_id = f['unit_id']
module_id = f['module_id']

spike_timestamp_light = f['task'][1]['spike_timestamp']
spikes_cluster_id_light = f['task'][1]['spike_cluster_id']
t_light = f['task'][1]['tracking']['t']
x_light = f['task'][1]['tracking']['x']
y_light = f['task'][1]['tracking']['y']
z_light = f['task'][1]['tracking']['z']
hd_light = f['task'][1]['tracking']['hd']

spikes_dark = {}
it = 0
for unit in unit_id:
    spikes_dark[it] = spike_timestamp_dark[spikes_cluster_id_dark==unit]
    it += 1
    
spikes_light = {}
it = 0
for unit in unit_id:
    spikes_light[it] = spike_timestamp_light[spikes_cluster_id_light==unit]
    it += 1
    
del f, spike_timestamp_dark, spike_timestamp_light, spikes_cluster_id_dark, spikes_cluster_id_light

Compute mean firing rate

In [None]:
min_dark_1, max_dark_1 = t_dark[0], t_dark[-1]
min_light_1, max_light_1 = t_light[0], t_light[-1]


num_neurons = len(spikes_dark)
meanRate_dark = np.zeros((num_neurons))
for i in range(num_neurons):
    spk = spikes_dark[i].copy()
    meanRate_dark[i] = np.sum((spk>min_dark_1) & (spk<max_dark_1))
meanRate_dark/=(max_dark_1-min_dark_1)


Sample spatial positions

In [None]:
tt_dark, xx_dark, yy_dark, speed_dark, hd_of_dark = get_pos(x_dark, y_dark, t_dark, hd_dark,
                                   min_time = min_dark_1, max_time = max_dark_1, dt_orig = 0.1, res = 10000)


tt1_dark, xx1_dark, yy1_dark, speed1_dark, hd_of_dark1 = get_pos(x_dark, y_dark, t_dark, hd_dark,
                                   min_time = min_dark_1, max_time = max_dark_1, dt_orig = 0.01, res = 100000)


tt_light, xx_light, yy_light, speed_light, hd_of_light = get_pos(x_light, y_light, t_light, hd_light,
                                   min_time = min_light_1, max_time = max_light_1, dt_orig = 0.1, res = 10000)

tt_light1, xx_light1, yy_light1, speed_light1, hd_of_light1 = get_pos(x_light, y_light, t_light, hd_light,
                                   min_time = min_light_1, max_time = max_light_1, dt_orig = 0.01, res = 100000)


Compute cross correlations

In [None]:
#### 'firing_rate' is more temporally precise, but might be slower
t0 = time.time()
sspikes_dark, __ = firing_rate(spikes_dark, sigma = 5, min_time = min_dark_1, max_time = max_dark_1, 
                             dt_orig = 0.1, res = 10000)
print(time.time()- t0)
sspikes_dark = np.sqrt(sspikes_dark)
t0 = time.time()
Xcorr_dark =  cross_corr_dist(sspikes_dark, lencorr = 30)
print(time.time()-t0)


#### 'firing_rate' is more temporally precise, but might be slower
t0 = time.time()
sspikes_light, __ = firing_rate(spikes_light, sigma = 5, min_time = min_light_1, max_time = max_light_1, 
                             dt_orig = 0.1, res = 10000)
print(time.time()- t0)
sspikes_light = np.sqrt(sspikes_light)
t0 = time.time()
Xcorr_light =  cross_corr_dist(sspikes_light, lencorr = 30)
print(time.time()-t0)


Save for backup

In [None]:
fname = fnames[taskid].replace(data_dir + '/', '').replace('.npy', '')
np.savez('waaga_Xcorrs' + fname, Xcorr_dark = Xcorr_dark,Xcorr_light = Xcorr_light,)


In [None]:
ff = glob.glob('waaga_Xcorrs*')
f = np.load(ff[0])
Xcorr_dark = f['Xcorr_dark'][()]
Xcorr_light = f['Xcorr_light'][()]
f.close()

Compute rate maps, spatial autocorrelograms and temporal autocorrelations for all non-mua neurons

In [None]:
sigma = 1
dt_curr = 0.1
time_resolution = 10000
spikes1 = {}
for i, s in enumerate(spikes_light):
    spikes1[i] = spikes_light[s][(spikes_light[s]>min_light_1) & (spikes_light[s]<min_light_1+1000)]
sspikes_light = firing_rate(spikes_light, sigma = sigma, min_time = min_light_1, max_time = max_light_1, 
                          dt_orig = dt_curr, res = time_resolution)[0]
sspikes_light = np.sqrt(sspikes_light)

num_neurons = len(sspikes_light[0,:])
rmap = np.zeros((num_neurons, 25,25))
acorr = np.zeros((num_neurons, 25,25))
for i in range(num_neurons):
    currmap = binned_statistic_2d(xx_light,yy_light, sspikes_light[:,i], statistic = 'mean', bins = 25)[0]
    nans = np.isnan(currmap)
    currmap[nans] = np.mean(currmap[~nans])
    currmap = gaussian_filter(currmap, 1.)
    acorr[i,:,:] = pearson_correlate2d(currmap, currmap)
    rmap[i,:,:] = currmap
acorr_tmp = np.zeros((num_neurons, 25**2))
for i in range(num_neurons):
    acorr_tmp[i,:] = acorr[i].flatten().copy()

t_acorr = get_temporal_acorr(spikes1)
t_acorr = t_acorr[:, 200:]
t_acorr = t_acorr.astype(float)
for i in range(len(t_acorr[:,0])):
    t_acorr[i,:] = t_acorr[i,:]/t_acorr[i,0]
t_acorr[:,0] = 0
t_acorr = gaussian_filter1d(t_acorr[:, :],sigma = 2, axis = 1)

hd_info = np.zeros(num_neurons)
for i in range(num_neurons):
    mtot, __,  circ  = binned_statistic(hd_of_light, sspikes_light[:,i],
                                              statistic = 'mean', 
                                              bins = 30)
    mu = np.mean(sspikes_light[:,i])
    hd_info[i] = information_score_1d(mtot, circ-1, mu)

scores_waaga = (('rmap', rmap),
          ('acorr2d', acorr),
          ('tacorrs', t_acorr),
          ('sum',meanRate_dark),
          ('hd_info', hd_info))

Cluster neurons using agglomerative clustering with average linkage, of given threshold.

In [None]:
dd1 =np.zeros_like(Xcorr_dark)
for x1 in [Xcorr_dark, ]:
    x1[np.isnan(x1)] =1 #np.median(x1[~np.isnan(x1)])
    dd1 += squareform(pdist(np.square(x1), 'correlation'))
dvals = dd1[np.triu_indices(len(dd1),1)]

thr = 0.89
print(thr)
plt.viridis()
ind1 = get_ind(dd1,thr, linkage = 'average')


In [None]:
fig = plt.figure(dpi = 300)
d2 = dd1.copy()
d2 = d2[np.argsort(ind1), :]
d2 = d2[:,np.argsort(ind1)]    

plt.imshow(d2, 
           vmin = np.percentile(d2.flatten(),5),
           vmax = np.percentile(d2.flatten(),95),
          )
plt.axis('off')
bin_ind = np.bincount(ind1)
numneuronsind = np.flip(np.argsort(bin_ind))
print('num: ', bin_ind[numneuronsind[:10]])
print('ind: ', numneuronsind[:10])
plt.savefig('Figures/ExtFig1d_waaga_cross.png', transparent = True,  bbox_inches='tight', pad_inches=0.2)
plt.savefig('Figures/ExtFig1d_waaga_cross.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)

data = []
data_names = []
for i in range(len(d2)):
    data.append(pd.Series(d2[:,i]))
    data_names.extend(['col_' + str(i)])
df = pd.concat(data, ignore_index=True, axis=1)            
df.columns = data_names
df.to_excel("Source_data/ExtFig1d_waaga_cross.xlsx", sheet_name='crossmat')  


In [None]:
plt.plot(module_id)
plt.plot(ind1)

In [None]:


sig = 2
for i in np.unique(ind1):
    mod_ind1s = np.where(ind1==i)[0]
    if len(mod_ind1s)>=15:
        print('Mod ', i)
        
        print('num_neurons ', len(mod_ind1s))
        print(mod_ind1s)
        sspk2 = sspikes_light[:,mod_ind1s]
        sspk2 = np.sqrt(sspk2)
        sspk2 = gaussian_filter1d(sspk2,sigma = sig, axis = 0)

        scores_cluster(sspk2,scores_waaga, mod_ind1s,xx_light,yy_light,  spk2 = [], num_example = 6, dim = 10, bUMAP = False)
        plt.show()

plt.figure()
plt.show()



In [None]:
dim = 6         # number of principal components
n_points = 2000   # number of downsampled points for persistence analysis 
k = 1000          # number of neighbours for downsampling
maxdim = 1        # dimension of homology - often just do 1 as it could be expensive (depends on number of points and neighbours)
metric = 'cosine' # what metric to use for persistence
sp = 50           # speed1 threshold
eps = 0.5        # radial distance downsampling
sigma = 10         # time bins
dt_curr = 0.01    # time bin interval
time_resolution = 100000 
sspk_d = {}
sspk_l = {}
coords_ds_all = {}
movetimes0_all = {}
indstemp_all = {}
for ii in [0,1,2]:
    print('Cluster ' + str(ii))
    mod_ind1s = np.where(ind1 == ii)[0]
    spktimes_tmp = {}
    count = 0
    for i, spk in enumerate(spikes_light):
        if i in mod_ind1s:
            spktimes_tmp[count] = spikes_light[spk]
            count += 1
    print('num_neurons in light cluster ', count)
    sspikes_light1 = firing_rate(spktimes_tmp, sigma = sigma, min_time = min_light_1, max_time = max_light_1, 
                              dt_orig = dt_curr, res = time_resolution)[0]
    sspikes_light1 = np.sqrt(sspikes_light1)
    sspk_l[ii] = sspikes_light1.copy()
    
    spktimes_tmp = {}
    count = 0
    for i, spk in enumerate(spikes_dark):
        if i in mod_ind1s:
            spktimes_tmp[count] = spikes_dark[spk]
            count += 1
    print('num_neurons in dark cluster ', count)
    sspikes_dark1 = firing_rate(spktimes_tmp, sigma = sigma, min_time = min_dark_1, max_time = max_dark_1, 
                              dt_orig = dt_curr, res = time_resolution)[0]
    sspikes_dark1 = np.sqrt(sspikes_dark1)
    sspk_d[ii] = sspikes_dark1


    movetimes0 = np.arange(0,len(sspikes_dark1),1)
    movetimes0 = movetimes0[speed1_dark>5]
#    np.savez('waaga_darkcluster_' + str(ii), sspikes_of = sspikes_dark1[movetimes0])
    spk1 = preprocessing.scale(sspikes_dark1[movetimes0],axis = 0)
    dim_red_spikes_move_scaled, e1, e2, var_exp = pca(spk1, dim = dim)
    fig, ax = plt.subplots(1,1)
    ax.plot(var_exp[:15])
    ax.set_aspect(1/ax.get_data_ratio())
    fig, axs = plt.subplots(1,dim, figsize= (10,5), dpi = 120)
    for c in range(dim):
        mtot, __, __, circ  = binned_statistic_2d(xx1_dark[movetimes0],
                                                  yy1_dark[movetimes0],
                                                  dim_red_spikes_move_scaled[:,c], 
                                                  statistic = 'mean', 
                                                  bins = 30,
                                                  expand_binnumbers = True)

        nans = np.isnan(mtot)
        mtot[nans] = np.mean(mtot[~nans])
        mtot = gaussian_filter(mtot, 1)
        plt.viridis()
        vals = np.unique(mtot)
        mtot[nans] = np.nan
        axs[c].imshow(mtot,vmin = vals[int(0.05*len(vals))], vmax = vals[int(0.95*len(vals))])
        axs[c].axis('off')
        axs[c].set_aspect(1/axs[c].get_data_ratio())


    plt.show()
    dim_red_spikes_move_scaled /= np.sqrt(e2[:dim])
    startindex = np.argmax(np.sum(np.abs(dim_red_spikes_move_scaled),1))
    movetimes1 = radial_downsampling(dim_red_spikes_move_scaled,  epsilon = eps, 
        startindex = startindex)
    indstemp  = sample_denoising(dim_red_spikes_move_scaled[movetimes1,:],  k, 
                                       n_points, 1, metric)[0]
    indstemp = movetimes1[indstemp]
    dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[indstemp,:]

    indstemp = indstemp[:n_points]
    dim_red_spikes_move_scaled = dim_red_spikes_move_scaled[:n_points,:]

    d = squareform(pdist(dim_red_spikes_move_scaled[:,:], metric))
    thresh = np.max(d[~np.isinf(d)])
    if maxdim > 1:
        hom_dims = list(range(maxdim+1))
        VR = VietorisRipsPersistence(
        homology_dimensions=hom_dims,
        metric='precomputed',
        coeff=47,
        max_edge_length= thresh,
        collapse_edges=False,  # True faster?
        n_jobs=None  # -1 faster?
        )
        diagrams = VR.fit_transform([d])
        dgms = from_giotto_to_ripser(diagrams[0])
        persistence = ripser(d, maxdim=1, coeff=47, do_cocycles= True, distance_matrix = True, thresh = thresh)    
    else:
        persistence = ripser(d, maxdim=1, coeff=47, do_cocycles= True, distance_matrix = True, thresh = thresh)    
        dgms = persistence['dgms'] 
    plt.figure()
    plot_diagrams(dgms, list(np.arange(maxdim+1)), lifetime = True)
    plt.show()
    plot_barcode(dgms)

    coords_ds, coords_ds_consistent = get_coords_consistent(persistence, coeff = 47, ph_classes = [0,1,], bConsistent = True)
    fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
    for i in range(len(coords_ds)):
        ax[i].plot(coords_ds[i,np.argsort(coords_ds[i,:])])
    ax[2].scatter(*coords_ds[:2,:], s = 100)
    for i in range(3):
        ax[i].set_aspect(1/ax[i].get_data_ratio())


    fig, ax = plt.subplots(1,3, figsize = (10,5), dpi = 120)
    for i in range(len(coords_ds)):
        ax[i].plot(coords_ds_consistent[i,np.argsort(coords_ds_consistent[i,:])])
    ax[2].scatter(*coords_ds_consistent[:2,:], s = 100)
    for i in range(3):
        ax[i].set_aspect(1/ax[i].get_data_ratio())

    coords_ds_all[ii] = coords_ds_consistent.copy()
    movetimes0_all[ii] = movetimes0.copy()
    indstemp_all[ii] = indstemp.copy()


In [None]:
diagrams_roll = {}
count = -1
for i in range(0,101):
    if i == 0:
        f = np.load('Waaga/mod0_dgms' + str(0) + '.npz', allow_pickle = True)
        dgms = f['dgms'][()][0]
        f.close()
        continue
    try :
        f = np.load('Waaga/mod0_dgms' + str(i) + '.npz', allow_pickle = True)
        dgmstmp = f['dgms'][()]
        f.close()
        count += 1
        diagrams_roll[count] = dgmstmp[0]
    except:
        continue
xmax = 1
print(count)


plot_barcode(dgms, diagrams_roll = diagrams_roll, percshuf = 99, dpi = 300, SaveSourceDataName = 'Waaga_mod0_barcode')
plt.savefig('Figures/Waaga_mod0_barcode.png', transparent = True,  bbox_inches='tight', pad_inches=0.2)
plt.savefig('Figures/Waaga_mod0_barcode.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)

plt.show()


diagrams_roll = {}
count = -1
for i in range(0,101):
    if i == 0:
        f = np.load('Waaga/mod1_dgms' + str(0) + '.npz', allow_pickle = True)
        dgms = f['dgms'][()][0]
        f.close()
        continue
    try :
        f = np.load('Waaga/mod1_dgms' + str(i) + '.npz', allow_pickle = True)
        dgmstmp = f['dgms'][()]
        f.close()
        count += 1
        diagrams_roll[count] = dgmstmp[0]
    except:
        continue
xmax = 1
print(count)
plot_barcode(dgms, diagrams_roll = diagrams_roll, percshuf = 99, dpi = 300, SaveSourceDataName = 'Waaga_mod1_barcode')
plt.savefig('Figures/Waaga_mod1_barcode.png', transparent = True,  bbox_inches='tight', pad_inches=0.2)
plt.savefig('Figures/Waaga_mod1_barcode.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)

plt.show()

ff1 = ['Waaga',]

diagrams_roll = {}
count = -1
for i in range(0,101):
    if i == 0:
        f = np.load('Waaga/mod2_dgms' + str(0) + '.npz', allow_pickle = True)
        dgms = f['dgms'][()][0]
        f.close()
        continue
    try :
        f = np.load('Waaga/mod2_dgms' + str(i) + '.npz', allow_pickle = True)
        dgmstmp = f['dgms'][()]
        f.close()
        count += 1
        diagrams_roll[count] = dgmstmp[0]
    except:
        continue
xmax = 1
print(count)
plot_barcode(dgms, diagrams_roll = diagrams_roll, percshuf = 99, dpi = 300, SaveSourceDataName = 'Waaga_mod2_barcode')
plt.savefig('Figures/Waaga_mod2_barcode.png', transparent = True,  bbox_inches='tight', pad_inches=0.2)
plt.savefig('Figures/Waaga_mod2_barcode.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)
plt.show()


In [None]:
coords_darks = {}
for ii in [0,1,2]:
    spk, spk2, coords_ds, movetimes0, indstemp = (sspk_d[ii], sspk_l[ii], coords_ds_all[ii], 
                                                  movetimes0_all[ii], indstemp_all[ii])
    coords_mod1 = get_coords_all(spk, coords_ds, movetimes0,                             
                         indstemp, dim = dim, bPCA = True)
    coords_mod1 = coords_mod1%(2*np.pi)
    if ii == 1:
        coords_mod1[:,1] = (coords_mod1[:,1] + coords_mod1[:,0])%(2*np.pi)
        coords_mod1[:,1] = 2*np.pi-coords_mod1[:,1]
    elif ii == 2:
        coords_mod1[:,1] = (coords_mod1[:,1] - coords_mod1[:,0])%(2*np.pi)
        coords_mod1 = np.flip(coords_mod1, 1)

    fig, axs = plt.subplots(1,2)
    for c in [0,1]:
        nans0 = ~np.isnan(coords_mod1[:,c])
        mtot, __, __, circ  = binned_statistic_2d(xx1_dark,
                                                  yy1_dark,
                                                  coords_mod1[:,c], 
                                                  statistic = circmean, 
                                                  bins = 50,
                                                  expand_binnumbers = True)

        nans = np.isnan(mtot)
        sintot = np.sin(mtot)
        costot = np.cos(mtot)
        sintot[nans] = np.mean(sintot[~nans])
        costot[nans] = np.mean(costot[~nans])
        sintot = gaussian_filter(sintot,1)
        costot = gaussian_filter(costot,1)
        mtot = np.arctan2(sintot, costot)
        #mtot = gaussian_filter(mtot,1)
        plt.viridis()
        mtot[nans] = np.nan
        axs[c].imshow(mtot)
        axs[c].axis('off')
        axs[c].set_aspect(1/axs[c].get_data_ratio())

    coords_darks[ii] = coords_mod1.copy()


## Align and compare 

In [None]:
coords_lights = {}
for ii in [0,1,2]:
    spk, spk2, coords_ds, movetimes0, indstemp = (sspk_d[ii], sspk_l[ii], coords_ds_all[ii], 
                                                  movetimes0_all[ii], indstemp_all[ii])
    coords_mod1 = get_coords_all(spk, coords_ds, movetimes0,                             
                         indstemp, dim = dim, bPCA = True, spk2 = preprocessing.scale(spk2, axis = 0))
    coords_mod1 = coords_mod1%(2*np.pi)
    if ii == 1:
        coords_mod1[:,1] = (coords_mod1[:,1] + coords_mod1[:,0])%(2*np.pi)
        coords_mod1[:,1] = 2*np.pi-coords_mod1[:,1]
    elif ii == 2:
        coords_mod1[:,1] = (coords_mod1[:,1] - coords_mod1[:,0])%(2*np.pi)
        coords_mod1 = np.flip(coords_mod1, 1)
    fig, axs = plt.subplots(1,2)
    for c in [0,1]:
        nans0 = ~np.isnan(coords_mod1[:,c])
        mtot, __, __, circ  = binned_statistic_2d(xx_light1,
                                                  yy_light1,
                                                  coords_mod1[:,c], 
                                                  statistic = circmean, 
                                                  bins = 50,
                                                  expand_binnumbers = True)

        nans = np.isnan(mtot)
        sintot = np.sin(mtot)
        costot = np.cos(mtot)
        sintot[nans] = np.mean(sintot[~nans])
        costot[nans] = np.mean(costot[~nans])
        sintot = gaussian_filter(sintot,1)
        costot = gaussian_filter(costot,1)
        mtot = np.arctan2(sintot, costot)
        #mtot = gaussian_filter(mtot,1)
        plt.viridis()
        mtot[nans] = np.nan
        axs[c].imshow(mtot)
        axs[c].axis('off')
        axs[c].set_aspect(1/axs[c].get_data_ratio())

    coords_lights[ii] = coords_mod1.copy()


In [None]:
for ii in [0,1,2]:
    fig, axs = plt.subplots(1,2)
    for c in [0,1]:
        nans0 = ~np.isnan(coords_mod1[:,c])
        mtot, __, __, circ  = binned_statistic_2d(xx_light1[speed_light1>5],
                                                  yy_light1[speed_light1>5],
                                                  coords_lights[ii][speed_light1>5,c], 
                                                  statistic = circmean, 
                                                  bins = 50,
                                                  expand_binnumbers = True)

        nans = np.isnan(mtot)
        sintot = np.sin(mtot)
        costot = np.cos(mtot)
        sintot[nans] = np.mean(sintot[~nans])
        costot[nans] = np.mean(costot[~nans])
        sintot = gaussian_filter(sintot,1)
        costot = gaussian_filter(costot,1)
        mtot = np.cos(np.arctan2(sintot, costot))
        #mtot = gaussian_filter(mtot,1)
        plt.viridis()
        mtot[nans] = np.nan
        axs[c].imshow(mtot)
        axs[c].axis('off')
        axs[c].set_aspect(1/axs[c].get_data_ratio())
        data = []
        data_names = []
        for i in range(len(mtot)):
            data.append(pd.Series(mtot[:,i]))
            data_names.extend(['col' + str(i)])
        df = pd.concat(data, ignore_index=True, axis=1)            
        df.columns = data_names
        df.to_excel('Source_data/ExtFig1d_spatial_tuning_circ' + str(c) + '.xlsx', sheet_name='waaga_spatial_tuning_circ'+str(c))  
    plt.savefig('Figures/waaga_light_spatial_tuning_circ' + str(c) + '.png', transparent = True, bbox_inches='tight', pad_inches=0.2)
    plt.savefig('Figures/waaga_light_spatial_tuning_circ' + str(c) + '.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)


    plt.show()
    


In [None]:
for ii in [0,1,2]:
    fig, axs = plt.subplots(1,2)
    for c in [0,1]:
        nans0 = ~np.isnan(coords_mod1[:,c])
        mtot, __, __, circ  = binned_statistic_2d(xx1_dark[speed1_dark>5],
                                                  yy1_dark[speed1_dark>5],
                                                  coords_darks[ii][speed1_dark>5,c], 
                                                  statistic = circmean, 
                                                  bins = 50,
                                                  expand_binnumbers = True)

        nans = np.isnan(mtot)
        sintot = np.sin(mtot)
        costot = np.cos(mtot)
        sintot[nans] = np.mean(sintot[~nans])
        costot[nans] = np.mean(costot[~nans])
        sintot = gaussian_filter(sintot,1)
        costot = gaussian_filter(costot,1)
        mtot = np.cos(np.arctan2(sintot, costot))
        #mtot = gaussian_filter(mtot,1)
        plt.viridis()
        mtot[nans] = np.nan
        axs[c].imshow(mtot)
        axs[c].axis('off')
        axs[c].set_aspect(1/axs[c].get_data_ratio())
        data = []
        data_names = []
        for i in range(len(mtot)):
            data.append(pd.Series(mtot[:,i]))
            data_names.extend(['col' + str(i)])
        df = pd.concat(data, ignore_index=True, axis=1)            
        df.columns = data_names
        df.to_excel('Source_data/ExtFig1d_spatial_tuning_dark_circ' + str(c) + '.xlsx', sheet_name='waaga_dark_spatial_tuning_circ'+str(c))  
    plt.savefig('Figures/waaga_dark_spatial_tuning_circ' + str(c) + '.png', transparent = True, bbox_inches='tight', pad_inches=0.2)
    plt.savefig('Figures/waaga_dark_spatial_tuning_circ' + str(c) + '.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)



    plt.show()

In [None]:
sig_sim = 0.1
pcorr_all_dark = {}
pcorr_all_light = {}
pcorr1_all_dark = {}
pcorr1_all_light = {}

phases_all_dark = {}
phases_all_light = {}
spk_sim_dark = {}
spk_sim_light = {}

for ii in [0,1,2]:
    spk, spk2, coords_ds, movetimes0, indstemp, coords_mod_dark, coords_mod_light = (sspk_d[ii], sspk_l[ii], coords_ds_all[ii], 
                                              movetimes0_all[ii], indstemp_all[ii], coords_darks[ii], coords_lights[ii])

    ############ Dark ############
    num_times_all, num_neurons = np.shape(spk)    
    times = np.arange(0, num_times_all,20)
    movetimes00 = np.arange(0, num_times_all,10)
    coords_mod2 = coords_mod_dark.copy()        
    inds, inds_label =  get_coord_distribution(coords_mod2, numbins = 50,epsilon = 0.1, metric = 'euclidean', startindex = -1)
    phases_1 = get_phases(spk, coords_mod2, inds, inds_label)  
    pcorr1 = match_phases(coords_mod2, spk, phases_1, times = times)
    coords_mod1 = coords_mod2.copy()
    coords_mod1[:,0] = 2*np.pi - coords_mod1[:,0]
    phases_2 = get_phases(spk, coords_mod1, inds, inds_label)  
    pcorr2 = match_phases(coords_mod1, spk, phases_2, times = times)
    if np.median(pcorr2)> np.median(pcorr1):
        coords_mod_dark[:,0] = 2*np.pi - coords_mod_dark[:,0]
        pcorr1 = pcorr2
        phases_1 = phases_2

    spk_sim = simulate_spk_hex(coords_mod_dark, phases_1, t = sig_sim, nums = 1)
    spk_sim_dark[ii] = spk_sim.copy()
    pcorr = np.zeros(num_neurons)
    for i in range(num_neurons):
        pcorr[i] = pearsonr(spk_sim[:,i], spk[:,i])[0]    
    print('pcorr', np.median(pcorr))
    print('pcorr1', np.median(pcorr1))
    pcorr_all_dark[ii] = pcorr.copy()
    pcorr1_all_dark[ii] = pcorr1.copy()
    phases_all_dark[ii] = phases_1.copy()

    num_times_all, num_neurons = np.shape(spk2)    
    times = np.arange(0, num_times_all,20)
    movetimes00 = np.arange(0, num_times_all,10)
    coords_mod2 = coords_mod_light.copy()        
    inds, inds_label =  get_coord_distribution(coords_mod2, numbins = 50,epsilon = 0.1, metric = 'euclidean', startindex = -1)
    phases_1 = get_phases(spk2, coords_mod2, inds, inds_label)  
    pcorr1 = match_phases(coords_mod2, spk2, phases_1, times = times)
    coords_mod1 = coords_mod2.copy()
    coords_mod1[:,0] = 2*np.pi - coords_mod1[:,0]
    phases_2 = get_phases(spk2, coords_mod1, inds, inds_label)  
    pcorr2 = match_phases(coords_mod1, spk2, phases_2, times = times)
    if np.median(pcorr2)> np.median(pcorr1):
        coords_mod_light[:,0] = 2*np.pi - c
        oords_mod_light[:,0]
        pcorr1 = pcorr2
        phases_1 = phases_2

    spk_sim = simulate_spk_hex(coords_mod_light, phases_1, t = sig_sim, nums = 1)
    spk_sim_light[ii] = spk_sim.copy()

    pcorr = np.zeros(num_neurons)
    for i in range(num_neurons):
        pcorr[i] = pearsonr(spk_sim[:,i], spk2[:,i])[0]    
    print('pcorr', np.median(pcorr))
    print('pcorr1', np.median(pcorr1))
    pcorr_all_light[ii] = pcorr.copy()
    pcorr1_all_light[ii] = pcorr1.copy()
    phases_all_light[ii] = phases_1.copy()


In [None]:
numh = 5
for ii in [0,1,2]:
    (spk, spk2, coords_ds, movetimes0, 
     indstemp, coords_mod_dark, coords_mod_light,
     spk_sim_d, spk_sim_l) = (sspk_d[ii], sspk_l[ii], coords_ds_all[ii], 
                              movetimes0_all[ii], indstemp_all[ii], 
                              coords_darks[ii], coords_lights[ii],
                              spk_sim_dark[ii], spk_sim_light[ii])
    num_neurons = len(spk[0,:])

    nw = 0
    numbins1 = 30
    sig1 = 1
    torsort = np.flip(np.argsort(pcorr_all_dark[ii]))[:3]
    for nn, n in enumerate(torsort):
        fig = plt.figure(dpi = 120)
        
        nnn = nn%numh

        xnum = 161
        ############ Dark toroidal #############
        ax = plt.subplot(xnum) 
        xnum += 1    
        mtot_tmp, __, __, circ  = binned_statistic_2d(coords_mod_dark[:,0],
                                                  coords_mod_dark[:,1],
                                                  spk[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        mtot_tmp = smooth_tuning_map(np.rot90(mtot_tmp,0), numbins1+1, sig1, bClose = False) 
        mtot_tmp[nans] = -np.inf

        maxtot = np.sort(mtot_tmp.flatten())
        maxtot = maxtot[int(0.975*len(maxtot))]
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect(1/ax.get_data_ratio())
        r_box = transforms.Affine2D().skew_deg(15,15)

        for x in ax.images + ax.lines + ax.collections:
            trans = x.get_transform()
            x.set_transform(r_box+trans) 
            if isinstance(x, PathCollection):
                transoff = x.get_offset_transform()
                x._transOffset = r_box+transoff     
        ax.set_xlim(0, 2*np.pi + 3*np.pi/5)
        ax.set_ylim(0, 2*np.pi + 3*np.pi/5)
        ax.set_aspect('equal', 'box') 
        ax.axis('off')   

        
        
        ############ Light toroidal #############
        ax = plt.subplot(xnum) 
        xnum += 1    
        mtot_tmp, __, __, circ  = binned_statistic_2d(coords_mod_dark[:,0],
                                                  coords_mod_dark[:,1],
                                                  spk[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        sig1 = 1
        mtot_tmp = smooth_tuning_map(np.rot90(mtot_tmp,0), numbins1+1, sig1, bClose = False) 
        mtot_tmp[nans] = -np.inf

        maxtot = np.sort(mtot_tmp.flatten())
        maxtot = maxtot[int(0.975*len(maxtot))]
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect(1/ax.get_data_ratio())
        r_box = transforms.Affine2D().skew_deg(15,15)

        for x in ax.images + ax.lines + ax.collections:
            trans = x.get_transform()
            x.set_transform(r_box+trans) 
            if isinstance(x, PathCollection):
                transoff = x.get_offset_transform()
                x._transOffset = r_box+transoff     
        ax.set_xlim(0, 2*np.pi + 3*np.pi/5)
        ax.set_ylim(0, 2*np.pi + 3*np.pi/5)
        ax.set_aspect('equal', 'box') 
        ax.axis('off')   
        
        
        
        
        ############ Dark reconstructed #############
        ax = plt.subplot(xnum) 
        xnum += 1
        mtot_tmp, __, __, circ  = binned_statistic_2d(xx1_dark,
                                                  yy1_dark,
                                                  spk_sim_d[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
        maxtot = np.sort(mtot_tmp.flatten())
        mintot = maxtot[int(0.025*len(maxtot))]
        maxtot = maxtot[int(0.975*len(maxtot))]
        mtot_tmp[nans] = -np.inf
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = mintot, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')   

        
        ############ Dark spatial #############
        ax = plt.subplot(xnum) 
        xnum += 1
        mtot_tmp, __, __, circ  = binned_statistic_2d(xx1_dark,
                                                  yy1_dark,
                                                  spk[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
        maxtot = np.sort(mtot_tmp.flatten())
        mintot = maxtot[int(0.025*len(maxtot))]
        maxtot = maxtot[int(0.975*len(maxtot))]
        mtot_tmp[nans] = -np.inf
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = mintot, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')   

        
        ############ Light rexonstructed #############
        ax = plt.subplot(xnum) 
        xnum += 1
        mtot_tmp, __, __, circ  = binned_statistic_2d(xx_light1,
                                                  yy_light1,
                                                  spk_sim_l[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
        maxtot = np.sort(mtot_tmp.flatten())
        mintot = maxtot[int(0.025*len(maxtot))]
        maxtot = maxtot[int(0.975*len(maxtot))]
        mtot_tmp[nans] = -np.inf
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')   
        
        ############ Light spatial #############
        ax = plt.subplot(xnum) 
        xnum += 1
        mtot_tmp, __, __, circ  = binned_statistic_2d(xx_light1,
                                                  yy_light1,
                                                  spk2[:,n], 
                                                  statistic = 'mean', 
                                                  bins = numbins1,
                                                  expand_binnumbers = True)
        nans = np.isnan(mtot_tmp)
        mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
        mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
        maxtot = np.sort(mtot_tmp.flatten())
        mintot = maxtot[int(0.025*len(maxtot))]
        maxtot = maxtot[int(0.975*len(maxtot))]
        mtot_tmp[nans] = -np.inf
        ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')   

        plt.show()

In [None]:
numh = 5

numfigs = 4
numw = 1
num_neurons = len(conj_cells)
numh = int(np.ceil(num_neurons/numw))
outer1 = gridspec.GridSpec(1, numw)
fig = plt.figure(figsize=(np.ceil((numw*numfigs+numw-1)*1.05), np.ceil(numh*1.1)), dpi = 300)
nw = 0
numbins1 = 30
#torsort = np.flip(np.argsort(hd_info))
sig1 = 1

mod_cells = [0,0,0]

nnn = 0
outer2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer1[nw], wspace = .4)
gs2 = gridspec.GridSpecFromSubplotSpec(numh, numfigs, subplot_spec = outer2[0], 
                                       hspace = 0.2,wspace = .2)
for ii in [1,0, 2]:
    
    n = mod_cells[ii]
    nn = ii
    data = []
    data_names = []

#    if nnn == 0:
    #    nw += 1
    
    
    
    (spk, spk2, coords_ds, movetimes0, 
     indstemp, coords_mod_dark, coords_mod_light,
     spk_sim_d, spk_sim_l) = (sspk_d[ii], sspk_l[ii], coords_ds_all[ii], 
                              movetimes0_all[ii], indstemp_all[ii], 
                              coords_darks[ii], coords_lights[ii],
                              spk_sim_dark[ii], spk_sim_light[ii])
    num_neurons = len(spk[0,:])

    n = np.flip(np.argsort(pcorr_all_dark[ii]))[n]

    nnn = nn%numh

    xnum = 0
    ############ Dark toroidal #############

    ax = plt.subplot(gs2[nnn,xnum]) 
    xnum += 1
    mtot_tmp, __, __, circ  = binned_statistic_2d(coords_mod_dark[:,0],
                                              coords_mod_dark[:,1],
                                              spk[:,n], 
                                              statistic = 'mean', 
                                              bins = numbins1,
                                              expand_binnumbers = True)
    nans = np.isnan(mtot_tmp)
    mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
    mtot_tmp = smooth_tuning_map(np.rot90(mtot_tmp,0), numbins1+1, sig1, bClose = False) 
    mtot_tmp[nans] = -np.inf

    maxtot = np.sort(mtot_tmp.flatten())
    maxtot = maxtot[int(0.975*len(maxtot))]
    ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1/ax.get_data_ratio())
    r_box = transforms.Affine2D().skew_deg(15,15)

    for x in ax.images + ax.lines + ax.collections:
        trans = x.get_transform()
        x.set_transform(r_box+trans) 
        if isinstance(x, PathCollection):
            transoff = x.get_offset_transform()
            x._transOffset = r_box+transoff     
    ax.set_xlim(0, 2*np.pi + 3*np.pi/5)
    ax.set_ylim(0, 2*np.pi + 3*np.pi/5)
    ax.set_aspect('equal', 'box') 
    ax.axis('off')   

    for i in range(len(mtot_tmp)):
        data.append(pd.Series(mtot_tmp[:,i]))
        data_names.extend(['toroidal_tuning_dark_' + str(i)])



    ############ Light toroidal #############

    ax = plt.subplot(gs2[nnn,xnum]) 
    xnum += 1
    mtot_tmp, __, __, circ  = binned_statistic_2d(coords_mod_dark[:,0],
                                              coords_mod_dark[:,1],
                                              spk[:,n], 
                                              statistic = 'mean', 
                                              bins = numbins1,
                                              expand_binnumbers = True)
    nans = np.isnan(mtot_tmp)
    mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
    sig1 = 1
    mtot_tmp = smooth_tuning_map(np.rot90(mtot_tmp,0), numbins1+1, sig1, bClose = False) 
    mtot_tmp[nans] = -np.inf

    maxtot = np.sort(mtot_tmp.flatten())
    maxtot = maxtot[int(0.975*len(maxtot))]
    ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1/ax.get_data_ratio())
    r_box = transforms.Affine2D().skew_deg(15,15)

    for x in ax.images + ax.lines + ax.collections:
        trans = x.get_transform()
        x.set_transform(r_box+trans) 
        if isinstance(x, PathCollection):
            transoff = x.get_offset_transform()
            x._transOffset = r_box+transoff     
    ax.set_xlim(0, 2*np.pi + 3*np.pi/5)
    ax.set_ylim(0, 2*np.pi + 3*np.pi/5)
    ax.set_aspect('equal', 'box') 
    ax.axis('off')   

    for i in range(len(mtot_tmp)):
        data.append(pd.Series(mtot_tmp[:,i]))
        data_names.extend(['toroidal_tuning_light_' + str(i)])

    ############ Dark spatial #############

    ax = plt.subplot(gs2[nnn,xnum]) 
    xnum += 1
    mtot_tmp, __, __, circ  = binned_statistic_2d(xx1_dark,
                                              yy1_dark,
                                              spk[:,n], 
                                              statistic = 'mean', 
                                              bins = numbins1,
                                              expand_binnumbers = True)
    nans = np.isnan(mtot_tmp)
    mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
    mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
    maxtot = np.sort(mtot_tmp.flatten())
    mintot = maxtot[int(0.025*len(maxtot))]
    maxtot = maxtot[int(0.975*len(maxtot))]
    mtot_tmp[nans] = -np.inf
    ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = mintot, vmax = maxtot)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')   

    for i in range(len(mtot_tmp)):
        data.append(pd.Series(mtot_tmp[:,i]))
        data_names.extend(['spatial_tuning_dark_' + str(i)])


    ############ Light spatial #############

    ax = plt.subplot(gs2[nnn,xnum]) 
    xnum += 1
    mtot_tmp, __, __, circ  = binned_statistic_2d(xx_light1,
                                              yy_light1,
                                              spk2[:,n], 
                                              statistic = 'mean', 
                                              bins = numbins1,
                                              expand_binnumbers = True)
    nans = np.isnan(mtot_tmp)
    mtot_tmp[np.isnan(mtot_tmp)] = np.mean(mtot_tmp[~np.isnan(mtot_tmp)])
    mtot_tmp = gaussian_filter(mtot_tmp, sigma = sig1)
    maxtot = np.sort(mtot_tmp.flatten())
    mintot = maxtot[int(0.025*len(maxtot))]
    maxtot = maxtot[int(0.975*len(maxtot))]
    mtot_tmp[nans] = -np.inf
    ax.imshow(mtot_tmp, origin = 'lower', extent = [0,2*np.pi,0, 2*np.pi], vmin = 0, vmax = maxtot)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')   

    for i in range(len(mtot_tmp)):
        data.append(pd.Series(mtot_tmp[:,i]))
        data_names.extend(['spatial_tuning_light_' + str(i)])

    df = pd.concat(data, ignore_index=True, axis=1)            
    df.columns = data_names
    df.to_excel('Source_data/ExtFig1d_waaga_single_tuning' + str(nn) + '.xlsx', sheet_name='waaga_single_tuning' + str(nn))  

plt.savefig('Figures/waaga_single_tuning_single_tuning.png', transparent = True, bbox_inches='tight', pad_inches=0.2)
plt.savefig('Figures/waaga_single_tuning_single_tuning.pdf', transparent = True, bbox_inches='tight', pad_inches=0.2)
