Skip to content

Commit

Permalink
add ability to train the soundstream to be denoising, as in the paper…
Browse files Browse the repository at this point in the history
…, may be needed for naturalspeech2
  • Loading branch information
lucidrains committed Apr 28, 2023
1 parent d05c020 commit bead0a3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
53 changes: 45 additions & 8 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,15 @@ def forward(self, x):

return x

class FiLM(nn.Module):
def __init__(self, dim, dim_cond):
super().__init__()
self.to_cond = nn.Linear(dim_cond, dim * 2)

def forward(self, x, cond):
gamma, beta = self.to_cond(cond).chunk(2, dim = -1)
return x * gamma + beta

class SoundStream(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -487,6 +496,8 @@ def __init__(

self.encoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

self.encoder_film = FiLM(codebook_dim, dim_cond = 2)

self.num_quantizers = rq_num_quantizers

self.codebook_dim = codebook_dim
Expand All @@ -504,6 +515,8 @@ def __init__(
quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
)

self.decoder_film = FiLM(codebook_dim, dim_cond = 2)

self.decoder_attn = LocalTransformer(**attn_kwargs) if use_local_attn else None

decoder_blocks = []
Expand Down Expand Up @@ -570,6 +583,10 @@ def __init__(

self.register_buffer('zero', torch.tensor([0.]), persistent = False)

@property
def device(self):
return next(self.parameters()).device

@property
def configs(self):
return pickle.loads(self._configs)
Expand Down Expand Up @@ -641,17 +658,33 @@ def non_discr_parameters(self):
*self.encoder.parameters(),
*self.decoder.parameters(),
*(self.encoder_attn.parameters() if exists(self.encoder_attn) else []),
*(self.decoder_attn.parameters() if exists(self.decoder_attn) else [])
*(self.decoder_attn.parameters() if exists(self.decoder_attn) else []),
*self.encoder_film.parameters(),
*self.decoder_film.parameters()
]

@property
def seq_len_multiple_of(self):
return functools.reduce(lambda x, y: x * y, self.strides)

def process_input(self, x, input_sample_hz = None):
x, ps = pack([x], '* n')

if exists(input_sample_hz):
x = resample(x, input_sample_hz, self.target_sample_hz)

x = curtail_to_multiple(x, self.seq_len_multiple_of)

if x.ndim == 2:
x = rearrange(x, 'b n -> b 1 n')

return x, ps

def forward(
self,
x,
target = None,
is_denoising = None, # if you want to learn film conditioners that teach the soundstream to denoise - target would need to be passed in above
return_encoded = False,
return_discr_loss = False,
return_discr_losses_separately = False,
Expand All @@ -660,15 +693,12 @@ def forward(
input_sample_hz = None,
apply_grad_penalty = False
):
x, ps = pack([x], '* n')
assert not (exists(is_denoising) and not exists(target))

if exists(input_sample_hz):
x = resample(x, input_sample_hz, self.target_sample_hz)
x, ps = self.process_input(x, input_sample_hz = input_sample_hz)

x = curtail_to_multiple(x, self.seq_len_multiple_of)

if x.ndim == 2:
x = rearrange(x, 'b n -> b 1 n')
if exists(target):
target, _ = self.process_input(target, input_sample_hz = input_sample_hz)

orig_x = x.clone()

Expand All @@ -679,11 +709,18 @@ def forward(
if exists(self.encoder_attn):
x = self.encoder_attn(x)

if exists(is_denoising):
denoise_input = torch.tensor([is_denoising, not is_denoising], dtype = x.dtype, device = self.device) # [1, 0] for denoise, [0, 1] for not denoising
x = self.encoder_film(x, denoise_input)

x, indices, commit_loss = self.rq(x)

if return_encoded:
return x, indices, commit_loss

if exists(is_denoising):
x = self.decoder_film(x, denoise_input)

if exists(self.decoder_attn):
x = self.decoder_attn(x)

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.28.2'
__version__ = '0.29.0'

0 comments on commit bead0a3

Please sign in to comment.