Skip to content

Commit

Permalink
reaching 99% coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
faroit committed Jun 3, 2016
1 parent 35e5445 commit 72e7aac
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 48 deletions.
35 changes: 8 additions & 27 deletions dsdtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def load_dsd_tracks(self, subsets=None, ids=None):
# add tracks to components
target_sources.append(sources[source])
# add sources to target
if not target_sources:
if target_sources:
targets[name] = Target(sources=target_sources)
# add targets to track
track.targets = targets
Expand Down Expand Up @@ -284,21 +284,23 @@ def test(self, user_function):

return True

def evaluate(self, user_function=None, estimates_dir=None):
def evaluate(
self, user_function=None, estimates_dir=None, *args, **kwargs
):
"""Run the dsdtools evaluation
shortcut to
``run(
user_function=None,
estimates_dir=estimates_dir,
save=False,
evaluate=True
)``
"""
return self.run(
user_function=user_function,
estimates_dir=estimates_dir,
evaluate=True
evaluate=True,
*args, **kwargs
)

def _process_function(self, track, user_function, estimates_dir, evaluate):
Expand Down Expand Up @@ -384,6 +386,7 @@ def run(
# list of tracks to be processed
tracks = self.load_dsd_tracks(subsets=subsets, ids=ids)

success = False
if parallel:
pool = multiprocessing.Pool(cpus, initializer=init_worker)
success = list(
Expand Down Expand Up @@ -421,7 +424,7 @@ def run(
total=len(tracks)
)
)
return success
return success


def process_function_alias(obj, *args, **kwargs):
Expand All @@ -430,25 +433,3 @@ def process_function_alias(obj, *args, **kwargs):

def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)


if __name__ == '__main__':
def my_function(dsd_track):
print(dsd_track.name)
for i in range(1000000):
i * i + i

estimates = {
'vocals': dsd_track.audio,
'accompaniment': dsd_track.audio
}
return estimates

dsd = DB()

# Test my_function
if dsd.test(my_function):
print("success")

# Run my_function and save the results to disk
dsd.run(my_function)
24 changes: 19 additions & 5 deletions dsdtools/audio_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def audio(self):
"""array_like: [shape=(num_samples, num_channels)]
"""

# return cached audio it explicitly set bet setter
# return cached audio if explicitly set by setter
if self._audio is not None:
return self._audio
# read from disk to save RAM otherwise
Expand All @@ -39,9 +39,9 @@ def audio(self):
self._rate = rate
return audio
else:
print("Oops! %s cannot be loaded" % self.path)
self._rate = None
self._audio = None
raise ValueError("Oops! %s cannot be loaded" % self.path)

@property
def rate(self):
Expand All @@ -50,7 +50,14 @@ def rate(self):

# load audio to set rate
if self._rate is None:
self.audio()
if os.path.exists(self.path):
audio, rate = sf.read(self.path, always_2d=True)
self._rate = rate
return rate
else:
self._rate = None
self._audio = None
raise ValueError("Oops! %s cannot be loaded" % self.path)
return self._rate

@audio.setter
Expand Down Expand Up @@ -147,9 +154,9 @@ def audio(self):
self._rate = rate
return audio
else:
print("Oops! %s cannot be loaded" % self.path)
self._rate = None
self._audio = None
raise ValueError("Oops! %s cannot be loaded" % self.path)

@property
def rate(self):
Expand All @@ -158,7 +165,14 @@ def rate(self):

# load audio to set rate
if self._rate is None:
self.audio()
if os.path.exists(self.path):
audio, rate = sf.read(self.path, always_2d=True)
self._rate = rate
return rate
else:
self._rate = None
self._audio = None
raise ValueError("Oops! %s cannot be loaded" % self.path)
return self._rate

@audio.setter
Expand Down
4 changes: 2 additions & 2 deletions examples/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def my_function(track):

# initiate dsdtools
dsd = dsdtools.DB(
root_dir="../data/dsdtoolssubset",
root_dir="../data/DSD100subset",
)

# verify if my_function works correctly
Expand All @@ -32,5 +32,5 @@ def my_function(track):

dsd.run(
my_function,
estimates_dir='./Estimates'
estimates_dir='./Estimates',
)
69 changes: 69 additions & 0 deletions tests/test_audio_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import dsdtools.audio_classes as ac
import dsdtools
import numpy as np


