Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augmentation refactoring and torchaudio SoX effects support #124

Merged
merged 17 commits into from
Nov 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ A short snippet to show how Lhotse can make audio data prepartion quick and easy

```python

from lhotse import CutSet, Fbank
from lhotse import CutSet, Fbank, LilcomFilesWriter
from lhotse.dataset import VadDataset
from lhotse.recipes import prepare_switchboard

Expand All @@ -91,10 +91,11 @@ cuts = CutSet.from_manifests(
# Then, we pad the cuts to 5 seconds to ensure all cuts are of equal length,
# as the last window in each recording might have a shorter duration.
# The padding will be performed once the features are loaded into memory.
cuts = cuts.compute_and_store_features(
extractor=Fbank(),
output_dir='make_feats'
).pad(duration=5.0)
with LilcomFilesWriter('feats') as storage:
cuts = cuts.compute_and_store_features(
extractor=Fbank(),
storage=storage,
).pad(duration=5.0)

# Construct a Pytorch Dataset class for Voice Activity Detection task:
dataset = VadDataset(cuts)
Expand Down
25 changes: 24 additions & 1 deletion docs/augmentation.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
Augmentation
============

We support time-domain data augmentation via `WavAugment`_ library. ``WavAugment`` combines libsox and its own implementations to provide a range of augmentations. Since ``WavAugment`` depends on libsox, it is an optional depedency for Lhotse, which can be installed using ``tools/install_wavaugment.sh`` (for convenience, on Mac OS X the script will also compile libsox from source - though note that the ``WavAugment`` authors warn their library is untested on Mac).
We support time-domain data augmentation via `WavAugment`_ and `torchaudio`_ libraries.
They both leverage libsox to provide about 50 different audio effects like reverb, speed perturbation, pitch, etc.

Since ``WavAugment`` depends on libsox, it is an optional depedency for Lhotse, which can be installed using ``tools/install_wavaugment.sh`` (for convenience, the script will also compile libsox from source - note that the ``WavAugment`` authors warn their library is untested on Mac).

Torchaudio also depends on libsox, but seems to provide it when installed via anaconda.
This functionality is only available with PyTorch 1.7+ and torchaudio 0.7+.

Using Lhotse's Python API, you can compose an arbitrary effect chain. On the other hand, for the CLI we provide a small number of predefined effect chains, such as ``pitch`` (pitch shifting), ``reverb`` (reverberation), and ``pitch_reverb_tdrop`` (pitch shift + reverberation + time dropout of a 50ms chunk).

Python usage
************

.. warning::
When using WavAugment or torchaudio data augmentation together with a multiprocessing executor (i.e. ``ProcessPoolExecutor``), it is necessary to start it using the "spawn" context. Otherwise the process may hang (or terminate) on some systems due to libsox internals not handling forking well. Use: ``ProcessPoolExecutor(..., mp_context=multiprocessing.get_context("spawn"))``.

Lhotse's ``FeatureExtractor`` and ``Cut`` offer convenience functions for feature extraction with data augmentation
performed before that. These functions expose an optional parameter called ``augment_fn`` that has a signature like:

.. code-block::

def augment_fn(audio: Union[np.ndarray, torch.Tensor], sampling_rate: int) -> np.ndarray: ...

For ``torchaudio`` we define a ``SoxEffectTransform`` class:

.. autoclass:: lhotse.augmentation.SoxEffectTransform
:members:
:noindex:

We define a ``WavAugmenter`` class that is a thin wrapper over ``WavAugment``. It can either be created with a predefined, or a user-supplied effect chain.

.. autoclass:: lhotse.augmentation.WavAugmenter
Expand All @@ -33,3 +55,4 @@ You can create a dataset with both clean and augmented features by combining dif
lhotse yaml combine {clean,pitch,reverb}_feats/feature_manifest.yml.gz combined_feats.yml

.. _WavAugment: https://github.com/facebookresearch/WavAugment
.. _torchaudio: https://pytorch.org/audio/stable/index.html
11 changes: 6 additions & 5 deletions docs/getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ A short snippet to show how Lhotse can make audio data prepartion quick and easy

.. code-block::

from lhotse import CutSet, Fbank
from lhotse import CutSet, Fbank, LilcomFilesWriter
from lhotse.dataset import VadDataset
from lhotse.recipes import prepare_switchboard

Expand All @@ -88,10 +88,11 @@ A short snippet to show how Lhotse can make audio data prepartion quick and easy
# Then, we pad the cuts to 5 seconds to ensure all cuts are of equal length,
# as the last window in each recording might have a shorter duration.
# The padding will be performed once the features are loaded into memory.
cuts = cuts.compute_and_store_features(
extractor=Fbank(),
output_dir='make_feats'
).pad(duration=5.0)
with LilcomFilesWriter('feats') as storage:
cuts = cuts.compute_and_store_features(
extractor=Fbank(),
storage=storage,
).pad(duration=5.0)

# Construct a Pytorch Dataset class for Voice Activity Detection task:
dataset = VadDataset(cuts)
Expand Down
8 changes: 4 additions & 4 deletions lhotse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .audio import RecordingSet, Recording, AudioSource
from .augmentation import WavAugmenter
from .cut import CutSet, Cut
from .audio import AudioSource, Recording, RecordingSet
from .augmentation import *
from .cut import Cut, CutSet
from .features import *
from .kaldi import load_kaldi_data_dir
from .manipulation import load_manifest
from .supervision import SupervisionSet, SupervisionSegment
from .supervision import SupervisionSegment, SupervisionSet
3 changes: 3 additions & 0 deletions lhotse/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .common import AugmentFn
from .torchaudio import *
from .wavaugment import WavAugmenter, is_wav_augment_available
6 changes: 6 additions & 0 deletions lhotse/augmentation/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Callable

import numpy as np

# def augment_fn(audio: np.ndarray, sampling_rate: int) -> np.ndarray
AugmentFn = Callable[[np.ndarray, int], np.ndarray]
113 changes: 113 additions & 0 deletions lhotse/augmentation/torchaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import warnings
from dataclasses import dataclass
from typing import List, Union

import numpy as np
import torch
import torchaudio
from packaging.version import parse as _version

from lhotse.utils import during_docs_build

if not during_docs_build() and _version(torchaudio.__version__) < _version('0.7'):
warnings.warn('Torchaudio SoX effects chains are only introduced in version 0.7 - '
'please upgrade your PyTorch to 1.7+ and torchaudio to 0.7+ to use them.')


@dataclass
class RandomValue:
"""
Represents a uniform distribution in the range [start, end].
"""
start: Union[int, float]
end: Union[int, float]

def sample(self):
return np.random.uniform(self.start, self.end)


# Input to the SoxEffectTransform class - the values are either effect names,
# numeric parameters, or uniform distribution over possible values.
EffectsList = List[List[Union[str, int, float, RandomValue]]]


class SoxEffectTransform:
"""
Class-style wrapper for torchaudio SoX effect chains.
It should be initialized with a config-like list of items that define SoX effect to be applied.
It supports sampling randomized values for effect parameters through the ``RandomValue`` wrapper.

Example:
>>> audio = np.random.rand(16000)
>>> augment_fn = SoxEffectTransform(effects=[
>>> ['reverb', 50, 50, RandomValue(0, 100)],
>>> ['speed', RandomValue(0.9, 1.1)],
>>> ['rate', 16000],
>>> ])
>>> augmented = augment_fn(audio, 16000)

See SoX manual or ``torchaudio.sox_effects.effect_names()`` for the list of possible effects.
The parameters and the meaning of the values are explained in SoX manual/help.
"""

def __init__(self, effects: EffectsList):
super().__init__()
self.effects = effects

def __call__(self, tensor: Union[torch.Tensor, np.ndarray], sampling_rate: int):
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor)
effects = self.sample_effects()
augmented, new_sampling_rate = torchaudio.sox_effects.apply_effects_tensor(tensor, sampling_rate, effects)
assert augmented.shape[0] == tensor.shape[0], "Lhotse does not support modifying the number " \
"of channels during data augmentation."
assert sampling_rate == new_sampling_rate, \
f"Lhotse does not support changing the sampling rate during data augmentation. " \
f"The original SR was '{sampling_rate}', after augmentation it's '{new_sampling_rate}'."
# Matching shapes after augmentation -> early return.
if augmented.shape[1] == tensor.shape[1]:
return augmented
# We will truncate/zero-pad the signal if the number of samples has changed to mimic
# the WavAugment behavior that we relied upon so far.
resized = torch.zeros_like(tensor)
if augmented.shape[1] > tensor.shape[1]:
resized = augmented[:, :tensor.shape[1]]
else:
resized[:, :augmented.shape[1]] = augmented
return resized

def sample_effects(self) -> List[List[str]]:
"""
Resolve a list of effects, replacing random distributions with samples from them.
It converts every number to string to match the expectations of torchaudio.
"""
return [
[
str(item.sample() if isinstance(item, RandomValue) else item)
for item in effect
]
for effect in self.effects
]


def speed(sampling_rate: int) -> List[List[str]]:
return [
# Random speed perturbation factor between 0.9x and 1.1x the original speed
['speed', RandomValue(0.9, 1.1)],
['rate', sampling_rate], # Resample back to the original sampling rate (speed changes it)
Copy link
Contributor

@freewym freewym Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this line makes the running hang. It works without this line.

edit: actually not hang. It terminated with
"File "/export/fs04/a07/ywang/fairseq4/espresso/tools/lhotse/lhotse/cut.py", line 1311, in compute_and_store_features
executor.submit(
File "/export/b03/ywang/anaconda3/lib/python3.8/concurrent/futures/process.py", line 629, in submit
raise BrokenProcessPool(self._broken)
concurrent.futures.process.BrokenProcessPool: A child process terminated abruptly, the process pool is not usable anymore"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😫

I'll have a look

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to replicate the issue and add a unit test that causes it. I submitted the issue to torchaudio here: pytorch/audio#1021

Copy link
Contributor

@freewym freewym Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, if replacing RandomValue() above with a function _get_value(factor) where factor is simply returned, the running hangs as well. Do you have any clue of the cause?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the function defined as a closure (i.e. within another function) and captures some variable outside of its scope? That could explain it... Otherwise, I don't know.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@freewym some good news, if you create executor like: ProcessPoolExecutor(..., mp_context=multiprocessing.get_context("spawn")) it solves the segfault/hanging problem. Could you try? If it works I'll go on and merge this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Credits to @mthrok for suggesting this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(to make it clear: it works for me on the grid, on my mac, and in GitHub CI, so it should be okay)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it works!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

]


def reverb(sampling_rate: int) -> List[List[str]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make such functions more general, in that they can accept more arguments, e.g., the lower/up bound of room sizes can been passed into this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll change that

Copy link
Collaborator Author

@pzelasko pzelasko Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'd rather make a follow-up PR later on, as I'm not sure which parameters it makes sense to tweak and how general they should be. If we want to tweak everything it's simpler to just write your own chain...

(I'm open to suggestions)

return [
['reverb', 50, 50, RandomValue(0, 100)],
['remix', '-'], # Merge all channels (reverb changes mono to stereo)
]


def pitch(sampling_rate: int) -> List[List[str]]:
return [
# The returned values are 1/100ths of a semitone, meaning the default is up to a minor third shift up or down.
['pitch', '-q', RandomValue(-300, 300)],
['rate', sampling_rate] # Resample back to the original sampling rate (pitch changes it)
]
51 changes: 44 additions & 7 deletions lhotse/augmentation.py → lhotse/augmentation/wavaugment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import warnings
from typing import Union, List
from typing import List, Union

import numpy as np
import torch

__all__ = ['is_wav_augment_available', 'WavAugmenter', 'available_wav_augmentations', 'register_wav_augmentation',
'pitch', 'reverb', 'pitch_reverb_tdrop']


def is_wav_augment_available() -> bool:
"""Returns a boolean indicating if WavAugment is both installed and possible to import."""
Expand All @@ -18,14 +21,21 @@ class WavAugmenter:
"""
A wrapper class for WavAugment's effect chain.
You should construct the ``augment.EffectChain`` beforehand and pass it on to this class.

This class is only available when WavAugment is installed, as it is an optional dependency for Lhotse.
It can be installed using the script in "<main-repo-directory>/tools/install_wavaugment.sh"

For more details on how to augment, see https://github.com/facebookresearch/WavAugment
"""

def __init__(self, effect_chain, sampling_rate: int):
def __init__(self, effect_chain):
warnings.warn('WavAugment support is deprecated and it will eventually be removed from Lhotse. '
'For similar functionality, please use torchaudio based augmentation in '
'"lhotse.augmentation.torchaudio". It requires PyTorch 1.7+ and torchaudio 0.7+.',
category=DeprecationWarning)
# A local import so that ``augment`` can be optional.
import augment
self.chain: augment.EffectChain = effect_chain
self.sampling_rate = sampling_rate

@staticmethod
def create_predefined(name: str, sampling_rate: int, **kwargs) -> 'WavAugmenter':
Expand All @@ -38,14 +48,21 @@ def create_predefined(name: str, sampling_rate: int, **kwargs) -> 'WavAugmenter'
"""
return WavAugmenter(
effect_chain=_DEFAULT_AUGMENTATIONS[name](sampling_rate=sampling_rate, **kwargs),
sampling_rate=sampling_rate
)

def apply(self, audio: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
def __call__(
self,
audio: Union[torch.Tensor, np.ndarray],
sampling_rate: int
) -> np.ndarray:
"""
Apply the effect chain on the ``audio`` tensor.

:param audio: a (num_channels, num_samples) shaped tensor placed on the CPU.
:param sampling_rate: The input and output sampling rate (has to be the same).
:return a numpy ndarray with the augmented audio signal.
In case SoX returned Nan or Inf for some sample, fall back to returning the non-augmented
signal instead.
"""
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
Expand All @@ -57,12 +74,12 @@ def apply(self, audio: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
src_info={
'channels': audio.shape[0],
'length': audio.shape[1],
'rate': self.sampling_rate,
'rate': sampling_rate,
},
target_info={
'channels': 1,
'length': audio.shape[1],
'rate': self.sampling_rate,
'rate': sampling_rate,
}
)

Expand Down Expand Up @@ -107,6 +124,21 @@ def pitch(sampling_rate: int):
return effect_chain


@register_wav_augmentation
def speed(sampling_rate: int):
"""
Returns a pitch modification effect for wav augmentation.

:param sampling_rate: a sampling rate value for which the effect will be created (resampling is needed for pitch).
"""
import augment
effect_chain = augment.EffectChain()
# The pitch effect changes the sampling ratio; we have to compensate for that.
# Here, we specify 'quick' options on both pitch and rate effects, to speed up things
effect_chain.speed(_random_speed_perturb).rate("-q", sampling_rate)
return effect_chain


@register_wav_augmentation
def reverb(*args, **kwargs):
"""
Expand Down Expand Up @@ -143,6 +175,11 @@ def pitch_reverb_tdrop(sampling_rate: int):
return effect_chain


def _random_speed_perturb() -> int:
"""The returned values are speed perturbation factors (0.9x - 1.1x the original speed)."""
return np.random.uniform(0.9, 1.1)


def _random_pitch_shift() -> int:
"""The returned values are 1/100ths of a semitone, meaning the default is up to a minor third shift up or down."""
return np.random.randint(-300, 300)
Expand Down
10 changes: 5 additions & 5 deletions lhotse/bin/modes/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import click

from lhotse.audio import RecordingSet
from lhotse.augmentation import available_wav_augmentations, WavAugmenter
from lhotse.augmentation import WavAugmenter, available_wav_augmentations
from lhotse.bin.modes.cli_base import cli
from lhotse.features import FeatureExtractor, FeatureSetBuilder, create_default_feature_extractor, Fbank
from lhotse.features import Fbank, FeatureExtractor, FeatureSetBuilder, create_default_feature_extractor
from lhotse.features.io import available_storage_backends, get_writer
from lhotse.utils import Pathlike

Expand Down Expand Up @@ -67,18 +67,18 @@ def extract(
output_dir.mkdir(exist_ok=True, parents=True)
storage_path = output_dir / 'feats.h5' if 'hdf5' in storage_type else output_dir / 'storage'

augmenter = None
augment_fn = None
if augmentation is not None:
sampling_rate = next(iter(recordings)).sampling_rate
assert all(rec.sampling_rate == sampling_rate for rec in recordings), \
"Wav augmentation effect chains expect all the recordings to have the same sampling rate at this time."
augmenter = WavAugmenter.create_predefined(name=augmentation, sampling_rate=sampling_rate)
augment_fn = WavAugmenter.create_predefined(name=augmentation, sampling_rate=sampling_rate)

with get_writer(storage_type)(storage_path, tick_power=lilcom_tick_power) as storage:
feature_set_builder = FeatureSetBuilder(
feature_extractor=feature_extractor,
storage=storage,
augmenter=augmenter
augment_fn=augment_fn
)
feature_set_builder.process_and_store_recordings(
recordings=recordings,
Expand Down
Loading