In [9]:
%matplotlib qt


import sys
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from rastermap import Rastermap
from scipy.stats import zscore

# spks is neurons by time
# sys.path.insert(0, '/Users/josefbitzenhofer/Documents/code/viral/data/cached_for_rastermap')
spks_noITI = np.load('/Users/josefbitzenhofer/Documents/code/viral/data/cached_for_rastermap/JB027_2025-02-24_corridor_neur_noITI.npz')["spks"]
spks = np.load('/Users/josefbitzenhofer/Documents/code/viral/data/cached_for_rastermap/JB027_2025-02-24_corridor_neur.npz')["spks"]
behaviour = np.load('/Users/josefbitzenhofer/Documents/code/viral/data/cached_for_rastermap/JB027_2025-02-24_corridor_behavior.npz')

In [10]:
# like Stringer et al. 2024 Fig 2

# fit rastermap
# n_PCs = 200
# n_clusters = 100
# locality = 0.75
# time_lag_windows = 5
# , time_bin=5
model = Rastermap(n_PCs=200, n_clusters=100, 
                  locality=1, time_lag_window=5).fit(spks_noITI)
# y = model.embedding # neurons x 1
isort = model.isort
cc_nodes = model.cc # sorted asymmetric similarity matrix

X = spks[isort]


2025-03-07 10:10:33,897 [INFO] normalizing data across axis=1
2025-03-07 10:10:33,962 [INFO] projecting out mean along axis=0
2025-03-07 10:10:34,010 [INFO] data normalized, 0.11sec
2025-03-07 10:10:34,010 [INFO] sorting activity: 642 valid samples by 21268 timepoints
2025-03-07 10:10:34,447 [INFO] n_PCs = 200 computed, 0.55sec
2025-03-07 10:10:34,527 [INFO] 99 clusters computed, time 0.63sec




2025-03-07 10:11:05,888 [INFO] clusters sorted, time 31.99sec
2025-03-07 10:11:05,951 [INFO] clusters upsampled, time 32.05sec
2025-03-07 10:11:06,093 [INFO] rastermap complete, time 32.20sec


