Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TTS fine-tuning for SpeechT5 #21824

Merged
merged 38 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5b15e26
wrong argument name
hollance Feb 22, 2023
6452ec3
append eos_token_id
hollance Feb 23, 2023
2cc9c89
all tokenizers need mask and ctc_blank tokens
hollance Feb 23, 2023
9bf24d1
remove reduction factor from feature extractor
hollance Feb 27, 2023
651be98
add proper TTS loss
hollance Feb 28, 2023
6e3c3f2
did shifting the wrong way around
hollance Mar 1, 2023
4acaf1b
mask out padded portions
hollance Mar 1, 2023
75b13fd
remove logits again (don't really need it)
hollance Mar 1, 2023
9473f93
fix unit tests
hollance Mar 1, 2023
3c0b202
fixup
hollance Mar 1, 2023
aee0272
pad also returns the decoder attention mask, since that's useful to have
hollance Mar 2, 2023
193e6d1
clean up feature extractor logic
hollance Mar 2, 2023
2a81ec4
pad can handle TTS task too
hollance Mar 2, 2023
218830c
remove stop_labels from loss calculation
hollance Mar 2, 2023
a05b2ea
simplify logic
hollance Mar 2, 2023
7d803ae
fixup
hollance Mar 2, 2023
1be2f8b
do -100 masking properly
hollance Mar 6, 2023
40e7057
small STFT optimization (calculate mel filterbanks only once)
hollance Mar 6, 2023
92f79a6
replace torchaudio fbanks with audio_utils
hollance Mar 6, 2023
f4edf81
remove torchaudio dependency
hollance Mar 6, 2023
7c2a44c
simplify & speed up the STFT
hollance Mar 6, 2023
0dea7f5
don't serialize window and mel filters
hollance Mar 7, 2023
81668a2
output cross attentions when generating speech
hollance Mar 9, 2023
ff83735
add guided attention loss
hollance Mar 13, 2023
9a02068
fix failing test
hollance Mar 16, 2023
cf67785
Update src/transformers/models/speecht5/feature_extraction_speecht5.py
hollance Mar 23, 2023
0ac8589
Update src/transformers/models/speecht5/modeling_speecht5.py
hollance Mar 23, 2023
03bc390
change type annotation of attention_mask to LongTensor
hollance Mar 23, 2023
0b2785e
extract loss into class
hollance Mar 23, 2023
b0558ac
remove unused frame_signal_scale argument
hollance Apr 3, 2023
94b61bc
use config object in loss class
hollance Apr 11, 2023
5d2f774
fix type annotations in doc comments
hollance Apr 12, 2023
39f0f8b
change optional to just bool
hollance Apr 12, 2023
ab501a4
implement missing tokenizer method
hollance Apr 12, 2023
630f97a
add deprecation warning
hollance Apr 13, 2023
614f7aa
Update src/transformers/models/speecht5/feature_extraction_speecht5.py
hollance Apr 13, 2023
a74e414
Update src/transformers/models/speecht5/feature_extraction_speecht5.py
hollance Apr 13, 2023
671c44b
add deprecation warning for stop_labels
hollance Apr 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/transformers/models/speecht5/configuration_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig):
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
The dropout probability for the speech decoder post-net layers.
reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor for the speech decoder post-net.
Spectrogram length reduction factor for the speech decoder inputs.
max_speech_positions (`int`, *optional*, defaults to 4000):
The maximum sequence length of speech features that this model might ever be used with.
max_text_positions (`int`, *optional*, defaults to 450):
The maximum sequence length of text features that this model might ever be used with.
encoder_max_relative_position (`int`, *optional*, defaults to 160):
Maximum distance for relative position embedding in the encoder.
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
Whether to apply guided attention loss while training the TTS model.
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
attention heads.
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
Standard deviation for guided attention loss.
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
Scaling coefficient for guided attention loss (also known as lambda).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).

