Skip to content

Commit

Permalink
condition latent features with aligned conditions prior to wavenet stack
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 18, 2023
1 parent e84d0a1 commit 091e603
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
58 changes: 54 additions & 4 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def has_int_squareroot(num):

# tensor helpers

def pad_or_curtail_to_length(t, length):
if t.shape[-1] == length:
return t

if t.shape[-1] > length:
return t[..., :length]

return F.pad(t, (0, length - t.shape[-1]))

def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
Expand Down Expand Up @@ -834,6 +843,7 @@ def __init__(
)

# prompt condition

self.cond_drop_prob = cond_drop_prob # for classifier free guidance
self.condition_on_prompt = condition_on_prompt
self.to_prompt_cond = None
Expand Down Expand Up @@ -861,6 +871,15 @@ def __init__(
use_flash_attn = use_flash_attn
)

# aligned conditioning from aligner + duration module

self.null_cond = None
self.cond_to_model_dim = None

if self.condition_on_prompt:
self.cond_to_model_dim = nn.Conv1d(dim_prompt, dim, 1)
self.null_cond = nn.Parameter(torch.zeros(dim, 1))

# conditioning includes time and optionally prompt

dim_cond_mult = dim_cond_mult * (2 if condition_on_prompt else 1)
Expand Down Expand Up @@ -913,23 +932,27 @@ def forward(
times,
prompt = None,
prompt_mask = None,
cond= None,
cond = None,
cond_drop_prob = None
):
b = x.shape[0]
cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)
# prepare prompt condition
# prob should remove going forward

t = self.to_time_cond(times)
c = None

if exists(self.to_prompt_cond):
assert exists(prompt)

prompt_cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)

prompt_cond = self.to_prompt_cond(prompt)

prompt_cond = torch.where(
rearrange(drop_mask, 'b -> b 1'),
rearrange(prompt_cond_drop_mask, 'b -> b 1'),
self.null_prompt_cond,
prompt_cond,
)
Expand All @@ -939,12 +962,37 @@ def forward(
resampled_prompt_tokens = self.perceiver_resampler(prompt, mask = prompt_mask)

c = torch.where(
rearrange(drop_mask, 'b -> b 1 1'),
rearrange(prompt_cond_drop_mask, 'b -> b 1 1'),
self.null_prompt_tokens,
resampled_prompt_tokens
)

# rearrange to channel first

x = rearrange(x, 'b n d -> b d n')

# sum aligned condition to input sequence

if exists(self.cond_to_model_dim):
assert exists(cond)
cond = self.cond_to_model_dim(cond)

cond_drop_mask = prob_mask_like((b,), cond_drop_prob, self.device)

cond = torch.where(
rearrange(cond_drop_mask, 'b -> b 1 1'),
self.null_cond,
cond
)

# for now, conform the condition to the length of the latent features

cond = pad_or_curtail_to_length(cond, x.shape[-1])

x = x + cond

# main wavenet body

x = self.wavenet(x, t)
x = rearrange(x, 'b d n -> b n d')

Expand Down Expand Up @@ -1527,6 +1575,7 @@ def forward(
duration_pred, pitch_pred = self.duration_pitch(phoneme_enc, prompt_enc)

pitch = average_over_durations(pitch, aln_hard)

cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)

# pitch and duration loss
Expand All @@ -1536,6 +1585,7 @@ def forward(
pitch = rearrange(pitch, 'b 1 d -> b d')
pitch_loss = F.l1_loss(pitch, pitch_pred)
align_loss = self.aligner_loss(aln_log , text_lens, mel_lens)

# weigh the losses

aux_loss = (duration_loss * self.duration_loss_weight) \
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.6'
__version__ = '0.1.7'

0 comments on commit 091e603

Please sign in to comment.