# SpikeDataset Demo

Quick walkthrough of how to load and explore the hippocampal spike data.

In [5]:
from dataset import SpikeDataset

## 1. Create the dataset loader

In [6]:
ds = SpikeDataset("dataset", mouse="M1199_PAG", window_size=108)
print(ds.summary())

Mouse:        M1199_PAG
Window:       108 ms
Stride:       4
Groups:       4
Channels:     [6, 4, 6, 4]
TFRec file:   M1199_PAG_stride4_win108_test.tfrec
Parquet file: M1199_PAG_stride4_win108_test.parquet


## 2. Explore the data (lazy — no memory crash)

In [7]:
df = ds.load_parquet()  # instant — reads only metadata, zero data in memory
print(f"Shape: {df.shape}")
df

Shape: (62257, 18)


Unnamed: 0,groups,indexInDat,length,pos,pos_index,...,indices0,zeroForGather,indices1,indices2,indices3
0,"[1, 0, 2, 2, 0, 1, 0, 3, 1, 3, 0, 3, 2, 1, 3, 2, 1, 3, 1, 0, 2, 0, 3, 0, 2, 3, 1, 3, 2, 0, 2, 3, 3, 1, 0, 3, 3, 1, 3, 3, 3, 1, 0, 2, 1, 0, 1, 3, 1, 1, 0, 2, 2, 3, 1, 2, 1, 0, 2, 2, 2, 0, 1, 2, 1, 2, 0, 1, 3, 1, 2, 1, 3]","[216059597, 216059611, 216059622, 216059703, 216059773, 216059776, 216059814, 216059823, 216059836, 216059864, 216059866, 216059892, 216059912, 216059915, 216059944, 216059958, 216060024, 216060036, 216060042, 216060049, 216060051, 216060113, 216060113, 216060202, 216060229, 216060238, 216060245, 216060280, 216060293, 216060331, 216060361, 216060368, 216060406, 216060414, 216060466, 216060500, 216060526, 216060593, 216060620, 216060704, 216061042, 216061051, 216061065, 216061069, 216061083, 216061110, 216061117, 216061129, 216061144, 216061174, 216061225, 216061265, 216061293, 216061308, 216061318, 216061375, 216061440, 216061449, 216061460, 216061481, 216061513, 216061517, 216061570, 216061576, 216061613, 216061648, 216061660, 216061663, 216061671, 216061701, 216061702, 216061747, 216061754]",[73],"[0.2380991101561953, 0.8049437338425552, -2.8922485869289143, 0.11309178790299763]",[156622],...,"[0, 1, 0, 0, 2, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 6, 0, 7, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 11, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 14, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[1, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 4, 0, 0, 5, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 11, 0, 12, 0, 13, 14, 0, 0, 0, 0, 15, 0, 16, 0, 0, 0, 0, 0, 17, 0, 18, 0, 0, 19, 0, 20, 0, 21, 0]","[0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 5, 0, 0, 0, 6, 0, 0, 0, 7, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 10, 11, 0, 0, 12, 0, 0, 13, 14, 15, 0, 0, 16, 0, 17, 0, 0, 0, 0, 18, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 3, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 6, 0, 0, 7, 0, 8, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 13, 14, 15, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 0, 0, 0, 19]"
1,"[0, 2, 3, 3, 1, 0, 3, 3, 1, 3, 3, 3, 1, 0, 2, 1, 0, 1, 3, 1, 1, 0, 2, 2, 3, 1, 2, 1, 0, 2, 2, 2, 0, 1, 2, 1, 2, 0, 1, 3, 1, 2, 1, 3, 0, 3, 1, 2, 1, 0, 2, 0, 0, 2, 3, 1, 0, 0, 3, 1, 2, 3, 3, 2, 1, 0]","[216060331, 216060361, 216060368, 216060406, 216060414, 216060466, 216060500, 216060526, 216060593, 216060620, 216060704, 216061042, 216061051, 216061065, 216061069, 216061083, 216061110, 216061117, 216061129, 216061144, 216061174, 216061225, 216061265, 216061293, 216061308, 216061318, 216061375, 216061440, 216061449, 216061460, 216061481, 216061513, 216061517, 216061570, 216061576, 216061613, 216061648, 216061660, 216061663, 216061671, 216061701, 216061702, 216061747, 216061754, 216061789, 216061794, 216061818, 216061830, 216061836, 216061840, 216062034, 216062048, 216062068, 216062080, 216062084, 216062105, 216062129, 216062187, 216062204, 216062256, 216062276, 216062278, 216062301, 216062318, 216062324, 216062402]",[66],"[0.2380991101561953, 0.8049437338425552, -2.8922485869289143, 0.11309178790299763]",[156622],...,"[1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 10, 0, 11, 12, 0, 0, 0, 13, 14, 0, 0, 0, 0, 0, 0, 0, 15]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 4, 0, 5, 0, 6, 7, 0, 0, 0, 0, 8, 0, 9, 0, 0, 0, 0, 0, 10, 0, 11, 0, 0, 12, 0, 13, 0, 14, 0, 0, 0, 15, 0, 16, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 18, 0, 0, 0, 0, 19, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 7, 8, 0, 0, 9, 0, 10, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0]","[0, 0, 1, 2, 0, 0, 3, 4, 0, 5, 6, 7, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 11, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 14, 0, 0, 15, 16, 0, 0, 0]"
2,"[1, 0, 2, 1, 0, 1, 3, 1, 1, 0, 2, 2, 3, 1, 2, 1, 0, 2, 2, 2, 0, 1, 2, 1, 2, 0, 1, 3, 1, 2, 1, 3, 0, 3, 1, 2, 1, 0, 2, 0, 0, 2, 3, 1, 0, 0, 3, 1, 2, 3, 3, 2, 1, 0, 1, 0, 3, 0, 3, 2, 3, 1, 3, 2, 0, 3, 2, 0, 3, 3, 2, 1, 2, 1, 2, 2, 0, 0]","[216061051, 216061065, 216061069, 216061083, 216061110, 216061117, 216061129, 216061144, 216061174, 216061225, 216061265, 216061293, 216061308, 216061318, 216061375, 216061440, 216061449, 216061460, 216061481, 216061513, 216061517, 216061570, 216061576, 216061613, 216061648, 216061660, 216061663, 216061671, 216061701, 216061702, 216061747, 216061754, 216061789, 216061794, 216061818, 216061830, 216061836, 216061840, 216062034, 216062048, 216062068, 216062080, 216062084, 216062105, 216062129, 216062187, 216062204, 216062256, 216062276, 216062278, 216062301, 216062318, 216062324, 216062402, 216062496, 216062512, 216062533, 216062582, 216062597, 216062621, 216062626, 216062627, 216062645, 216062661, 216062664, 216062684, 216062686, 216062690, 216062717, 216062742, 216062759, 216062774, 216062800, 216062812, 216062827, 216062963, 216063082, 216063182]",[78],"[0.22347367133862373, 0.800746161144787, -3.030559578649006, 0.12442088996867252]",[156623],...,"[0, 1, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 8, 0, 9, 10, 0, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 13, 0, 14, 0, 15, 0, 0, 0, 0, 0, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 18, 19]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[1, 0, 0, 2, 0, 3, 0, 4, 5, 0, 0, 0, 0, 6, 0, 7, 0, 0, 0, 0, 0, 8, 0, 9, 0, 0, 10, 0, 11, 0, 12, 0, 0, 0, 13, 0, 14, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 16, 0, 0, 0, 0, 17, 0, 18, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 0, 21, 0, 0, 0, 0]","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 4, 0, 0, 5, 6, 7, 0, 0, 8, 0, 9, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 0, 0, 0, 0, 14, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 18, 0, 0, 0, 19, 0, 20, 0, 21, 22, 0, 0]","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 0, 0, 0, 10, 0, 11, 0, 12, 0, 13, 0, 0, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0]"
3,"[0, 3, 1, 2, 1, 0, 2, 0, 0, 2, 3, 1, 0, 0, 3, 1, 2, 3, 3, 2, 1, 0, 1, 0, 3, 0, 3, 2, 3, 1, 3, 2, 0, 3, 2, 0, 3, 3, 2, 1, 2, 1, 2, 2, 0, 0, 0, 1, 3, 0, 1, 3, 3, 3, 2, 1, 3, 1, 0, 2, 3, 1, 0, 3, 2, 1, 2, 0]","[216061789, 216061794, 216061818, 216061830, 216061836, 216061840, 216062034, 216062048, 216062068, 216062080, 216062084, 216062105, 216062129, 216062187, 216062204, 216062256, 216062276, 216062278, 216062301, 216062318, 216062324, 216062402, 216062496, 216062512, 216062533, 216062582, 216062597, 216062621, 216062626, 216062627, 216062645, 216062661, 216062664, 216062684, 216062686, 216062690, 216062717, 216062742, 216062759, 216062774, 216062800, 216062812, 216062827, 216062963, 216063082, 216063182, 216063389, 216063402, 216063441, 216063496, 216063515, 216063535, 216063588, 216063649, 216063668, 216063676, 216063683, 216063709, 216063719, 216063755, 216063810, 216063814, 216063817, 216063831, 216063842, 216063889, 216063891, 216063940]",[68],"[0.22347367133862373, 0.800746161144787, -3.030559578649006, 0.12442088996867252]",[156623],...,"[1, 0, 0, 0, 0, 2, 0, 3, 4, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 8, 0, 9, 0, 0, 0, 0, 0, 0, 10, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 12, 13, 14, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 17, 0, 0, 0, 0, 18]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 9, 0, 0, 0, 0, 0, 10, 0, 0, 11, 0, 0, 0, 0, 12, 0, 13, 0, 0, 0, 14, 0, 0, 0, 15, 0, 0]","[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 7, 0, 0, 8, 0, 0, 0, 9, 0, 10, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 14, 0, 0, 0, 0, 15, 0, 16, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 4, 5, 0, 0, 0, 0, 0, 6, 0, 7, 0, 8, 0, 9, 0, 0, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 14, 15, 16, 0, 0, 17, 0, 0, 0, 18, 0, 0, 19, 0, 0, 0, 0]"
4,"[0, 3, 0, 3, 2, 3, 1, 3, 2, 0, 3, 2, 0, 3, 3, 2, 1, 2, 1, 2, 2, 0, 0, 0, 1, 3, 0, 1, 3, 3, 3, 2, 1, 3, 1, 0, 2, 3, 1, 0, 3, 2, 1, 2, 0, 0, 2, 2, 0, 3, 2, 2, 3, 0, 2, 3, 2, 3, 3, 2, 3, 2, 0, 3]","[216062512, 216062533, 216062582, 216062597, 216062621, 216062626, 216062627, 216062645, 216062661, 216062664, 216062684, 216062686, 216062690, 216062717, 216062742, 216062759, 216062774, 216062800, 216062812, 216062827, 216062963, 216063082, 216063182, 216063389, 216063402, 216063441, 216063496, 216063515, 216063535, 216063588, 216063649, 216063668, 216063676, 216063683, 216063709, 216063719, 216063755, 216063810, 216063814, 216063817, 216063831, 216063842, 216063889, 216063891, 216063940, 216063969, 216063970, 216064001, 216064043, 216064059, 216064114, 216064188, 216064226, 216064240, 216064300, 216064312, 216064331, 216064373, 216064478, 216064488, 216064574, 216064591, 216064602, 216064665]",[64],"[0.2061004002545895, 0.8016716472082892, -3.089633952301482, 0.1408563180950702]",[156624],...,"[1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 7, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 0, 0, 11, 12, 0, 0, 13, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 5, 0, 0, 0, 0, 6, 0, 7, 0, 0, 0, 8, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 4, 0, 5, 0, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 9, 0, 0, 0, 0, 10, 0, 11, 0, 0, 12, 13, 0, 0, 14, 15, 0, 0, 16, 0, 17, 0, 0, 18, 0, 19, 0, 0]","[0, 1, 0, 2, 0, 3, 0, 4, 0, 0, 5, 0, 0, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 9, 10, 11, 0, 0, 12, 0, 0, 0, 13, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 18, 19, 0, 20, 0, 0, 21]"


