Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving feature names #42

Merged
merged 6 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/plot_mne_sample_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
# Prepare for the classification task:

pipe = Pipeline([('scaler', StandardScaler()),
('lr', LogisticRegression(random_state=42))])
('lr', LogisticRegression(random_state=42, solver='lbfgs'))])
y = labels

###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_mne_sample_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
selected_funcs=['app_entropy',
'mean'])),
('scaler', StandardScaler()),
('clf', LogisticRegression(random_state=42))])
('clf', LogisticRegression(random_state=42, solver='lbfgs'))])
skf = StratifiedKFold(n_splits=3, random_state=42)
y = labels

Expand Down
2 changes: 1 addition & 1 deletion examples/plot_user_defined_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compute_medfilt(arr):
pipe = Pipeline([('fe', FeatureExtractor(sfreq=raw.info['sfreq'],
selected_funcs=selected_funcs)),
('scaler', StandardScaler()),
('clf', LogisticRegression(random_state=42))])
('clf', LogisticRegression(random_state=42, solver='lbfgs'))])
skf = StratifiedKFold(n_splits=3, random_state=42)
y = labels

Expand Down
18 changes: 8 additions & 10 deletions mne_features/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .bivariate import get_bivariate_funcs
from .univariate import get_univariate_funcs
from .utils import _get_python_func


class FeatureFunctionTransformer(FunctionTransformer):
Expand Down Expand Up @@ -66,13 +67,19 @@ def transform(self, X, y='deprecated'):
details.
"""
X_out = super(FeatureFunctionTransformer, self).transform(X, y)
_feature_func = _get_python_func(self.func)
_params = self.get_params()
if hasattr(_feature_func, 'get_feature_names'):
self.feature_names = _feature_func.get_feature_names(X, **_params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that calling transform affects the state of the object. Only a fit is allowed to do this. Can you see a way out? also any attribute that is data dependent should end with _

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then, what about

    def fit(self, X, y=None):
        """Fit the FeatureFunctionTransformer (does not extract features).
        
        Parameters
        ----------
        X : ndarray, shape (n_channels, n_times)
        
        y : ignored
        
        Returns
        -------
        self
        """
        self._check_input(X)
        _feature_func = _get_python_func(self.func)
        _params = self.get_params()
        if hasattr(_feature_func, 'get_feature_names'):
            self.feature_names_ = _feature_func.get_feature_names(X, **_params)
        return self

in FeatureFunctionTransformer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is ok in terms of API but _params = self.get_params() could be in the if block

self.output_shape_ = X_out.shape[0]
return X_out

def get_feature_names(self):
"""Mapping of the feature indices to feature names."""
if not hasattr(self, 'output_shape_'):
raise ValueError('Call `transform` or `fit_transform` first.')
elif hasattr(self, 'feature_names'):
return self.feature_names
else:
return np.arange(self.output_shape_).astype(str)

Expand All @@ -85,16 +92,7 @@ def get_params(self, deep=True):
If True, the method will get the parameters of the transformer.
(See :class:`~sklearn.preprocessing.FunctionTransformer`).
"""
_params = super(FeatureFunctionTransformer, self).get_params(deep=deep)
if hasattr(_params['func'], 'func'):
# If `_params['func'] is of type `functools.partial`
func_to_inspect = _params['func'].func
elif hasattr(_params['func'], 'py_func'):
# If `_params['func'] is a jitted Python function
func_to_inspect = _params['func'].py_func
else:
# If `_params['func'] is an actual Python function
func_to_inspect = _params['func']
func_to_inspect = _get_python_func(self.func)
# Get code object from the function
if hasattr(func_to_inspect, 'func_code'):
func_code = func_to_inspect.func_code
Expand Down
2 changes: 1 addition & 1 deletion mne_features/tests/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_gridsearch_feature_extractor():
('clf', CheckingClassifier(
check_X=lambda arr: arr.shape[1:] == (X.shape[1],)))])
params_grid = {'FE__higuchi_fd__kmax': [5, 10]}
gs = GridSearchCV(estimator=pipe, param_grid=params_grid)
gs = GridSearchCV(estimator=pipe, param_grid=params_grid, cv=3)
gs.fit(X, y)
assert_equal(hasattr(gs, 'cv_results_'), True)

Expand Down
22 changes: 22 additions & 0 deletions mne_features/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,28 @@ def compute_pow_freq_bands(sfreq, data, freq_bands=np.array([0.5, 4., 8., 13.,
return band_ratios.ravel()


def _compute_pow_freq_bands_feat_names(data, freq_bands, normalize, ratios):
"""Utility function to create feature names compatible with the output
of :func:`compute_pow_freq_bands`."""
n_channels = data.shape[0]
n_freq_bands = (freq_bands.shape[0] - 1 if freq_bands.ndim == 1 else
freq_bands.shape[0])
ratios_names = ['ch%s_%s_%s' % (ch_num, i, j) for ch_num in
range(n_channels) for _, i, j in
_idxiter(n_freq_bands, triu=False)]
pow_names = ['ch%s_%s' % (ch_num, i) for ch_num in
range(n_channels) for i in range(n_freq_bands)]
if ratios is None:
return pow_names
elif ratios == 'only':
return ratios_names
else:
return pow_names + ratios_names


compute_pow_freq_bands.get_feature_names = _compute_pow_freq_bands_feat_names


def compute_hjorth_mobility_spect(sfreq, data, normalize=False):
"""Hjorth mobility (per channel).

Expand Down
23 changes: 23 additions & 0 deletions mne_features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,29 @@ def _get_feature_funcs(sfreq, module_name):
return feature_funcs


def _get_python_func(func):
"""Get the Python function underlying partial or jitted functions.

Parameters
----------
func : function or instance of `functools.partial` or jitted function.
Transfomed feature function.

Returns
-------
function
"""
if hasattr(func, 'func'):
# If `func` is of type `functools.partial`
return func.func
elif hasattr(func, 'py_func'):
# If `func` is a jitted Python function
return func.py_func
else:
# If `func` is an actual Python function
return func


def _wavelet_coefs(data, wavelet_name='db4'):
"""Compute Discrete Wavelet Transform coefficients.

Expand Down