Skip to content

Commit

Permalink
TTS fine-tuning for SpeechT5 (#21824)
Browse files Browse the repository at this point in the history
* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 18, 2023
1 parent dacd345 commit ac2bc50
Show file tree
Hide file tree
Showing 10 changed files with 448 additions and 234 deletions.
21 changes: 20 additions & 1 deletion src/transformers/models/speecht5/configuration_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig):
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
The dropout probability for the speech decoder post-net layers.
reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor for the speech decoder post-net.
Spectrogram length reduction factor for the speech decoder inputs.
max_speech_positions (`int`, *optional*, defaults to 4000):
The maximum sequence length of speech features that this model might ever be used with.
max_text_positions (`int`, *optional*, defaults to 450):
The maximum sequence length of text features that this model might ever be used with.
encoder_max_relative_position (`int`, *optional*, defaults to 160):
Maximum distance for relative position embedding in the encoder.
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
Whether to apply guided attention loss while training the TTS model.
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
attention heads.
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
Standard deviation for guided attention loss.
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
Scaling coefficient for guided attention loss (also known as lambda).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Expand Down Expand Up @@ -241,6 +250,10 @@ def __init__(
max_speech_positions=4000,
max_text_positions=450,
encoder_max_relative_position=160,
use_guided_attention_loss=True,
guided_attention_loss_num_heads=2,
guided_attention_loss_sigma=0.4,
guided_attention_loss_scale=10.0,
use_cache=True,
is_encoder_decoder=True,
**kwargs,
Expand Down Expand Up @@ -311,6 +324,12 @@ def __init__(
self.max_speech_positions = max_speech_positions
self.max_text_positions = max_text_positions
self.encoder_max_relative_position = encoder_max_relative_position

self.use_guided_attention_loss = use_guided_attention_loss
self.guided_attention_loss_num_heads = guided_attention_loss_num_heads
self.guided_attention_loss_sigma = guided_attention_loss_sigma
self.guided_attention_loss_scale = guided_attention_loss_scale

self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,11 @@ def convert_speecht5_checkpoint(
if vocab_path:
tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions)

if task == "pretrain":
# Mask token behaves like a normal word, i.e. include the space before it
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
tokenizer.mask_token = mask_token
tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_tokens(["<ctc_blank>"])
# Mask token behaves like a normal word, i.e. include the space before it
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
tokenizer.mask_token = mask_token
tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_tokens(["<ctc_blank>"])

feature_extractor = SpeechT5FeatureExtractor()
processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
Expand Down
172 changes: 68 additions & 104 deletions src/transformers/models/speecht5/feature_extraction_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
"""Feature extractor class for SpeechT5."""

from typing import List, Optional, Union
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
import torchaudio

from ...audio_utils import get_mel_filter_banks
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
Expand Down Expand Up @@ -60,15 +61,15 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
win_function (`str`, *optional*, defaults to `"hann_window"`):
Name for the window function used for windowing, must be accessible via `torch.{win_function}`
frame_signal_scale (`float`, *optional*, defaults to 1.0):
Constant multiplied in creating the frames before applying DFT.
Constant multiplied in creating the frames before applying DFT. This argument is deprecated.
fmin (`float`, *optional*, defaults to 80):
Minimum mel frequency in Hz.
fmax (`float`, *optional*, defaults to 7600):
Maximum mel frequency in Hz.
mel_floor (`float`, *optional*, defaults to 1e-10):
Minimum value of mel frequency banks.
reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor.
Spectrogram length reduction factor. This argument is deprecated.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
"""
Expand Down Expand Up @@ -109,10 +110,33 @@ def __init__(

self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000

self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_freqs = (self.n_fft // 2) + 1

window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)

self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
frequency_max=self.fmax,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)

if frame_signal_scale != 1.0:
warnings.warn(
"The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if reduction_factor != 2.0:
warnings.warn(
"The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)

@staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
Expand All @@ -137,99 +161,45 @@ def zero_mean_unit_var_norm(
return normed_input_values

@staticmethod
def _center_pad(one_waveform, n_fft, pad_mode):
padding = [(int(n_fft // 2), int(n_fft // 2))]
return np.pad(one_waveform, padding, mode=pad_mode)

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]

return frames

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)

shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size

for frame in range(n_frames):
begin = frame * n_samples
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)

inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))

dft[frame] = np.abs(out[:K])
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length

return dft
return spectrogram

def _extract_fbank_features(
def _extract_mel_features(
self,
one_waveform: np.ndarray,
) -> np.ndarray:
"""
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC
code and librosa.
Extracts log-mel filterbank features for one waveform array (unbatched).
"""
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect")

n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)

frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
)

window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
window = window.numpy()

frames = self._windowing(frames, self.sample_size, window)

dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)

fbanks = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_freqs,
f_min=self.fmin,
f_max=self.fmax,
n_mels=self.num_mel_bins,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
fbanks = fbanks.numpy()
if self.n_fft != self.sample_size:
raise NotImplementedError(
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
)

return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks)))
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)

def _reduce(self, inputs):
reduced = []
for i in range(len(inputs)):
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor])
return reduced
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))

def __call__(
self,
Expand Down Expand Up @@ -341,7 +311,6 @@ def __call__(
return inputs_target
else:
inputs["labels"] = inputs_target["input_values"]
inputs["stop_labels"] = inputs_target["stop_labels"]
decoder_attention_mask = inputs_target.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
Expand Down Expand Up @@ -381,8 +350,7 @@ def _process_audio(

# convert into correct format for padding
if is_target:
features = [self._extract_fbank_features(waveform) for waveform in speech]
fbank_sizes = [len(x) for x in features]
features = [self._extract_mel_features(waveform) for waveform in speech]
encoded_inputs = BatchFeature({"input_values": features})
self.feature_size = self.num_mel_bins
else:
Expand Down Expand Up @@ -429,22 +397,18 @@ def _process_audio(
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
)

if is_target:
# make labels for stop prediction
stop_labels = []
for i, l in enumerate(fbank_sizes):
labels = np.zeros(len(padded_inputs["input_values"][i]))
labels[l - 1 :] = 1.0
stop_labels.append(labels)
padded_inputs["stop_labels"] = stop_labels

# thin out frames for reduction factor
if self.reduction_factor > 1:
padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"])
if attention_mask is not None:
padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"])

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

return padded_inputs

def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()

# Don't serialize these as they are derived from the other properties.
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
for name in names:
if name in output:
del output[name]

return output
Loading

0 comments on commit ac2bc50

Please sign in to comment.