Skip to content

Commit

Permalink
add masking, but give each quantizer position its own positional mask…
Browse files Browse the repository at this point in the history
…ing token
  • Loading branch information
lucidrains committed May 17, 2023
1 parent 7e85f47 commit dde9ec9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
21 changes: 19 additions & 2 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,33 @@ def __init__(
if isinstance(conformer, dict):
self.conformer = Conformer(**self.conformer)

dim = self.conformer.dim

self.mask_tokens = nn.Parameter(torch.randn(num_tokens_reduce, dim))

self.num_tokens_reduce = num_tokens_reduce
self.num_tokens_per_head = default(num_tokens_per_head, num_tokens_reduce)

dim = self.conformer.dim

self.heads = nn.Sequential(
nn.Linear(dim, dim * self.num_tokens_per_head),
Rearrange('b n (h d) -> b (n h) d', h = self.num_tokens_per_head)
)

def add_mask_tokens(
self,
x,
mask
):
h = self.num_tokens_reduce

x = torch.where(
rearrange(mask, 'b (n h) -> b n h 1', h = h),
rearrange(x, 'b (n h) d -> b n h d', h = h),
self.mask_tokens,
)

return rearrange(x, 'b n h d -> b (n h) d')

def forward(
self,
x,
Expand Down

0 comments on commit dde9ec9

Please sign in to comment.