Skip to content

Commit

Permalink
Sort labels in _populate_label_list, remove list sorting from tests
Browse files Browse the repository at this point in the history
Ensures label lists are sorted upon creation, allowing stricter and simpler regression tests
Also make atol and rtol arugments in all functions where np.allclose() is used
  • Loading branch information
justinsalamon committed Jul 21, 2018
1 parent c5dbe31 commit 1076169
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
2 changes: 2 additions & 0 deletions scaper/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def _populate_label_list(folder_path, label_list):
if (os.path.isdir(os.path.join(folder_path, fname)) and
fname[0] != '.'):
label_list.append(fname)
# ensure consistent ordering of labels
label_list.sort()


def _trunc_norm(mu, sigma, trunc_min, trunc_max):
Expand Down
39 changes: 17 additions & 22 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
BG_LABELS = ['park', 'restaurant', 'street']


def test_generate_from_jams():
def test_generate_from_jams(atol=1e-8, rtol=1e-8):

# Test for invalid jams: no annotations
tmpfiles = []
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_generate_from_jams():
# validate audio
orig_wav, sr = soundfile.read(orig_wav_file.name)
gen_wav, sr = soundfile.read(gen_wav_file.name)
assert np.allclose(gen_wav, orig_wav, atol=1e-8, rtol=1e-8)
assert np.allclose(gen_wav, orig_wav, atol=atol, rtol=rtol)

# Now add in trimming!
for _ in range(5):
Expand All @@ -114,7 +114,7 @@ def test_generate_from_jams():
# validate audio
orig_wav, sr = soundfile.read(orig_wav_file.name)
gen_wav, sr = soundfile.read(gen_wav_file.name)
assert np.allclose(gen_wav, orig_wav, atol=1e-8, rtol=1e-8)
assert np.allclose(gen_wav, orig_wav, atol=atol, rtol=rtol)

# Double trimming
for _ in range(2):
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_generate_from_jams():
# validate audio
orig_wav, sr = soundfile.read(orig_wav_file.name)
gen_wav, sr = soundfile.read(gen_wav_file.name)
assert np.allclose(gen_wav, orig_wav, atol=1e-8, rtol=1e-8)
assert np.allclose(gen_wav, orig_wav, atol=atol, rtol=rtol)

# Test with new FG and BG paths
for _ in range(5):
Expand All @@ -157,7 +157,7 @@ def test_generate_from_jams():
# validate audio
orig_wav, sr = soundfile.read(orig_wav_file.name)
gen_wav, sr = soundfile.read(gen_wav_file.name)
assert np.allclose(gen_wav, orig_wav, atol=1e-8, rtol=1e-8)
assert np.allclose(gen_wav, orig_wav, atol=atol, rtol=rtol)

# Ensure jam file saved correctly
scaper.generate_from_jams(orig_jam_file.name, gen_wav_file.name,
Expand All @@ -167,7 +167,7 @@ def test_generate_from_jams():
assert orig_jam == gen_jam


def test_trim():
def test_trim(atol=1e-8, rtol=1e-8):

# Things we want to test:
# 1. Jam trimmed correctly (mainly handled by jams.slice)
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_trim():
# validate audio
orig_wav, sr = soundfile.read(orig_wav_file.name)
trim_wav, sr = soundfile.read(trim_wav_file.name)
assert np.allclose(trim_wav, orig_wav[3*sr:7*sr], atol=1e-8, rtol=1e-8)
assert np.allclose(trim_wav, orig_wav[3*sr:7*sr], atol=atol, rtol=rtol)


def test_get_value_from_dist():
Expand Down Expand Up @@ -589,8 +589,8 @@ def test_scaper_init():

# ensure fg_labels and bg_labels populated properly
sc = scaper.Scaper(10.0, FG_PATH, BG_PATH)
assert sorted(sc.fg_labels) == sorted(FB_LABELS)
assert sorted(sc.bg_labels) == sorted(BG_LABELS)
assert sc.fg_labels == FB_LABELS
assert sc.bg_labels == BG_LABELS

# ensure default values have been set
assert sc.sr == 44100
Expand Down Expand Up @@ -879,12 +879,7 @@ def test_scaper_instantiate():
sorted(regann.sandbox.scaper.keys())):
assert k == kreg
if k not in ['bg_spec', 'fg_spec']:
# Lists might not be ordered same way, must account for this
if type(ann.sandbox.scaper[k]) is list:
assert (sorted(ann.sandbox.scaper[k]) ==
sorted(regann.sandbox.scaper[kreg]))
else:
assert ann.sandbox.scaper[k] == regann.sandbox.scaper[kreg]
assert ann.sandbox.scaper[k] == regann.sandbox.scaper[kreg]

# to compare specs need to covert raw specs to list of lists
assert (
Expand All @@ -904,7 +899,7 @@ def test_scaper_instantiate():
(ann.data == regann.data).all().all()


def test_generate_audio():
def test_generate_audio(atol=1e-8, rtol=1e-8):

# Regression test: same spec, same audio (not this will fail if we update
# any of the audio processing techniques used (e.g. change time stretching
Expand Down Expand Up @@ -970,22 +965,22 @@ def test_generate_audio():
# validate audio
wav, sr = soundfile.read(wav_file.name)
regwav, sr = soundfile.read(REG_WAV_PATH)
assert np.allclose(wav, regwav, atol=1e-8, rtol=1e-8)
assert np.allclose(wav, regwav, atol=atol, rtol=rtol)

# with reverb
sc._generate_audio(wav_file.name, jam.annotations[0], reverb=0.2)
# validate audio
wav, sr = soundfile.read(wav_file.name)
regwav, sr = soundfile.read(REG_REVERB_WAV_PATH)
assert np.allclose(wav, regwav, atol=1e-8, rtol=1e-8)
assert np.allclose(wav, regwav, atol=atol, rtol=rtol)

# Don't disable sox warnings (just to cover line)
sc._generate_audio(wav_file.name, jam.annotations[0],
disable_sox_warnings=False)
# validate audio
wav, sr = soundfile.read(wav_file.name)
regwav, sr = soundfile.read(REG_WAV_PATH)
assert np.allclose(wav, regwav, atol=1e-8, rtol=1e-8)
assert np.allclose(wav, regwav, atol=atol, rtol=rtol)

# namespace must be sound_event
jam.annotations[0].namespace = 'tag_open'
Expand Down Expand Up @@ -1020,10 +1015,10 @@ def test_generate_audio():
# validate audio
wav, sr = soundfile.read(wav_file.name)
regwav, sr = soundfile.read(REG_BGONLY_WAV_PATH)
assert np.allclose(wav, regwav, atol=1e-8, rtol=1e-8)
assert np.allclose(wav, regwav, atol=atol, rtol=rtol)


def test_generate():
def test_generate(atol=1e-8, rtol=1e-8):

# Final regression test on all files
sc = scaper.Scaper(10.0, fg_path=FG_PATH, bg_path=BG_PATH)
Expand Down Expand Up @@ -1090,7 +1085,7 @@ def test_generate():
# validate audio
wav, sr = soundfile.read(wav_file.name)
regwav, sr = soundfile.read(REG_WAV_PATH)
assert np.allclose(wav, regwav, atol=1e-8, rtol=1e-8)
assert np.allclose(wav, regwav, atol=atol, rtol=rtol)

# validate jams
jam = jams.load(jam_file.name)
Expand Down

0 comments on commit 1076169

Please sign in to comment.