In [1]:
"""
Tests of neo.rawio.spikeglxrawio
"""

import unittest

from neo.rawio.spikeglxrawio import SpikeGLXRawIO
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
import numpy as np


In [2]:


class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
    rawioclass = SpikeGLXRawIO
    entities_to_download = ["spikeglx"]
    entities_to_test = [
        "spikeglx/Noise4Sam_g0",
        "spikeglx/TEST_20210920_0_g0",
        # this is only g0 multi index
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g0",
        # this is only g1 multi index
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0/5-19-2022-CI0_g1",
        # this mix both multi gate and multi trigger (and also multi probe)
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI0",
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI1",
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI2",
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI3",
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4",
        "spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI5",
        # different sync/sybset options with commercial NP2
        "spikeglx/NP2_with_sync",
        "spikeglx/NP2_no_sync",
        "spikeglx/NP2_subset_with_sync",
        # NP-ultra
        "spikeglx/np_ultra_stub",
        # Filename changed by the user, multi-dock
        "spikeglx/multi_probe_multi_dock_multi_shank_filename_without_info",
        # CatGT
        "spikeglx/multi_trigger_multi_gate/CatGT/CatGT-A",
        "spikeglx/multi_trigger_multi_gate/CatGT/CatGT-B",
        "spikeglx/multi_trigger_multi_gate/CatGT/CatGT-C",
        "spikeglx/multi_trigger_multi_gate/CatGT/CatGT-D",
        "spikeglx/multi_trigger_multi_gate/CatGT/CatGT-E",
        "spikeglx/multi_trigger_multi_gate/CatGT/Supercat-A",
    ]

    def test_with_location(self):
        rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/Noise4Sam_g0"), load_channel_location=True)
        rawio.parse_header()
        # one of the stream have channel location
        have_location = []
        for sig_anotations in rawio.raw_annotations["blocks"][0]["segments"][0]["signals"]:
            have_location.append("channel_location_0" in sig_anotations["__array_annotations__"])
        assert any(have_location)

    def test_sync(self):
        rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True)
        rawio_with_sync.parse_header()
        stream_index = list(rawio_with_sync.header["signal_streams"]["name"]).index("imec0.ap")

        # AP stream has 385 channels
        chunk = rawio_with_sync.get_analogsignal_chunk(
            block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index
        )
        assert chunk.shape[1] == 385

        rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False)
        rawio_no_sync.parse_header()

        # AP stream has 384 channels
        chunk = rawio_no_sync.get_analogsignal_chunk(
            block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index
        )
        assert chunk.shape[1] == 384

    def test_no_sync(self):
        # requesting sync channel when there is none raises an error
        with self.assertRaises(ValueError):
            rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_no_sync"), load_sync_channel=True)
            rawio_no_sync.parse_header()

    def test_subset_with_sync(self):
        rawio_sub = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_subset_with_sync"), load_sync_channel=True)
        rawio_sub.parse_header()
        stream_index = list(rawio_sub.header["signal_streams"]["name"]).index("imec0.ap")

        # AP stream has 121 channels
        chunk = rawio_sub.get_analogsignal_chunk(
            block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index
        )
        assert chunk.shape[1] == 121

        rawio_sub_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_subset_with_sync"), load_sync_channel=False)
        rawio_sub_no_sync.parse_header()
        # AP stream has 120 channels
        chunk = rawio_sub_no_sync.get_analogsignal_chunk(
            block_index=0, seg_index=0, i_start=0, i_stop=100, stream_index=stream_index
        )
        assert chunk.shape[1] == 120

    def test_nidq_digital_channel(self):
        rawio_digital = SpikeGLXRawIO(self.get_local_path("spikeglx/DigitalChannelTest_g0"))
        rawio_digital.parse_header()
        # This data should have 8 event channels
        assert np.shape(rawio_digital.header["event_channels"])[0] == 8

        # Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
        all_events = rawio_digital.get_event_timestamps(0, 0, 0)
        on_events = np.where(all_events[2] == "XD0 ON")
        on_ts = all_events[0][on_events]
        on_ts_scaled = rawio_digital.rescale_event_timestamp(on_ts)
        on_diff = np.diff(on_ts_scaled)
        atol = 0.001
        assert np.allclose(on_diff, 1, atol=atol)

    def test_t_start_reading(self):
        """Test that t_start values are correctly read for all streams and segments."""

        # Expected t_start values for each stream and segment
        expected_t_starts = {
            "imec0.ap": {0: 15.319535472007237, 1: 15.339535431281986, 2: 21.284723325294053, 3: 21.3047232845688},
            "imec1.ap": {0: 15.319554693264516, 1: 15.339521518106308, 2: 21.284735282142822, 3: 21.304702106984614},
            "imec0.lf": {0: 15.3191688060872, 1: 15.339168765361949, 2: 21.284356659374016, 3: 21.304356618648765},
            "imec1.lf": {0: 15.319321358082725, 1: 15.339321516521915, 2: 21.284568614155827, 3: 21.30456877259502},
        }

        # Initialize the RawIO
        rawio = SpikeGLXRawIO(self.get_local_path("spikeglx/multi_trigger_multi_gate/SpikeGLX/5-19-2022-CI4"))
        rawio.parse_header()

        # Get list of stream names
        stream_names = rawio.header["signal_streams"]["name"]

        # Test t_start for each stream and segment
        for stream_name, expected_values in expected_t_starts.items():
            # Get stream index
            stream_index = list(stream_names).index(stream_name)

            # Check each segment
            for seg_index, expected_t_start in expected_values.items():
                actual_t_start = rawio.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=stream_index)

                # Use numpy.testing for proper float comparison
                np.testing.assert_allclose(
                    actual_t_start,
                    expected_t_start,
                    rtol=1e-9,
                    atol=1e-9,
                    err_msg=f"Mismatch in t_start for stream '{stream_name}', segment {seg_index}",
                )


