Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add append capability and events storage in connectivities #49

Merged
merged 7 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
'n_node_names', 'n_tapers', 'n_signals', 'n_step', 'n_freqs',
'epochs', 'freqs', 'times', 'arrays', 'lists', 'func', 'n_nodes',
'n_estimated_nodes', 'n_samples', 'n_channels', 'Renderer',
'n_ytimes', 'n_ychannels',
'n_ytimes', 'n_ychannels', 'n_events'
}
numpydoc_xref_aliases = {
# Python
Expand Down
182 changes: 148 additions & 34 deletions mne_connectivity/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from copy import copy
from copy import copy, deepcopy

import numpy as np
import xarray as xr
from mne.utils import (_check_combine, _check_option, _validate_type,
copy_function_doc_to_method_doc, object_size,
sizeof_fmt, check_random_state)
sizeof_fmt, _check_event_id, _ensure_events,
_on_missing, warn, check_random_state)

from mne_connectivity.utils import fill_doc
from mne_connectivity.viz import plot_connectivity_circle
Expand Down Expand Up @@ -35,6 +36,84 @@ def times(self):


class EpochMixin:
def _init_epochs(self, events, event_id, on_missing='warn') -> None:
if events is not None: # Epochs can have events
events = _ensure_events(events)
else:
events = np.empty((0, 3))

event_id = _check_event_id(event_id, events)
self.event_id = event_id
self.events = events

# see BaseEpochs init in MNE-Python
if events is not None:
for key, val in self.event_id.items():
if val not in events[:, 2]:
msg = ('No matching events found for %s '
'(event id %i)' % (key, val))
_on_missing(on_missing, msg)

# ensure metadata matches original events size
self.selection = np.arange(len(events))
self.events = events
del events

values = list(self.event_id.values())
selected = np.where(np.in1d(self.events[:, 2], values))[0]

self.events = self.events[selected]

def append(self, epoch_conn):
"""Append another connectivity structure.

Parameters
----------
epoch_conn : instance of Connectivity
The Epoched Connectivity class to append.

Returns
-------
self : instance of Connectivity
The altered Epoched Connectivity class.
"""
if type(self) != type(epoch_conn):
raise ValueError(f'The type of the epoch connectivity to append '
f'is {type(epoch_conn)}, which does not match '
f'{type(self)}.')
if hasattr(self, 'times'):
if not np.allclose(self.times, epoch_conn.times):
raise ValueError('Epochs must have same times')
if hasattr(self, 'freqs'):
if not np.allclose(self.freqs, epoch_conn.freqs):
raise ValueError('Epochs must have same frequencies')

events = list(deepcopy(self.events))
event_id = deepcopy(self.event_id)

# compare event_id
common_keys = list(set(event_id).intersection(
set(epoch_conn.event_id)))
for key in common_keys:
if not event_id[key] == epoch_conn.event_id[key]:
msg = ('event_id values must be the same for identical keys '
'for all concatenated epochs. Key "{}" maps to {} in '
'some epochs and to {} in others.')
raise ValueError(msg.format(key, event_id[key],
epoch_conn.event_id[key]))

evs = epoch_conn.events.copy()
if epoch_conn.n_epochs == 0:
warn('Epoch Connectivity object to append was empty.')
event_id.update(epoch_conn.event_id)
events = np.concatenate((events, evs), axis=0)

# now combine the xarray data, altered events and event ID
self._obj = xr.concat([self.xarray, epoch_conn.xarray], dim='epochs')
self.events = events
self.event_id = event_id
return self

def combine(self, combine='mean'):
"""Combine connectivity data over epochs.

Expand Down Expand Up @@ -308,6 +387,8 @@ def __init__(self, data, names, indices, method,
def __repr__(self) -> str:
r = f'<{self.__class__.__name__} | '

if self.n_epochs is not None:
r += f"n_epochs : {self.n_epochs}, "
if 'freqs' in self.dims:
r += "freq : [%f, %f], " % (self.freqs[0], self.freqs[-1])
if 'times' in self.dims:
Expand All @@ -324,9 +405,7 @@ def _get_num_connections(self, data):
# account for epoch data structures
if self.is_epoched:
start_idx = 1
self.n_epochs = data.shape[0]
else:
self.n_epochs = None
start_idx = 0
self.n_estimated_nodes = data.shape[start_idx]

Expand Down Expand Up @@ -436,6 +515,18 @@ def _check_data_consistency(self, data, indices, n_nodes):
f'But there should be {expected_len} '
f'estimated connections.')

def copy(self):
return deepcopy(self)

@property
def n_epochs(self):
"""The number of epochs the connectivity data varies over."""
if self.is_epoched:
n_epochs = self._data.shape[0]
else:
n_epochs = None
return n_epochs

@property
def _data(self):
"""Numpy array of connectivity data."""
Expand Down Expand Up @@ -507,7 +598,9 @@ def xarray(self):
def n_epochs_used(self):
"""Number of epochs used in computation of connectivity.

Can be 'None', if there was no epochs used.
Can be 'None', if there was no epochs used. This is
equivalent to the number of epochs, if there is no
combining of epochs.
"""
return self.attrs.get('n_epochs_used')

Expand Down Expand Up @@ -723,10 +816,11 @@ class SpectralConnectivity(_Connectivity, SpectralMixin):
def __init__(self, data, freqs, n_nodes, names=None,
indices='all', method=None, spec_method=None,
n_epochs_used=None, **kwargs):
super().__init__(data, names=names, method=method,
indices=indices, n_nodes=n_nodes,
freqs=freqs, spec_method=spec_method,
n_epochs_used=n_epochs_used, **kwargs)
super(SpectralConnectivity, self).__init__(
data, names=names, method=method,
indices=indices, n_nodes=n_nodes,
freqs=freqs, spec_method=spec_method,
n_epochs_used=n_epochs_used, **kwargs)


@fill_doc
Expand Down Expand Up @@ -763,10 +857,11 @@ class to this one. However, that describes one connectivity snapshot

def __init__(self, data, times, n_nodes, names=None, indices='all',
method=None, n_epochs_used=None, **kwargs):
super().__init__(data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
times=times, n_epochs_used=n_epochs_used,
**kwargs)
super(TemporalConnectivity, self).__init__(
data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
times=times, n_epochs_used=n_epochs_used,
**kwargs)


@fill_doc
Expand Down Expand Up @@ -798,10 +893,11 @@ class SpectroTemporalConnectivity(_Connectivity, SpectralMixin, TimeMixin):
def __init__(self, data, freqs, times, n_nodes, names=None,
indices='all', method=None,
spec_method=None, n_epochs_used=None, **kwargs):
super().__init__(data, names=names, method=method, indices=indices,
n_nodes=n_nodes, freqs=freqs,
spec_method=spec_method, times=times,
n_epochs_used=n_epochs_used, **kwargs)
super(SpectroTemporalConnectivity, self).__init__(
data, names=names, method=method, indices=indices,
n_nodes=n_nodes, freqs=freqs,
spec_method=spec_method, times=times,
n_epochs_used=n_epochs_used, **kwargs)


@fill_doc
Expand All @@ -821,18 +917,22 @@ class EpochSpectralConnectivity(SpectralConnectivity, EpochMixin):
%(indices)s
%(method)s
%(spec_method)s
%(events)s
%(event_id)s
%(connectivity_kwargs)s
"""
# whether or not the connectivity occurs over epochs
is_epoched = True

def __init__(self, data, freqs, n_nodes, names=None,
indices='all', method=None,
spec_method=None, **kwargs):
super().__init__(
spec_method=None, events=None,
event_id=None, **kwargs):
super(EpochSpectralConnectivity, self).__init__(
data, freqs=freqs, names=names, indices=indices,
n_nodes=n_nodes, method=method,
spec_method=spec_method, **kwargs)
self._init_epochs(events, event_id, on_missing='warn')


