Skip to content

Commit

Permalink
at inference time, the alignment mask is derived from the duration. i…
Browse files Browse the repository at this point in the history
…mprovise a get_mask_from_lengths function, consult with someone in the field later
  • Loading branch information
lucidrains committed Aug 31, 2023
1 parent 4ff58db commit 44d1a1f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 114 deletions.
11 changes: 0 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,15 @@ diffusion = NaturalSpeech2(
raw_audio = torch.randn(4, 327680)
prompt = torch.randn(4, 32768) # they randomly excised a range on the audio for the prompt during training, eventually will take care of this auto-magically

mel_lens = torch.tensor([120, 60 , 80, 70])
mel = torch.randn((4, 80, 120))

text = torch.randint(0, 100, (4, 100))
text_lens = torch.tensor([100, 50 , 80, 120])

pitch = torch.randn(4, 1, 120)

# forwards and backwards

loss = diffusion(
audio = raw_audio,
text = text,
text_lens = text_lens,
mel = mel,
mel_lens = mel_lens,
pitch = pitch,
prompt = prompt
)

Expand All @@ -134,9 +126,6 @@ loss.backward()
generated_audio = diffusion.sample(
length = 1024,
text = text,
mel = mel,
mel_lens = mel_lens,
pitch = pitch,
prompt = prompt
) # (1, 327680)
```
Expand Down
17 changes: 9 additions & 8 deletions naturalspeech2_pytorch/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,18 @@ def forward(
y,
y_mask
):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(y, rearrange(x, 'b d t -> b t d'), x_mask)
alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')

alignment_mas = maximum_path(
alignment_soft.contiguous(),
rearrange(attn_mask, 'b 1 c t -> b c t').contiguous()
)
x_mask = rearrange(x_mask, '... i -> ... i 1')
y_mask = rearrange(y_mask, '... j -> ... 1 j')
attn_mask = x_mask * y_mask
attn_mask = rearrange(attn_mask, 'b 1 i j -> b i j')

alignment_soft = rearrange(alignment_soft, 'b 1 c t -> b t c')
alignment_mask = maximum_path(alignment_soft, attn_mask)

alignment_hard = torch.sum(alignment_mas, -1).int()
return alignment_hard, alignment_soft, alignment_logprob, alignment_mas
alignment_hard = torch.sum(alignment_mask, -1).int()
return alignment_hard, alignment_soft, alignment_logprob, alignment_mask

if __name__ == '__main__':
batch_size = 10
Expand Down
177 changes: 83 additions & 94 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def prob_mask_like(shape, prob, device):
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

def generate_mask_from_lengths(lengths):
src = lengths.int()
device = src.device
tgt_length = src.sum(dim = -1).amax().item()

cumsum = src.cumsum(dim = -1)
cumsum_exclusive = F.pad(cumsum, (1, -1), value = 0.)

tgt_arange = torch.arange(tgt_length, device = device)
tgt_arange = repeat(tgt_arange, '... j -> ... i j', i = src.shape[-1])

cumsum = rearrange(cumsum, '... i -> ... i 1')
cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')

mask = (tgt_arange < cumsum) & (tgt_arange >= cumsum_exclusive)
return mask

# sinusoidal positional embeds

class LearnedSinusoidalPosEmb(nn.Module):
Expand Down Expand Up @@ -1344,76 +1361,6 @@ def process_prompt(self, prompt = None):

return prompt

def process_conditioning(
self,
*,
prompt,
audio = None,
pitch = None,
text = None,
text_lens = None,
mel = None,
mel_lens = None
):
batch = prompt.shape[0]

assert exists(text)
text_max_length = text.shape[-1]

if not exists(text_lens):
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)

text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')

prompt = self.process_prompt(prompt)
prompt_enc = self.prompt_enc(prompt)
phoneme_enc = self.phoneme_enc(text)

# process pitch

if not exists(pitch):
assert exists(audio) and audio.ndim == 2
assert exists(self.target_sample_hz)

pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
pitch = rearrange(pitch, 'b n -> b 1 n')

# process mel

if not exists(mel):
assert exists(audio) and audio.ndim == 2

mel = self.audio_to_mel(audio)
mel = mel[..., :text_max_length]

mel_max_length = mel.shape[-1]

if not exists(mel_lens):
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)

mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')

# alignment

aln_hard, aln_soft, aln_log, aln_mas = self.aligner(phoneme_enc, text_mask, mel, mel_mask)
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)

pitch = average_over_durations(pitch, aln_hard)
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mas, 'b n c -> b 1 n c'), pitch)

# pitch and duration loss

duration_loss = F.l1_loss(aln_hard, duration_pred)

pitch = rearrange(pitch, 'b 1 d -> b d')
pitch_loss = F.l1_loss(pitch, pitch_pred)

# weigh the losses

aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight

return prompt_enc, cond, aux_loss

def expand_encodings(self, phoneme_enc, attn, pitch):
expanded_dur = einsum('k l m n, k j m -> k j n', attn, phoneme_enc)
pitch_emb = self.pitch_emb(rearrange(f0_to_coarse(pitch), 'b 1 t -> b t'))
Expand All @@ -1430,29 +1377,25 @@ def sample(
prompt = None,
batch_size = 1,
cond_scale = 1.,
pitch = None,
text = None,
text_lens = None,
mel = None,
mel_lens = None,
):
sample_fn = self.ddpm_sample if not self.use_ddim else self.ddim_sample

prompt = self.process_prompt(prompt)

prompt_enc = cond = None

if self.conditional:
assert exists(mel)

prompt_enc, cond, _ = self.process_conditioning(
prompt = prompt,
text = text,
pitch = pitch,
mel = mel,
text_lens = text_lens,
mel_lens = mel_lens
)
assert exists(prompt) and exists(text)
prompt = self.process_prompt(prompt)
prompt_enc = self.prompt_enc(prompt)
phoneme_enc = self.phoneme_enc(text)

duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
pitch = rearrange(pitch, 'b n -> b 1 n')

aln_mask = generate_mask_from_lengths(duration).float()

cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)

if exists(prompt):
batch_size = prompt.shape[0]
Expand Down Expand Up @@ -1494,15 +1437,61 @@ def forward(
duration_pitch_loss = 0.

if self.conditional:
prompt_enc, cond, duration_pitch_loss = self.process_conditioning(
audio = audio,
prompt = prompt,
text = text,
pitch = pitch,
mel = mel,
text_lens = text_lens,
mel_lens = mel_lens
)
batch = prompt.shape[0]

assert exists(text)
text_max_length = text.shape[-1]

if not exists(text_lens):
text_lens = torch.full((batch,), text_max_length, device = self.device, dtype = torch.long)

text_mask = rearrange(create_mask(text_lens, text_max_length), 'b n -> b 1 n')

prompt = self.process_prompt(prompt)
prompt_enc = self.prompt_enc(prompt)
phoneme_enc = self.phoneme_enc(text)

# process pitch

if not exists(pitch):
assert exists(audio) and audio.ndim == 2
assert exists(self.target_sample_hz)

pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
pitch = rearrange(pitch, 'b n -> b 1 n')

# process mel

if not exists(mel):
assert exists(audio) and audio.ndim == 2
mel = self.audio_to_mel(audio)
mel = mel[..., :pitch.shape[-1]]

mel_max_length = mel.shape[-1]

if not exists(mel_lens):
mel_lens = torch.full((batch,), mel_max_length, device = self.device, dtype = torch.long)

mel_mask = rearrange(create_mask(mel_lens, mel_max_length), 'b n -> b 1 n')

# alignment

aln_hard, aln_soft, aln_log, aln_mas = self.aligner(phoneme_enc, text_mask, mel, mel_mask)
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)

pitch = average_over_durations(pitch, aln_hard)
cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mas, 'b n c -> b 1 n c'), pitch)

# pitch and duration loss

duration_loss = F.l1_loss(aln_hard, duration_pred)

pitch = rearrange(pitch, 'b 1 d -> b d')
pitch_loss = F.l1_loss(pitch, pitch_pred)

# weigh the losses

aux_loss = duration_loss * self.duration_loss_weight + pitch_loss + self.pitch_loss_weight

# automatically encode raw audio to residual vq with codec

Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.48'
__version__ = '0.0.49'

0 comments on commit 44d1a1f

Please sign in to comment.