From f4da1cff884bf2871f10805034272984b4f79a66 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 18 May 2023 14:25:53 -0700 Subject: [PATCH] ability to generate in seconds --- README.md | 5 +++-- setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 14 +++++++++++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 62914f7..c22d59b 100644 --- a/README.md +++ b/README.md @@ -95,14 +95,15 @@ loss.backward() # and now you can generate state-of-the-art speech -generated_audio = model.generate(1024, batch_size = 2) # generated audio is also a raw wave if soundstream is present +generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above) ``` ## Todo - [x] integrate soundstream +- [x] when generating, and length can be defined in seconds (takes into sampling freq etc) -- [ ] when generating, make sure it can return audio file, and length can be defined in seconds (takes into sampling freq etc) +- [ ] option to return list of audio files when generating - [ ] turn it into a command line tool - [ ] add cross attention and adaptive layernorm conditioning (just copy paste in the entire conformer repository, if conditioning adds too much cruft to the other repo) - [ ] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension diff --git a/setup.py b/setup.py index cfaa66a..8eb689a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index fed6bf8..d552523 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -276,13 +276,21 @@ def __init__( @torch.no_grad() def generate( self, - max_seq_len, + seq_len = None, + *, + seconds = None, batch_size = None, start_temperature = 1., filter_thres = 0.7, noise_level_scale = 1., **kwargs ): + assert exists(seq_len) ^ exists(seconds) + + if not exists(seq_len): + assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds' + seq_len = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of + sample_one = not exists(batch_size) batch_size = default(batch_size, 1) @@ -295,14 +303,14 @@ def generate( # sequence starts off as all masked - shape = (batch_size, max_seq_len) + shape = (batch_size, seq_len) seq = torch.full(shape, self.mask_id, device = device) mask = torch.full(shape, True, device = device) # slowly demask - all_mask_num_tokens = (self.schedule_fn(times[1:]) * max_seq_len).long() + all_mask_num_tokens = (self.schedule_fn(times[1:]) * seq_len).long() # self conditioning