@fill_doc
Expand All @@ -851,16 +951,21 @@ class EpochTemporalConnectivity(TemporalConnectivity, EpochMixin):
%(names)s
%(indices)s
%(method)s
%(events)s
%(event_id)s
%(connectivity_kwargs)s
"""
# whether or not the connectivity occurs over epochs
is_epoched = True

def __init__(self, data, times, n_nodes, names=None,
indices='all', method=None, **kwargs):
super().__init__(data, times=times, names=names,
indices=indices, n_nodes=n_nodes,
method=method, **kwargs)
indices='all', method=None, events=None,
event_id=None, **kwargs):
super(EpochTemporalConnectivity, self).__init__(
data, times=times, names=names,
indices=indices, n_nodes=n_nodes,
method=method, **kwargs)
self._init_epochs(events, event_id, on_missing='warn')


@fill_doc
Expand All @@ -883,18 +988,22 @@ class EpochSpectroTemporalConnectivity(
%(indices)s
%(method)s
%(spec_method)s
%(events)s
%(event_id)s
%(connectivity_kwargs)s
"""
# whether or not the connectivity occurs over epochs
is_epoched = True

def __init__(self, data, freqs, times, n_nodes,
names=None, indices='all', method=None,
spec_method=None, **kwargs):
super().__init__(
spec_method=None, events=None, event_id=None,
**kwargs):
super(EpochSpectroTemporalConnectivity, self).__init__(
data, names=names, freqs=freqs, times=times, indices=indices,
n_nodes=n_nodes, method=method, spec_method=spec_method,
**kwargs)
self._init_epochs(events, event_id, on_missing='warn')


@fill_doc
Expand Down Expand Up @@ -923,10 +1032,10 @@ class Connectivity(_Connectivity, EpochMixin):

def __init__(self, data, n_nodes, names=None, indices='all',
method=None, n_epochs_used=None, **kwargs):
super().__init__(data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
n_epochs_used=n_epochs_used,
**kwargs)
super(Connectivity, self).__init__(data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
n_epochs_used=n_epochs_used,
**kwargs)


@fill_doc
Expand All @@ -945,6 +1054,8 @@ class EpochConnectivity(_Connectivity, EpochMixin):
%(indices)s
%(method)s
%(n_epochs_used)s
%(events)s
%(event_id)s
%(connectivity_kwargs)s

See Also
Expand All @@ -957,8 +1068,11 @@ class EpochConnectivity(_Connectivity, EpochMixin):
is_epoched = True

def __init__(self, data, n_nodes, names=None, indices='all',
method=None, n_epochs_used=None, **kwargs):
super().__init__(data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
n_epochs_used=n_epochs_used,
**kwargs)
method=None, n_epochs_used=None, events=None,
event_id=None, **kwargs):
super(EpochConnectivity, self).__init__(
data, names=names, method=method,
n_nodes=n_nodes, indices=indices,
n_epochs_used=n_epochs_used,
**kwargs)
self._init_epochs(events, event_id, on_missing='warn')
17 changes: 15 additions & 2 deletions mne_connectivity/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD (3-clause)

import numpy as np
from mne import Epochs
from mne.filter import next_fast_len
from mne.source_estimate import _BaseSourceEstimate
from mne.utils import (_check_option, verbose, logger, _validate_type, warn,
Expand All @@ -22,7 +23,7 @@ def envelope_correlation(data, names=None,

Parameters
----------
data : array-like, shape=(n_epochs, n_signals, n_times) | generator
data : array-like, shape=(n_epochs, n_signals, n_times) | Epochs | generator
The data from which to compute connectivity.
The array-like object can also be a list/generator of array,
each with shape (n_signals, n_times), or a :class:`~mne.SourceEstimate`
Expand Down Expand Up @@ -71,14 +72,23 @@ def envelope_correlation(data, names=None,
References
----------
.. footbibliography::
"""
""" # noqa
_check_option('orthogonalize', orthogonalize, (False, 'pairwise'))
from scipy.signal import hilbert

corrs = list()

n_nodes = None

events = None
event_id = None
metadata = None
if isinstance(data, Epochs):
events = data.events
event_id = data.event_id
metadata = data.metadata
data = data.get_data()

# Note: This is embarassingly parallel, but the overhead of sending
# the data to different workers is roughly the same as the gain of
# using multiple CPUs. And we require too much GIL for prefer='threading'
Expand Down Expand Up @@ -182,6 +192,9 @@ def envelope_correlation(data, names=None,
indices='symmetric',
n_epochs_used=n_epochs,
n_nodes=n_nodes,
events=events,
event_id=event_id,
metadata=metadata
)
return conn

Expand Down
Loading