Skip to content

Commit

Permalink
Merge pull request NeuralEnsemble#159 from jpgill86/neo-backwards-com…
Browse files Browse the repository at this point in the history
…patibility

Make Neo RawIO sources compatible with Neo 0.6-0.10
  • Loading branch information
jpgill86 committed Sep 9, 2021
2 parents 5b2b051 + 283c8ec commit 7c82f8a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 147 deletions.
230 changes: 90 additions & 140 deletions ephyviewer/datasource/neosource.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,88 +110,70 @@ def get_sources_from_neo_segment(neo_seg):

## neo.rawio stuff

# this can be remove when neo version 0.10 will be out
class AnalogSignalFromNeoRawIOSource_until_v9(BaseAnalogSignalSource):
def __init__(self, neorawio, channel_indexes=None):

BaseAnalogSignalSource.__init__(self)
self.with_scatter = False

self.neorawio =neorawio
if channel_indexes is None:
channel_indexes = slice(None)
self.channel_indexes = channel_indexes
self.channels = self.neorawio.header['signal_channels'][channel_indexes]
self.sample_rate = self.neorawio.get_signal_sampling_rate(channel_indexes=self.channel_indexes)

#TODO: something for multi segment
self.block_index = 0
self.seg_index = 0

@property
def nb_channel(self):
return len(self.channels)

def get_channel_name(self, chan=0):
return self.channels[chan]['name']

@property
def t_start(self):
t_start = self.neorawio.get_signal_t_start(self.block_index, self.seg_index,
channel_indexes=self.channel_indexes)
return t_start

@property
def t_stop(self):
t_stop = self.t_start + self.get_length()/self.sample_rate
return t_stop

def get_length(self):
length = self.neorawio.get_signal_size(self.block_index, self.seg_index,
channel_indexes=self.channel_indexes)
return length

def get_gains(self):
return self.neorawio.header['signal_channels']['gain'][self.channel_indexes]

def get_offsets(self):
return self.neorawio.header['signal_channels']['offset'][self.channel_indexes]

def get_shape(self):
return (self.get_length(), self.nb_channel)

def get_chunk(self, i_start=None, i_stop=None):
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
i_start=i_start, i_stop=i_stop, channel_indexes=self.channel_indexes)
return sigs


# this fit the neo API >= 0.10 (with streams concept)
class AnalogSignalFromNeoRawIOSource(BaseAnalogSignalSource):
def __init__(self, neorawio, channel_indexes=None, stream_index=None):
"""
Create an analog signal source from a Neo RawIO.
Parameters
----------
neorawio : subclass of neo.rawio.BaseRawIO
Neo RawIO reader from which to load signals.
channel_indexes : list of ints
Indexes of signals to use. Note that for Neo>=0.10, channels within
a signal stream are indexed independently of channels in other
streams; for Neo<0.10, channels are indexed globally, regardless of
signal group membership. If None is passed, uses all channels within
a stream (or all channels globally for Neo<0.10).
stream_index : int
Index of signal stream to use. If only one signal stream exists,
this parameter is not required. For Neo<0.10, signal streams do not
exist and this parameter must be None.
"""

BaseAnalogSignalSource.__init__(self)
self.with_scatter = False

self.neorawio = neorawio

if stream_index is not None:
self.stream_index = stream_index
elif self.neorawio.signal_streams_count() == 1:
self.stream_index = 0
else:
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
if self.neorawio.header is None:
self.neorawio.parse_header()

if channel_indexes is None:
channel_indexes = slice(None)
self.channel_indexes = channel_indexes

self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
signal_channels = self.neorawio.header['signal_channels']
mask = signal_channels['stream_id'] == self.stream_id
self.channels = signal_channels[mask][self.channel_indexes]
if V(neo.__version__)>='0.10.0':
# Neo >= 0.10
# - versions 0.10+ index channels within a stream
if stream_index is not None:
self.stream_index = stream_index
elif self.neorawio.signal_streams_count() == 1:
self.stream_index = 0
else:
raise ValueError(f'Because the Neo RawIO source contains multiple signal streams ({self.neorawio.signal_streams_count()}), stream_index must be provided')
self.stream_id = self.neorawio.header['signal_streams'][self.stream_index]['id']
signal_channels = self.neorawio.header['signal_channels']
mask = signal_channels['stream_id'] == self.stream_id
self.channels = signal_channels[mask][self.channel_indexes]
else:
# Neo < 0.10
# - versions 0.6-0.9 index channels globally (ignoring signal group)
assert stream_index is None, f'Neo version {neo.__version__} is installed, but only Neo>=0.10 uses stream_index'
self.channels = self.neorawio.header['signal_channels'][self.channel_indexes]

