Skip to content

Commit

Permalink
Kapre 0.3.7 (#137)
Browse files Browse the repository at this point in the history
* Feature/add spec augment layer (#135)

* Add SpecAugment Layer

* Extend docstrings for methods in the main class

* Add __all__ to import augmentation techniques

* Improve docstrings and add exceptions handling

* Improve tests for SpecAugment augmentation layer

* Uncomment tests. Fix setup.py

* Solve saving issue with tf format / Refactor code

* Add Jupyter Notebook and apply black for code reformatting

* add version printing; add 5 figures

* bump version; 0.3.6 -> 0.3.7

Co-authored-by: Keunwoo Choi <keunwoo.choi@bytedance.com>

* add release note

Co-authored-by: Miguel Otero Pedrido <32078719+MichaelisTrofficus@users.noreply.github.com>
  • Loading branch information
keunwoochoi and MichaelisTrofficus committed Jan 21, 2022
1 parent f25fdab commit 0b368b1
Show file tree
Hide file tree
Showing 8 changed files with 695 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Expand Up @@ -8,4 +8,5 @@ srcs/
_build
_static
_templates

.idea/
build/
4 changes: 4 additions & 0 deletions docs/release_note.rst
@@ -1,6 +1,10 @@
Release Note
^^^^^^^^^^^^

* 21 Jan 2022
- 0.3.7
- Add [SpecAugment](https://github.com/keunwoochoi/kapre/pull/135) layer

* 13 Nov 2021
- 0.3.6
- bugfix/pad end tflite #131
Expand Down
325 changes: 325 additions & 0 deletions examples/using-SpecAugment-for-data-augmentation.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion kapre/__init__.py
@@ -1,4 +1,4 @@
__version__ = '0.3.6'
__version__ = '0.3.7'
VERSION = __version__

from . import composed
Expand Down
211 changes: 211 additions & 0 deletions kapre/augmentation.py
Expand Up @@ -10,6 +10,8 @@
from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR
import numpy as np

__all__ = ['SpecAugment', 'ChannelSwap']


class ChannelSwap(Layer):
"""
Expand Down Expand Up @@ -101,3 +103,212 @@ def get_config(self):
}
)
return config


class SpecAugment(Layer):
"""
Apply SpecAugment to a Spectrogram. For more info, check the original paper at:
https://arxiv.org/abs/1904.08779
Args:
freq_mask_param (`int`): Frequency Mask Parameter (F in the paper)
time_mask_param (ìnt`): Time Mask Parameter (T in the paper)
n_freq_masks (`int`): Number of frequency masks to apply (mF in the paper). By default is 1.
n_time_masks (`int`): Number of time masks to apply (mT in the paper). By default is 1.
mask_value (`float`): Value of the applied masks. By default is 0.
data_format (`str`): specifies the data format of batch input/output
**kwargs: Keyword args for the parent keras layer (e.g., `name`)
Example:
::
input_shape = (2048, 1) # mono signal
# We compute the Mel Spectrogram of the input signal
melgram = kapre.composed.get_melspectrogram_layer(input_shape=input_shape,
n_fft=1024,
return_decibel=True,
n_mels=256,
input_data_format='channels_last',
output_data_format='channels_last')
# Now we define the SpecAugment layer. It will apply 5 masks in the frequency axis,
# 3 masks in the time axis. The frequency mask param is 5 and the time mask param
# is 10.
spec_augment = SpecAugment(freq_mask_param=5,
time_mask_param=10,
n_freq_masks=5,
n_time_masks=3)
model = Sequential()
model.add(melgram)
# Add the spec_augment layer for augmentation
model.add(spec_augment)
::
"""

def __init__(
self,
freq_mask_param,
time_mask_param,
n_freq_masks=1,
n_time_masks=1,
mask_value=0.0,
data_format='default',
**kwargs,
):

backend.validate_data_format_str(data_format)

super(SpecAugment, self).__init__(**kwargs)

self.freq_mask_param = freq_mask_param
self.time_mask_param = time_mask_param
self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks
self.mask_value = mask_value

if not self.freq_mask_param or not self.time_mask_param:
raise RuntimeError(
"Both freq_mask_param and time_mask_param must be defined and different "
"than zero"
)

self.data_format = K.image_data_format() if data_format == _CH_DEFAULT_STR else data_format

@staticmethod
def _generate_axis_mask(inputs):
"""
Generate a mask for the axis provided
Args:
inputs (`tuple`): A 3-tuple with the following structure:
inputs[0] (float `Tensor`): A spectrogram. Its shape is (time, freq, ch) or (ch, time, freq)
depending on data_format
inputs[1] (int): The axis limit. If mask will be applied to time axis it will be `time`, if it will
be applied to frequency axis, then it will be `freq`
inputs[2] (int `Tensor`): The axis indices. We need this Tensor of indices to indicate where to apply
the mask.
inputs[3] (int): The mask param as defined in the original paper, which is the max width of the mask
applied.
Returns:
(bool `Tensor`): A boolean tensor representing the mask. Its shape is (time, freq, ch) or (ch, time, freq)
depending on inputs[0] shape (that is, the input spectrogram).
"""
x, axis_limit, axis_indices, mask_param = inputs

mask_width = tf.random.uniform(shape=(), maxval=mask_param, dtype=tf.int32)
mask_start = tf.random.uniform(shape=(), maxval=axis_limit - mask_width, dtype=tf.int32)

return tf.logical_and(axis_indices >= mask_start, axis_indices <= mask_start + mask_width)

def _apply_masks_to_axis(self, x, axis, mask_param, n_masks):
"""
Applies a number of masks (defined by the parameter n_masks) to the spectrogram
by the axis provided.
Args:
x (float `Tensor`): A spectrogram. Its shape is (time, freq, ch) or (ch, time, freq)
depending on data_format.
axis (int): The axis where the masks will be applied
mask_param (int): The mask param as defined in the original paper, which is the max width of the mask
applied to the specified axis.
n_masks (int): The number of masks to be applied
Returns:
(float `Tensor`): The masked spectrogram. Its shape is (time, freq, ch) or (ch, time, freq)
depending on x shape (that is, the input spectrogram).
"""
axis_limit = K.int_shape(x)[axis]
axis_indices = tf.range(axis_limit)

if axis == 0:
axis_indices = tf.reshape(axis_indices, (-1, 1, 1))
elif axis == 1:
axis_indices = tf.reshape(axis_indices, (1, -1, 1))
elif axis == 2:
axis_indices = tf.reshape(axis_indices, (1, 1, -1))
else:
raise NotImplementedError(f"Axis parameter must be one of the following: 0, 1, 2")

# Check if mask_width is greater than axis_limit
if axis_limit < mask_param:
raise ValueError(
"Time and freq axis shapes must be greater than time_mask_param "
"and freq_mask_param respectively"
)

x_repeated = tf.repeat(tf.expand_dims(x, 0), n_masks, axis=0)
axis_limit_repeated = tf.repeat(axis_limit, n_masks, axis=0)
axis_indices_repeated = tf.repeat(tf.expand_dims(axis_indices, 0), n_masks, axis=0)
mask_param_repeated = tf.repeat(mask_param, n_masks, axis=0)

masks = tf.map_fn(
elems=(x_repeated, axis_limit_repeated, axis_indices_repeated, mask_param_repeated),
fn=self._generate_axis_mask,
dtype=(tf.float32, tf.int32, tf.int32, tf.int32),
fn_output_signature=tf.bool,
)

mask = tf.math.reduce_any(masks, 0)
return tf.where(mask, self.mask_value, x)

def _apply_spec_augment(self, x):
"""
Main method that applies SpecAugment technique by both frequency and
time axis.
Args:
x (float `Tensor`) : A spectrogram. Its shape is (time, freq, ch) or (ch, time, freq)
depending on data_format.
Returns:
(float `Tensor`): The spectrogram masked by time and frequency axis. Its shape is (time, freq, ch)
or (ch, time, freq) depending on x shape (that is, the input spectrogram).
"""
if self.data_format == _CH_LAST_STR:
time_axis, freq_axis = 0, 1
else:
time_axis, freq_axis = 1, 2

if self.n_time_masks >= 1:
x = self._apply_masks_to_axis(
x, axis=time_axis, mask_param=self.time_mask_param, n_masks=self.n_time_masks
)
if self.n_freq_masks >= 1:
x = self._apply_masks_to_axis(
x, axis=freq_axis, mask_param=self.freq_mask_param, n_masks=self.n_freq_masks
)
return x

def call(self, x, training=None, **kwargs):
if training in (None, False):
return x

if K.ndim(x) != 4:
raise ValueError(
'ndim of input tensor x should be 4 (batch spectrogram),' 'but it is %d' % K.ndim(x)
)

ch_axis = 1 if self.data_format == 'channels_first' else 3

if K.int_shape(x)[ch_axis] != 1:
raise RuntimeError(
'SpecAugment does not support spectrograms with depth greater than 1'
)

return tf.map_fn(
elems=x, fn=self._apply_spec_augment, dtype=tf.float32, fn_output_signature=tf.float32
)

def get_config(self):
config = super(SpecAugment, self).get_config()
config.update(
{
'freq_mask_param': self.freq_mask_param,
'time_mask_param': self.time_mask_param,
'n_freq_masks': self.n_freq_masks,
'n_time_masks': self.n_time_masks,
'mask_value': self.mask_value,
'data_format': self.data_format,
}
)
return config
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='kapre',
version='0.3.6',
version='0.3.7',
description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning',
author='Keunwoo Choi',
url='http://github.com/keunwoochoi/kapre/',
Expand Down

0 comments on commit 0b368b1

Please sign in to comment.