Skip to content

Commit

Permalink
make sure forward sum loss actually runs, cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 1, 2023
1 parent 5cbf6e9 commit fb9e1d5
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ raw_audio = torch.randn(4, 327680)
prompt = torch.randn(4, 32768) # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically

text = torch.randint(0, 100, (4, 100))
text_lens = torch.tensor([100, 50 , 80, 120])
text_lens = torch.tensor([100, 50 , 80, 100])

# forwards and backwards

Expand Down
55 changes: 35 additions & 20 deletions naturalspeech2_pytorch/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import numpy as np

import torch
from torch import nn
from torch import nn, Tensor
from torch.nn import Module
import torch.nn.functional as F

from einops import rearrange
from einops import rearrange, repeat

from beartype import beartype
from beartype.typing import Optional

def exists(val):
return val is not None
Expand All @@ -22,7 +25,6 @@ def __init__(
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)

self.key_layers = nn.ModuleList([
nn.Conv1d(
Expand Down Expand Up @@ -50,7 +52,13 @@ def __init__(
nn.Conv1d(dim_in, attn_channels, kernel_size=1, padding=0, bias=True)
])

def forward(self, queries: torch.Tensor, keys: torch.Tensor, mask: torch.Tensor = None):
@beartype
def forward(
self,
queries: Tensor,
keys: Tensor,
mask: Optional[Tensor] = None
):
key_out = keys
for layer in self.key_layers:
key_out = layer(key_out)
Expand All @@ -61,12 +69,15 @@ def forward(self, queries: torch.Tensor, keys: torch.Tensor, mask: torch.Tensor

key_out = rearrange(key_out, 'b c t -> b t c')
query_out = rearrange(query_out, 'b c t -> b t c')
attn_logp = torch.cdist(query_out, key_out).unsqueeze(1)

attn_logp = torch.cdist(query_out, key_out)
attn_logp = rearrange(attn_logp, 'b ... -> b 1 ...')

if exists(mask):
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
mask = rearrange(mask.bool(), '... c -> ... 1 c')
attn_logp.data.masked_fill_(~mask, -torch.finfo(attn_logp.dtype).max)

attn = self.softmax(attn_logp)
attn = attn_logp.softmax(dim = -1)
return attn, attn_logp

def pad_tensor(input, pad, value=0):
Expand Down Expand Up @@ -110,34 +121,38 @@ def maximum_path(value, mask, const=None):
path = path.to(dtype=dtype)
return path

class ForwardSumLoss():
def __init__(self, blank_logprob=-1):
class ForwardSumLoss(Module):
def __init__(
self,
blank_logprob = -1
):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=-1)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob

def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
self.ctc_loss = torch.nn.CTCLoss(
blank = 0, # check this value
zero_infinity = True
)

def forward(self, attn_logprob, key_lens, query_lens):
device, blank_logprob = attn_logprob.device, self.blank_logprob
max_key_len = attn_logprob.size(-1)

# Reorder input to [query_len, batch_size, key_len]
attn_logprob = rearrange(attn_logprob, 'b c t -> c b t')
attn_logprob = rearrange(attn_logprob, 'b 1 c t -> c b t')

# Add blank label
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), self.blank_logprob)
attn_logprob = F.pad(attn_logprob, (1, 0, 0, 0, 0, 0), value = blank_logprob)

# Convert to log probabilities
# Note: Mask out probs beyond key_len
device = attn_logprob.device
attn_logprob.masked_fill_(torch.arange(max_key_len + 1, device=device, dtype=torch.long).view(1, 1, -1) > key_lens.view(1, -1, 1), -1e15)

attn_logprob = self.log_softmax(attn_logprob)
attn_logprob = attn_logprob.log_softmax(dim = -1)

# Target sequences
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long).unsqueeze(0)
target_seqs = target_seqs.repeat(key_lens.numel(), 1)
target_seqs = torch.arange(1, max_key_len + 1, device=device, dtype=torch.long)
target_seqs = repeat(target_seqs, 'n -> b n', b = key_lens.numel())

# Evaluate CTC loss
cost = self.ctc_loss(attn_logprob, target_seqs, query_lens, key_lens)
Expand Down
4 changes: 4 additions & 0 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,8 @@ def forward(
if not exists(text_lens):
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)

text_lens.clamp_(max = text_max_length)

text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')

prompt = self.process_prompt(prompt)
Expand Down Expand Up @@ -1475,6 +1477,8 @@ def forward(
if not exists(mel_lens):
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)

mel_lens.clamp_(max = mel_max_length)

mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')

# alignment
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.1.1'

0 comments on commit fb9e1d5

Please sign in to comment.