In [1]:
import numpy as np
import sys
import os

sys.path.append(os.getcwd())
import matplotlib.pyplot as plt
import scipy.io as sio

# using the code and from: https://github.com/zhd96/pi-vae/blob/main/examples/pi-vae_rat_data.ipynb
# Zhou, D., Wei, X.
# Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE.
# NeurIPS 2020. https://arxiv.org/abs/2011.04798

# you can download the preprocessed matfile from the pi-VAE authors:
# https://drive.google.com/drive/folders/1lUVX1IvKZmw-uL2UWLxgx4NJ62YbCwMo?usp=sharing

In [2]:
np.random.seed(666)

dataset_name = "../../data_untracked/Achilles_data.mat"

rat_data = sio.loadmat(dataset_name)

In [3]:
## load trial information
idx_split = rat_data["trial"][0]
## load spike data
spikes = rat_data["spikes"]
## load locations
locs = rat_data["loc"][0]

In [4]:
# we didn't split the trials and train the model on this long trajectory
# rather then presenting them one by one

# but note that it is completely possible to also train the model using trials,
# or also using the data where the all the recorded spikes were used during the maze epoch,
# rather then sampling them only when the location data also exist, as is done here.
# this effectively means having fs approximately 40Hz, as it's the sampling rate of the location,
# as given in https://crcns.org/data-sets/hc/hc-11/about-hc-11

# discard neurons that are silent
spikes = spikes[:, np.sum(spikes, axis=0) > 230]

In [None]:
spikes.shape

In [None]:
print(idx_split)

In [None]:
# split data into train and test, use the idx_split
train_size = idx_split[int(idx_split.shape[0] * 0.71)]
train_spikes = spikes[:train_size]
test_spikes = spikes[train_size:]
train_locs = locs[:train_size]
test_locs = locs[train_size:]

np.save("../../data_untracked/train_spikes_hpc11.npy", train_spikes.T)
np.save("../../data_untracked/test_spikes_hpc11.npy", test_spikes.T)

np.save("../../data_untracked/locs.npy", locs)
np.save("../../data_untracked/train_locs.npy", train_locs)
np.save("../../data_untracked/test_locs.npy", test_locs)