-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adding SequenceDataset
class to train on sequences of windows
#263
Conversation
Also, I was not sure whether we would want to keep both sleep staging tutorials - maybe just this one would be enough? Although I like that the old one is fairly straightforward. |
braindecode/datasets/base.py
Outdated
On-the-fly transform applied to a window before it is returned. | ||
seq_len : int | ||
Number of consecutive windows in a sequence. | ||
step_len : int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does step_len
parameter correspond to stride
in create_windows
functions? Maybe stride
would be a better name then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could name to n_windows_seq
and n_windows_stride
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would vote for n_windows and n_windows_stride
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also fine for me
braindecode/datasets/base.py
Outdated
seq, ys = list(), list() | ||
# The following cannot be put into a list comprehension | ||
# because of scope requirements for `super()` | ||
for i in range(start_ind, start_ind + self.seq_len): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this fast enough for our application? Can't we use __getitem__
to obtain multiple windows?
I see that this line does not allow this X = self.windows.get_data(item=index)[0].astype('float32')
but maybe we can modify WindowsDataset
to make it possible? I think that you can create an additional method that returns multiple windows and use it here and in the WindowsDataset.__getitem__
where you can select only the first epoch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed some performance measurements may be helpful. in case it is fast enough, it does seem nicely decoupled at least this way.
braindecode/datasets/base.py
Outdated
@@ -291,3 +375,27 @@ def save(self, path, overwrite=False): | |||
if concat_of_raws: | |||
json.dump({'target_name': target_name}, open(target_file_name, 'w')) | |||
self.description.to_json(description_file_name) | |||
|
|||
|
|||
def get_sequence_dataset(concat_ds, seq_len, step_len=1, label_transform=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe to be more consistent on naming it would be better to use create_sequence_dataset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me it looks like this function is quite similar to create_windows
functions? It takes concat_ds
and returns BaseConcatDataset
. Maybe we should move it to preprocessing/windowers.py
? I feel like datasets/base.py
is not the best place for this function but as well preprocessing/windowers.py
may not be perfect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely also feel create_sequence_dataset
would be nicer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
n_channels, n_times). | ||
""" | ||
feats = [self.feat_extractor.embed(x[:, i]) for i in range(x.shape[1])] | ||
feats = torch.stack(feats, dim=0).transpose(0, 1).flatten(start_dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the batch dimension is not returned as the first dimension? Why do we need this transpose before flatten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just needed to change to dim=1
and I could remove the transpose.
braindecode/datasets/base.py
Outdated
@@ -142,6 +144,88 @@ def transform(self, value): | |||
self._transform = value | |||
|
|||
|
|||
class SequenceDataset(WindowsDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hubertjb @robintibor I still wonder whether functionalities of this class couldn't be added to the WindowsDataset
. It would need parameters in WindowsDataset.__init__
that specify the seq_len=1
and step_len=1
with default values and the rest of the behavior would stay similar (we have to keep self.start_inds
for WindowsDataset but with default parameters it should behave the same as before)? I would say that the behavior of this class is not so different from the base class. As well label_transform
may be used in the standard case without multiple windows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case maybe a better name would be SequenceWindowsDataset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree that SequenceWindowsDataset is maybe more explicit
braindecode/datasets/base.py
Outdated
if 'ignore' it will proceed silently. | ||
""" | ||
def __init__(self, windows, description=None, transform=None, seq_len=21, | ||
step_len=1, label_transform=None, on_missing='raise'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about target_transform
instead of label_transform
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we consistent about this naming in braindecode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we mostly use "target", so target_transform
makes more sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also how they name it in torchvision ;)
What was the problem when trying to do this via pytorch loader/sampler instead of a new dataset? |
braindecode/datasets/base.py
Outdated
import pandas as pd | ||
from mne.utils.check import _on_missing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be careful using private mne functions in downstream packages
braindecode/datasets/base.py
Outdated
@@ -142,6 +144,88 @@ def transform(self, value): | |||
self._transform = value | |||
|
|||
|
|||
class SequenceDataset(WindowsDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree that SequenceWindowsDataset is maybe more explicit
braindecode/datasets/base.py
Outdated
|
||
Parameters | ||
---------- | ||
windows : mne.Epochs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to rather construct it from a WindowsDataset ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, yes I think this might be better. If we decide to stick with this approach (see #263 (comment)) then I can do the change.
braindecode/datasets/base.py
Outdated
Holds additional info about the windows. | ||
transform : callable | None | ||
On-the-fly transform applied to a window before it is returned. | ||
seq_len : int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_windows?
braindecode/datasets/base.py
Outdated
On-the-fly transform applied to a window before it is returned. | ||
seq_len : int | ||
Number of consecutive windows in a sequence. | ||
step_len : int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would vote for n_windows and n_windows_stride
braindecode/datasets/base.py
Outdated
if seq_len > len(windows): | ||
msg = ('Sequence length is larger than number of windows ' | ||
f'({seq_len} > {len(windows)}).') | ||
_on_missing(on_missing, msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you foresee a case where user with not use on_missing = 'raise' here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking if you're working with a large uncurated dataset you might want to just ignore files that are not long enough for the sequence length you're interested in. But I guess right now this would break anyway when calling getitem. Maybe I should just remove the on_missing logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed in the __getitem__
way it would crash anyways, hence on_missing
could not be used...
braindecode/datasets/base.py
Outdated
Returns | ||
------- | ||
np.ndarray : | ||
Sequence of windows, of shape (seq_len, n_chs, *). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_chs -> n_channels
braindecode/datasets/base.py
Outdated
@@ -291,3 +375,27 @@ def save(self, path, overwrite=False): | |||
if concat_of_raws: | |||
json.dump({'target_name': target_name}, open(target_file_name, 'w')) | |||
self.description.to_json(description_file_name) | |||
|
|||
|
|||
def get_sequence_dataset(concat_ds, seq_len, step_len=1, label_transform=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
###################################################################### | ||
# We extract 30-s windows to be used in the classification task. | ||
|
||
from braindecode.preprocessing.windowers import create_windows_from_events |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from braindecode.preprocessing.windowers import create_windows_from_events | |
from braindecode.preprocessing import create_windows_from_events |
input_size_s=input_size_samples / sfreq | ||
) | ||
|
||
model = TimeDistributedNet( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to move the TimeDistributedNet in the main library?
@robintibor @sliwy @gemeinl thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. I would make sense to include it in the same script as SleepStagerChambon2018 for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I agree it would definitely make sense, would need more documentation like seems atm len_last_layer
needs to exist in the supplied feature extractor. I think to keep overhead low, it is fine to add models even not fully polished (like here, one could also try to determine given an input size how long is the output etc.) and polish them if more people use them/report problems. But of course these intricacies like len_last_layer
should then be in the docstring at least.
Thanks @sliwy @robintibor @agramfort for the comments! Before going into the details, here's a summary of our options for feeding sequences of windows to a model:
I think 1-3 can all work, but 4 feels like a reach (the relevant information for creating sequences is lost by the time we get to collate_fn, and it might interfere with batch-level data augmentation in #254). What I like about 1 and 2 is that they are conceptually simple. I think 2 is nicer because the reading of epochs might be more efficient as @sliwy suggested. Also this would avoid creating a new Dataset class. What I like about 3 is that more complicated sampling logic can be more easily implemented, e.g. first sampling a recording, then a specific label, and then the sequence around it (like in the USleep paper). Also this removes any need to introduce a new Dataset class or to add logic to WindowsDataset. The more I think about it, the more I like this approach. So, after some more thought, my ranking is now: I'm curious to hear what you think! |
Yes I agree that Sampler makes most sense from a PyTorch-logic kind of way. Since I have not implemented custom samplers myself, I don't know if there is something that would not be possible in this way... but if you also like this approach then definitely have a try and see if it works for your cases. |
Codecov Report
@@ Coverage Diff @@
## master #263 +/- ##
=============================
=============================
|
The new version uses a Sampler instead of a dedicated SequenceDataset. I think this might be more flexible in the long run. Let me know what you guys think. I've left the Left to do:
|
What kind of concatenation is needed here? I guess, it is different from batch concatenation. Would it be possible to implement this through |
Also, you can create subsets of windows with |
I am a bit confused now, so with the sampler-based approach you could remove SequenceWindowsDataset? Like currently in the example you are using both? Would you also have to write a loader in order to be able to transform targets? In that case, would you still feel better with Loader/Sampler or do you feel SequenceWindowsDataset would then be simpler? |
We need to stack multiple windows along a new axis. Since we still want to use minibatches (of sequences), this means the models should expect tensors of shape (batch_size, n_windows, n_channels, n_times) or something like that. I don't think |
I think this would be equivalent to creating a new WindowsDataset that can only return a subset of the original windows. What we need is to return sequences of windows, so I don't think this would work. |
Sorry for the confusion @robintibor, I should have called the dataset in the example something else! The reason I had to create a SequenceWindowsDataset in the example (which is not the same as the previous SequenceWindowsDataset!) is that the I think the target transform should also be implemented in WindowsDataset, as this might be a more generally applicable feature anyway. So in the end, there would be no need for a SequenceWindowsDataset. Sampling many consecutive windows and target transformation would both be handled by WindowsDataset. |
As discussed this sounds very reasonable to me |
This latest version gets rid of the first In the end I decided against implementing the logic for grabbing multiple windows in WindowsDataset, and instead implemented it in BaseConcatDataset. The reason is that the pytorch Second conceptual issue I encountered was about |
braindecode/datasets/base.py
Outdated
"""Get a window and its target, or a list of windows and targets. | ||
|
||
Parameters | ||
---------- | ||
index : int | ||
Index to the window (and target) to return. | ||
|
||
Returns | ||
------- | ||
np.ndarray | ||
Window of shape (n_channels, n_times). | ||
int | np.ndarray | ||
Target(s) for the windows. | ||
np.ndarray | ||
Crop indices. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm but now this only works for BaseConcatDataset not for BaseDataset right? like there is no code change in function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, this is an artifact of a previous version. I'll clean up the docstring.
def _get_sequence(self, indices): | ||
X, y = list(), list() | ||
for ind in indices: | ||
out_i = super().__getitem__(ind) | ||
X.append(out_i[0]) | ||
y.append(out_i[1]) | ||
|
||
X = np.stack(X, axis=0) | ||
y = np.array(y) | ||
if self.seq_target_transform is not None: | ||
y = self.seq_target_transform(y) | ||
|
||
return X, y | ||
|
||
def __getitem__(self, idx): | ||
""" | ||
Parameters | ||
---------- | ||
idx : int | list | ||
Index of window and target to return. If provided as a list of | ||
ints, multiple windows and targets will be extracted and | ||
concatenated. | ||
""" | ||
if isinstance(idx, Iterable): # Sample multiple windows | ||
return self._get_sequence(idx) | ||
else: | ||
return super().__getitem__(idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fairly clean?
This I don't understand. Can't we just have a single target_transform? I mean the user should know whether they are getting sequences or elements out of their pipeline? hmhm or did you feel it should work in any case? At least then I think we should implement both already and not only the |
…taset` to return sequences of windows - adding modified sleep staging examples that work on sequences of EEG windows - adding a method `embed` to `SleepStagetChambon2018` class to reuse it as a feature extractor in the sleep staging on sequences examples - adding tests
… test - updating sleep_staging_sequences example to use sampler instead of SequenceDataset
- adding seq_target_transform property to BaseConcatDataset to transform the targets when extracting sequences of SequenceWindowsDataset - adding functionality in BaseConcatDataset to index with a list of indices, such that a sequence a windows and targets can be returned - fixing SleepPhysionet bug where cropping would cause an error with some recordings - changing hyperparameters of sleep staging example to improve performance in a few epochs - adding more representative splitting functionality in sleep staging example - using class weights for imbalanced classes in sleep staging example
- fixing formatting in sleep staging example - adding SequenceSampler to docs
008ae93
to
e8d126e
Compare
SequenceDataset
class to train on sequences of windowsSequenceDataset
class to train on sequences of windows
indices = range(100) | ||
y = concat_windows_dataset[indices][1] | ||
|
||
transform = lambda x: sum(x) # noqa: E731 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a small thing, but why not just:
transform=sum
?
transform = lambda x: sum(x) # noqa: E731 | |
transform = sum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops I simplified a more complicated expression and didn't realize what it had become :P
- adding whatsnew
Amazing work! Example looks great to me as well! Merged! |
1 similar comment
Amazing work! Example looks great to me as well! Merged! |
This PR adds a
SequenceDataset
object that returns sequences of consecutive windows instead of single windows. This is helpful e.g. in sleep staging tasks where we might want to do sequence-to-label or even sequence-to-sequence prediction.I considered different approaches for returning sequences of windows (including using samplers) but making a separate dataset class made the most sense to me. Other pytorch projects working on videos (i.e. sequences of images) have used a similar approach too e.g. https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch/blob/main/video_dataset.py. This means though that some additional work would be required to ensure
BaseConcatDataset
functionalities likeget_metadata()
and serialization work as intended.I created an updated version of the sleep staging example to show how a model can be trained on sequences of windows, as that was the case for the example's reference paper (Chambon et al. 2018). For simplicity though I've started with end-to-end training of the model - in the paper the feature extractor is first pretrained, then frozen, before training a classifier, which adds quite a bit of logic.
Remaining tasks: