-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TTS fine-tuning for SpeechT5 #21824
TTS fine-tuning for SpeechT5 #21824
Conversation
The documentation is not available anymore as the PR was closed or merged. |
2227cd5
to
52d0c2f
Compare
32e0d6d
to
d66db17
Compare
d66db17
to
f6a626a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR @hollance! The custom NumPy STFT implementation in the feature extractor looks great - what kind of speed-up do you get with this STFT improvement?
The BCE + Guided Attention Loss is clean 👍 Are we good setting the loss to an unweighted average of the three loss terms?
Otherwise think the PR is good to go!
src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py
Show resolved
Hide resolved
src/transformers/models/speecht5/feature_extraction_speecht5.py
Outdated
Show resolved
Hide resolved
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft | ||
def _dft(frames, K, n_frames, n_samples, n_fft): | ||
dft = np.zeros([n_frames, K]) | ||
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long term, would it make sense for an stft function to go in audio utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely. And that would also remove the "must be a power of two" limitation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we should be able to batch the stft (long term goal)
for i in range(len(inputs)): | ||
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor]) | ||
return reduced | ||
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0)) | ||
bce_loss = bce_criterion(logits, stop_labels) | ||
|
||
loss = l1_loss + bce_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all the loss terms always weighted equally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a weighting term but it's always 1 in the original code, so I didn't bother including it. So yes in practice, the loss terms (including guided attention) are weighted equally.
@@ -2628,6 +2713,54 @@ def forward( | |||
encoder_attentions=outputs.encoder_attentions, | |||
) | |||
|
|||
def _compute_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering whether it makes sense to register a new module for the loss (since we init 3 different loss modules in this _compute_loss
method)?
We can register the three losses in the init, and call _compute_loss
in the forward
. Something along the lines of:
class SpeechT5SpectrogramLoss(nn.Module):
def __int__(self):
super.__init__()
self.bce_criterion =
...
def forward(self, attention_mask, outputs_before_postnet, ...):
# The inner workings of _compute_loss goes here
And then we just call this module to compute the loss:
if labels is not None:
loss = self.loss_module(...)
@@ -2637,7 +2770,8 @@ def generate_speech( | |||
minlenratio: float = 0.0, | |||
maxlenratio: float = 20.0, | |||
vocoder: Optional[nn.Module] = None, | |||
) -> torch.FloatTensor: | |||
output_cross_attentions: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cross attentions are useful for viewing the text-speech alignment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly, that's why I added them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can add output_cross_attention
to the config, like we do for use_cache
or output_attention
targets = self.tokenizer.pad(labels, **kwargs) | ||
labels = targets["input_ids"] | ||
else: | ||
feature_size_hack = self.feature_extractor.feature_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for me since this is the same 'hack' we employ in the feature extractor:
transformers/src/transformers/models/speecht5/feature_extraction_speecht5.py
Lines 379 to 380 in ff20f9c
# needed to make pad() work on spectrogram inputs | |
feature_size_hack = self.feature_size |
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314, | ||
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462, | ||
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988] | ||
[-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the values have changed quite a bit going from torchaudio
-> custom numpy
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not the reason for the change. ;-) SpeechT5 uses something called a "reduction factor", which is 2. I misunderstood this to mean that the target lengths would be reduced by 2x, which happened in the feature extractor. That was wrong: the targets keep their original size, but the input to the decoder is reduced by 2x. So previously the feature extractor was doing the wrong thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Great we've fixed it now!
Requesting review from @ArthurZucker for the custom STFT / log-Mel feature extraction components ( |
a1d0701
to
823df93
Compare
Gently pinging @ArthurZucker :) |
Will review in 1h! Sorry for the delay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool work ! 🤗 REALLY like the torch audio dependency being removed!
Left a few nits here and there 😉
@@ -175,7 +190,18 @@ def test_tokenizer_integration(self): | |||
] | |||
|
|||
# fmt: off | |||
expected_encoding = {'input_ids': [[4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26], [4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} | |||
expected_encoding = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably fit all of this in a single line since no one is going to look at it 😉
self.mel_filters = get_mel_filter_banks( | ||
nb_frequency_bins=self.n_freqs, | ||
nb_mel_filters=self.num_mel_bins, | ||
frequency_min=self.fmin, | ||
frequency_max=self.fmax, | ||
sample_rate=self.sampling_rate, | ||
norm="slaney", | ||
mel_scale="slaney", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥 Kudos for using the audio utils! Simplifies a lot
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft | ||
def _dft(frames, K, n_frames, n_samples, n_fft): | ||
dft = np.zeros([n_frames, K]) | ||
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we should be able to batch the stft (long term goal)
def to_dict(self) -> Dict[str, Any]: | ||
output = super().to_dict() | ||
|
||
# Don't serialize these as they are derived from the other properties. | ||
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"] | ||
for name in names: | ||
if name in output: | ||
del output[name] | ||
|
||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 😉
if token_ids_1 is None: | ||
return token_ids_0 + [self.eos_token_id] | ||
# We don't expect to process pairs, but leave the pair logic for API consistency | ||
return token_ids_0 + token_ids_1 + [self.eos_token_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return token_ids_0 + token_ids_1 + [self.eos_token_id] | |
return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] |
The eos should be added in between no? (Not sure !)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this from elsewhere and everyone does it this way. 🤷♂️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha no, some models don't always add the eos so they have a flags, but most of the models also copied from somewhere. Well, I doubt this function will be used (should be used only for sequence classification)
@@ -2247,9 +2374,9 @@ def set_output_embeddings(self, new_embeddings): | |||
def forward( | |||
self, | |||
input_values: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not necessary (if it is not long, it is be casted I think)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(small nit but valide for these changes to attention mask types_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely the point of type annotations is to be as specific as possible? ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal is not to be exact (any kind of tensor is accepted) but to be good documentation, so in this case, I agree with @hollance
outputs = spectrogram | ||
|
||
if output_cross_attentions: | ||
cross_attentions = torch.cat(cross_attentions, dim=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure we usually return the cross attention as a list, would be good to keep that expected behaviour (unless it is specific / required by the model)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not exactly the same thing: normally this returns a list of (batch, heads, out len, in len)
tensors, with one tensor per layer. But here, it returns one tensor of shape (layers, heads, out len, in len)
. There is no batch dimension.
We could change it to a list of (1, heads, out len, in len)
tensors to be consistent with how it normally happens, I suppose. But currently generate_speech()
does not handle batches anyway.
|
||
loss = None | ||
if labels is not None: | ||
criterion = SpeechT5SpectrogramLoss(self.config) | ||
loss = criterion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regarding my comment I gues it would mean concatenating the cross attentions in the criterion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those cross attentions aren't used in the loss function. They're only provided to let the user visualize how well the input sequence maps to the output sequence (if the model works well we'd expect to see a diagonal line in the cross attentions).
@@ -2637,7 +2770,8 @@ def generate_speech( | |||
minlenratio: float = 0.0, | |||
maxlenratio: float = 20.0, | |||
vocoder: Optional[nn.Module] = None, | |||
) -> torch.FloatTensor: | |||
output_cross_attentions: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we can add output_cross_attention
to the config, like we do for use_cache
or output_attention
'input_ids': [ | ||
[4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26, 2], | ||
[4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||
[4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||
], | ||
'attention_mask': [ | ||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | ||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | ||
] | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fits in one line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but I like it better as 3 lines, since they are 3 separate examples.
216f9ff
to
7c09a1b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice PR! Thanks for adding and for reworking parts of the processing code, it's all v. clean :D
There's just two questions / comments I have relating to backwards compatibility before giving the 👍
- Have the slow integration tests for the SpeechT5 models been run to check outputs are the same with the processing updates?
- Am I right in understanding
stop_labels
were never used (and so removal doesn't affect things?) - With
reduction_factor
being moved toshift_spectrograms_right
, does this effectively mean theinput_values
output from the processor has changed for the same config?
def test_special_tokens_mask(self): | ||
pass | ||
|
||
def test_special_tokens_mask_input_pairs(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add unittest.skip
decorators here, with a message about why they're skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out these shouldn't have been skipped and the tokenizer was missing a method. Good catch!
|
||
decoder_attention_mask = targets.get("attention_mask") | ||
if decoder_attention_mask is not None: | ||
inputs["decoder_attention_mask"] = decoder_attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice - this is a lot cleaner 🔥
@@ -2414,7 +2542,8 @@ def _generate_speech( | |||
minlenratio: float = 0.0, | |||
maxlenratio: float = 20.0, | |||
vocoder: Optional[nn.Module] = None, | |||
) -> torch.FloatTensor: | |||
output_cross_attentions: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean to have a value of None
for this param? Often for e.g. output_attentions
it's used to take the default config values. As far as I can tell, it's only ever used as a bool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
padded_inputs["stop_labels"] = stop_labels | ||
|
||
# thin out frames for reduction factor | ||
if self.reduction_factor > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I correct in understanding that this reduction now takes place in shift_spectrograms_right
in the modelling file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, I mistakenly thought it applied to the labels but it applies to the input that the decoder sees.
The outputs are not the same because the processing of the labels changed. But that's OK since the labels weren't used up to this point anyway.
Correct.
It didn't affect the |
7c09a1b
to
3cfc6f2
Compare
@amyeroberts If you're OK with the changes, I think this can be merged now. The failing tests seem unrelated to SpeechT5. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ❤️
I'd just like to get a second opinion from @sgugger, in particular regarding three potential breaking changes:
- The removal of
frame_signal_scale
andreduction_factor
as attributes from the feature extractor. I would potentially add them as a property with a deprecation warning, as users sometimes access them in their pipelines e.g. here for max_size. "stop_labels"
not being returned from the feature extractor. They weren't used in the model, but potentially used by users elsewhere? Is this something we guarantee?stop_labels
no longer accepted as an input to the model. I realised this has no affect on the output, and is in line with the feature extractor. Do we typically have a deprecation cycle for model inputs?
I'm pretty sure no one was using any of these properties before, since we only released SpeechT5 very recently and no one would have used it for training yet. Adding deprecation warnings seems excessive to me in this case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Regarding the breaking changes, even while keeping in mind this is a fairly recent model, I think we can make a bit of an effort regarding backward compatibility (remember Transformers promises no breaking changes between minor releases), especially since this behavior will have been present in two releases (4.27.0 and 4.28.0 since the branch is already cut).
The removal of frame_signal_scale and reduction_factor as attributes from the feature extractor. I would potentially add them as a property with a deprecation warning, as users sometimes access them in their pipelines e.g. here for max_size.
Here this is easy to do to avoid a breaking change.
"stop_labels" not being returned from the feature extractor. They weren't used in the model, but potentially used by users elsewhere? Is this something we guarantee?
This one we can remove probably and wait to see if users complain. We can add an additional argument to return those stop labels if they are requested.
stop_labels no longer accepted as an input to the model. I realized this has no affect on the output, and is in line with the feature extractor. Do we typically have a deprecation cycle for model inputs?
Typically yes. And this is very easy to add so I don't see any reason not to do it.
In both cases, we can probably say it will be removed in two minor versions (so 4.30.0).
@@ -2247,9 +2374,9 @@ def set_output_embeddings(self, new_embeddings): | |||
def forward( | |||
self, | |||
input_values: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.Tensor] = None, | |||
attention_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal is not to be exact (any kind of tensor is accepted) but to be good documentation, so in this case, I agree with @hollance
3cfc6f2
to
5a7e21f
Compare
OK, put frame_signal_scale and reduction_factor back and added a deprecation warning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. There is one deprecation warning for stop_labels
in the model code as well.
src/transformers/models/speecht5/feature_extraction_speecht5.py
Outdated
Show resolved
Hide resolved
src/transformers/models/speecht5/feature_extraction_speecht5.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
3f8f934
to
671c44b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating! Super nice PR :)
If you're all happy with it, feel free to merge (I don't have rights for that). 😃 |
@hollance - sorry, my bad, I thought you did! |
* wrong argument name * append eos_token_id * all tokenizers need mask and ctc_blank tokens * remove reduction factor from feature extractor * add proper TTS loss * did shifting the wrong way around * mask out padded portions * remove logits again (don't really need it) * fix unit tests * fixup * pad also returns the decoder attention mask, since that's useful to have * clean up feature extractor logic * pad can handle TTS task too * remove stop_labels from loss calculation * simplify logic * fixup * do -100 masking properly * small STFT optimization (calculate mel filterbanks only once) * replace torchaudio fbanks with audio_utils * remove torchaudio dependency * simplify & speed up the STFT * don't serialize window and mel filters * output cross attentions when generating speech * add guided attention loss * fix failing test * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/speecht5/modeling_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change type annotation of attention_mask to LongTensor * extract loss into class * remove unused frame_signal_scale argument * use config object in loss class * fix type annotations in doc comments * change optional to just bool * implement missing tokenizer method * add deprecation warning * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add deprecation warning for stop_labels --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
What does this PR do?
Adds fine-tuning support for SpeechT5, in particular the TTS model.
The loss function is a combination of L1 loss for the mel-spectrograms, BCE for the stop token prediction, and (optionally) guided attention loss to persuade the cross-attentions to be diagonal.
The STFT feature extraction has been sped up, which also means it currently assumes the frame size is a power of two and throws an error otherwise.
The feature extractor no longer outputs a
stop_labels
target. Padded areas in the spectrogram target are assumed to have the value -100 during training; from this the stop labels are computed automatically.Various other small fixes to the tokenizer, processor, etc to support fine-tuning.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.