# SpikeDataset Demo

Quick walkthrough of how to load and explore the hippocampal spike data.

In [1]:
from dataset import SpikeDataset

## 1. Create the dataset loader

In [2]:
ds = SpikeDataset("dataset", mouse="M1199_PAG", window_size=108)
print(ds.summary())

Mouse:        M1199_PAG
Window:       108 ms
Stride:       4
Groups:       4
Channels:     [6, 4, 6, 4]
TFRec file:   M1199_PAG_stride4_win108_test.tfrec
Parquet file: M1199_PAG_stride4_win108_test.parquet


## 2. Load as a pandas DataFrame (for exploration)

In [3]:
df = ds.load_parquet()
print(f"Shape: {df.shape}")
df.head()

Shape: (62257, 18)


Unnamed: 0,group0,group1,group2,group3,groups,indexInDat,length,pos,pos_index,time,time_behavior,speedMask,indexInDat_raw,indices0,zeroForGather,indices1,indices2,indices3
0,"[4.4905214, 4.289115, 2.397555, 4.9237323, 1.9...","[-2.7625754, -0.041320268, 1.1797833, 0.106313...","[-0.3823218, 0.26588744, -0.06600025, -2.07445...","[-11.746682, -9.458845, -10.755431, -9.787609,...","[1, 0, 2, 2, 0, 1, 0, 3, 1, 3, 0, 3, 2, 1, 3, ...","[216059597, 216059611, 216059622, 216059703, 2...",[73],"[0.2380991101561953, 0.8049437338425552, -2.89...",[156622],[10803.034],[10803.116],[False],"[216059597, 216059611, 216059622, 216059703, 2...","[0, 1, 0, 0, 2, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[1, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 4, 0, ...","[0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 3, 0, 0, 4, ..."
1,"[4.744651, -0.46449965, -4.9791293, -1.4566978...","[-6.5401173, -9.58557, -5.05113, -4.1449285, -...","[1.3171366, 3.385435, 5.466522, 3.262233, 2.73...","[-6.449205, -3.7243226, -0.94095474, 4.024223,...","[0, 2, 3, 3, 1, 0, 3, 3, 1, 3, 3, 3, 1, 0, 2, ...","[216060331, 216060361, 216060368, 216060406, 2...",[66],"[0.2380991101561953, 0.8049437338425552, -2.89...",[156622],[10803.073],[10803.116],[False],"[216060331, 216060361, 216060368, 216060406, 2...","[1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, ...","[0, 0, 1, 2, 0, 0, 3, 4, 0, 5, 6, 7, 0, 0, 0, ..."
2,"[-4.867984, -2.0415728, -2.2092433, -2.701678,...","[-4.509027, -2.4044425, -4.283478, -8.628607, ...","[-3.866517, -3.463913, -1.9244556, -2.5225954,...","[-8.362156, -7.7195864, -5.204652, -9.035585, ...","[1, 0, 2, 1, 0, 1, 3, 1, 1, 0, 2, 2, 3, 1, 2, ...","[216061051, 216061065, 216061069, 216061083, 2...",[78],"[0.22347367133862373, 0.800746161144787, -3.03...",[156623],[10803.101],[10803.186],[True],"[216061051, 216061065, 216061069, 216061083, 2...","[0, 1, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[1, 0, 0, 2, 0, 3, 0, 4, 5, 0, 0, 0, 0, 6, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 4, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, ..."
3,"[-7.8074355, -7.983128, -7.1518755, -12.03488,...","[-8.327696, -10.350467, -11.630261, -10.267146...","[1.3819051, 1.580239, 3.6971939, 0.8405953, -1...","[-9.0769, -12.542657, -14.468764, -10.566613, ...","[0, 3, 1, 2, 1, 0, 2, 0, 0, 2, 3, 1, 0, 0, 3, ...","[216061789, 216061794, 216061818, 216061830, 2...",[68],"[0.22347367133862373, 0.800746161144787, -3.03...",[156623],[10803.142],[10803.186],[True],"[216061789, 216061794, 216061818, 216061830, 2...","[1, 0, 0, 0, 0, 2, 0, 3, 4, 0, 0, 0, 5, 6, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, ..."
4,"[-5.3024197, -4.027823, -3.3042214, -7.5992155...","[7.775497, 2.7831135, 2.6834428, 1.9097018, 4....","[-3.9249418, -4.999227, -1.9690266, -3.817349,...","[6.3386106, 1.3817701, -4.808083, -1.6837108, ...","[0, 3, 0, 3, 2, 3, 1, 3, 2, 0, 3, 2, 0, 3, 3, ...","[216062512, 216062533, 216062582, 216062597, 2...",[64],"[0.2061004002545895, 0.8016716472082892, -3.08...",[156624],[10803.176],[10803.256],[True],"[216062512, 216062533, 216062582, 216062597, 2...","[1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, ...","[0, 1, 0, 2, 0, 3, 0, 4, 0, 0, 5, 0, 0, 6, 7, ..."