In [11]:
# from Stringer et al. 2024
def bin1d(X, bin_size, axis=0):
    """ mean bin over axis of data with bin bin_size """
    if bin_size > 0:
        size = list(X.shape)
        Xb = X.swapaxes(0, axis)
        size_new = Xb.shape
        Xb = Xb[:size[axis]//bin_size*bin_size].reshape((size[axis]//bin_size, bin_size, *size_new[1:])).mean(axis=1)
        Xb = Xb.swapaxes(axis, 0)
        return Xb
    else:
        return X

In [12]:
# from Stringer et al. 2024
kp_colors = np.array([[0.55,0.55,0.55],
                      [0.,0.,1],
                      [0.8,0,0],
                      [1.,0.4,0.2],
                      [0,0.6,0.4],
                      [0.2,1,0.5],
                      ])

In [13]:
# print(X.shape)
# X_zscored = zscore(X)

In [14]:
# TODO: bin over neurons

bin_size = 2

### ----------- bin across embedding --------------------------------------- ###
        # if data is not None and compute_X_embedding:
        #     bin_size=self.bin_size
        #     if (bin_size==0 or n_samples < bin_size or 
        #         (bin_size == 50 and n_samples < 1000)):
        #         bin_size = max(1, n_samples // 500)
        #     self.X_embedding = zscore(bin1d(X[self.isort], bin_size, axis=0), axis=1)

X_zs = zscore(bin1d(X, bin_size, axis=0), axis=1)
# X_zs = zscore(X, axis=1)

In [15]:
corridor_starts = behaviour["corridor_starts"]
corridor_widths = behaviour["corridor_widths"]
# corridor_imgs = behaviour["corridor_imgs"]
VRpos = behaviour["VRpos"]
reward_idx = behaviour["reward_inds"]
lick_idx = behaviour["lick_inds"]
run = behaviour["run"]

In [16]:
sampling_rate = 30 # fps

In [None]:
n_trials_plot = 20
plot_start = int(corridor_starts[-n_trials_plot, 0])
to_plot = X_zs[:, plot_start:]
nTimepoints = to_plot.shape[1]
xticks = np.arange(0, nTimepoints, sampling_rate * 60) # every minute
xticks_labels = np.arange(0, len(xticks), 1)

lick_idx_plot = lick_idx[lick_idx > plot_start] - plot_start
reward_idx_plot = reward_idx[reward_idx > plot_start] - plot_start

n_features = 4

# Scaling of the rasterplot
vmin = 0.75
vmax = 2

plt.rcParams["font.family"] = "Arial"
plt.rcParams["font.size"] = 14


fig, ax = plt.subplots(nrows=n_features, ncols=1, figsize=(24, 10), gridspec_kw={"height_ratios": [25, 1, 1, 4]}, sharex=True)

# NEURONS RASTERPLOT
ax1 = ax[0]
ax1.imshow(to_plot, vmin=vmin, vmax=vmax, cmap="gray_r", aspect="auto")
ax1.set_xticks(xticks)
ax1.set_xticklabels(xticks_labels)
ax1.set_xlabel("Time (minutes)")
ax1.set_ylabel("Neurons")

# TODO: y-axis is weird

for (start, reward_condition), width in zip(corridor_starts[-n_trials_plot:, :], corridor_widths[-n_trials_plot:]):
    ax1.fill_betweenx(
        y=[0, to_plot.shape[0]],
        x1=start - plot_start,
        x2=start + width - plot_start,  
        color="forestgreen" if reward_condition == 1 else "violet",
        alpha=0.3
    )


# REWARDS
ax2 = ax[1]
ax2.scatter(
    reward_idx_plot,  # Align with the time window
    np.ones(len(reward_idx_plot)), 
    color="g",
    marker="^", 
    s=30
)
ax2.axis("off")
ax2.legend(["Reward"], loc="center left", frameon=False, handletextpad=0.2)

# LICKING
ax3 = ax[2]
ax3.scatter(
    lick_idx_plot,  # Align with the time window
    np.ones(len(lick_idx_plot)), 
    color=[1.0,0.3,0.3],
    marker=".", 
    s=30
)
ax3.axis("off")
ax3.legend(["Licks"], loc="center left", frameon=False, handletextpad=0.2)


# RUNNING SPEED
ax4 = ax[3]
ax4.fill_between(np.arange(nTimepoints), run[plot_start:], color=kp_colors[0]) #alpha=0.5)
ax4.axis("off")
ax4.legend([matplotlib.lines.Line2D([0], [0], color="none")], ["Running speed"], loc="upper left", frameon=False, handlelength=0, bbox_to_anchor=(0, 1.5))



ax1.set_xlim([0, nTimepoints])
ax2.set_xlim([0, nTimepoints])
ax3.set_xlim([0, nTimepoints])
ax4.set_xlim([0, nTimepoints])

plt.show()

In [18]:
print(len(to_plot[1]))
print(len(run[plot_start:]))

14413
14413


In [19]:
print("Reward indices:", reward_idx_plot)
print("Lick indices:", lick_idx_plot)

Reward indices: [ 2419  2420  2421  2422  2423  2424  2425  2426  2427  3733  3734  3735
  3736  3737  3738  3739  3740  3741  4985  4986  4987  4988  4989  4990
  4991  4992  4993  7992  7993  7994  7995  7996  7997  7998  7999  8000
 10777 10778 10779 10780 10781 10782 10783 10784 10785]
Lick indices: [ 2414  2415  2416  2417  2418  2419  2420  2421  2422  2423  2424  2425
  2426  2427  2428  2429  2430  2431  2432  2433  2434  2435  2436  2437
  2438  2439  2440  2441  2442  2443  2444  2445  2446  2447  2448  2449
  2450  2451  2452  2453  2454  2455  2456  2457  2458  2459  2460  2461
  2462  2463  2464  2465  2466  2467  2468  2469  2470  2471  2472  2473
  2474  2475  2476  2477  2478  2479  2480  2481  2482  2483  2484  2485
  2486  2487  2488  2489  2490  2491  2492  2493  2494  2495  2496  2497
  2498  2499  2500  2501  2502  2503  2504  2505  2506  2507  2508  2509
  2510  2511  2512  2513  2514  2515  2516  2517  2518  2519  2520  2521
  2522  2523  2524  2525  2526  2527  

In [20]:
# TODO: Add corridor stuff
# TODO: Add licks, rewards, running speed