In [16]:
# %%python3 -m ipykernel install --user --name=musiclm

%pip install musiclm-pytorch torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchtext==0.14.1 torchaudio==0.13.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu117


Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117
Collecting torch==1.13.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-linux_x86_64.whl (1801.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:02[0m
[?25hCollecting torchvision==0.14.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torchvision-0.14.1%2Bcu117-cp310-cp310-linux_x86_64.whl (24.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.3/24.3 MB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchtext==0.14.1
  Downloading torchtext-0.14.1-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
Collecting torchdata==0.5.1
  Downloading torchdata-0.5.1-cp310-cp310-manylinux_2_17_x8

In [2]:
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer,
)

wavs = torch.randn(2,1024)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

embeds=mulan.get_audio_latents(wavs) # during training
# embeds=mulan.get_audio_latents(texts) # during inference

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer


In [4]:
from musiclm_pytorch import MuLaNEmbedQuantizer

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,
    conditioning_dims = (1024,1024,1024),
    namespaces = ('semantics', 'coarse', 'fine'),
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2,1024)
conds = quantizer(wavs=wavs, namespace='semantics') # (2,8,1024) - 8 is number of quantizers

In [5]:
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
    folder ='/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()