From bead0a38472438929428465f61abd3ae54165def Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 28 Apr 2023 08:55:59 -0700 Subject: [PATCH] add ability to train the soundstream to be denoising, as in the paper, may be needed for naturalspeech2 --- audiolm_pytorch/soundstream.py | 53 +++++++++++++++++++++++++++++----- audiolm_pytorch/version.py | 2 +- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/audiolm_pytorch/soundstream.py b/audiolm_pytorch/soundstream.py index ff9fc31..7c0e7d8 100644 --- a/audiolm_pytorch/soundstream.py +++ b/audiolm_pytorch/soundstream.py @@ -404,6 +404,15 @@ def forward(self, x): return x +class FiLM(nn.Module): + def __init__(self, dim, dim_cond): + super().__init__() + self.to_cond = nn.Linear(dim_cond, dim * 2) + + def forward(self, x, cond): + gamma, beta = self.to_cond(cond).chunk(2, dim = -1) + return x * gamma + beta + class SoundStream(nn.Module): def __init__( self, @@ -487,6 +496,8 @@ def __init__( self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None + self.encoder_film = FiLM(codebook_dim, dim_cond = 2) + self.num_quantizers = rq_num_quantizers self.codebook_dim = codebook_dim @@ -504,6 +515,8 @@ def __init__( quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) + self.decoder_film = FiLM(codebook_dim, dim_cond = 2) + self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None decoder_blocks = [] @@ -570,6 +583,10 @@ def __init__( self.register_buffer('zero', torch.tensor([0.]), persistent = False) + @property + def device(self): + return next(self.parameters()).device + @property def configs(self): return pickle.loads(self._configs) @@ -641,17 +658,33 @@ def non_discr_parameters(self): *self.encoder.parameters(), *self.decoder.parameters(), *(self.encoder_attn.parameters() if exists(self.encoder_attn) else []), - *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []) + *(self.decoder_attn.parameters() if exists(self.decoder_attn) else []), + *self.encoder_film.parameters(), + *self.decoder_film.parameters() ] @property def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) + def process_input(self, x, input_sample_hz = None): + x, ps = pack([x], '* n') + + if exists(input_sample_hz): + x = resample(x, input_sample_hz, self.target_sample_hz) + + x = curtail_to_multiple(x, self.seq_len_multiple_of) + + if x.ndim == 2: + x = rearrange(x, 'b n -> b 1 n') + + return x, ps + def forward( self, x, target = None, + is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above return_encoded = False, return_discr_loss = False, return_discr_losses_separately = False, @@ -660,15 +693,12 @@ def forward( input_sample_hz = None, apply_grad_penalty = False ): - x, ps = pack([x], '* n') + assert not (exists(is_denoising) and not exists(target)) - if exists(input_sample_hz): - x = resample(x, input_sample_hz, self.target_sample_hz) + x, ps = self.process_input(x, input_sample_hz = input_sample_hz) - x = curtail_to_multiple(x, self.seq_len_multiple_of) - - if x.ndim == 2: - x = rearrange(x, 'b n -> b 1 n') + if exists(target): + target, _ = self.process_input(target, input_sample_hz = input_sample_hz) orig_x = x.clone() @@ -679,11 +709,18 @@ def forward( if exists(self.encoder_attn): x = self.encoder_attn(x) + if exists(is_denoising): + denoise_input = torch.tensor([is_denoising, not is_denoising], dtype = x.dtype, device = self.device) # [1, 0] for denoise, [0, 1] for not denoising + x = self.encoder_film(x, denoise_input) + x, indices, commit_loss = self.rq(x) if return_encoded: return x, indices, commit_loss + if exists(is_denoising): + x = self.decoder_film(x, denoise_input) + if exists(self.decoder_attn): x = self.decoder_attn(x) diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 873054e..9093e4e 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.28.2' +__version__ = '0.29.0'