Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: deprecated ConcatenateChannels and renamed EpochVectorizer #2361

Merged
merged 10 commits into from Jul 30, 2015
4 changes: 3 additions & 1 deletion doc/whats_new.rst
Expand Up @@ -45,7 +45,9 @@ Changelog

- Add support to append new channels to an object from a list of other objects by `Chris Holdgraf`_

- Deprecated `lws` and renamed `ledoit_wolf` for the `reg` argument in :class:`mne.decoding.csp.CSP` and :class:`mne.preprocessing.Xdawn` by `Romain Trachel`_
- Deprecated :class: `mne.decoding.transformer.ConcatenateChannels` and replaced by :class: `mne.decoding.transformer.EpochsVectorizer` by `Romain Trachel`_

- Deprecated `lws` and renamed `ledoit_wolf` for the ``reg`` argument in :class:`mne.decoding.csp.CSP` and :class:`mne.preprocessing.Xdawn` by `Romain Trachel`_


BUG
Expand Down
4 changes: 2 additions & 2 deletions examples/decoding/plot_decoding_xdawn_eeg.py
Expand Up @@ -34,7 +34,7 @@
from mne import io, pick_types, read_events, Epochs
from mne.datasets import sample
from mne.preprocessing import Xdawn
from mne.decoding import ConcatenateChannels
from mne.decoding import EpochsVectorizer
from mne.viz import tight_layout


Expand Down Expand Up @@ -63,7 +63,7 @@

# Create classification pipeline
clf = make_pipeline(Xdawn(n_components=3),
ConcatenateChannels(),
EpochsVectorizer(),
MinMaxScaler(),
LogisticRegression(penalty='l1'))

Expand Down
6 changes: 3 additions & 3 deletions examples/realtime/plot_compute_rt_decoder.py
Expand Up @@ -55,17 +55,17 @@
from sklearn.svm import SVC # noqa
from sklearn.pipeline import Pipeline # noqa
from sklearn.cross_validation import cross_val_score, ShuffleSplit # noqa
from mne.decoding import ConcatenateChannels, FilterEstimator # noqa
from mne.decoding import EpochsVectorizer, FilterEstimator # noqa


scores_x, scores, std_scores = [], [], []

filt = FilterEstimator(rt_epochs.info, 1, 40)
scaler = preprocessing.StandardScaler()
concatenator = ConcatenateChannels()
vectorizer = EpochsVectorizer()
clf = SVC(C=1, kernel='linear')

concat_classifier = Pipeline([('filter', filt), ('concat', concatenator),
concat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),
('scaler', scaler), ('svm', clf)])

