-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Augmentation refactoring and torchaudio SoX effects support #124
Changes from all commits
f3700df
4de6057
c626357
50e9b61
4d45676
fa15e0a
4a54cee
ac426ad
2e1c39a
ba156ca
28bc3f8
b99ecc7
6aa7b0f
7234dab
1b0ebf1
62988a0
ba685cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from .audio import RecordingSet, Recording, AudioSource | ||
from .augmentation import WavAugmenter | ||
from .cut import CutSet, Cut | ||
from .audio import AudioSource, Recording, RecordingSet | ||
from .augmentation import * | ||
from .cut import Cut, CutSet | ||
from .features import * | ||
from .kaldi import load_kaldi_data_dir | ||
from .manipulation import load_manifest | ||
from .supervision import SupervisionSet, SupervisionSegment | ||
from .supervision import SupervisionSegment, SupervisionSet |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .common import AugmentFn | ||
from .torchaudio import * | ||
from .wavaugment import WavAugmenter, is_wav_augment_available |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
# def augment_fn(audio: np.ndarray, sampling_rate: int) -> np.ndarray | ||
AugmentFn = Callable[[np.ndarray, int], np.ndarray] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torchaudio | ||
from packaging.version import parse as _version | ||
|
||
from lhotse.utils import during_docs_build | ||
|
||
if not during_docs_build() and _version(torchaudio.__version__) < _version('0.7'): | ||
warnings.warn('Torchaudio SoX effects chains are only introduced in version 0.7 - ' | ||
'please upgrade your PyTorch to 1.7+ and torchaudio to 0.7+ to use them.') | ||
|
||
|
||
@dataclass | ||
class RandomValue: | ||
""" | ||
Represents a uniform distribution in the range [start, end]. | ||
""" | ||
start: Union[int, float] | ||
end: Union[int, float] | ||
|
||
def sample(self): | ||
return np.random.uniform(self.start, self.end) | ||
|
||
|
||
# Input to the SoxEffectTransform class - the values are either effect names, | ||
# numeric parameters, or uniform distribution over possible values. | ||
EffectsList = List[List[Union[str, int, float, RandomValue]]] | ||
|
||
|
||
class SoxEffectTransform: | ||
""" | ||
Class-style wrapper for torchaudio SoX effect chains. | ||
It should be initialized with a config-like list of items that define SoX effect to be applied. | ||
It supports sampling randomized values for effect parameters through the ``RandomValue`` wrapper. | ||
|
||
Example: | ||
>>> audio = np.random.rand(16000) | ||
>>> augment_fn = SoxEffectTransform(effects=[ | ||
>>> ['reverb', 50, 50, RandomValue(0, 100)], | ||
>>> ['speed', RandomValue(0.9, 1.1)], | ||
>>> ['rate', 16000], | ||
>>> ]) | ||
>>> augmented = augment_fn(audio, 16000) | ||
|
||
See SoX manual or ``torchaudio.sox_effects.effect_names()`` for the list of possible effects. | ||
The parameters and the meaning of the values are explained in SoX manual/help. | ||
""" | ||
|
||
def __init__(self, effects: EffectsList): | ||
super().__init__() | ||
self.effects = effects | ||
|
||
def __call__(self, tensor: Union[torch.Tensor, np.ndarray], sampling_rate: int): | ||
if isinstance(tensor, np.ndarray): | ||
tensor = torch.from_numpy(tensor) | ||
effects = self.sample_effects() | ||
augmented, new_sampling_rate = torchaudio.sox_effects.apply_effects_tensor(tensor, sampling_rate, effects) | ||
assert augmented.shape[0] == tensor.shape[0], "Lhotse does not support modifying the number " \ | ||
"of channels during data augmentation." | ||
assert sampling_rate == new_sampling_rate, \ | ||
f"Lhotse does not support changing the sampling rate during data augmentation. " \ | ||
f"The original SR was '{sampling_rate}', after augmentation it's '{new_sampling_rate}'." | ||
# Matching shapes after augmentation -> early return. | ||
if augmented.shape[1] == tensor.shape[1]: | ||
return augmented | ||
# We will truncate/zero-pad the signal if the number of samples has changed to mimic | ||
# the WavAugment behavior that we relied upon so far. | ||
resized = torch.zeros_like(tensor) | ||
if augmented.shape[1] > tensor.shape[1]: | ||
resized = augmented[:, :tensor.shape[1]] | ||
else: | ||
resized[:, :augmented.shape[1]] = augmented | ||
return resized | ||
|
||
def sample_effects(self) -> List[List[str]]: | ||
""" | ||
Resolve a list of effects, replacing random distributions with samples from them. | ||
It converts every number to string to match the expectations of torchaudio. | ||
""" | ||
return [ | ||
[ | ||
str(item.sample() if isinstance(item, RandomValue) else item) | ||
for item in effect | ||
] | ||
for effect in self.effects | ||
] | ||
|
||
|
||
def speed(sampling_rate: int) -> List[List[str]]: | ||
return [ | ||
# Random speed perturbation factor between 0.9x and 1.1x the original speed | ||
['speed', RandomValue(0.9, 1.1)], | ||
['rate', sampling_rate], # Resample back to the original sampling rate (speed changes it) | ||
] | ||
|
||
|
||
def reverb(sampling_rate: int) -> List[List[str]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to make such functions more general, in that they can accept more arguments, e.g., the lower/up bound of room sizes can been passed into this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll change that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I'd rather make a follow-up PR later on, as I'm not sure which parameters it makes sense to tweak and how general they should be. If we want to tweak everything it's simpler to just write your own chain... (I'm open to suggestions) |
||
return [ | ||
['reverb', 50, 50, RandomValue(0, 100)], | ||
['remix', '-'], # Merge all channels (reverb changes mono to stereo) | ||
] | ||
|
||
|
||
def pitch(sampling_rate: int) -> List[List[str]]: | ||
return [ | ||
# The returned values are 1/100ths of a semitone, meaning the default is up to a minor third shift up or down. | ||
['pitch', '-q', RandomValue(-300, 300)], | ||
['rate', sampling_rate] # Resample back to the original sampling rate (pitch changes it) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this line makes the running hang. It works without this line.
edit: actually not hang. It terminated with
"File "/export/fs04/a07/ywang/fairseq4/espresso/tools/lhotse/lhotse/cut.py", line 1311, in compute_and_store_features
executor.submit(
File "/export/b03/ywang/anaconda3/lib/python3.8/concurrent/futures/process.py", line 629, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😫
I'll have a look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to replicate the issue and add a unit test that causes it. I submitted the issue to torchaudio here: pytorch/audio#1021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, if replacing RandomValue() above with a function
_get_value(factor)
wherefactor
is simply returned, the running hangs as well. Do you have any clue of the cause?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the function defined as a closure (i.e. within another function) and captures some variable outside of its scope? That could explain it... Otherwise, I don't know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@freewym some good news, if you create executor like:
ProcessPoolExecutor(..., mp_context=multiprocessing.get_context("spawn"))
it solves the segfault/hanging problem. Could you try? If it works I'll go on and merge thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Credits to @mthrok for suggesting this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(to make it clear: it works for me on the grid, on my mac, and in GitHub CI, so it should be okay)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it works!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