In [3]:
testclass = TestSpikeGLXRawIO()

### NP2 with sync

In [4]:
# rawio = SpikeGLXRawIO(testclass.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True)
# /home/dgupta/data/20250305-133614
rawio = SpikeGLXRawIO(testclass.get_local_path("/home/dgupta/data/20250305-133614"), load_sync_channel=True)

rawio.parse_header()
stream_index = list(rawio.header["signal_streams"]["name"]).index("imec0.ap")



In [6]:
rawio.header['signal_streams']

array([('nidq', 'nidq', 'nidq'), ('imec0.ap', 'imec0.ap', 'imec0.ap')],
      dtype=[('name', '<U64'), ('id', '<U64'), ('buffer_id', '<U64')])

In [7]:
rawio.signals_info_dict.keys()


dict_keys([(0, 'nidq'), (0, 'imec0.ap')])

In [8]:
rawio.signals_info_dict.keys()

dict_keys([(0, 'nidq'), (0, 'imec0.ap')])

In [9]:
#finding sync channel in imec0.ap

for key in rawio.signals_info_dict[(0,'imec0.ap')]['meta']:
    if 'sync' in key:
        print(key, rawio.signals_info_dict[(0,'imec0.ap')]['meta'][key])

syncImInputSlot 2
syncSourceIdx 3
syncSourcePeriod 1


In [10]:
rawio._buffer_descriptions

{0: {0: {'nidq': {'type': 'raw',
    'file_path': '/home/dgupta/data/20250305-133614/sglx_20250305-133614.nidq.bin',
    'dtype': 'int16',
    'order': 'C',
    'file_offset': 0,
    'shape': (8022991, 2)},
   'imec0.ap': {'type': 'raw',
    'file_path': '/home/dgupta/data/20250305-133614/sglx_20250305-133614.imec0.ap.bin',
    'dtype': 'int16',
    'order': 'C',
    'file_offset': 0,
    'shape': (4813845, 385)}}}}

In [11]:
im_dict = rawio.signals_info_dict[(0,'imec0.ap')]

In [12]:
#finding sync channel in nidq

for key in rawio.signals_info_dict[(0,'nidq')]['meta']:
    if 'sync' in key:
        print(key, rawio.signals_info_dict[(0,'nidq')]['meta'][key])

syncNiChan 2
syncNiChanType 0
syncNiThresh 1.1
syncSourceIdx 3
syncSourcePeriod 1


In [13]:
rawio.signals_info_dict[(0,'nidq')]['analog_channels']

['XA0']

In [5]:
timestamps, durations, labels = rawio._get_sync_events(stream_index=0)

timestamps_sec = rawio._rescale_event_timestamp(timestamps, stream_index= 0)

In [6]:
timestamps_sec

