Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
hollance committed Mar 16, 2023
1 parent dff7e37 commit d66db17
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,8 @@ def forward(
if labels is not None:
if decoder_input_values is None:
decoder_input_values = shift_spectrograms_right(labels, self.config.reduction_factor)
if self.config.use_guided_attention_loss:
output_attentions = True

outputs = self.speecht5(
input_values=input_ids,
Expand All @@ -2678,7 +2680,7 @@ def forward(
past_key_values=past_key_values,
use_cache=use_cache,
speaker_embeddings=speaker_embeddings,
output_attentions=output_attentions or labels is not None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
Expand All @@ -2692,8 +2694,8 @@ def forward(
outputs_before_postnet,
outputs_after_postnet,
logits,
outputs.cross_attentions,
labels,
outputs.cross_attentions,
)

if not return_dict:
Expand All @@ -2718,8 +2720,8 @@ def _compute_loss(
outputs_before_postnet: torch.FloatTensor,
outputs_after_postnet: torch.FloatTensor,
logits: torch.FloatTensor,
cross_attentions: torch.FloatTensor,
labels: torch.FloatTensor,
cross_attentions: Optional[torch.FloatTensor] = None,
):
padding_mask = labels != -100.0

Expand Down

0 comments on commit d66db17

Please sign in to comment.