Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concat Concat of Datasets #142

Merged
merged 1 commit into from Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 11 additions & 10 deletions braindecode/datasets/base.py
Expand Up @@ -99,40 +99,41 @@ class BaseConcatDataset(ConcatDataset):
"""A base class for concatenated datasets. Holds either mne.Raw or
mne.Epoch in self.datasets and has a pandas DataFrame with additional
description.

Parameters
----------
list_of_ds: list
list of BaseDataset of WindowsDataset to be concatenated.
list of BaseDataset, BaseConcatDataset or WindowsDataset
"""
def __init__(self, list_of_ds):
# if we get a list of BaseConcatDataset, get all the individual datasets
if isinstance(list_of_ds[0], BaseConcatDataset):
list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
super().__init__(list_of_ds)
self.description = pd.DataFrame([ds.description for ds in list_of_ds])
self.description.reset_index(inplace=True, drop=True)
Copy link
Collaborator Author

@gemeinl gemeinl Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not reset the index of the new description DataFrame and keep the original index?


def split(self, some_property=None, split_ids=None):
def split(self, property=None, split_ids=None):
"""Split the dataset based on some property listed in its description
DataFrame or based on indices.

Parameters
----------
some_property: str
property: str
some property which is listed in info DataFrame
split_ids: list(int)
list of indices to be combined in a subset

Returns
-------
splits: dict{split_name: BaseConcatDataset}
mapping of split name based on property or index based on split_ids
to subset of the data
"""
if split_ids is None and some_property is None:
if split_ids is None and property is None:
raise ValueError('Splitting requires defining ids or a property.')
if split_ids is None:
if some_property not in self.description:
raise ValueError(f'{some_property} not found in self.description')
if property not in self.description:
raise ValueError(f'{property} not found in self.description')
split_ids = {k: list(v) for k, v in self.description.groupby(
some_property).groups.items()}
property).groups.items()}
else:
split_ids = {split_i: split
for split_i, split in enumerate(split_ids)}
Expand Down
21 changes: 21 additions & 0 deletions test/unit_tests/datasets/test_dataset.py
Expand Up @@ -105,3 +105,24 @@ def test_split_concat_dataset(concat_ds_targets):
assert isinstance(v, BaseConcatDataset)

assert len(concat_ds) == sum([len(v) for v in splits.values()])


def test_concat_concat_dataset(concat_ds_targets):
concat_ds, targets = concat_ds_targets
concat_ds1 = BaseConcatDataset(concat_ds.datasets[:2])
concat_ds2 = BaseConcatDataset(concat_ds.datasets[2:])
list_of_concat_ds = [concat_ds1, concat_ds2]
descriptions = pd.concat([ds.description for ds in list_of_concat_ds])
descriptions.reset_index(inplace=True, drop=True)
lens = [0] + [len(ds) for ds in list_of_concat_ds]
cumsums = [ds.cumulative_sizes for ds in list_of_concat_ds]
cumsums = [l
for i, cumsum in enumerate(cumsums)
for l in np.array(cumsum) + lens[i]]
concat_concat_ds = BaseConcatDataset(list_of_concat_ds)
assert len(concat_concat_ds) == sum(lens)
assert len(concat_concat_ds) == concat_concat_ds.cumulative_sizes[-1]
assert len(concat_concat_ds.datasets) == len(descriptions)
assert len(concat_concat_ds.description) == len(descriptions)
np.testing.assert_array_equal(cumsums, concat_concat_ds.cumulative_sizes)
pd.testing.assert_frame_equal(descriptions, concat_concat_ds.description)