if V(neo.__version__)>='0.10.0':
# Neo >= 0.10
# - versions 0.10+ use stream_index as an argument often,
# but also require channel_indexes for get_chunk
self.signal_indexing_kwarg = {'stream_index': self.stream_index}
self.get_chunk_kwargs = {'stream_index': self.stream_index, 'channel_indexes': self.channel_indexes}
else:
# Neo < 0.10
# - versions 0.6-0.9 use channel_indexes as an argument often
self.signal_indexing_kwarg = {'channel_indexes': self.channel_indexes}
self.get_chunk_kwargs = {'channel_indexes': self.channel_indexes}

self.sample_rate = self.neorawio.get_signal_sampling_rate(stream_index=self.stream_index)
self.sample_rate = self.neorawio.get_signal_sampling_rate(**self.signal_indexing_kwarg)

#TODO: something for multi segment
self.block_index = 0
Expand All @@ -207,7 +189,7 @@ def get_channel_name(self, chan=0):
@property
def t_start(self):
t_start = self.neorawio.get_signal_t_start(self.block_index, self.seg_index,
stream_index=self.stream_index)
**self.signal_indexing_kwarg)
return t_start

@property
Expand All @@ -217,76 +199,43 @@ def t_stop(self):

def get_length(self):
length = self.neorawio.get_signal_size(self.block_index, self.seg_index,
stream_index=self.stream_index)
**self.signal_indexing_kwarg)
return length

def get_gains(self):
return self.channels['gain']
return self.channels['gain']

def get_offsets(self):
return self.channels['offset']
return self.channels['offset']

def get_shape(self):
return (self.get_length(), self.nb_channel)

def get_chunk(self, i_start=None, i_stop=None):
sigs = self.neorawio.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,
i_start=i_start, i_stop=i_stop, stream_index=self.stream_index,
channel_indexes=self.channel_indexes)
i_start=i_start, i_stop=i_stop, **self.get_chunk_kwargs)
return sigs


# handle old neo API <0.10
class SpikeFromNeoRawIOSource_until_v9(BaseSpikeSource):
def __init__(self, neorawio, channel_indexes=None):
self.neorawio =neorawio
if channel_indexes is None:
channel_indexes = slice(None)
self.channel_indexes = channel_indexes

self.channels = self.neorawio.header['unit_channels'][channel_indexes]

#TODO: something for multi segment
self.block_index = 0
self.seg_index = 0

@property
def nb_channel(self):
return len(self.channels)

def get_channel_name(self, chan=0):
return self.channels[chan]['name']

@property
def t_start(self):
t_start = self.neorawio.segment_t_start(self.block_index, self.seg_index)
return t_start

@property
def t_stop(self):
t_stop = self.neorawio.segment_t_stop(self.block_index, self.seg_index)
return t_stop

def get_chunk(self, chan=0, i_start=None, i_stop=None):
raise(NotImplementedError)

def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None):
spike_timestamp = self.neorawio.get_spike_timestamps(block_index=self.block_index,
seg_index=self.seg_index, unit_index=chan, t_start=t_start, t_stop=t_stop)

spike_times = self.neorawio.rescale_spike_timestamp(spike_timestamp, dtype='float64')

return spike_times

# this fit the new neo rawio API >=0.10
class SpikeFromNeoRawIOSource(BaseSpikeSource):
def __init__(self, neorawio, channel_indexes=None):
self.neorawio =neorawio
if channel_indexes is None:
channel_indexes = slice(None)
self.channel_indexes = channel_indexes

self.channels = self.neorawio.header['spike_channels'][channel_indexes]
if V(neo.__version__)>='0.10.0':
# Neo >= 0.10
# - versions 0.10+ have spike_channels
self.channels = self.neorawio.header['spike_channels'][channel_indexes]
self.get_chunk_kwarg = 'spike_channel_index'
else:
# Neo < 0.10
# - versions 0.6-0.9 have unit_channels
self.channels = self.neorawio.header['unit_channels'][channel_indexes]
self.get_chunk_kwarg = 'unit_index'

