-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
37 lines (29 loc) · 1.5 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import hparams as hp
class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """
def __init__(self):
super(FastSpeech2Loss, self).__init__()
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
def forward(self, log_d_predicted, log_d_target, p_predicted, p_target, e_predicted, e_target, mel, mel_postnet, mel_target, src_mask, mel_mask):
log_d_target.requires_grad = False
p_target.requires_grad = False
e_target.requires_grad = False
mel_target.requires_grad = False
log_d_predicted = log_d_predicted.masked_select(src_mask)
log_d_target = log_d_target.masked_select(src_mask)
p_predicted = p_predicted.masked_select(src_mask)
p_target = p_target.masked_select(src_mask)
e_predicted = e_predicted.masked_select(src_mask)
e_target = e_target.masked_select(src_mask)
mel = mel.masked_select(mel_mask.unsqueeze(-1))
mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1))
mel_target = mel_target.masked_select(mel_mask.unsqueeze(-1))
mel_loss = self.mse_loss(mel, mel_target)
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
d_loss = self.mae_loss(log_d_predicted, log_d_target)
p_loss = self.mae_loss(p_predicted, p_target)
e_loss = self.mae_loss(e_predicted, e_target)
return mel_loss, mel_postnet_loss, d_loss, p_loss, e_loss