Skip to content

Commit

Permalink
Merge pull request #155 from michael-kuhlmann/master
Browse files Browse the repository at this point in the history
contrib: Updates to time-frequency feature extractors
  • Loading branch information
michael-kuhlmann committed Apr 3, 2024
2 parents ec3d912 + e47356a commit d7b977a
Showing 1 changed file with 95 additions and 38 deletions.
133 changes: 95 additions & 38 deletions padertorch/contrib/mk/modules/features/timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from padertorch.contrib.mk.typing import TSeqLen, TSeqReturn

__all__ = [
'Sequential',
'Identity',
'Logarithm',
'STFT',
'MelTransform',
'MFCC',
]


class Sequential(pt.Module):
class Identity(pt.Module):
def forward(
self, x: Tensor, sequence_lengths: TSeqLen = None
) -> TSeqReturn:
Expand Down Expand Up @@ -71,6 +71,7 @@ def forward(self, x: Tensor) -> Tensor:
x = self.log_fn(torch.maximum(
torch.tensor(self.eps).to(x.device), x
))
return x

def inverse(self, x: torch.Tensor) -> torch.Tensor:
return self.power_fn(x)
Expand All @@ -88,6 +89,8 @@ class STFT(_STFT):
pad (bool): See paderbox.transform.module_stft.stft.
symmetric_window (bool): See paderbox.transform.module_stft.stft.
complex_representation (str): See padertorch.ops._stft.STFT.
preemphasis (float, optional): If not None, apply pre-emphasis with
this value to the input signals. Defaults to None.
spectrogram (bool): If True, return the magnitude spectrogram. Defaults
to False.
power (float): If `spectrogram` is True, raise magnitude to `power`.
Expand All @@ -97,6 +100,8 @@ class STFT(_STFT):
paderbox.transform.module_fbank.fbank. Defaults to False.
log_base (str, int, float, bool, optional): See Logarithm. Defaults to
False.
sequence_last (bool): If True, move the sequence axis to the last
position. Defaults to True.
normalization (InputNormalization, optional): InputNormalization
instance to perform z-normalization. Defaults to None.
"""
Expand All @@ -111,10 +116,12 @@ def __init__(
pad: bool = True,
symmetric_window: bool = False,
complex_representation: str = 'complex',
preemphasis: tp.Optional[float] = None,
spectrogram: bool = False,
power: float = 1.,
scale_spec: bool = False,
log_base: tp.Union[None, str, int, float, bool] = False,
sequence_last: bool = True,
normalization: tp.Union[InputNormalization, None] = None,
):
if not spectrogram and log_base:
Expand All @@ -134,12 +141,48 @@ def __init__(
# Keep references to window and symmetric_window
self.window = window
self.symmetric_window = symmetric_window

if preemphasis is not None:
try:
from torchaudio.transforms import Preemphasis
except ImportError as e:
try:
import torchaudio
raise ImportError(
f"Your torchaudio version ({torchaudio.__version__}) "
"does not support pre-emphasis. If you want to use "
"pre-emphasis, install torchaudio>=2.0.1."
) from e
except ImportError as e2:
raise ImportError(
"You need to install torchaudio>=2.0.1 to use "
"pre-emphasis."
) from e2
self.preemphasis = Preemphasis(preemphasis)
else:
self.preemphasis = None
self.spectrogram = spectrogram
self.power = power
self.scale_spec = scale_spec
self.log = Logarithm(log_base=log_base)
self.sequence_last = sequence_last
self.normalization = normalization

def to_spectrogram(self, stft_signal: Tensor) -> Tensor:
if self.complex_representation == 'complex':
spect = torch.abs(stft_signal)
elif self.complex_representation == 'stacked':
spect = stft_signal.pow(2).sum(-1).sqrt()
else:
real, imag = torch.split(
stft_signal, stft_signal.shape[-1] // 2, dim=-1
)
spect = (real.pow(2) + imag.pow(2)).sqrt()
spect = spect ** self.power
if self.scale_spec:
spect /= self.size
return spect

def __call__(
self,
inputs: Tensor,
Expand All @@ -154,34 +197,35 @@ def __call__(
time signals in `inputs`. Defaults to None.
Returns:
encoded (Tensor): Spectrogram of shape (batch, time, bins) if
encoded (Tensor): Spectrogram of shape (batch, bins, time) if
`spectrogram` is True else STFT of shape
- (batch, time, bins) if `complex_representation` is 'complex',
- (batch, time, bins, 2) if `complex_representation` is 'stacked', or
- (batch, channels, time, 2*bins) if `complex_representation` is 'concat'.
- (batch, bins, time) if `complex_representation` is 'complex',
- (batch, bins, time, 2) if `complex_representation` is 'stacked', or
- (batch, channels, 2*bins, time) if `complex_representation` is 'concat'.
If `sequence_last` is False, the time and bins axis are swapped.
sequence_lengths (list, optional): List of number of frames of
spectrograms in `encoded` if input `sequence_lengths` is not
None.
"""
if self.preemphasis is not None:
inputs = self.preemphasis(inputs)