data_picks = mne.pick_types(rt_epochs.info, meg='grad', eeg=False, eog=True,
Expand Down
6 changes: 3 additions & 3 deletions examples/realtime/rt_feedback_server.py
Expand Up @@ -43,7 +43,7 @@
from mne.datasets import sample
from mne.realtime import StimServer
from mne.realtime import MockRtClient
from mne.decoding import ConcatenateChannels, FilterEstimator
from mne.decoding import EpochsVectorizer, FilterEstimator

print(__doc__)

Expand All @@ -66,10 +66,10 @@
# Constructing the pipeline for classification
filt = FilterEstimator(raw.info, 1, 40)
scaler = preprocessing.StandardScaler()
concatenator = ConcatenateChannels()
vectorizer = EpochsVectorizer()
clf = SVC(C=1, kernel='linear')

concat_classifier = Pipeline([('filter', filt), ('concat', concatenator),
concat_classifier = Pipeline([('filter', filt), ('vector', vectorizer),
('scaler', scaler), ('svm', clf)])

stim_server.start(verbose=True)
Expand Down
2 changes: 1 addition & 1 deletion mne/decoding/__init__.py
@@ -1,5 +1,5 @@
from .transformer import Scaler, FilterEstimator
from .transformer import PSDEstimator, ConcatenateChannels
from .transformer import PSDEstimator, EpochsVectorizer, ConcatenateChannels
from .mixin import TransformerMixin
from .base import BaseEstimator, LinearModel
from .csp import CSP
Expand Down
18 changes: 9 additions & 9 deletions mne/decoding/tests/test_transformer.py
Expand Up @@ -12,7 +12,7 @@

from mne import io, read_events, Epochs, pick_types
from mne.decoding import Scaler, FilterEstimator
from mne.decoding import PSDEstimator, ConcatenateChannels
from mne.decoding import PSDEstimator, EpochsVectorizer

warnings.simplefilter('always') # enable b/c these tests throw warnings

Expand Down Expand Up @@ -119,8 +119,8 @@ def test_psdestimator():
assert_raises(ValueError, psd.transform, epochs, y)


def test_concatenatechannels():
"""Test methods of ConcatenateChannels
def test_epochs_vectorizer():
"""Test methods of EpochsVectorizer
"""
raw = io.Raw(raw_fname, preload=False)
events = read_events(event_name)
Expand All @@ -131,26 +131,26 @@ def test_concatenatechannels():
epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), preload=True)
epochs_data = epochs.get_data()
concat = ConcatenateChannels(epochs.info)
vector = EpochsVectorizer(epochs.info)
y = epochs.events[:, -1]
X = concat.fit_transform(epochs_data, y)
X = vector.fit_transform(epochs_data, y)

# Check data dimensions
assert_true(X.shape[0] == epochs_data.shape[0])
assert_true(X.shape[1] == epochs_data.shape[1] * epochs_data.shape[2])

assert_array_equal(concat.fit(epochs_data, y).transform(epochs_data), X)
assert_array_equal(vector.fit(epochs_data, y).transform(epochs_data), X)

# Check if data is preserved
n_times = epochs_data.shape[2]
assert_array_equal(epochs_data[0, 0, 0:n_times], X[0, 0:n_times])

# Check inverse transform
Xi = concat.inverse_transform(X, y)
Xi = vector.inverse_transform(X, y)
assert_true(Xi.shape[0] == epochs_data.shape[0])
assert_true(Xi.shape[1] == epochs_data.shape[1])
assert_array_equal(epochs_data[0, 0, 0:n_times], Xi[0, 0, 0:n_times])

# Test init exception
assert_raises(ValueError, concat.fit, epochs, y)
assert_raises(ValueError, concat.transform, epochs, y)
assert_raises(ValueError, vector.fit, epochs, y)
assert_raises(ValueError, vector.transform, epochs, y)
107 changes: 104 additions & 3 deletions mne/decoding/transformer.py
Expand Up @@ -13,7 +13,7 @@
band_stop_filter)
from ..time_frequency import multitaper_psd
from ..externals import six
from ..utils import _check_type_picks
from ..utils import _check_type_picks, deprecated


class Scaler(TransformerMixin):
Expand Down Expand Up @@ -147,13 +147,13 @@ def inverse_transform(self, epochs_data, y=None):
return X


@deprecated("Class 'ConcatenateChannels' has been renamed to "
"'EpochsVectorizer' and will be removed in release 0.11.")
class ConcatenateChannels(TransformerMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you just do:

@deprecated("Class 'ConcatenateChannels' has been renamed to  ...")
class ConcatenateChannels(EpochsVectorizer):
     pass

in order to remove the old code of ConcatenateChannels?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, @Eric89GXL fix removed.
It seems to pass docstring test.
Thanks @agramfort !

"""Concatenates data from different channels into a single feature vector

Parameters
----------
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it didn't pass doc tests because of the call to deprecated function.
I found that removing parameters doc fix this issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't fix it this way, fix it by adding the function or class name here:

https://github.com/mne-tools/mne-python/blob/master/mne/tests/test_docstring_parameters.py#L64

info : instance of Info
The measurement info.

Attributes
----------
Expand Down Expand Up @@ -247,6 +247,107 @@ def inverse_transform(self, X, y=None):
return X.reshape(self.n_epochs, self.n_channels, self.n_times)


class EpochsVectorizer(TransformerMixin):
"""EpochsVectorizer transforms epoch data to fit into a scikit-learn pipeline.

Parameters
----------
info : instance of Info
The measurement info.

Attributes
----------
n_epochs : int
The number of epochs.
n_channels : int
The number of channels.
n_times : int
The number of time points.

"""
def __init__(self, info=None):
self.info = info
self.n_epochs = None
self.n_channels = None
self.n_times = None

def fit(self, epochs_data, y):
"""For each epoch, concatenate data from different channels into a single
feature vector.

Parameters
----------
epochs_data : array, shape (n_epochs, n_channels, n_times)
The data to concatenate channels.
y : array, shape (n_epochs,)
The label for each epoch.

Returns
-------
self : instance of ConcatenateChannels
returns the modified instance
"""
if not isinstance(epochs_data, np.ndarray):
raise ValueError("epochs_data should be of type ndarray (got %s)."
% type(epochs_data))

return self

def transform(self, epochs_data, y=None):
"""For each epoch, concatenate data from different channels into a single
feature vector.

Parameters
----------
epochs_data : array, shape (n_epochs, n_channels, n_times)
The data.
y : None | array, shape (n_epochs,)
The label for each epoch.
If None not used. Defaults to None.

Returns
-------
X : array, shape (n_epochs, n_channels * n_times)
The data concatenated over channels
"""
if not isinstance(epochs_data, np.ndarray):
raise ValueError("epochs_data should be of type ndarray (got %s)."
% type(epochs_data))

epochs_data = np.atleast_3d(epochs_data)

n_epochs, n_channels, n_times = epochs_data.shape
X = epochs_data.reshape(n_epochs, n_channels * n_times)
# save attributes for inverse_transform
self.n_epochs = n_epochs
self.n_channels = n_channels
self.n_times = n_times

return X

def inverse_transform(self, X, y=None):
"""For each epoch, reshape a feature vector into the original data shape

Parameters
----------
X : array, shape (n_epochs, n_channels * n_times)
The feature vector concatenated over channels
y : None | array, shape (n_epochs,)
The label for each epoch.
If None not used. Defaults to None.

Returns
-------
epochs_data : array, shape (n_epochs, n_channels, n_times)
The original data
"""
if not isinstance(X, np.ndarray):
raise ValueError("epochs_data should be of type ndarray (got %s)."
% type(X))

return X.reshape(self.n_epochs, self.n_channels, self.n_times)


class PSDEstimator(TransformerMixin):
"""Compute power spectrum density (PSD) using a multi-taper method

Expand Down