Skip to content

Commit

Permalink
Merge pull request #8 from faroit/devset
Browse files Browse the repository at this point in the history
Support Validation dataset
  • Loading branch information
faroit committed Jul 8, 2016
2 parents 9154044 + a0861d2 commit 18c7c93
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 6 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ dsd.run(my_training_function, subsets="Dev")
dsd.run(my_test_function, subsets="Test")
```

If you want to exclude tracks from the training you can specify track ids as the `dsdtools.DB(..., valid_ids=[1, 2]`) object. Those tracks are then not included in `Dev` but are returned for `subsets="Valid"`.


#### Processing single or multiple DSD100 tracks

```python
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def __getattr__(cls, name):
copyright = u'2016, Fabian-Robert Stöter'
author = u'Fabian-Robert Stöter'

version = u'0.1.1'
release = u'0.1.1'
version = u'0.1.2'
release = u'0.1.2'

language = None

Expand Down
5 changes: 5 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ training subset and then apply the algorithm on the test data:
dsd.run(my_training_function, subsets="Dev")
dsd.run(my_test_function, subsets="Test")
If you want to exclude tracks from the training you can specify track ids as
``dsdtools.DB(..., valid_ids=[1, 2]`` object. Those tracks are then not
included in ``Dev`` but are returned for ``subsets="Valid"``.


Processing single or multiple DSD100 tracks
'''''''''''''''''''''''''''''''''''''''''''

Expand Down
33 changes: 31 additions & 2 deletions dsdtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class DB(object):
evaluation : str, {None, 'bss_eval', 'mir_eval'}
Setup evaluation module and starts matlab if bsseval is enabled
valid_ids : list[int] or int, optional
select single or multiple _dsdtools_ items by ID that will be used
for validation data (ie not included in the `Dev` set)
Attributes
----------
setup_file : str
Expand Down Expand Up @@ -74,7 +78,8 @@ def __init__(
self,
root_dir=None,
setup_file=None,
evaluation=None
evaluation=None,
valid_ids=None,
):
if root_dir is None:
if "DSD_PATH" in os.environ:
Expand All @@ -101,6 +106,12 @@ def __init__(
self.root_dir, "Sources"
)

if valid_ids is not None:
if not isinstance(valid_ids, collections.Sequence):
valid_ids = [valid_ids]

self.valid_ids = valid_ids

self.sources_names = list(self.setup['sources'].keys())
self.targets_names = list(self.setup['targets'].keys())

Expand Down Expand Up @@ -132,13 +143,23 @@ def load_dsd_tracks(self, subsets=None, ids=None):
subsets = [subsets]
else:
subsets = subsets
if all(x in ['Valid', 'Dev'] for x in subsets):
raise ValueError(
"Cannot load Valid and Dev at the same time"
)
else:
subsets = ['Dev', 'Test']

tracks = []
if op.isdir(self.mixtures_dir):
for subset in subsets:
subset_folder = op.join(self.mixtures_dir, subset)

# For validation use Dev set and filter by ids later
if subset == 'Valid':
subset_folder = op.join(self.mixtures_dir, 'Dev')
else:
subset_folder = op.join(self.mixtures_dir, subset)

for _, track_folders, _ in os.walk(subset_folder):
for track_filename in track_folders:

Expand Down Expand Up @@ -194,6 +215,14 @@ def load_dsd_tracks(self, subsets=None, ids=None):
# add track to list of tracks
tracks.append(track)

# Filter tracks by valid_ids
if self.valid_ids is not None:
if subset == 'Dev':
tracks = [t for t in tracks
if t.id not in self.valid_ids]
if subset == 'Valid':
tracks = [t for t in tracks if t.id in self.valid_ids]

if ids is not None:
return [t for t in tracks if t.id in ids]
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
universal = 1

[pytest]
norecursedirs = .env\* .cache .git examples
norecursedirs = .env\* .cache .git examples docs
addopts = --doctest-modules --cov-report term-missing --cov dsdtools --pep8

[metadata]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
name='dsdtools',

# Version
version="0.1.1",
version="0.1.2",

# Description
description='Python tools for the Demixing Secrets Dataset (DSD)',
Expand Down
17 changes: 17 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ def test_file_loading():
assert len(tracks) == 1


def test_file_loading_valid():
# initiate dsdtools

dsd = dsdtools.DB(root_dir="data/DSD100subset", valid_ids=55)
tracks = dsd.load_dsd_tracks(subsets='Dev')
# from two tracks there is only one track left (id=81)
assert len(tracks) == 1
assert tracks[0].id == 81

tracks = dsd.load_dsd_tracks(subsets='Valid')
assert len(tracks) == 1
assert tracks[0].id == 55

with pytest.raises(ValueError):
tracks = dsd.load_dsd_tracks(subsets=['Dev', 'Valid'])


@pytest.fixture(params=['data/DSD100subset'])
def dsd(request):
return dsdtools.DB(root_dir=request.param)
Expand Down

0 comments on commit 18c7c93

Please sign in to comment.