Skip to content

Commit

Permalink
Merge pull request #53 from pseeth/source_time_resampling
Browse files Browse the repository at this point in the history
Modify source time tuple so that it stays within bounds before sampling from distribution
  • Loading branch information
justinsalamon committed Jan 30, 2020
2 parents ec5a5f6 + eada1c8 commit 803a26b
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 26 deletions.
6 changes: 6 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ before_install:
- deps='pip numpy scipy'
- conda create -q -n test-environment "python=$TRAVIS_PYTHON_VERSION" $deps
- source activate test-environment
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
pip install pytest==4.6.0 PyYAML==4.2b4 pytest-faulthandler==1.6.0;
fi
- if [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then
pip install pytest==4.6.0 PyYAML==4.2b4 pytest-faulthandler==1.6.0;
fi
- conda install -c conda-forge ffmpeg
- conda install -c conda-forge sox
- pip install python-coveralls
Expand Down
5 changes: 5 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
Changelog
---------

v1.1.0
~~~~~~
- Added functionality which modifies a source_time distribution tuple according to the duration of the source and the duration of the event.
- This release alters behavior of Scaper compared to earlier versions.

v1.0.3
~~~~~~
- Fix bug where temp files might not be closed if an error is raised
Expand Down
144 changes: 125 additions & 19 deletions scaper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import shutil
import csv
from copy import deepcopy
from .scaper_exceptions import ScaperError
from .scaper_warnings import ScaperWarning
from .util import _close_temp_files
Expand Down Expand Up @@ -358,6 +359,86 @@ def _validate_distribution(dist_tuple):
'number that is equal to or greater than trunc_min.')


def _ensure_satisfiable_source_time_tuple(source_time, source_duration, event_duration):
'''
Modify a source_time distribution tuple according to the duration of the
source and the duration of the event. This allows you to sample from
anywhere in a source file without knowing the exact duration of every
source file.
Parameters
----------
source_time : tuple
Tuple specifying a distribution to sample from. See Scaper.add_event
for details about the expected format of the tuple and allowed values.
source_duration : float
Duration of the source audio file.
event_duration : float
Duration of the event to be extracted from the source file.
See Also
--------
Scaper.add_event : Add a foreground sound event to the foreground
specification.
'''
_validate_distribution(source_time)
old_source_time = deepcopy(source_time)
source_time = list(source_time)

# If it's a constant distribution, just make sure it's within bounds.
if source_time[0] == 'const':
if source_time[1] + event_duration > source_duration:
source_time[1] = max(0, source_duration - event_duration)

# If it's a choose, iterate through the list to make sure it's all in bounds.
# Some logic here so we don't add stuff out of bounds more than once.
elif source_time[0] == 'choose':
for i, t in enumerate(source_time[1]):
if t + event_duration > source_duration:
source_time[1][i] = max(0, source_duration - event_duration)
source_time[1] = list(set(source_time[1]))

# If it's a uniform distribution, tuple must be of length 3, We change the 3rd
# item to source_duration - event_duration so that we stay in bounds. If the min
# out of bounds, we change it to be source_duration - event_duration.
elif source_time[0] == 'uniform':
if source_time[1] + event_duration > source_duration:
source_time[1] = max(0, source_duration - event_duration)
if source_time[2] + event_duration > source_duration:
source_time[2] = max(0, source_duration - event_duration)
if (source_time[1] == source_time[2]):
# switch to const
source_time = ['const', source_time[1]]

# If it's a normal distribution, we change the mean of the distribution to
# source_duration - event_duration if source_duration - mean < event_duration.
elif source_time[0] == 'normal':
if source_time[1] + event_duration > source_duration:
source_time[1] = max(0, source_duration - event_duration)

# If it's a truncated normal distribution, we change the mean as we did above for a
# normal distribution, and change the max (5th item) to
# source_duration - event_duration if it's bigger. If the min is out of bounds, we
# change it like in the uniform case.
elif source_time[0] == 'truncnorm':
if source_time[1] + event_duration > source_duration:
source_time[1] = max(0, source_duration - event_duration)
if source_time[3] + event_duration > source_duration:
source_time[3] = max(0, source_duration - event_duration)
if source_time[4] + event_duration > source_duration:
source_time[4] = max(0, source_duration - event_duration)
if (source_time[3] == source_time[4]):
# switch to const
source_time = ['const', source_time[1]]

source_time = tuple(source_time)
# check if the source_time changed from the old_source_time to throw a warning.
# it gets set here but the warning happens after the return from this call
warn = (source_time != old_source_time)

return tuple(source_time), warn