In [8]:
# preview first rows (heavy spike waveform columns excluded for speed)
df.head()

Unnamed: 0,groups,indexInDat,length,pos,pos_index,time,time_behavior,speedMask,indexInDat_raw,indices0,zeroForGather,indices1,indices2,indices3
0,"[1, 0, 2, 2, 0, 1, 0, 3, 1, 3, 0, 3, 2, 1, 3, ...","[216059597, 216059611, 216059622, 216059703, 2...",[73],"[0.2380991101561953, 0.8049437338425552, -2.89...",[156622],[10803.034],[10803.116],[False],"[216059597, 216059611, 216059622, 216059703, 2...","[0, 1, 0, 0, 2, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[1, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 4, 0, ...","[0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 3, 0, 0, 4, ..."
1,"[0, 2, 3, 3, 1, 0, 3, 3, 1, 3, 3, 3, 1, 0, 2, ...","[216060331, 216060361, 216060368, 216060406, 2...",[66],"[0.2380991101561953, 0.8049437338425552, -2.89...",[156622],[10803.073],[10803.116],[False],"[216060331, 216060361, 216060368, 216060406, 2...","[1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, ...","[0, 0, 1, 2, 0, 0, 3, 4, 0, 5, 6, 7, 0, 0, 0, ..."
2,"[1, 0, 2, 1, 0, 1, 3, 1, 1, 0, 2, 2, 3, 1, 2, ...","[216061051, 216061065, 216061069, 216061083, 2...",[78],"[0.22347367133862373, 0.800746161144787, -3.03...",[156623],[10803.101],[10803.186],[True],"[216061051, 216061065, 216061069, 216061083, 2...","[0, 1, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[1, 0, 0, 2, 0, 3, 0, 4, 5, 0, 0, 0, 0, 6, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 4, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0, ..."
3,"[0, 3, 1, 2, 1, 0, 2, 0, 0, 2, 3, 1, 0, 0, 3, ...","[216061789, 216061794, 216061818, 216061830, 2...",[68],"[0.22347367133862373, 0.800746161144787, -3.03...",[156623],[10803.142],[10803.186],[True],"[216061789, 216061794, 216061818, 216061830, 2...","[1, 0, 0, 0, 0, 2, 0, 3, 4, 0, 0, 0, 5, 6, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, ..."
4,"[0, 3, 0, 3, 2, 3, 1, 3, 2, 0, 3, 2, 0, 3, 3, ...","[216062512, 216062533, 216062582, 216062597, 2...",[64],"[0.2061004002545895, 0.8016716472082892, -3.08...",[156624],[10803.176],[10803.256],[True],"[216062512, 216062533, 216062582, 216062597, 2...","[1, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 4, 0, 0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, ...","[0, 1, 0, 2, 0, 3, 0, 4, 0, 0, 5, 0, 0, 6, 7, ..."


