Skip to content

Commit

Permalink
Merge pull request #142 from gemeinl/concat_concat_ds
Browse files Browse the repository at this point in the history
Concat Concat of Datasets
  • Loading branch information
robintibor committed Aug 4, 2020
2 parents 82e19d6 + aff7e46 commit da78213
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
21 changes: 11 additions & 10 deletions braindecode/datasets/base.py
Expand Up @@ -99,40 +99,41 @@ class BaseConcatDataset(ConcatDataset):
"""A base class for concatenated datasets. Holds either mne.Raw or
mne.Epoch in self.datasets and has a pandas DataFrame with additional
description.
Parameters
----------
list_of_ds: list
list of BaseDataset of WindowsDataset to be concatenated.
list of BaseDataset, BaseConcatDataset or WindowsDataset
"""
def __init__(self, list_of_ds):
# if we get a list of BaseConcatDataset, get all the individual datasets
if isinstance(list_of_ds[0], BaseConcatDataset):
list_of_ds = [d for ds in list_of_ds for d in ds.datasets]
super().__init__(list_of_ds)
self.description = pd.DataFrame([ds.description for ds in list_of_ds])
self.description.reset_index(inplace=True, drop=True)

def split(self, some_property=None, split_ids=None):
def split(self, property=None, split_ids=None):
"""Split the dataset based on some property listed in its description
DataFrame or based on indices.
Parameters
----------
some_property: str
property: str
some property which is listed in info DataFrame
split_ids: list(int)
list of indices to be combined in a subset
Returns
-------
splits: dict{split_name: BaseConcatDataset}
mapping of split name based on property or index based on split_ids
to subset of the data
"""
if split_ids is None and some_property is None:
if split_ids is None and property is None:
raise ValueError('Splitting requires defining ids or a property.')
if split_ids is None:
if some_property not in self.description:
raise ValueError(f'{some_property} not found in self.description')
if property not in self.description:
raise ValueError(f'{property} not found in self.description')
split_ids = {k: list(v) for k, v in self.description.groupby(
some_property).groups.items()}
property).groups.items()}
else:
split_ids = {split_i: split
for split_i, split in enumerate(split_ids)}
Expand Down
21 changes: 21 additions & 0 deletions test/unit_tests/datasets/test_dataset.py
Expand Up @@ -105,3 +105,24 @@ def test_split_concat_dataset(concat_ds_targets):
assert isinstance(v, BaseConcatDataset)

assert len(concat_ds) == sum([len(v) for v in splits.values()])


def test_concat_concat_dataset(concat_ds_targets):
concat_ds, targets = concat_ds_targets
concat_ds1 = BaseConcatDataset(concat_ds.datasets[:2])
concat_ds2 = BaseConcatDataset(concat_ds.datasets[2:])
list_of_concat_ds = [concat_ds1, concat_ds2]
descriptions = pd.concat([ds.description for ds in list_of_concat_ds])
descriptions.reset_index(inplace=True, drop=True)
lens = [0] + [len(ds) for ds in list_of_concat_ds]
cumsums = [ds.cumulative_sizes for ds in list_of_concat_ds]
cumsums = [l
for i, cumsum in enumerate(cumsums)
for l in np.array(cumsum) + lens[i]]
concat_concat_ds = BaseConcatDataset(list_of_concat_ds)
assert len(concat_concat_ds) == sum(lens)
assert len(concat_concat_ds) == concat_concat_ds.cumulative_sizes[-1]
assert len(concat_concat_ds.datasets) == len(descriptions)
assert len(concat_concat_ds.description) == len(descriptions)
np.testing.assert_array_equal(cumsums, concat_concat_ds.cumulative_sizes)
pd.testing.assert_frame_equal(descriptions, concat_concat_ds.description)

0 comments on commit da78213

Please sign in to comment.