@pytest.fixture(params=['data/DSD100subset'])
def dsd(request):
return dsdtools.DB(root_dir=request.param)


def test_targets(dsd):

tracks = dsd.load_dsd_tracks(ids=1)

for track in tracks:
for key, target in track.targets.items():
print target
assert target.audio.shape > 0


def test_rates(dsd):

tracks = dsd.load_dsd_tracks(ids=1)

for track in tracks:
assert track.rate == 44100
assert track.audio.shape > 0
for key, source in track.sources.items():
assert source.rate == 44100
assert source.audio.shape > 0


def test_source(dsd):

with pytest.raises(ValueError):
source = ac.Source(name="test", path="None")
source.audio

with pytest.raises(ValueError):
source = ac.Source(name="test", path="None")
source.rate

source.audio = np.zeros((2, 44100))
assert source.audio.shape == (2, 44100)

source.rate = 44100
assert source.rate == 44100

print source


def test_track(dsd):

with pytest.raises(ValueError):
track = ac.Track(name="test", path="None")
track.audio

with pytest.raises(ValueError):
track = ac.Track(name="test", path="None")
track.rate

track.audio = np.zeros((2, 44100))
assert track.audio.shape == (2, 44100)

track.rate = 44100
assert track.rate == 44100

print track
31 changes: 31 additions & 0 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import dsdtools


def user_function1(track):
'''Pass'''

# return any number of targets
estimates = {
'vocals': track.audio,
'accompaniment': track.audio,
}
return estimates


@pytest.mark.parametrize(
"method",
[
'mir_eval',
pytest.mark.xfail('not_a_function', raises=ValueError)
]
)
def test_evaluate(method):

dsd = dsdtools.DB(root_dir="data/DSD100subset", evaluation=method)

# process dsd but do not save the results
assert dsd.evaluate(
user_function=user_function1,
ids=1
)
64 changes: 50 additions & 14 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest
import dsdtools
import numpy as np
Expand Down Expand Up @@ -46,27 +47,70 @@ def user_function4(track):
return estimates


def test_fileloading():
def user_function5(track):
'''fails because output is not a dict'''

# return any number of targets
return track.audio


def test_file_loading():
# initiate dsdtools

dsd = dsdtools.DB(root_dir="data/DSD100subset")
tracks = dsd.load_dsd_tracks()

assert len(tracks) == 4

for track in tracks:
assert track.audio.shape[1] > 0

# load only the dev set
tracks = dsd.load_dsd_tracks(subsets='dev')

assert len(tracks) == 2

# load only the dev set
tracks = dsd.load_dsd_tracks(subsets=['dev', 'test'])

assert len(tracks) == 4

# load only a single id
tracks = dsd.load_dsd_tracks(ids=1)

assert len(tracks) == 1


@pytest.fixture(params=['data/DSD100subset'])
def dsd(request):
return dsdtools.DB(root_dir=request.param)


@pytest.mark.parametrize(
"path",
[
pytest.mark.xfail(None, raises=RuntimeError),
pytest.mark.xfail("wrong/path", raises=IOError),
"data/DSD100subset",
]
)
def test_env(path):

if path is not None:
os.environ["DSD_PATH"] = path

assert dsdtools.DB()


@pytest.mark.parametrize(
"func",
[
user_function1,
pytest.mark.xfail(user_function2, raises=ValueError),
pytest.mark.xfail(user_function3, raises=ValueError),
pytest.mark.xfail(user_function4, raises=ValueError),
pytest.mark.xfail(user_function5, raises=ValueError),
pytest.mark.xfail("not_a_function", raises=TypeError),
]
)
def test_user_functions_test(func, dsd):
Expand Down Expand Up @@ -97,18 +141,10 @@ def test_run(func, dsd):
dsd.run(estimates_dir='./Estimates')


@pytest.mark.parametrize(
"method",
[
'mir_eval',
pytest.mark.xfail('not_a_function', raises=ValueError)
]
)
def test_evaluate(method):
def test_parallel(dsd):

dsd = dsdtools.DB(root_dir='data/DSD100subset', evaluation=method)

# process dsd but do not save the results
assert dsd.evaluate(
user_function=user_function1
assert dsd.run(
user_function=user_function1,
parallel=True,
cpus=1
)

0 comments on commit 72e7aac

Please sign in to comment.