Skip to content

Commit

Permalink
support sampling with --labels=False
Browse files Browse the repository at this point in the history
  • Loading branch information
heewooj committed Jun 10, 2020
1 parent 7939619 commit c54d631
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 2 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,46 @@ you changed any hps directly in the command line script (eg:`heads`), make sure
that `make_models` restores our checkpoint correctly.
- Run sample.py as outlined in the sampling section, but now with `--model=my_model`

For example, let's say we trained `small_vqvae`, `small_prior`, and `small_upsampler` under `/path/to/jukebox/logs`. In `make_models.py`, we are going to declare a tuple of the new models as `my_model`.
```
MODELS = {
'5b': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b"),
'5b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_5b_lyrics"),
'1b_lyrics': ("vqvae", "upsampler_level_0", "upsampler_level_1", "prior_1b_lyrics"),
'my_model': ("my_small_vqvae", "my_small_upsampler", "my_small_prior"),
}
```

Next, in `hparams.py`, we add them to the registry with the corresponding `restore_`paths and any other command line options used during training. Another important note is that for top-level priors with lyric conditioning, we have to locate a self-attention layer that shows alignment between the lyric and music tokens. Look for layers where `prior.prior.transformer._attn_mods[layer].attn_func` is either 6 or 7. If your model is starting to sing along lyrics, it means some layer, head pair has learned alignment. Congrats!
```
my_small_vqvae = Hyperparams(
restore_vqvae='/path/to/jukebox/logs/small_vqvae/checkpoint_some_step.pth.tar',
)
my_small_vqvae.update(small_vqvae)
HPARAMS_REGISTRY["my_small_vqvae"] = my_small_vqvae
my_small_prior = Hyperparams(
restore_prior='/path/to/jukebox/logs/small_prior/checkpoint_latest.pth.tar',
level=1,
labels=False,
# TODO For the two lines below, if `--labels` was used and the model is
# trained with lyrics, find and enter the layer, head pair that has learned
# alignment.
alignment_layer=47,
alignment_head=0,
)
my_small_prior.update(small_prior)
HPARAMS_REGISTRY["my_small_prior"] = my_small_prior
my_small_upsampler = Hyperparams(
restore_prior='/path/to/jukebox/logs/small_upsampler/checkpoint_latest.pth.tar',
level=0,
labels=False,
)
my_small_upsampler.update(small_upsampler)
HPARAMS_REGISTRY["my_small_upsampler"] = my_small_upsampler
```

#### Train with labels
To train with you own metadata for your audio files, implement `get_metadata` in `data/files_dataset.py` to return the
`artist`, `genre` and `lyrics` for a given audio file. For now, you can pass `''` for lyrics to not use any lyrics.
Expand Down
19 changes: 19 additions & 0 deletions jukebox/data/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ def get_relevant_lyric_tokens(full_tokens, n_tokens, total_length, offset, durat
assert tokens == [full_tokens[index] if index != -1 else 0 for index in indices]
return tokens, indices

class EmptyLabeller():
def get_label(self, artist=None, genre=None, lyrics=None, total_length=None, offset=None):
y = np.array([0], dtype=np.int64)
info = dict(artist="n/a", genre="n/a", lyrics=[], full_tokens=[])
return dict(y=y, info=info)

def get_batch_labels(self, metas, device='cpu'):
ys, infos = [], []
for meta in metas:
label = self.get_label()
y, info = label['y'], label['info']
ys.append(y)
infos.append(info)

ys = t.stack([t.from_numpy(y) for y in ys], dim=0).to(device).long()
assert ys.shape[0] == len(metas)
assert len(infos) == len(metas)
return dict(y=ys, info=infos)

class Labeller():
def __init__(self, max_genre_words, n_tokens, sample_length, v3=False):
self.ag_processor = ArtistGenreProcessor(v3)
Expand Down
6 changes: 5 additions & 1 deletion jukebox/prior/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jukebox.transformer.ops import LayerNorm
from jukebox.prior.autoregressive import ConditionalAutoregressive2D
from jukebox.prior.conditioners import Conditioner, LabelConditioner
from jukebox.data.labels import Labeller
from jukebox.data.labels import EmptyLabeller, Labeller

from jukebox.utils.torch_utils import assert_shape
from jukebox.utils.dist_utils import print_once
Expand Down Expand Up @@ -131,11 +131,15 @@ def __init__(self, z_shapes, l_bins, encoder, decoder, level,
if labels:
self.labels_v3 = labels_v3
self.labeller = Labeller(self.y_emb.max_bow_genre_size, self.n_tokens, self.sample_length, v3=self.labels_v3)
else:
self.labeller = EmptyLabeller()

print(f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}")


def get_y(self, labels, start, get_indices=False):
if isinstance(self.labeller, EmptyLabeller):
return None
y = labels['y'].clone()

# Set sample_length to match this level
Expand Down
3 changes: 2 additions & 1 deletion jukebox/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jukebox.utils.dist_adapter as dist

from jukebox.hparams import Hyperparams
from jukebox.data.labels import EmptyLabeller
from jukebox.utils.torch_utils import empty_cache
from jukebox.utils.audio_utils import save_wav, load_audio
from jukebox.make_models import make_model
Expand Down Expand Up @@ -114,7 +115,7 @@ def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
os.makedirs(logdir)
t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
save_wav(logdir, x, hps.sr)
if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0:
if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
save_html(logdir, x, zs, labels[-1], alignments, hps)
return zs
Expand Down

0 comments on commit c54d631

Please sign in to comment.