Skip to content

Commit

Permalink
address #29
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 24, 2023
1 parent 091e603 commit 659bec7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
19 changes: 18 additions & 1 deletion naturalspeech2_pytorch/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def forward(self, attn_logprob, key_lens, query_lens):

# Convert to log probabilities
# Note: Mask out probs beyond key_len
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15)
mask_value = -torch.finfo(attn_logprob.dtype).max
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)

attn_logprob = attn_logprob.log_softmax(dim = -1)

Expand All @@ -159,6 +160,22 @@ def forward(self, attn_logprob, key_lens, query_lens):

return cost

class BinLoss(Module):
def forward(self, attn_hard, attn_logprob, key_lens):
batch, device = attn_logprob.shape[0], attn_logprob.device
max_key_len = attn_logprob.size(-1)

# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')
attn_hard = rearrange(attn_hard, 'b t c -> c b t')

mask_value = -torch.finfo(attn_logprob.dtype).max

attn_logprob.masked_fill_(torch.arange(max_key_len, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), mask_value)
attn_logprob = attn_logprob.log_softmax(dim = -1)

return (attn_hard * attn_logprob).sum() / batch

class Aligner(Module):
def __init__(
self,
Expand Down
15 changes: 12 additions & 3 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from beartype.door import is_bearable

from naturalspeech2_pytorch.attend import Attend
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss
from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak
from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask
from naturalspeech2_pytorch.version import __version__
Expand Down Expand Up @@ -1192,7 +1192,8 @@ def __init__(
scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
duration_loss_weight = 1.,
pitch_loss_weight = 1.,
aligner_loss_weight = 1.
aligner_loss_weight = 1.,
aligner_bin_loss_weight = 0.
):
super().__init__()

Expand Down Expand Up @@ -1233,7 +1234,10 @@ def __init__(
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels)
self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim)

self.aligner_loss = ForwardSumLoss()
self.bin_loss = BinLoss()
self.aligner_bin_loss_weight = aligner_bin_loss_weight

# rest of ddpm

Expand Down Expand Up @@ -1584,7 +1588,12 @@ def forward(

pitch = rearrange(pitch, 'b 1 d -> b d')
pitch_loss = F.l1_loss(pitch, pitch_pred)
align_loss = self.aligner_loss(aln_log , text_lens, mel_lens)

align_loss = self.aligner_loss(aln_log, text_lens, mel_lens)

if self.aligner_bin_loss_weight > 0.:
align_bin_loss = self.bin_loss(aln_mask, aln_log, text_lens) * self.aligner_bin_loss_weight
align_loss = align_loss + align_bin_loss

# weigh the losses

Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.7'
__version__ = '0.1.8'

0 comments on commit 659bec7

Please sign in to comment.