Skip to content

Commit

Permalink
Merge pull request #3 from openai/primed_sampling_option
Browse files Browse the repository at this point in the history
Add primed sampling option
  • Loading branch information
heewooj committed Apr 30, 2020
2 parents 8081004 + 75f8e4e commit 393d02b
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions jukebox/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jukebox.hparams import Hyperparams
from jukebox.utils.torch_utils import empty_cache
from jukebox.utils.io import save_wav
from jukebox.utils.io import save_wav, load_audio
from jukebox.make_models import make_model
from jukebox.align import get_alignment
from jukebox.save_html import save_html
Expand Down Expand Up @@ -128,12 +128,19 @@ def upsample(zs, labels, sampling_kwargs, priors, hps):
# Primed sample
def primed_sample(x, labels, sampling_kwargs, priors, hps):
sample_levels = list(range(len(priors)))
# TODO: Confirm if vqvae encode supports shorter length segments
zs = priors[-1].encode(x, start_level=0, end_level=len(priors))
zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
return zs

def save_samples(model, device, hps):
def load_prompt(audio_files, sr, duration):
xs = []
for audio_file in audio_files:
x, _ = load_audio(audio_file, sr=sr, duration=duration, offset=0.0)
x = x.T.mean(1, keepdims=True)
xs.append(x)
return xs

def save_samples(model, device, hps, sample_hps):
print(hps)
from jukebox.lyricdict import poems, gpt_2_lyrics
vqvae, priors = make_model(model, device, hps)
Expand Down Expand Up @@ -193,16 +200,32 @@ def save_samples(model, device, hps):
dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
dict(temp=0.99, fp16=True, chunk_size=chunk_size, max_batch_size=max_batch_size)]

ancestral_sample(labels, sampling_kwargs, priors, hps)
if sample_hps.mode == 'ancestral':
ancestral_sample(labels, sampling_kwargs, priors, hps)
elif sample_hps.mode == 'primed':
top_hop_length = vqvae.hop_lengths[-1]
duration = int(sample_hps.prompt_length_in_seconds * hps.sr) // top_hop_length * top_hop_length
assert sample_hps.audio_file is not None
audio_files = sample_hps.audio_file.split(',')
xs = load_prompt(audio_files, hps.sr, duration)
while len(xs) < hps.n_samples:
xs.extend(xs)
xs = xs[:hps.n_samples]
x = t.stack([t.from_numpy(x) for x in xs])
x = x.to('cuda', non_blocking=True)
primed_sample(x, labels, sampling_kwargs, priors, hps)
else:
raise ValueError(f'Unknown sample mode {mode}.')


def run(model, port=29500, **kwargs):
def run(model, mode='ancestral', audio_file=None, prompt_length_in_seconds=12.0, port=29500, **kwargs):
from jukebox.utils.dist_utils import setup_dist_from_mpi
rank, local_rank, device = setup_dist_from_mpi(port=port)
hps = Hyperparams(**kwargs)
sample_hps = Hyperparams(dict(mode=mode, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))

with t.no_grad():
save_samples(model, device, hps)
save_samples(model, device, hps, sample_hps)

if __name__ == '__main__':
fire.Fire(run)

0 comments on commit 393d02b

Please sign in to comment.