Expand Down Expand Up @@ -241,6 +250,10 @@ def __init__(
max_speech_positions=4000,
max_text_positions=450,
encoder_max_relative_position=160,
use_guided_attention_loss=True,
guided_attention_loss_num_heads=2,
guided_attention_loss_sigma=0.4,
guided_attention_loss_scale=10.0,
use_cache=True,
is_encoder_decoder=True,
**kwargs,
Expand Down Expand Up @@ -311,6 +324,12 @@ def __init__(
self.max_speech_positions = max_speech_positions
self.max_text_positions = max_text_positions
self.encoder_max_relative_position = encoder_max_relative_position

self.use_guided_attention_loss = use_guided_attention_loss
self.guided_attention_loss_num_heads = guided_attention_loss_num_heads
self.guided_attention_loss_sigma = guided_attention_loss_sigma
self.guided_attention_loss_scale = guided_attention_loss_scale

self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,11 @@ def convert_speecht5_checkpoint(
if vocab_path:
tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions)

if task == "pretrain":
# Mask token behaves like a normal word, i.e. include the space before it
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
tokenizer.mask_token = mask_token
tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_tokens(["<ctc_blank>"])
# Mask token behaves like a normal word, i.e. include the space before it
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
tokenizer.mask_token = mask_token
tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_tokens(["<ctc_blank>"])

feature_extractor = SpeechT5FeatureExtractor()
processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
Expand Down
172 changes: 68 additions & 104 deletions src/transformers/models/speecht5/feature_extraction_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# limitations under the License.
"""Feature extractor class for SpeechT5."""

from typing import List, Optional, Union
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
import torchaudio

from ...audio_utils import get_mel_filter_banks
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
Expand Down Expand Up @@ -60,15 +61,15 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
win_function (`str`, *optional*, defaults to `"hann_window"`):
Name for the window function used for windowing, must be accessible via `torch.{win_function}`
frame_signal_scale (`float`, *optional*, defaults to 1.0):
Constant multiplied in creating the frames before applying DFT.
Constant multiplied in creating the frames before applying DFT. This argument is deprecated.
fmin (`float`, *optional*, defaults to 80):
Minimum mel frequency in Hz.
fmax (`float`, *optional*, defaults to 7600):
Maximum mel frequency in Hz.
mel_floor (`float`, *optional*, defaults to 1e-10):
Minimum value of mel frequency banks.
reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor.
Spectrogram length reduction factor. This argument is deprecated.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
"""
Expand Down Expand Up @@ -109,10 +110,33 @@ def __init__(

self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000

self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_freqs = (self.n_fft // 2) + 1

window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)

self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
frequency_max=self.fmax,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
Comment on lines +119 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 Kudos for using the audio utils! Simplifies a lot


if frame_signal_scale != 1.0:
warnings.warn(
"The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if reduction_factor != 2.0:
warnings.warn(
"The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)

@staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
Expand All @@ -137,99 +161,45 @@ def zero_mean_unit_var_norm(
return normed_input_values

@staticmethod
def _center_pad(one_waveform, n_fft, pad_mode):
padding = [(int(n_fft // 2), int(n_fft // 2))]
return np.pad(one_waveform, padding, mode=pad_mode)

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]

return frames

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)

shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term, would it make sense for an stft function to go in audio utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes absolutely. And that would also remove the "must be a power of two" limitation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should be able to batch the stft (long term goal)

"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size

for frame in range(n_frames):
begin = frame * n_samples
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)

inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))

dft[frame] = np.abs(out[:K])
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length

return dft
return spectrogram

def _extract_fbank_features(
def _extract_mel_features(
self,
one_waveform: np.ndarray,
) -> np.ndarray:
"""
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC
code and librosa.
Extracts log-mel filterbank features for one waveform array (unbatched).
"""
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect")

n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)

frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
)

window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
window = window.numpy()

frames = self._windowing(frames, self.sample_size, window)

dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)

fbanks = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_freqs,
f_min=self.fmin,
f_max=self.fmax,
n_mels=self.num_mel_bins,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
fbanks = fbanks.numpy()
if self.n_fft != self.sample_size:
raise NotImplementedError(
f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
)

return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks)))
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)

def _reduce(self, inputs):
reduced = []
for i in range(len(inputs)):
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor])
return reduced
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


def __call__(
self,
Expand Down Expand Up @@ -341,7 +311,6 @@ def __call__(
return inputs_target
else:
inputs["labels"] = inputs_target["input_values"]
inputs["stop_labels"] = inputs_target["stop_labels"]
decoder_attention_mask = inputs_target.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
Expand Down Expand Up @@ -381,8 +350,7 @@ def _process_audio(

# convert into correct format for padding
if is_target:
features = [self._extract_fbank_features(waveform) for waveform in speech]
fbank_sizes = [len(x) for x in features]
features = [self._extract_mel_features(waveform) for waveform in speech]
encoded_inputs = BatchFeature({"input_values": features})
self.feature_size = self.num_mel_bins
else:
Expand Down Expand Up @@ -429,22 +397,18 @@ def _process_audio(
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
)

if is_target:
# make labels for stop prediction
stop_labels = []
for i, l in enumerate(fbank_sizes):
labels = np.zeros(len(padded_inputs["input_values"][i]))
labels[l - 1 :] = 1.0
stop_labels.append(labels)
padded_inputs["stop_labels"] = stop_labels

# thin out frames for reduction factor
if self.reduction_factor > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in understanding that this reduction now takes place in shift_spectrograms_right in the modelling file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I mistakenly thought it applied to the labels but it applies to the input that the decoder sees.

padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"])
if attention_mask is not None:
padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"])

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

return padded_inputs

def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()

# Don't serialize these as they are derived from the other properties.
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
for name in names:
if name in output:
del output[name]

return output
Comment on lines +405 to +414
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 😉

Loading