Skip to content

Commit

Permalink
ability to generate in seconds
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 18, 2023
1 parent d033fcc commit f4da1cf
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,15 @@ loss.backward()

# and now you can generate state-of-the-art speech

generated_audio = model.generate(1024, batch_size = 2) # generated audio is also a raw wave if soundstream is present
generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)
```

## Todo

- [x] integrate soundstream
- [x] when generating, and length can be defined in seconds (takes into sampling freq etc)

- [ ] when generating, make sure it can return audio file, and length can be defined in seconds (takes into sampling freq etc)
- [ ] option to return list of audio files when generating
- [ ] turn it into a command line tool
- [ ] add cross attention and adaptive layernorm conditioning (just copy paste in the entire conformer repository, if conditioning adds too much cruft to the other repo)
- [ ] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 11 additions & 3 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,21 @@ def __init__(
@torch.no_grad()
def generate(
self,
max_seq_len,
seq_len = None,
*,
seconds = None,
batch_size = None,
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
**kwargs
):
assert exists(seq_len) ^ exists(seconds)

if not exists(seq_len):
assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds'
seq_len = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of

sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)

Expand All @@ -295,14 +303,14 @@ def generate(

# sequence starts off as all masked

shape = (batch_size, max_seq_len)
shape = (batch_size, seq_len)

seq = torch.full(shape, self.mask_id, device = device)
mask = torch.full(shape, True, device = device)

# slowly demask

all_mask_num_tokens = (self.schedule_fn(times[1:]) * max_seq_len).long()
all_mask_num_tokens = (self.schedule_fn(times[1:]) * seq_len).long()

# self conditioning

Expand Down

0 comments on commit f4da1cf

Please sign in to comment.