In [4]:
df.columns.tolist()

['group0',
 'group1',
 'group2',
 'group3',
 'groups',
 'indexInDat',
 'length',
 'pos',
 'pos_index',
 'time',
 'time_behavior',
 'speedMask',
 'indexInDat_raw',
 'indices0',
 'zeroForGather',
 'indices1',
 'indices2',
 'indices3']

## 3. Load as a TensorFlow dataset (for training)

In [5]:
raw_tf = ds.load_raw_tf(use_speed_mask=True)

# grab one sample to see what the data looks like
sample = next(iter(raw_tf))
for key, val in sample.items():
    print(f"{key:20s} shape={val.shape}  dtype={val.dtype}")

group0               shape=(19, 6, 32)  dtype=<dtype: 'float32'>
group1               shape=(21, 4, 32)  dtype=<dtype: 'float32'>
group2               shape=(22, 6, 32)  dtype=<dtype: 'float32'>
group3               shape=(16, 4, 32)  dtype=<dtype: 'float32'>
groups               shape=(78,)  dtype=<dtype: 'int64'>
indexInDat           shape=(78,)  dtype=<dtype: 'int64'>
indices0             shape=(78,)  dtype=<dtype: 'int32'>
indices1             shape=(78,)  dtype=<dtype: 'int32'>
indices2             shape=(78,)  dtype=<dtype: 'int32'>
indices3             shape=(78,)  dtype=<dtype: 'int32'>
pos                  shape=(4,)  dtype=<dtype: 'float32'>
speedMask            shape=(1,)  dtype=<dtype: 'string'>
length               shape=()  dtype=<dtype: 'int64'>
pos_index            shape=()  dtype=<dtype: 'int64'>
time                 shape=()  dtype=<dtype: 'float32'>
time_behavior        shape=()  dtype=<dtype: 'float32'>


2026-02-25 22:53:35.188612: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144


## 4. Get train/val splits ready for training

In [6]:
train_ds, val_ds = ds.get_tf_dataset(batch_size=256, val_split=0.2)

# peek at one batch
inputs, targets = next(iter(train_ds))
print("Targets (mouse position) shape:", targets.shape)
print("\nInput keys:")
for key, val in inputs.items():
    print(f"  {key:20s} shape={val.shape}")

2026-02-25 22:53:43.267118: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Targets (mouse position) shape: (256, 2)

Input keys:
  group0               shape=(256, 29, 6, 32)
  group1               shape=(256, 29, 4, 32)
  group2               shape=(256, 35, 6, 32)
  group3               shape=(256, 25, 4, 32)
  groups               shape=(256, 107)
  indexInDat           shape=(256, 107)
  indices0             shape=(256, 107)
  indices1             shape=(256, 107)
  indices2             shape=(256, 107)
  indices3             shape=(256, 107)
  length               shape=(256,)
  pos_index            shape=(256,)
  time                 shape=(256,)
  time_behavior        shape=(256,)


## 5. Using with PyTorch



In [8]:
train_gen, val_gen = ds.get_pytorch_batches(batch_size=256, val_split=0.2)

# grab one batch
inputs, targets = next(train_gen())
print(f"targets: {targets.shape}, dtype={targets.dtype}")
print("\ninput keys:")
for key, val in inputs.items():
    print(f"  {key:20s} shape={val.shape}  dtype={val.dtype}")

2026-02-25 22:54:49.389663: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


targets: torch.Size([256, 2]), dtype=torch.float32

input keys:
  group0               shape=torch.Size([256, 29, 6, 32])  dtype=torch.float32
  group1               shape=torch.Size([256, 27, 4, 32])  dtype=torch.float32
  group2               shape=torch.Size([256, 37, 6, 32])  dtype=torch.float32
  group3               shape=torch.Size([256, 27, 4, 32])  dtype=torch.float32
  groups               shape=torch.Size([256, 108])  dtype=torch.int64
  indexInDat           shape=torch.Size([256, 108])  dtype=torch.int64
  indices0             shape=torch.Size([256, 108])  dtype=torch.int32
  indices1             shape=torch.Size([256, 108])  dtype=torch.int32
  indices2             shape=torch.Size([256, 108])  dtype=torch.int32
  indices3             shape=torch.Size([256, 108])  dtype=torch.int32
  length               shape=torch.Size([256])  dtype=torch.int64
  pos_index            shape=torch.Size([256])  dtype=torch.int64
  time                 shape=torch.Size([256])  dtype=torch.fl