# Training MusicLM model
### Robert Chen, Ahmadsho Akdodshoev, Philip Timofeev

## 0. Imports

In [71]:
!pip install musiclm-pytorch



In [72]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, \
                            AudioSpectrogramTransformer, TextTransformer, MusicLM
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer, \
                            CoarseTransformer, CoarseTransformerTrainer, \
                            FineTransformer, FineTransformerTrainer, \
                            AudioLM, HubertWithKmeans, MusicLMSoundStream, \
                            SoundStreamTrainer, SoundStream 
import os
from scipy.io.wavfile import read as read_wav
import urllib.request
import pandas as pd
import numpy as np
import audio2numpy as a2n
from x_clip.tokenizer import tokenizer

## 1. Creating dataloaders and downloading Hubert K-means checkpoints

Creating the dataset

In [73]:
dataset_path = '/kaggle/input/musiclm-test/music-lm/data/dataset.tsv'
audio_path = '/kaggle/input/musiclm-test/music-lm/data/'


class MusicLMDataset(Dataset):
    def __init__(self, path: str):
        self.df = pd.read_csv(path, sep='\t')
        self.filenames = list(map(lambda x: audio_path + x, self.df['filename']))
        self.authors = self.df['author']
        self.years = self.df['year']
    def __getitem__(self, idx):
        return torch.tensor(a2n.audio_from_file(self.filenames[idx])[0]), tokenizer.tokenize([self.authors[idx], self.years[idx]]).reshape(-1)
    def __len__(self):
        return len(self.filenames)
    
train_dataset = MusicLMDataset(dataset_path)
train_dataloader = DataLoader(train_dataset, batch_size=3)
tmp = next(iter(train_dataloader))
print(tmp[1].shape)

torch.Size([3, 10])


Downloading Hubert checkpoints

In [74]:
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = 'hubert/hubert_base_ls960_L9_km500.bin'
soundstream_ckpt = './results/soundstream.pt'
mulan_ckpt = './results/mulan.pt'

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

## 2. Training MuLaN

Arguments for every module are defined in the respective dictionaries to make fine-tuning easier

In [75]:
AUDIO_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'accept_spec': True,
    'dim_head': 64,
    'spec_n_fft': 128,
    'spec_win_length': 24,
    'spec_aug_stretch_factor': 0.8,
    'patch_dropout_prob': 0.
}

TEXT_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'dim_head': 64
}

MULAN_KWARGS = {
    'dataset': train_dataset,
    'num_train_steps': 10,
    'batch_size': 16,
    'force_clear_prev_results': False,
    'save_model_every': 5
}

MULAN_QUANTIZER_KWARGS = {
    'conditioning_dims': (1024, 1024, 1024),
    'namespaces': ('semantic', 'coarse', 'fine')
}

HUBERT_KWARGS = {
    'checkpoint_path': hubert_ckpt,
    'kmeans_path': hubert_quantizer
}

SOUNDSTREAM_TRAINER_KWARGS = {
    'folder': audio_path,
    'num_train_steps': 20,
    'save_model_every': 2,
    'batch_size': 4,
    'data_max_length_seconds': 60
}
    
SEMANTIC_KWARGS = {
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

COARSE_KWARGS = {
    'codebook_size': 1024,
    'num_coarse_quantizers': 4,
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

FINE_KWARGS = {
    'codebook_size': 1024,
    'num_coarse_quantizers': 4,
    'num_fine_quantizers': 8,
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

TRANSFORMER_TRAINER_KWARGS = {
    'folder': audio_path,
    'num_train_steps': 10,
    'save_model_every': 2,
    'batch_size': 4,
    'data_max_length': 320 * 32
}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

Training MuLaN

In [76]:
import copy
from math import sqrt
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import wraps, partial

from typing_extensions import Annotated

from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is
from beartype.typing import Union, List, Optional, Tuple, Callable

from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

from lion_pytorch import Lion

from musiclm_pytorch import MuLaN

from einops import rearrange

from accelerate import Accelerator

# for automatically routing data emitted from a dataset to keywords of the transformer wrappers

DATASET_FIELD_TYPE_CONFIG = dict(
    wavs = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
    ],
    raw_texts = List[str],
    texts = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.long and t.ndim == 2]
    ],
)

# helpers

def exists(val):
    return val is not None

def default(*args):
    for arg in args:
        if exists(arg):
            return arg
    return None

def noop(*args, **kwargs):
    pass

def cycle(dl):
    while True:
        for data in dl:
            yield data

def cast_tuple(t):
    return t if isinstance(t, (tuple, list)) else (t,)

def yes_or_no(question):
    answer = input(f'{question} (y/n) ')
    return answer.lower() in ('yes', 'y')

def accum_log(log, new_logs):
    for key, new_value in new_logs.items():
        old_value = log.get(key, 0.)
        log[key] = old_value + new_value
    return log

# auto data to module keyword argument routing functions

def has_duplicates(tup):
    counts = dict()
    for el in tup:
        if el not in counts:
            counts[el] = 0
        counts[el] += 1
    return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

# optimizer functions

def separate_weight_decayable_params(params):
    wd_params, no_wd_params = [], []
    for param in params:
        param_list = no_wd_params if param.ndim < 2 else wd_params
        param_list.append(param)
    return wd_params, no_wd_params

# dataloader functions