def _validate_label(label, allowed_labels):
'''
Validate that a label tuple is in the right format and that it's values
Expand Down Expand Up @@ -1211,25 +1292,50 @@ def _instantiate_event(self, event, isbackground=False,
event_duration),
ScaperWarning)

# determine source time
source_time = -np.Inf
while source_time < 0:
source_time = _get_value_from_dist(event.source_time)

# Make sure source time + event duration is not greater than the
# source duration, if it is, adjust the source time (i.e. duration
# takes precedences over start time).
if source_time + event_duration > source_duration:
old_source_time = source_time
source_time = source_duration - event_duration
if not disable_instantiation_warnings:
warnings.warn(
'{:s} source time ({:.2f}) is too great given event '
'duration ({:.2f}) and source duration ({:.2f}), changed '
'to {:.2f}.'.format(
label, old_source_time, event_duration,
source_duration, source_time),
ScaperWarning)
# Modify event.source_time so that sampling from the source time distribution
# stays within the bounds of the audio file - event_duration. This allows users
# to sample from anywhere in a source file without knowing the exact duration
# of every source file. Only modify if label is not in protected labels.
if label not in self.protected_labels:
tuple_still_invalid = False
modified_source_time, warn = _ensure_satisfiable_source_time_tuple(
event.source_time, source_duration, event_duration
)

# determine source time and also check again just in case (for normal dist).
# if it happens again, just use the old method.
source_time = -np.Inf
while source_time < 0:
source_time = _get_value_from_dist(modified_source_time)
if source_time + event_duration > source_duration:
source_time = source_duration - event_duration
warn = True
tuple_still_invalid = True

if warn and not disable_instantiation_warnings:
old_source_time = ', '.join(map(str, event.source_time))
new_source_time = ', '.join(map(str, modified_source_time))
if not tuple_still_invalid:
warnings.warn(
"{:s} source time tuple ({:s}) could not be satisfied given "
"source duration ({:.2f}) and event duration ({:.2f}), "
"source time tuple changed to ({:s})".format(
label, old_source_time, source_duration,
event_duration, new_source_time),
ScaperWarning)
else:
warnings.warn(
"{:s} source time tuple ({:s}) could not be satisfied given "
"source duration ({:.2f}) and event duration ({:.2f}), "
"source time tuple changed to ({:s}) but was still not "
"satisfiable, likely due to using 'normal' distribution with "
"bounds too close to the start or end of the audio file".format(
label, old_source_time, source_duration,
event_duration, new_source_time),
ScaperWarning)
else:
source_time = 0.0


# determine event time
# for background events the event time is fixed to 0, but for
Expand Down
4 changes: 2 additions & 2 deletions scaper/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# -*- coding: utf-8 -*-
"""Version info"""

short_version = '1.0'
version = '1.0.3'
short_version = '1.1'
version = '1.1.0'
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
],
install_requires=[
'sox>=1.3.3',
'jams>=0.3.2'
'jams>=0.3.2',
'numpy>=1.13.3'
],
extras_require={
'docs': [
Expand Down
105 changes: 101 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_generate_from_jams(atol=1e-5, rtol=1e-8):
np.random.uniform(0, 2), np.random.uniform(4, 6))
scaper.generate_from_jams(orig_jam_file.name, gen_wav_file.name)

# Tripple trimming
# Triple trimming
for _ in range(2):
sc.generate(orig_wav_file.name, orig_jam_file.name,
disable_instantiation_warnings=True)
Expand Down Expand Up @@ -425,6 +425,60 @@ def __test_bad_tuple_list(tuple_list):
__test_bad_tuple_list(badargs)


def test_ensure_satisfiable_source_time_tuple():
# Documenting the expected behavior of _ensure_satisfiable_source_time_tuple
source_duration = 10
event_duration = 5

_test_dist = ('uniform', 4, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1:], (4, 5))

_test_dist = ('truncnorm', 8, 1, 4, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1:], (5, 1, 4, 5))

_test_dist = ('const', 6)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1:], (5))

_test_dist = ('uniform', 1, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1:], (1, 5))

_test_dist = ('truncnorm', 4, 1, 1, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1:], (4, 1, 1, 5))

_test_dist = ('uniform', 6, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1], (5))

_test_dist = ('truncnorm', 8, 1, 6, 10)
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1], (5))

_test_dist = ('choose', [0, 1, 2, 10, 12, 15, 20])
_adjusted, warn = scaper.core._ensure_satisfiable_source_time_tuple(
_test_dist, source_duration, event_duration)
assert (warn)
assert np.allclose(_adjusted[1], [0, 1, 2, 5])


def test_validate_distribution():

def __test_bad_tuple_list(tuple_list):
Expand Down Expand Up @@ -859,11 +913,54 @@ def test_scaper_instantiate_event():
time_stretch=('const', 2))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event4)

# source_time + event_duration > source_duration: warning
fg_event5 = fg_event._replace(event_time=('const', 0),
# 'const' source_time + event_duration > source_duration: warning
fg_event5a = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('const', 20))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5)
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5a)

# 'choose' source_time + event_duration > source_duration: warning
fg_event5b = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('choose', [20, 20]))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5b)

# 'uniform' source_time + event_duration > source_duration: warning
fg_event5c = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('uniform', 20, 25))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5c)

# 'normal' source_time + event_duration > source_duration: warning
fg_event5d = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('normal', 20, 2))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5d)

# 'truncnorm' source_time + event_duration > source_duration: warning
fg_event5e = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('truncnorm', 20, 2, 20, 20))
pytest.warns(ScaperWarning, sc._instantiate_event, fg_event5e)

# 'normal' random draw above mean with mean = source_duration - event_duration
# source_time + event_duration > source_duration: warning
fg_event5f = fg_event._replace(event_time=('const', 0),
event_duration=('const', 8),
source_time=('normal', 18.25, 2))

def _repeat_instantiation(event):
# keep going till we hit a draw that covers when the draw exceeds
# source_duration - event_duration (18.25). Use max_draws
# just in case so that testing is guaranteed to terminate.
source_time = 0
num_draws = 0
max_draws = 1000
while source_time < 18.25 and num_draws < max_draws:
instantiated_event = sc._instantiate_event(event)
source_time = instantiated_event.source_time
num_draws += 1
pytest.warns(ScaperWarning, _repeat_instantiation, fg_event5f)

# event_time + event_duration > soundscape duration: warning
fg_event6 = fg_event._replace(event_time=('const', 8),
Expand Down

0 comments on commit 803a26b

Please sign in to comment.