Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Adding SequenceDataset class to train on sequences of windows #263

Merged
merged 10 commits into from
Jun 25, 2021

Conversation

hubertjb
Copy link
Collaborator

This PR adds a SequenceDataset object that returns sequences of consecutive windows instead of single windows. This is helpful e.g. in sleep staging tasks where we might want to do sequence-to-label or even sequence-to-sequence prediction.

I considered different approaches for returning sequences of windows (including using samplers) but making a separate dataset class made the most sense to me. Other pytorch projects working on videos (i.e. sequences of images) have used a similar approach too e.g. https://github.com/RaivoKoot/Video-Dataset-Loading-Pytorch/blob/main/video_dataset.py. This means though that some additional work would be required to ensure BaseConcatDataset functionalities like get_metadata() and serialization work as intended.

I created an updated version of the sleep staging example to show how a model can be trained on sequences of windows, as that was the case for the example's reference paper (Chambon et al. 2018). For simplicity though I've started with end-to-end training of the model - in the paper the feature extractor is first pretrained, then frozen, before training a classifier, which adds quite a bit of logic.

Remaining tasks:

  • Adapt get_metadata() function for a BaseConcatDataset of WindowsDatasets?
  • Update serialization function to be able to reload SequenceDatasets too
  • Pick hyperparameters so that it runs fast
  • Documentation etc.

@hubertjb
Copy link
Collaborator Author

Also, I was not sure whether we would want to keep both sleep staging tutorials - maybe just this one would be enough? Although I like that the old one is fairly straightforward.

On-the-fly transform applied to a window before it is returned.
seq_len : int
Number of consecutive windows in a sequence.
step_len : int
Copy link
Collaborator

@sliwy sliwy Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does step_len parameter correspond to stride in create_windows functions? Maybe stride would be a better name then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could name to n_windows_seq and n_windows_stride?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote for n_windows and n_windows_stride

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fine for me

seq, ys = list(), list()
# The following cannot be put into a list comprehension
# because of scope requirements for `super()`
for i in range(start_ind, start_ind + self.seq_len):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this fast enough for our application? Can't we use __getitem__ to obtain multiple windows?
I see that this line does not allow this X = self.windows.get_data(item=index)[0].astype('float32') but maybe we can modify WindowsDataset to make it possible? I think that you can create an additional method that returns multiple windows and use it here and in the WindowsDataset.__getitem__ where you can select only the first epoch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed some performance measurements may be helpful. in case it is fast enough, it does seem nicely decoupled at least this way.

@@ -291,3 +375,27 @@ def save(self, path, overwrite=False):
if concat_of_raws:
json.dump({'target_name': target_name}, open(target_file_name, 'w'))
self.description.to_json(description_file_name)


def get_sequence_dataset(concat_ds, seq_len, step_len=1, label_transform=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe to be more consistent on naming it would be better to use create_sequence_dataset?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it looks like this function is quite similar to create_windows functions? It takes concat_ds and returns BaseConcatDataset. Maybe we should move it to preprocessing/windowers.py? I feel like datasets/base.py is not the best place for this function but as well preprocessing/windowers.py may not be perfect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely also feel create_sequence_dataset would be nicer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

n_channels, n_times).
"""
feats = [self.feat_extractor.embed(x[:, i]) for i in range(x.shape[1])]
feats = torch.stack(feats, dim=0).transpose(0, 1).flatten(start_dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the batch dimension is not returned as the first dimension? Why do we need this transpose before flatten?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just needed to change to dim=1 and I could remove the transpose.

@@ -142,6 +144,88 @@ def transform(self, value):
self._transform = value


class SequenceDataset(WindowsDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hubertjb @robintibor I still wonder whether functionalities of this class couldn't be added to the WindowsDataset. It would need parameters in WindowsDataset.__init__ that specify the seq_len=1 and step_len=1 with default values and the rest of the behavior would stay similar (we have to keep self.start_inds for WindowsDataset but with default parameters it should behave the same as before)? I would say that the behavior of this class is not so different from the base class. As well label_transform may be used in the standard case without multiple windows.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case maybe a better name would be SequenceWindowsDataset?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree that SequenceWindowsDataset is maybe more explicit

if 'ignore' it will proceed silently.
"""
def __init__(self, windows, description=None, transform=None, seq_len=21,
step_len=1, label_transform=None, on_missing='raise'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about target_transform instead of label_transform?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we consistent about this naming in braindecode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we mostly use "target", so target_transform makes more sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also how they name it in torchvision ;)

@robintibor
Copy link
Contributor

What was the problem when trying to do this via pytorch loader/sampler instead of a new dataset?