def collate_one_or_multiple_tensors(fn):
    @wraps(fn)
    def inner(data):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            data = torch.stack(data)
            return (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(outputs)

    return inner

@collate_one_or_multiple_tensors
def curtail_to_shortest_collate(data):
    min_len = min(*[datum.shape[0] for datum in data])
    data = [datum[:min_len] for datum in data]
    return torch.stack(data)

@collate_one_or_multiple_tensors
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)

# semantic transformer trainer

@beartype
class MuLaNTrainer(nn.Module):
    def __init__(
        self,
        mulan: MuLaN,
        dataset: Dataset,
        *,
        num_train_steps = None,
        batch_size,
        data_max_length = None,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        betas = (0.9, 0.99),
        max_grad_norm = 0.5,
        valid_frac = 0.05,
        random_split_seed = 42,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()
        assert batch_size > 1, 'batch size must be greater than 1 for contrastive learning (but ideally as large as possible)'

        self.accelerator = Accelerator(**accelerate_kwargs)

        self.mulan = mulan

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = default(num_train_steps, len(dataset)) # 1 epoch by default
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # optimizers

        optim_klass = Lion if use_lion else Adam
        self.optim = optim_klass(mulan.parameters(), lr = lr, betas = betas)

        # max grad norm

        self.max_grad_norm = max_grad_norm

        self.data_max_length = data_max_length

        # create dataset

        self.ds = dataset
        self.ds_fields = None

        # split for validation

        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # dataloader

        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True)

        # prepare with accelerator

        (
            self.mulan,
            self.optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.mulan,
            self.optim,
            self.dl,
            self.valid_dl
        )

        # dataloader iterators

        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        self.save_model_every = save_model_every

        hps = dict(
            num_train_steps = num_train_steps,
            data_max_length = data_max_length,
            learning_rate = lr
        )

        self.accelerator.init_trackers("mulan", config = hps)

        # results folder

        self.results_folder = Path(results_folder)

        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)

        # to device

        self.mulan.to(self.device)

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.mulan),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        mulan = self.accelerator.unwrap_model(self.mulan)
        mulan.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
        self.accelerator.print(msg)

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    def data_tuple_to_kwargs(self, data):
        data_kwargs = {'wavs': data[0], 'texts': data[1]}

        return data_kwargs

    def train_step(self):
        device = self.device

        steps = int(self.steps.item())

        self.mulan.train()

        # logs

        logs = {}

        # update vae (generator)

        for _ in range(self.grad_accum_every):
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            loss = self.mulan(**data_kwargs)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.mulan.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step = steps)

        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            model_path = str(self.results_folder / f'mulan.{steps}.pt')
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    def train(self, log_fn: Callable = noop):

        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')

In [77]:
audio_transformer = AudioSpectrogramTransformer(**AUDIO_KWARGS)

text_transformer = TextTransformer(**TEXT_KWARGS)

mulan = MuLaN(audio_transformer, text_transformer)

mulan_trainer = MuLaNTrainer(mulan, **MULAN_KWARGS)

mulan_trainer.train()

mulan_trainer.save(mulan_ckpt)

del mulan_trainer

training with dataset of 7214 samples and validating with randomly splitted 380 samples
0: loss: nan
0: saving model to results
1: loss: nan
2: loss: nan
3: loss: nan
4: loss: nan
5: loss: nan
5: saving model to results
6: loss: nan
7: loss: nan
8: loss: nan
9: loss: nan
training complete


## 3. Training SoundStream

In [78]:
soundstream = MusicLMSoundStream()
soundstream_trainer = SoundStreamTrainer(
    soundstream,
    **SOUNDSTREAM_TRAINER_KWARGS
)

soundstream_trainer.train()

soundstream_trainer.save(soundstream_ckpt)

del soundstream_trainer

AssertionError: only one Trainer can be instantiated at a time for training

## 4. Training conditioning embeddings

Defining the MuLaN Embed Quantizer and Hubert K-means Embedder

In [None]:
quantizer = MuLaNEmbedQuantizer(
    mulan=mulan,                         
    **MULAN_QUANTIZER_KWARGS
)

wav2vec = HubertWithKmeans(
    **HUBERT_KWARGS
)

Training Semantic Transformer

In [None]:
semantic_transformer = SemanticTransformer(
   num_semantic_tokens=wav2vec.codebook_size,
   **SEMANTIC_KWARGS 
).to(DEVICE)

semantic_trainer = SemanticTransformerTrainer(
    wav2vec,
    semantic_transformer,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

semantic_trainer.train()

del semantic_trainer

Training Coarse Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    **COARSE_KWARGS
).to(DEVICE)

coarse_trainer = CoarseTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

coarse_trainer.train()

del coarse_trainer

Training Fine Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

fine_transformer = FineTransformer(
    codebook_size=wav2vec.codebook_size,
    **FINE_KWARGS
).to(DEVICE)

fine_trainer = FineTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

fine_trainer.train()

del fine_trainer

## 5. Combining AudioLM and MusicLM

In [None]:
audio_lm = AudioLM(
    wav2vec=wav2vec,
    codec=soundstream,
    semantic_transformer=semantic_transformer,
    coarse_transformer=coarse_transformer,
    fine_transformer=fine_transformer   
)

In [None]:
music_lm = MusicLM(
    audio_lm=audio_lm,
    mulan_embed_quantizer=quantizer
)

music = music_lm('Café Au Lait 1970s', num_samples=3)

In [None]:
torch.save(music, 'generated_music.pt')

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu(), sample_rate)