Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Predict trials #312

Merged
merged 9 commits into from
Aug 9, 2021
Merged

Conversation

gemeinl
Copy link
Collaborator

@gemeinl gemeinl commented Jul 26, 2021

I am not sure I fully understood the issue. Is this what you had in mind @robintibor ?
What would be the best location for predict_trials?

@gemeinl gemeinl changed the title Predict trials [WIP] Predict trials Jul 26, 2021
@@ -268,3 +270,17 @@ def predict(self, X):

"""
return self.predict_proba(X).argmax(1)

def predict_trials(self, X):
"""Return trialwise predictions and targets.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, can you add one sentence how the expected output shape will be for trial_predictions and trial_labels, including meaning of dimensions for trial_predictions. That should be helpful

-------
trial_predictions, trial_labels: tuple(np.ndarray, np.ndarray)
"""
return predict_trials(self.module, X)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also check if self.cropped is True and if not raise error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, should there be an error? We always have trials no matter if we do cropped decoding or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think
a) Right now code would fail without cropped decoding? There are a lot of assumptions inside coming from cropped decoding no? Or would it just run?
b) for trialwise decoding existing predict method would already give you trialwise predictions right?

If that is correct either we assert we are in cropped mode (and maybe also add to name like predict_trials_from_cropped) or we just call regular predict function if not in cropped mode? We should check what happens atm when this function is called e.g. at end of trialwise decoding example

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not so much familiar with the internals of the cropped decoding to know about the assumptions.
In my opinion predict_trials should always work, no matter if cropped or trialwise decoding, since we always have trials.
For trialwise decoding the output will then just be the same as calling predict. For cropped decoding it will be different.

-------
trial_predictions, trial_labels: tuple(np.ndarray, np.ndarray)
"""
return predict_trials(self.module, X)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here for comments above... Maybe we could also consider to have a superclass EEGNeuralNet for both to avoid duplication?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely. I just added one more duplicate to the large list of code duplicates in those classes...

@robintibor
Copy link
Contributor

Thanks, great, yes that is what I had in mind. I was thinking if function name should somehow indicate that trial labels are being returned as well? Otherwise might be surprising? Or maybe just make return_labels:bool a parameter of the function? to return labels or not? Even with default true, would more explicitly indicate what is happening.

And of course needs tests + whats_new

@gemeinl
Copy link
Collaborator Author

gemeinl commented Jul 27, 2021

Yes it is surprising but what you requested.
How about we rename to get_trial_preds_and_labels?
Or do you like the return_labels flag better?

@robintibor
Copy link
Contributor

I think return_labels flag is better, predict_trials is quite nice name

@@ -30,6 +30,7 @@ Enhancements
- Adding Mixup augmentation :class:`braindecode.augmentation.Mixup` (:gh:`254` by `Simon Brandt`_)
- Adding saving of preprocessing and windowing choices in :func:`braindecode.preprocessing.preprocess`, :func:`braindecode.preprocessing.create_windows_from_events` and :func:`braindecode.preprocessing.create_fixed_length_windows` to datasets to facilitate reproducibility (:gh:`287` by `Lukas Gemein`_)
- Adding :func:`braindecode.models.util.aggregate_probas` to perform self-ensembling of predictions with sequence-to-sequence models (:gh:`294` by `Hubert Banville`_)
- Adding :func:`braindecode.training.scoring.predict_trials` to generate trialwise predictions (:gh:`312` by `Lukas Gemein`_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Adding :func:`braindecode.training.scoring.predict_trials` to generate trialwise predictions (:gh:`312` by `Lukas Gemein`_)
- Adding :func:`braindecode.training.scoring.predict_trials` to generate trialwise predictions after cropped training (:gh:`312` by `Lukas Gemein`_)

To make even clearer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sitll disagree. For me this is not specific for cropped decoding. See comment above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robintibor In braindecode.training.scoring.predict_trials we cannot know whether the model was trained in cropped fashion. We need an EEGClassifier / EEGRegressor for this. Is it save to add this function anyways? Or should we remove it and only have EEGClassifier/EEGRegressor.predict_trials() which will then check self.cropped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this here @robintibor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have it so people can use this if they are not using skorch.


Returns
-------
trial_predictions, trial_labels: tuple(np.ndarray, np.ndarray)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please specify the dimension of each array. This syntax tuple(np.ndarray, np.ndarray) is odd to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I won't expect dimensions in the header line of a docstring. It should be in the Returns statement.

you should skim through https://www.python.org/dev/peps/pep-0257/

we should activate https://github.com/PyCQA/pydocstyle on the repo....

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you. I support the idea for more automated checks to ensure we follow conventions.

@@ -295,3 +297,47 @@ def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
cached_net, dataset_train, self.y_trues_
)
self._record_score(net.history, current_score)


def predict_trials(module, dataset, return_targets=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why adding 2 ways of doing the same thing ie the method and this public function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess braindecode.training.scoring.predict_trials exists because you don't need an EEGClassifier / EEGRegressor to make predictions. It is sufficient to have any kind of model given as PyTorch module as well as a braindecode dataset.

However, you would expect your estimator to provide predict.. right?


def predict_trials(module, dataset, return_targets=True):
"""Create trialswise predictions (n_trials x n_classes x n_predictions),
and optionally also return trialwise labels (n_trials x n_targets) from
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_targets and n_predictions are conceptually different things? sorry but I get confused by n_trials, something I I would call n_crops etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_predictions is in time domain and depends on window_size and size of the receptive field of the net.
n_targets can be very different. It can be a single value in classification / regression tasks, it can be multiple values in multiple discrete target classification as introduced with #267, and it can also be a sequence as in #261.

Why do you get confused by n_trials? That is the point, we do not have crops / compute windows at this point. These are actual trials. If you want to generate crop / compute window predictions you would call .predict() instead. In predict_trials we invert the creation of compute windows (removing potentially overlapping predictions) to obtain trial predictions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok much clearer 🙏 . Maybe it's worth adding a glossary as we do with MNE https://mne.tools/stable/glossary.html ?

@gemeinl
Copy link
Collaborator Author

gemeinl commented Jul 30, 2021

why test_variable_length_trials_decoding fails?

@codecov
Copy link

codecov bot commented Jul 30, 2021

Codecov Report

Merging #312 (8855c7a) into master (a436f25) will increase coverage by 0.11%.
The diff coverage is 90.47%.

@@            Coverage Diff             @@
##           master     #312      +/-   ##
==========================================
+ Coverage   80.27%   80.38%   +0.11%     
==========================================
  Files          49       49              
  Lines        3047     3085      +38     
==========================================
+ Hits         2446     2480      +34     
- Misses        601      605       +4     

@robintibor
Copy link
Contributor

Worked on Rerun, so probably just need to increase tolerance even further

@gemeinl gemeinl changed the title [WIP] Predict trials [MRG] Predict trials Aug 9, 2021
@gemeinl
Copy link
Collaborator Author

gemeinl commented Aug 9, 2021

Done from my side unless @robintibor disagrees with current implementation regarding cropped / trialwise stuff...

Comment on lines 328 to 331
cropped_data = sum(dataset.get_metadata()['i_window_in_trial'] != 0) > 0
if not cropped_data:
raise ValueError('This function was designed to predict trials from '
'cropped datasets. This is a trialwise dataset.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cropped_data = sum(dataset.get_metadata()['i_window_in_trial'] != 0) > 0
if not cropped_data:
raise ValueError('This function was designed to predict trials from '
'cropped datasets. This is a trialwise dataset.')
more_than_one_window = sum(dataset.get_metadata()['i_window_in_trial'] != 0) > 0
if not more_than_one_window:
warnings.warn('This function was designed to predict trials from '
'cropped datasets, which typically have multiple compute windows per trial .'
'The given dataset has exactly one window per trial,')

Is it sure that this must be the case? I think this is not necessarily true. Maybe at least downgrade to a warning instead of an error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not come up with a case where it does not hold. Isn't it the definition of trialwise decoding to have one window per trial?

Copy link
Contributor

@robintibor robintibor Aug 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one can consider (and at least that's what I discussed with Tonio as well as far as I recall) trialwise decoding implies single window and single prediction per trial. Whereas if you have single window but multiple predictions, than you can still consider it cropped decoding. and all the existing cropped decoding code should also run fine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Then I will update according to your suggestions

@robintibor
Copy link
Contributor

Made one comment, what do you think @gemeinl ?

@robintibor robintibor merged commit 5b826c4 into braindecode:master Aug 9, 2021
@robintibor
Copy link
Contributor

Great stuff, merged now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Utility functions for computing trial predictions after cropped training
3 participants