In [9]:
# see per-column sizes — heavy columns (group0-3) are the spike waveforms
df.column_info()

Unnamed: 0_level_0,type,compressed_mb,heavy
column,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
group0,list<element: float>,560.9,True
group1,list<element: float>,341.6,True
group2,list<element: float>,737.2,True
group3,list<element: float>,316.3,True
groups,list<element: int64>,0.9,False
indexInDat,list<element: int64>,6.5,False
length,list<element: int64>,0.1,False
pos,list<element: double>,1.5,False
pos_index,list<element: int64>,0.3,False
time,list<element: float>,0.4,False


In [10]:
# select specific columns — stays lazy until you call .head() or .to_pandas()
df[["time", "pos", "speedMask"]].head(10)

Unnamed: 0,time,pos,speedMask
0,[10803.034],"[0.2380991101561953, 0.8049437338425552, -2.89...",[False]
1,[10803.073],"[0.2380991101561953, 0.8049437338425552, -2.89...",[False]
2,[10803.101],"[0.22347367133862373, 0.800746161144787, -3.03...",[True]
3,[10803.142],"[0.22347367133862373, 0.800746161144787, -3.03...",[True]
4,[10803.176],"[0.2061004002545895, 0.8016716472082892, -3.08...",[True]
5,[10803.212],"[0.2061004002545895, 0.8016716472082892, -3.08...",[True]
6,[10803.265],"[0.18725677468503407, 0.8041594979880731, -3.0...",[True]
7,[10803.306],"[0.18725677468503407, 0.8041594979880731, -3.0...",[True]
8,[10803.342],"[0.17120980423031826, 0.8085996553933661, -3.0...",[True]
9,[10803.368],"[0.17120980423031826, 0.8085996553933661, -3.0...",[True]


## 3. Load as a TensorFlow dataset (for training)

In [11]:
raw_tf = ds.load_raw_tf(use_speed_mask=True)

# grab one sample to see what the data looks like
sample = next(iter(raw_tf))
for key, val in sample.items():
    print(f"{key:20s} shape={val.shape}  dtype={val.dtype}")

group0               shape=(19, 6, 32)  dtype=<dtype: 'float32'>
group1               shape=(21, 4, 32)  dtype=<dtype: 'float32'>
group2               shape=(22, 6, 32)  dtype=<dtype: 'float32'>
group3               shape=(16, 4, 32)  dtype=<dtype: 'float32'>
groups               shape=(78,)  dtype=<dtype: 'int64'>
indexInDat           shape=(78,)  dtype=<dtype: 'int64'>
indices0             shape=(78,)  dtype=<dtype: 'int32'>
indices1             shape=(78,)  dtype=<dtype: 'int32'>
indices2             shape=(78,)  dtype=<dtype: 'int32'>
indices3             shape=(78,)  dtype=<dtype: 'int32'>
pos                  shape=(4,)  dtype=<dtype: 'float32'>
speedMask            shape=(1,)  dtype=<dtype: 'string'>
length               shape=()  dtype=<dtype: 'int64'>
pos_index            shape=()  dtype=<dtype: 'int64'>
time                 shape=()  dtype=<dtype: 'float32'>
time_behavior        shape=()  dtype=<dtype: 'float32'>