array([  0.20088169,   1.2008901 ,   2.20087852,   3.20088693,
         4.20087534,   5.20088376,   6.20089217,   7.20088058,
         8.200889  ,   9.20087741,  10.20088582,  11.20089424,
        12.20088265,  13.20089106,  14.20087948,  15.20088789,
        16.2008963 ,  17.20088472,  18.20089313,  19.20088154,
        20.20088996,  21.20089837,  22.20088678,  23.2008952 ,
        24.20088361,  25.20089203,  26.20090044,  27.20088885,
        28.20089727,  29.20088568,  30.20089409,  31.20090251,
        32.20089092,  33.20089933,  34.20088775,  35.20089616,
        36.20090457,  37.20089299,  38.2009014 ,  39.20088981,
        40.20089823,  41.20090664,  42.20089505,  43.20090347,
        44.20089188,  45.20090029,  46.20090871,  47.20089712,
        48.20090553,  49.20089395,  50.20090236,  51.20091077,
        52.20089919,  53.2009076 ,  54.20089601,  55.20090443,
        56.20091284,  57.20090125,  58.20090967,  59.20089808,
        60.20090649,  61.20089491,  62.20090332,  63.20

In [7]:
timestamps, durations, labels = rawio._get_sync_events(stream_index=1)

In [8]:
timestamps_sec = rawio._rescale_event_timestamp(timestamps, stream_index= 1)

In [9]:
timestamps_sec

array([  0.2009    ,   1.20093333,   2.20093333,   3.20093333,
         4.20093333,   5.20093333,   6.20093333,   7.20093333,
         8.20093333,   9.20093333,  10.20093333,  11.20093333,
        12.20093333,  13.20093333,  14.20093333,  15.20093333,
        16.20096667,  17.20096667,  18.20096667,  19.20096667,
        20.20096667,  21.20096667,  22.20096667,  23.20096667,
        24.20096667,  25.20096667,  26.20096667,  27.20096667,
        28.20096667,  29.201     ,  30.201     ,  31.201     ,
        32.201     ,  33.201     ,  34.201     ,  35.201     ,
        36.201     ,  37.201     ,  38.201     ,  39.201     ,
        40.201     ,  41.201     ,  42.201     ,  43.201     ,
        44.20103333,  45.20103333,  46.20103333,  47.20103333,
        48.20103333,  49.20103333,  50.20103333,  51.20103333,
        52.20103333,  53.20103333,  54.20103333,  55.20103333,
        56.20103333,  57.20106667,  58.20106667,  59.20106667,
        60.20106667,  61.20106667,  62.20106667,  63.20

In [None]:
labels

In [None]:
# Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
syncNiChan = int(rawio.signals_info_dict[(0,'nidq')]['meta']['syncNiChan'])
all_events = rawio.get_event_timestamps(0, 0, syncNiChan)
on_events = np.where(all_events[2] == f"XD{syncNiChan} ON")
on_ts = all_events[0][on_events]
on_ts_scaled = rawio.rescale_event_timestamp(on_ts)
on_diff = np.diff(on_ts_scaled)

In [None]:
on_diff

In [None]:
rawio.signals_info_dict[0, "imec0.ap"]['channel_names']

In [None]:
channel = 'SY0'
sync_data = rawio.get_analogsignal_chunk(channel_names = [channel], stream_index = stream_index)

In [None]:
# Convert the uint16 array to uint8
sync_data_uint8 = sync_data.view(np.uint8)

unpacked_sync_data = np.unpackbits(sync_data_uint8, axis=1)

In [None]:
this_stream = unpacked_sync_data[:,1]

In [None]:
timestamps, durations, labels = [], None, []

this_rising = np.where(np.diff(this_stream) == 1)[0] + 1
this_falling = (
    np.where(np.diff(this_stream) == 255)[0] + 1
)  # because the data is in unsigned 8 bit, -1 = 255!
if len(this_rising) > 0:
    timestamps.extend(this_rising)
    labels.extend([f"{channel} ON"] * len(this_rising))
if len(this_falling) > 0:
    timestamps.extend(this_falling)
    labels.extend([f"{channel} OFF"] * len(this_falling))
timestamps = np.asarray(timestamps)
if len(labels) == 0:
    labels = np.asarray(labels, dtype="U1")
else:
    labels = np.asarray(labels)

In [None]:
on_events = np.where(labels == f"{channel} ON")
on_ts = timestamps[on_events]
on_ts_scaled = on_ts / float(rawio.signals_info_dict[0, 'imec0.ap']['sampling_rate'])
on_diff = np.diff(on_ts_scaled)

In [None]:
on_diff

### DigitalChannelTest_g0

In [None]:
rawio_digital = SpikeGLXRawIO(testclass.get_local_path("spikeglx/DigitalChannelTest_g0"), load_sync_channel=True)
rawio_digital.parse_header()


In [None]:
for key in rawio_digital.signals_info_dict[(0,'nidq')]['meta']:
    if 'sync' in key:
        print(key, rawio_digital.signals_info_dict[(0,'nidq')]['meta'][key])

In [None]:
# Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
all_events = rawio_digital.get_event_timestamps(0, 0, 0)
on_events = np.where(all_events[2] == "XD0 ON")
on_ts = all_events[0][on_events]
on_ts_scaled = rawio_digital.rescale_event_timestamp(on_ts)
on_diff = np.diff(on_ts_scaled)

In [None]:
on_diff