import pandas as pd
from mne.utils.check import _on_missing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be careful using private mne functions in downstream packages

@@ -142,6 +144,88 @@ def transform(self, value):
self._transform = value


class SequenceDataset(WindowsDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree that SequenceWindowsDataset is maybe more explicit


Parameters
----------
windows : mne.Epochs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to rather construct it from a WindowsDataset ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, yes I think this might be better. If we decide to stick with this approach (see #263 (comment)) then I can do the change.

Holds additional info about the windows.
transform : callable | None
On-the-fly transform applied to a window before it is returned.
seq_len : int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_windows?

On-the-fly transform applied to a window before it is returned.
seq_len : int
Number of consecutive windows in a sequence.
step_len : int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote for n_windows and n_windows_stride

if seq_len > len(windows):
msg = ('Sequence length is larger than number of windows '
f'({seq_len} > {len(windows)}).')
_on_missing(on_missing, msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you foresee a case where user with not use on_missing = 'raise' here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if you're working with a large uncurated dataset you might want to just ignore files that are not long enough for the sequence length you're interested in. But I guess right now this would break anyway when calling getitem. Maybe I should just remove the on_missing logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed in the __getitem__ way it would crash anyways, hence on_missing could not be used...

Returns
-------
np.ndarray :
Sequence of windows, of shape (seq_len, n_chs, *).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_chs -> n_channels

@@ -291,3 +375,27 @@ def save(self, path, overwrite=False):
if concat_of_raws:
json.dump({'target_name': target_name}, open(target_file_name, 'w'))
self.description.to_json(description_file_name)


def get_sequence_dataset(concat_ds, seq_len, step_len=1, label_transform=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

######################################################################
# We extract 30-s windows to be used in the classification task.

from braindecode.preprocessing.windowers import create_windows_from_events
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from braindecode.preprocessing.windowers import create_windows_from_events
from braindecode.preprocessing import create_windows_from_events

input_size_s=input_size_samples / sfreq
)

model = TimeDistributedNet(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to move the TimeDistributedNet in the main library?

@robintibor @sliwy @gemeinl thoughts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. I would make sense to include it in the same script as SleepStagerChambon2018 for instance.

Copy link
Contributor

@robintibor robintibor Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree it would definitely make sense, would need more documentation like seems atm len_last_layerneeds to exist in the supplied feature extractor. I think to keep overhead low, it is fine to add models even not fully polished (like here, one could also try to determine given an input size how long is the output etc.) and polish them if more people use them/report problems. But of course these intricacies like len_last_layer should then be in the docstring at least.

@hubertjb
Copy link
Collaborator Author

Thanks @sliwy @robintibor @agramfort for the comments!

Before going into the details, here's a summary of our options for feeding sequences of windows to a model:

  1. Use a dedicated sequence dataset (current implementation in this PR)
  2. Adapt the WindowsDataset so that it can return sequences (suggested by @sliwy)
  3. Use a Sampler
  4. Use a dedicated collate_fn in the DataLoader

I think 1-3 can all work, but 4 feels like a reach (the relevant information for creating sequences is lost by the time we get to collate_fn, and it might interfere with batch-level data augmentation in #254).

What I like about 1 and 2 is that they are conceptually simple. I think 2 is nicer because the reading of epochs might be more efficient as @sliwy suggested. Also this would avoid creating a new Dataset class.

What I like about 3 is that more complicated sampling logic can be more easily implemented, e.g. first sampling a recording, then a specific label, and then the sequence around it (like in the USleep paper). Also this removes any need to introduce a new Dataset class or to add logic to WindowsDataset. The more I think about it, the more I like this approach.

So, after some more thought, my ranking is now:
(1) Use a Sampler
(2) Adapt the WindowsDataset
(3) Dedicated SequenceWindowsDataset

I'm curious to hear what you think!

@robintibor
Copy link
Contributor

Yes I agree that Sampler makes most sense from a PyTorch-logic kind of way. Since I have not implemented custom samplers myself, I don't know if there is something that would not be possible in this way... but if you also like this approach then definitely have a try and see if it works for your cases.

@codecov
Copy link

codecov bot commented Jun 23, 2021

Codecov Report

Merging #263 (7b0991b) into master (8f6f140) will not change coverage.
The diff coverage is n/a.

❗ Current head 7b0991b differs from pull request most recent head 1e4d46f. Consider uploading reports for the commit 1e4d46f to get more accurate results

@@      Coverage Diff      @@
##   master   #263   +/-   ##
=============================
=============================

@hubertjb
Copy link
Collaborator Author

The new version uses a Sampler instead of a dedicated SequenceDataset. I think this might be more flexible in the long run. Let me know what you guys think.

I've left the SequenceWindowsDataset stuff in for now, I'll remove it if you agree with this Sampler-based approach.

Left to do:

  • Add ability to sample multiple windows and concatenate them inside WindowsDataset - this way we wouldn't need to create a dedicated BaseConcatDataset as in the updated example.
  • Also, add target_transform to WindowsDataset?
  • Make sure the model learns and pick hyperparameters
  • Resolve conflicts
  • Documentation

@gemeinl
Copy link
Collaborator

gemeinl commented Jun 23, 2021

* Add ability to sample multiple windows and concatenate them inside WindowsDataset - this way we wouldn't need to create a dedicated BaseConcatDataset as in the updated example.

What kind of concatenation is needed here? I guess, it is different from batch concatenation. Would it be possible to implement this through collate_fn with the new SequenceSampler?

@gemeinl
Copy link
Collaborator

gemeinl commented Jun 23, 2021

Also, you can create subsets of windows with pytorch.utils.data.Subset by giving it the windows ids that you will receive from the SequenceSampler. Is this what you are looking for?

@robintibor
Copy link
Contributor

I am a bit confused now, so with the sampler-based approach you could remove SequenceWindowsDataset? Like currently in the example you are using both? Would you also have to write a loader in order to be able to transform targets? In that case, would you still feel better with Loader/Sampler or do you feel SequenceWindowsDataset would then be simpler?

@hubertjb
Copy link
Collaborator Author

What kind of concatenation is needed here? I guess, it is different from batch concatenation. Would it be possible to implement this through collate_fn with the new SequenceSampler?

We need to stack multiple windows along a new axis. Since we still want to use minibatches (of sequences), this means the models should expect tensors of shape (batch_size, n_windows, n_channels, n_times) or something like that.

I don't think collate_fn would be a good solution here. From what I understand collate_fn would receive a list of windows, and then would have to know which ones to aggregate together to make sequences. However I think the relevant information (which windows are consecutive) would be lost at that point. Also, I'm not sure this would be compatible with the batch-level data augmentation that is being worked on.

@hubertjb
Copy link
Collaborator Author

Also, you can create subsets of windows with pytorch.utils.data.Subset by giving it the windows ids that you will receive from the SequenceSampler. Is this what you are looking for?

I think this would be equivalent to creating a new WindowsDataset that can only return a subset of the original windows. What we need is to return sequences of windows, so I don't think this would work.

@hubertjb
Copy link
Collaborator Author

I am a bit confused now, so with the sampler-based approach you could remove SequenceWindowsDataset? Like currently in the example you are using both? Would you also have to write a loader in order to be able to transform targets? In that case, would you still feel better with Loader/Sampler or do you feel SequenceWindowsDataset would then be simpler?

Sorry for the confusion @robintibor, I should have called the dataset in the example something else! The reason I had to create a SequenceWindowsDataset in the example (which is not the same as the previous SequenceWindowsDataset!) is that the SequenceSampler returns a list of indices instead of a single index - since WindowsDataset expects a single index, I had to write a new __getitem__ that can handle a list of indices. To show the concept quickly I did that by creating a SequenceWindowsDataset, however down the line I think it would make sense to move this logic to the __getitem__ of WindowsDataset.

I think the target transform should also be implemented in WindowsDataset, as this might be a more generally applicable feature anyway.

So in the end, there would be no need for a SequenceWindowsDataset. Sampling many consecutive windows and target transformation would both be handled by WindowsDataset.

@robintibor
Copy link
Contributor

robintibor commented Jun 23, 2021

The reason I had to create a SequenceWindowsDataset in the example (which is not the same as the previous SequenceWindowsDataset!) is that the SequenceSampler returns a list of indices instead of a single index - since WindowsDataset expects a single index, I had to write a new getitem that can handle a list of indices. To show the concept quickly I did that by creating a SequenceWindowsDataset, however down the line I think it would make sense to move this logic to the getitem of WindowsDataset.

I think the target transform should also be implemented in WindowsDataset, as this might be a more generally applicable feature anyway.

So in the end, there would be no need for a SequenceWindowsDataset. Sampling many consecutive windows and target transformation would both be handled by WindowsDataset.

As discussed this sounds very reasonable to me

@hubertjb
Copy link
Collaborator Author

This latest version gets rid of the first SequenceWindowsDataset implementation in favour of the SequenceSampler approach and refines the sleep staging on sequences example.

In the end I decided against implementing the logic for grabbing multiple windows in WindowsDataset, and instead implemented it in BaseConcatDataset. The reason is that the pytorch ConcatDataset only accepts integers as input to__getitem__- therefore passing a list of indices doesn't work. One option would have been to overload __getitem__ in BaseConcatDataset to accept sequences. The two issues I had with that solution are: (1) __getitem__ seems like a pretty critical method of ConcatDataset, and I wasn't sure overloading it was a good idea from a maintainability perspective and (2) if the grabbing of multiple windows is implemented in WindowsDataset, it would not be possible to sample sequences of windows across recordings (which is useful in self-supervised learning tasks, e.g. CPC).

Second conceptual issue I encountered was about target_transform. There are actually two levels of transforms when working with sequences: transform for individual windows/targets, and transform for sequences of windows/targets. Here we need the latter. I added a property seq_target_transform to BaseConcatDataset to take care of that. If we agree on this pattern, we might want to also add a seq_transform for transforming sequences of windows at some point.

Comment on lines 124 to 163
"""Get a window and its target, or a list of windows and targets.

Parameters
----------
index : int
Index to the window (and target) to return.

Returns
-------
np.ndarray
Window of shape (n_channels, n_times).
int | np.ndarray
Target(s) for the windows.
np.ndarray
Crop indices.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm but now this only works for BaseConcatDataset not for BaseDataset right? like there is no code change in function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, this is an artifact of a previous version. I'll clean up the docstring.

Comment on lines +187 to +260
def _get_sequence(self, indices):
X, y = list(), list()
for ind in indices:
out_i = super().__getitem__(ind)
X.append(out_i[0])
y.append(out_i[1])

X = np.stack(X, axis=0)
y = np.array(y)
if self.seq_target_transform is not None:
y = self.seq_target_transform(y)

return X, y

def __getitem__(self, idx):
"""
Parameters
----------
idx : int | list
Index of window and target to return. If provided as a list of
ints, multiple windows and targets will be extracted and
concatenated.
"""
if isinstance(idx, Iterable): # Sample multiple windows
return self._get_sequence(idx)
else:
return super().__getitem__(idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fairly clean?

@robintibor
Copy link
Contributor

Second conceptual issue I encountered was about target_transform. There are actually two levels of transforms when working with sequences: transform for individual windows/targets, and transform for sequences of windows/targets. Here we need the latter. I added a property seq_target_transform to BaseConcatDataset to take care of that. If we agree on this pattern, we might want to also add a seq_transform for transforming sequences of windows at some point

This I don't understand. Can't we just have a single target_transform? I mean the user should know whether they are getting sequences or elements out of their pipeline? hmhm or did you feel it should work in any case? At least then I think we should implement both already and not only the seq_ one

…taset` to return sequences of windows

- adding modified sleep staging examples that work on sequences of EEG windows
- adding a method `embed` to `SleepStagetChambon2018` class to reuse it as a feature extractor in the sleep staging on sequences examples
- adding tests
… test

- updating sleep_staging_sequences example to use sampler instead of SequenceDataset
- adding seq_target_transform property to BaseConcatDataset to transform the targets when extracting sequences of SequenceWindowsDataset
- adding functionality in BaseConcatDataset to index with a list of indices, such that a sequence a windows and targets can be returned
- fixing SleepPhysionet bug where cropping would cause an error with some recordings
- changing hyperparameters of sleep staging example to improve performance in a few epochs
- adding more representative splitting functionality in sleep staging example
- using class weights for imbalanced classes in sleep staging example
- fixing formatting in sleep staging example
- adding SequenceSampler to docs
@hubertjb hubertjb changed the title [WIP] Adding SequenceDataset class to train on sequences of windows [MRG] Adding SequenceDataset class to train on sequences of windows Jun 24, 2021
indices = range(100)
y = concat_windows_dataset[indices][1]

transform = lambda x: sum(x) # noqa: E731
Copy link
Contributor

@robintibor robintibor Jun 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small thing, but why not just:
transform=sum ?

Suggested change
transform = lambda x: sum(x) # noqa: E731
transform = sum

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I simplified a more complicated expression and didn't realize what it had become :P

- adding whatsnew
@robintibor robintibor merged commit bbf8924 into braindecode:master Jun 25, 2021
@robintibor
Copy link
Contributor

Amazing work! Example looks great to me as well! Merged!

1 similar comment
@robintibor
Copy link
Contributor

Amazing work! Example looks great to me as well! Merged!

@hubertjb hubertjb deleted the sequence-dataset branch June 28, 2021 17:43
@gemeinl gemeinl linked an issue Jul 14, 2021 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Preprocessing / scaling of targets
6 participants