-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Self-ensembling for sequence-to-sequence models #294
Conversation
Codecov Report
@@ Coverage Diff @@
## master #294 +/- ##
==========================================
+ Coverage 80.10% 80.19% +0.08%
==========================================
Files 49 49
Lines 2931 2944 +13
==========================================
+ Hits 2348 2361 +13
Misses 583 583 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tried this on a real dataset already @hubertjb ?
braindecode/models/util.py
Outdated
------- | ||
np.ndarray : | ||
Array of shape (n_rows, n_classes, (n_rows - 1) * stride + n_windows) | ||
where each row is obtained by zero-padding the corresponding row in `x` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where each row is obtained by zero-padding the corresponding row in `x` | |
where each row is obtained by zero-padding the corresponding row in ``x`` |
braindecode/models/util.py
Outdated
""" | ||
if x.ndim != 3: | ||
raise NotImplementedError( | ||
f'x must be of shape (n_rows, n_clases, n_windows), got {x.shape}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f'x must be of shape (n_rows, n_clases, n_windows), got {x.shape}') | |
f'x must be of shape (n_rows, n_classes, n_windows), got {x.shape}') |
test/unit_tests/models/test_util.py
Outdated
n_outputs = (n_sequences - 1) * stride + n_windows | ||
shifted_y = np.concatenate([ | ||
np.concatenate(( | ||
np.zeros((1, n_classes, i * stride)), dense_y[[i]], | ||
np.zeros((1, n_classes, n_outputs - n_windows - i * stride))), | ||
axis=2) | ||
for i in range(n_sequences)], axis=0) | ||
shifted_y2 = _pad_shift_array(dense_y, stride=stride) | ||
|
||
assert (shifted_y == shifted_y2).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you just testing zero arrays here? might be good to have somethign that tests actual values as well, like 1-2 more explicit tests with predefined values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is testing against actual values - the concatenation only adds zeros around the values in dense_y
.
test/unit_tests/models/test_util.py
Outdated
def test_aggregate_probas(n_sequences, n_classes, n_windows, stride): | ||
n_outputs = (n_sequences - 1) * stride + n_windows | ||
y_true = np.arange(n_outputs) % n_classes | ||
logits = OneHotEncoder(sparse=False).fit_transform(y_true.reshape(-1, 1)) | ||
logits = np.lib.stride_tricks.sliding_window_view( | ||
logits, n_windows, axis=0)[::stride] | ||
|
||
y_pred_probas = aggregate_probas(logits, n_windows_stride=stride) | ||
|
||
assert y_pred_probas.ndim == 2 | ||
assert y_pred_probas.shape[0] == n_outputs | ||
assert y_pred_probas.shape[1] == n_classes | ||
assert (y_true == y_pred_probas.argmax(axis=1)).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you describe a bit more what you are testing here? like write a comment what is being tested? are values actually being tested? or just shapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a few comments to make this clearer. But yes, this tests both shapes and values!
maybe write a short line in whats new @hubertjb ? Then I can merge |
bd74581
to
8963567
Compare
Great, merged |
This PR introduces a utility function
aggregate_probas
to perform self-ensembling as described in Phan et al. (2018). This works by aggregating the window-wise predictions of a seq2seq model when fed with overlapping sequences. This will be useful e.g. in #282.To simplify its use, I also added an attribute
file_ids
toSequenceSampler
so we can easily see which file each sequence comes from.In the end, the function can be used in the following way given a seq2seq model
clf
(e.g. anEEGClassifier
that wrapsUSleep
):@tgnassou @agramfort