Skip to content

Commit

Permalink
if one uses -1 for padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 24, 2023
1 parent 93a825d commit 082c184
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 9 additions & 3 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,11 +1052,17 @@ def forward(

seq_mask = mask

if not exists(seq_mask) and exists(self.pad_id):
seq_mask = (x != self.pad_id).any(dim = -1)
elif not exists(seq_mask):
if not exists(seq_mask):
seq_mask = torch.ones((b, n), device = device, dtype = torch.bool)

if exists(self.pad_id):
pad_mask = (x == self.pad_id).any(dim = -1)
seq_mask = seq_mask & ~pad_mask

if self.pad_id < 0:
# if using say -1 for padding
x = torch.where(rearrange(pad_mask, 'b n -> b n 1'), 0, x)

# maybe condition

cond_tokens = self.maybe_get_condition(cond_semantic_token_ids, length = x.shape[-2])
Expand Down

0 comments on commit 082c184

Please sign in to comment.