#TODO: something for multi segment
self.block_index = 0
Expand Down Expand Up @@ -314,7 +263,7 @@ def get_chunk(self, chan=0, i_start=None, i_stop=None):

def get_chunk_by_time(self, chan=0, t_start=None, t_stop=None):
spike_timestamp = self.neorawio.get_spike_timestamps(block_index=self.block_index,
seg_index=self.seg_index, spike_channel_index=chan, t_start=t_start, t_stop=t_stop)
seg_index=self.seg_index, **{self.get_chunk_kwarg: chan}, t_start=t_start, t_stop=t_stop)

spike_times = self.neorawio.rescale_spike_timestamp(spike_timestamp, dtype='float64')

Expand Down Expand Up @@ -397,39 +346,40 @@ def get_sources_from_neo_rawio(neorawio):
sources = {'signal':[], 'epoch':[], 'spike':[]}


# handle of neo version
# this will be simplified in a while
if hasattr(neorawio, 'get_group_signal_channel_indexes'):
# Neo >= 0.9.0 and < 0.10
if hasattr(neorawio, 'signal_streams_count'):
# Neo >= 0.10.0
# - version 0.10 replaced signal groups with signal streams
for stream_index in range(neorawio.signal_streams_count()):
# one source per signal stream
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))
elif hasattr(neorawio, 'get_group_signal_channel_indexes'):
# Neo >= 0.9.0 and < 0.10
# - version 0.9 renamed BaseRawIO.get_group_channel_indexes() to BaseRawIO.get_group_signal_channel_indexes()
if neorawio.signal_channels_count() > 0:
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
for channel_indexes in channel_indexes_list:
#one soure by channel group
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
# one source per channel group
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes))
elif hasattr(neorawio, 'get_group_channel_indexes'):
# Neo < 0.9.0
# - versions 0.6-0.8 have BaseRawIO.get_group_channel_indexes()
if neorawio.signal_channels_count() > 0:
channel_indexes_list = neorawio.get_group_signal_channel_indexes()
channel_indexes_list = neorawio.get_group_channel_indexes()
for channel_indexes in channel_indexes_list:
#one soure by channel group
sources['signal'].append(AnalogSignalFromNeoRawIOSource_until_v9(neorawio, channel_indexes=channel_indexes))
elif hasattr(neorawio, 'signal_streams_count'):
# Neo >= 0.10.0
num_streams = neorawio.signal_streams_count()
for stream_index in range(num_streams):
#one soure by stream
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, stream_index=stream_index))

# one source per channel group
sources['signal'].append(AnalogSignalFromNeoRawIOSource(neorawio, channel_indexes=channel_indexes))


if hasattr(neorawio, 'unit_channels_count'):
# Neo < 0.10
if neorawio.unit_channels_count()>0:
sources['spike'].append(SpikeFromNeoRawIOSource_until_v9(neorawio, None))
elif hasattr(neorawio, 'spike_channels_count'):
# neo >= 0.10
if hasattr(neorawio, 'spike_channels_count'):
# Neo >= 0.10
# - version 0.10 renamed BaseRawIO.unit_channels_count() to BaseRawIO.spike_channels_count()
if neorawio.spike_channels_count()>0:
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))
elif hasattr(neorawio, 'unit_channels_count'):
# Neo < 0.10
# - versions 0.6-0.9 have BaseRawIO.unit_channels_count()
if neorawio.unit_channels_count()>0:
sources['spike'].append(SpikeFromNeoRawIOSource(neorawio, None))


if neorawio.event_channels_count()>0:
Expand Down
14 changes: 7 additions & 7 deletions ephyviewer/tests/test_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ def test_spikeinterface_sources():


if __name__=='__main__':
#~ test_InMemoryAnalogSignalSource()
#~ test_VideoMultiFileSource()
#~ test_InMemoryEventSource()
#~ test_InMemoryEpochSource()
#~ test_spikesource()
test_InMemoryAnalogSignalSource()
test_VideoMultiFileSource()
test_InMemoryEventSource()
test_InMemoryEpochSource()
test_spikesource()
test_neo_rawio_sources()
#~ test_neo_object_sources()
#~ test_spikeinterface_sources()
test_neo_object_sources()
test_spikeinterface_sources()

0 comments on commit 7c82f8a

Please sign in to comment.