encoded = super().__call__(inputs)
if self.spectrogram:
if self.complex_representation == 'complex':
encoded = torch.abs(encoded)
elif self.complex_representation == 'stacked':
encoded = encoded.pow(2).sum(-1).sqrt()
else:
real, imag = torch.split(
encoded, encoded.shape[-1] // 2, dim=-1
)
encoded = (real.pow(2) + imag.pow(2)).sqrt()
encoded = encoded ** self.power
if self.scale_spec:
encoded /= self.size
encoded = self.to_spectrogram(encoded)
encoded = self.log(encoded)
if sequence_lengths is not None:
sequence_lengths = self.samples_to_frames(
np.asarray(sequence_lengths)
)
if self.sequence_last:
if (
self.complex_representation == 'stacked'
and not self.spectrogram
):
encoded = encoded.transpose(-2, -3)
else:
encoded = encoded.transpose(-2, -1) # (..., bins, time)
if self.normalization is not None:
encoded = self.normalization(
encoded, sequence_lengths=sequence_lengths
Expand Down Expand Up @@ -235,6 +279,8 @@ class MelTransform(pt.Module):
filter banks are used.
squeeze_channel_axis (bool): If True, squeeze the channel axis and
always return a 3D tensor. Defaults to False.
sequence_last (boo): If True, move the sequence axis to the last
position. Defaults to True.
normalization (InputNormalization, optional): InputNormalization
instance to perform z-normalization. Defaults to None.
"""
Expand All @@ -254,17 +300,14 @@ def __init__(
warping_fn=None,
independent_axis: tp.Union[int, tp.Sequence[int]] = 0,
squeeze_channel_axis: bool = False,
sequence_last: bool = True,
normalization: tp.Union[InputNormalization, None] = None,
):
super().__init__()
self.sampling_rate = sampling_rate
self.stft_size = stft_size

self.stft = stft
if self.stft is not None and not self.stft.spectrogram:
raise ValueError(
f'stft.spectrogram must be True but is {stft.spectrogram}'
)

self.number_of_filters = number_of_filters
self.lowest_frequency = lowest_frequency
Expand Down Expand Up @@ -313,6 +356,7 @@ def __init__(
)

self.squeeze_channel_axis = squeeze_channel_axis
self.sequence_last = sequence_last
self.normalization = normalization

@classmethod
Expand All @@ -322,6 +366,7 @@ def finalize_dogmatic_config(cls, config):
'window': 'hann',
'spectrogram': True,
'size': config['stft_size'],
'sequence_last': False,
}

def _normalize(self, mel_basis):
Expand Down Expand Up @@ -365,14 +410,19 @@ def forward(
spectrograms or number of samples in `x`.
Returns:
x (Tensor): Mel spectrogram of shape
(batch, ..., time, number_of_filters).
(batch, ..., number_of_filters, time). If `sequence_last` is
False, the time and number_of_filters axis are swapped.
sequence_lengths (list, optional): List of number of frames of
mel spectrograms in `x` if input `sequence_lengths` is not None.
"""
x = x.float()

if self.stft is not None:
x, sequence_lengths = self.stft(x, sequence_lengths)
if not self.stft.spectrogram:
x = self.stft.to_spectrogram(x)
if self.stft.sequence_last:
x = x.transpose(-2, -1)

if not self.training or self.warping_fn is None:
x = torch.matmul(x, self.mel_basis.to(x.device))
Expand Down Expand Up @@ -410,6 +460,9 @@ def forward(
if x.ndim == 4 and self.squeeze_channel_axis:
x = x.squeeze(1)

if self.sequence_last:
x = x.transpose(-2, -1) # (..., bins, time)

if self.normalization is not None:
x = self.normalization(x, sequence_lengths=sequence_lengths)

Expand All @@ -427,7 +480,7 @@ def inverse(
class MFCC(pt.Module):
def __init__(
self,
number_of_filters: int,
number_of_bins: int,
transform: tp.Optional[tp.Union[MelTransform, STFT]] = None,
axis: int = -1,
channel_axis: int = 1,
Expand All @@ -439,24 +492,28 @@ def __init__(
"""Extract mel-cepstral coefficients from audio.
Args:
number_of_filters: Number of filters in the filterbank.
mel_transform: Optional `MelTransform` instance. If not None,
expect time signal as input and compute the log (mel)
number_of_bins (int): Number of frequency bins in the time-frequency
representation.
transform: Optional `MelTransform` or `STFT` instance. If not
None, expect time signal as input and compute the log (mel)
spectrogram before extracting the cepstral coefficients.
axis: Position of the frequency axis.
channel_axis: Position of the channel axis. Can be set to None if
the input has no channel axis.
num_cep: Number of cepstral coefficients to keep. If None, all
coefficients are kept.
low_pass: If True and `num_cep` is not None, keep the lowest
axis (int): Position of the frequency axis. Defaults to -1.
channel_axis (int): Position of the channel axis. Can be set to
None if the input has no channel axis. Defaults to 1.
num_cep (int, optional): Number of cepstral coefficients to keep.
If None, all coefficients are kept. Defaults to None.
low_pass (bool): If True and `num_cep` is not None, keep the lowest
`num_cep` coefficients and discard the rest (default behavior).
If False, keep the highest `number_of_filters-num_cep`
coefficients (high-pass behavior).
lifter_coeff: Liftering in the cepstral domain. See
coefficients (high-pass behavior). Defaults to True.
lifter_coeff (int): Liftering in the cepstral domain. See
`paderbox.transform.module_mfcc`. If 0, no liftering is applied.
Defaults to 0.
normalization (InputNormalization, optional): InputNormalization
instance to perform z-normalization. Defaults to None.
"""
super().__init__()
self.number_of_filters = number_of_filters
self.number_of_bins = number_of_bins
self.transform = transform
self.axis = axis
self.channel_axis = channel_axis
Expand Down Expand Up @@ -543,7 +600,7 @@ def inverse(self, x_mfcc: Tensor) -> Tensor:
if self.num_cep is not None:
shape = list(x_mfcc.shape)
if self.low_pass:
shape[self.axis] = self.number_of_filters - self.num_cep
shape[self.axis] = self.number_of_bins - self.num_cep
x_mfcc = torch.cat(
(x_mfcc, torch.zeros(shape).to(x_mfcc.device)),
dim=self.axis
Expand All @@ -556,7 +613,7 @@ def inverse(self, x_mfcc: Tensor) -> Tensor:
)
spect = torch.index_select(
torch.fft.irfft(x_mfcc, axis=self.axis, norm='ortho'),
self.axis, torch.arange(self.number_of_filters).to(x_mfcc.device)
self.axis, torch.arange(self.number_of_bins).to(x_mfcc.device)
)
return spect

Expand Down

0 comments on commit d7b977a

Please sign in to comment.