2026-02-26 20:29:47.285726: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144


## 4. Get train/val splits ready for training

In [12]:
train_ds, val_ds = ds.get_tf_dataset(batch_size=256, val_split=0.2)

# peek at one batch
inputs, targets = next(iter(train_ds))
print("Targets (mouse position) shape:", targets.shape)
print("\nInput keys:")
for key, val in inputs.items():
    print(f"  {key:20s} shape={val.shape}")

2026-02-26 20:30:49.274922: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Targets (mouse position) shape: (256, 2)

Input keys:
  group0               shape=(256, 33, 6, 32)
  group1               shape=(256, 33, 4, 32)
  group2               shape=(256, 43, 6, 32)
  group3               shape=(256, 26, 4, 32)
  groups               shape=(256, 124)
  indexInDat           shape=(256, 124)
  indices0             shape=(256, 124)
  indices1             shape=(256, 124)
  indices2             shape=(256, 124)
  indices3             shape=(256, 124)
  length               shape=(256,)
  pos_index            shape=(256,)
  time                 shape=(256,)
  time_behavior        shape=(256,)


## 5. Fast PyTorch loading (precomputed cache)


In [None]:
from torch.utils.data import DataLoader

# One-time preprocessing (parses TFRecords, saves mmap cache — skips if cache exists)
# Pass force=True to rebuild if the cache is stale (e.g. missing speed_mask.npy)
ds.preprocess()

# Load cached data — speed mask is ON by default (only moving samples)
train_ds, val_ds = ds.get_pytorch_dataset(val_split=0.2)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

# Create DataLoader with padding collation
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True,
                          collate_fn=ds.collate_fn, num_workers=0)

inputs, targets = next(iter(train_loader))
print(f"\ntargets: {targets.shape}, dtype={targets.dtype}")
print("\ninput keys:")
for key, val in inputs.items():
    print(f"  {key:20s} shape={val.shape}  dtype={val.dtype}")

## 6. Simple training loop

In [None]:
import torch
import torch.nn as nn

# --- tiny MLP baseline: mean-pool each group's waveforms, concat, predict (x, y) ---
class SimpleDecoder(nn.Module):
    def __init__(self, n_groups, n_channels_per_group, waveform_len=32, hidden=128):
        super().__init__()
        input_dim = sum(ch * waveform_len for ch in n_channels_per_group)
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 2),
        )
        self.n_groups = n_groups

    def forward(self, inputs):
        parts = []
        for g in range(self.n_groups):
            wv = inputs[f"group{g}"]           # (B, max_spikes, n_ch, 32)
            mask = (wv != -1).any(-1).any(-1)  # (B, max_spikes)
            wv = wv.clamp(min=0)
            counts = mask.sum(dim=1, keepdim=True).clamp(min=1).unsqueeze(-1)
            pooled = wv.sum(dim=1) / counts    # (B, n_ch, 32)
            parts.append(pooled.flatten(1))
        return self.fc(torch.cat(parts, dim=1))

device = "mps" if torch.backends.mps.is_available() else "cpu"
model = SimpleDecoder(ds.n_groups, ds.n_channels_per_group).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# quick smoke test: 10 batches, 1 epoch
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True,
                          collate_fn=ds.collate_fn, num_workers=0)

model.train()
for i, (inputs, targets) in enumerate(train_loader):
    if i >= 10:
        break
    inputs = {k: v.to(device) for k, v in inputs.items()}
    targets = targets.to(device)

    loss = loss_fn(model(inputs), targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 5 == 0:
        print(f"batch {i}: loss={loss.item():.4f}")

print("Pipeline works!")