Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: decoding API #2290

Closed
trachelr opened this issue Jul 14, 2015 · 23 comments
Closed

ENH: decoding API #2290

trachelr opened this issue Jul 14, 2015 · 23 comments

Comments

@trachelr
Copy link
Contributor

I suggest to enhance the decoding API. Actually there's a classifier.py module in which there is a set of transformers (Scaler, ConcatenateChannels, FilterEstimator, PSDEstimator).
I think we should rename this module to a transformer.py and create 2 modules, one for classification and another for regression, in which classes like LinearClassifier / LinearRegressor would be included.
The idea is to have a bench of transformers that could be chained into a sklearn pipeline with a LinearClassifier (or Regressor). It would look like this :

from sklearn.pipeline import Pipeline
from mne.decoding.transformer import Scaler, FilterEstimator, PSDEstimator, ConcatenateChannels
from mne.decoding.classifier import LinearClassifier

# get some epoch data
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
                    picks=picks, baseline=None, preload=True)
epochs_data = epochs.get_data()
labels = epochs.events[:,-1]

# create pipeline instances
sc = Scaler(epochs.info)
psd = PSDEstimator(epochs.info, sfreq=info['sfreq'], fmin=3, fmax=25, )
cat = ConcatenateChannels(epochs.info)

# create a linear classifier from a sklearn.linear_model instance
from sklearn.linear_model import RidgeClassifier
ridge = RidgeClassifier()
lin =  LinearClassifier(ridge)

pipe = Pipeline(('power', psd), ('concat', cat), ('scale', sc), ('linear', lin))
pipe.fit(epochs_data, labels)

# look the linear model patterns and/or filter 
pipe.steps[-1][0].plot_patterns()
pipe.steps[-1][0].plot_filters()

The LinearClassifier would be a simple wrapper of sklearn.linear_model which benefit for the compute_pattern function.
The idea is to have a mne.decoding tools (transformers, classifiers, regressors) that could be easily chained into a pipeline while keeping info throughout the pipeline.
What do you this about this?

@kingjr
Copy link
Member

kingjr commented Jul 14, 2015

Yes, that could be useful to have : clf = LinearClassifier(clf) so as to have clf.patterns_. However, I wouldn't put any plotting function quite there; if people want to use these things for time frequency data, source space etc, you won't be able to support this.

@choldgraf
Copy link
Contributor

I was just looking into TFR decompositions in MNE - sorry if I am late to the party here, but I'm wondering why functions like PSDEstimator exist in a different place than the time_frequency module. Does it do anything different from the functions in time_frequency (beyond being a class instead of function)?

@jasmainak
Copy link
Member

I think the idea was to be able to use these objects in an sklearn pipeline. If they don't follow sklearn API, you can't do that.

@choldgraf
Copy link
Contributor

yeah, that makes sense to me. I could just see people getting confused as to why there's this other function here that seems replicated in the time_frequency section. Probably not a big deal but it took me a few passes through the code to realize it was just a convenient wrapper to structure stuff w/ sklearn

@choldgraf
Copy link
Contributor

and I should mention - this does also implement a helpful new addition, which is running a multitaper PSD on epochs objects. Right now multitaper_psd expects n_signals x n_times, whereas this one lets you give a 3D array and takes care of reshaping etc under the hood.

@jasmainak
Copy link
Member

Ah, I think this would be a good candidate for improved documentation. Relates to #1495. If you feel like making a contribution, it would be very welcome and much appreciated. Maybe, it should go into a new section on Decoding in the manual.

@choldgraf
Copy link
Contributor

Sounds good - think it's better to include in a docstring of one of these functions, or as a separate tutorial kind of thing? It would be useful if there was a single notebook on "estimating the PSD with MNE-python"

@kingjr
Copy link
Member

kingjr commented Dec 15, 2015

I'm +1 on this.

@choldgraf
Copy link
Contributor

On that note - would people be +1 to this function adding a 'freqs' attribute when it estimates the PSD, rather than just throwing it away? And potentially moreover, adding a 'data' attribute once it's run as well, so that it can be contained within the object? I see that this is starting to stray from the original intended usage, but I think this is a useful structure for estimating the PSD in general.

@choldgraf
Copy link
Contributor

See #2710 for a first pass

@jasmainak
Copy link
Member

You might want to update the "See Also" section of the docstring. I'm not super sure of retaining attributes. They are usually retained in the fit method when you do decoding. Isn't there an equivalent function for epochs? Maybe this one?

def compute_epochs_psd(epochs, picks=None, fmin=0, fmax=np.inf, tmin=None,
.

+1 on a simple tutorial.

@choldgraf
Copy link
Contributor

There is a similar function for epochs, but it hasn't been updated to include the multitaper PSD. That's why I assumed that this class was actually an extension to the methods in time_frequency

@jasmainak
Copy link
Member

umm ... no, I think this class is meant for a different purpose -- for decoding. I guess you could use this for now if it does your job. But the cleaner approach would be to update the function for epochs. @agramfort might be able to advice better.

@jasmainak
Copy link
Member

Hey @choldgraf : were you looking for this? http://martinos.org/mne/stable/generated/mne.time_frequency.tfr_multitaper.html?highlight=tfr_multitaper#mne.time_frequency.tfr_multitaper

update: I actually realized that this is for TFR ... so yeah, we are still missing something for PSD

@choldgraf
Copy link
Contributor

I think the confusion comes from the fact that this class adds something different that doesn't exist in time_frequency.

To my knowledge, in time_frequency you have a few options:

  1. multitaper_psd - which does multitaper PSD but accepts n_signals x n_times, not epochs
  2. compute_epochs_psd - which accepts n_trials x n_signals x n_times but uses a welch PSD
  3. tfr_XXX - which does a time-frequency decomp and not a PSD

So this is the only thing that both allows you to use an epochs-shaped array and lets you use a multitaper method for calculating a PSD

I agree that the decoding stuff should be in decoding, and the time frequency stuff should be in time_frequency. That's why I was confused about this class in the first place

@jasmainak
Copy link
Member

yeah, I agree that it's confusing. If you can help make it consistent and/or improve the documentation via manual / tutorial pages, that's more than welcome :)

@choldgraf
Copy link
Contributor

Well I guess the question is do we:

A. Make it clearer with documentation etc, or
B. Make a function in time_frequency that lets you do epochs PSD estimation with multitaper method (aka, that mimics the functionality of this class, and then just have this class call that function rather than doing the reshaping on its own)

I'm partial to B...

@jasmainak
Copy link
Member

yeah, I'm good with B. +1 from my end. Any other opinion?

@jasmainak
Copy link
Member

Maybe @trachelr could comment on this ...

@choldgraf
Copy link
Contributor

Cool - I can make a PR out of it if the addition isn't too complex. I'll wait to see if others have thoughts.

@agramfort
Copy link
Member

agramfort commented Dec 16, 2015 via email

@trachelr
Copy link
Contributor Author

+1 for option B and make tutorials for decoding

@larsoner larsoner changed the title ENH: decoding API label:"mne sprint" ENH: decoding API Jan 6, 2017
@agramfort
Copy link
Member

we've made progress now. Closing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants