Skip to content

jeremyjordan/midi-lm

Repository files navigation

MIDI Language Model

Generative modeling of MIDI files

When it comes to modeling music, there's two general approaches:

I'm mostly interested in modeling symbolic representations because it preserves editability post-generation. As a composer, you could go back and make small edits (e.g. shift the pitch of an individual note), repeat certain sections (e.g. copy and paste a set of notes), or easily modify the instrument for a given track (e.g. turn the piano track into a trumpet track).

Another useful aspect of symbolic representations is that they're very data-efficient. For example, if we looked at the MAESTRO dataset, the raw audio representation would take 120GB of storage whereas the symbolic representation (MIDI) only takes 81MB of storage (perhaps the better comparison would be the number of tokens needed to represent each type of sequence, but this can vary depending on your tokenization strategy).

Project structure

This project is set up to encourage easy experimentation by providing a consistent training structure alongside interfaces for each of the axes which I might wish to experiment. The main interfaces are defined as:

  • dataset: a PyTorch Lightning data module which encapsulates training/validation datasets and dataloaders
  • network: the underlying torch.nn.Module architecture
  • model: a PyTorch Lightning module which defines training and validation steps for a given network
  • tokenizer: a object which can convert muspy.Music objects into a dictionary of tensors and vice versa
  • transforms: a set of function which can modify a muspy.Music object to do things like cropping and transposing

These various objects are assembled in midi_lm/train.py according to the configuration passed in by hydra from a command line invocation.

Supported datasets

I've added multiple different MIDI datasets of varying complexity (from basic scales to full orchestral pieces) along with networks of varying capacity so I can do some exploration of ideas at a small scale before ramping up the compute cost to train larger models on bigger datasets.

Dataset Description Train (file count) Validation (file count)
Eighth Notes One measure of eighth notes for 12 different pitches 6 6
Scales 12 major and 12 minor scales 20 4
JSB Chorales A collection of 382 four-part chorales by Johann Sebastian Bach 229 76
NES Songs from the soundtracks of 397 NES games 4441* 395*
MAESTRO MIDI recordings from ten years of International Piano-e-Competition 962 137
SymphonyNet A collection of classical and contemporary symphonic compositions 37088 9272
GiantMIDI Piano A classical piano MIDI dataset of 2,786 different composers 8682* 2170*

*Count after filtering out files from the original dataset which don't meet data quality thresholds for a minimum number of beats, tracks, etc.

You can see more information about these datasets in this Weights and Biases report.

Supported model architectures

I aim to support a mixture of reference implementations from the literature alongside various ideas that I'm exploring.

  • Multitrack Music Transformer (paper, code): a standard Transformer architecture which represents MIDI files as a sequence of (event_type, beat, position, pitch, duration, instrument) tokens. This model has 6 input/output heads, one for each different token type.
  • Multi-head transformer: a slightly more generic version architecture which supports arbitrary input/output heads which are merged and passed into a decoder-only Transformer model. This architecture can be used for other tokenization strategies such as the (time-shift, pitch, duration) tokenizer for single-track MIDI files.
  • Structured transformer: a standard decoder-only Transformer architecture (only one input/output head) which enforces an explicit sampling structure during inference, this model expects tokens to appear in repeating sets of (pitch, velocity, duration, time_shift) tokens.

Local setup

Initialize the environment

make bootstrap

Install the requirements

make install

Training a model

Model training can be kicked off from the command line using Hydra. Hydra provides a lot of control over how you compose configurations from the command line. These configurations are defined in midi_lm/config/ as dataclass objects. These dataclass objects are given short-names in the "config store" defined in midi_lm/config/__init__.py.

Example local runs:

train compute=local logger=wandb-test trainer=mps \
    tokenizer=mmt model=mmt network=mmt-1m \
    dataset=scales transforms=crop \
    trainer.max_epochs=20
train compute=local logger=wandb-test trainer=mps \
    tokenizer=mmt model=mmt network=mmt-1m \
    dataset=bach transforms=crop_transpose
train compute=local logger=wandb-test trainer=mps \
    tokenizer=mmt model=mmt network=mmt-1m \
    dataset=nes transforms=crop-transpose

Example remote run:

train compute=a10g logger=wandb trainer=gpu \
    tokenizer=mmt model=mmt network=mmt-7m \
    dataset=maestro transforms=crop-transpose
train compute=a10g logger=wandb-test trainer=gpu \
    tokenizer=mmt model=mmt network=mmt-7m \
    dataset=nes transforms=crop-transpose \
    +trainer.profiler="simple" +trainer.max_steps=10

Hydra override syntax examples:

train optimizer.lr=0.0012                          # override the default for a value in the config dataclass
train optimizer=adamw +optimizer.amsgrad=True      # add a new value that wasn't tracked in the config dataclass

If you want to debug the configuration, add -c job to the end of your command. It will print the Hydra config instead of running the job.

You can see the available config options to choose from by running train --help. The config options are also shown below for convenience.

== Configuration groups ==
Compose your configuration from those groups (group=option)

collator: multi-seq-dict
compute: a100, a10g, cpu, h100, local
dataset: bach, giantmidi, maestro, nes, scales, symphony-net
logger: tensorboard, wandb, wandb-test
lr_scheduler: cosine, plateau
model: mmt, multihead-transformer, structured
network: mmt-1m, mmt-20m, mmt-7m, structured-1m, structured-20m, structured-7m, tpd-19m, tpd-1m, tpd-6m
optimizer: adam, adamw, sgd
tokenizer: mmt, structured, tpd
trainer: cpu, gpu, mps, smoke-test
transforms: crop, crop-transpose

Resuming a training run

You can also continue training from a checkpoint saved in Weights and Biases using the resume command. This command assumes that you have a checkpoint saved as a Weights and Biases artifact. It will load the configuration from the run associated with the provided artifact, and then can optionally provide a set of overrides passed from the command line. Anything after the -- (end of options delimiter) is treated as a Hydra override command. You can use this to update certain configurations from the original run which you would like to change, such as training for more epochs.

resume user/project/model:tag -- trainer.max_epochs=20 dataset.batch_size=32

By default, all of the model and training states will be loaded from the checkpoint. This includes states such as the optimizer and learning rate schedulers, as well as the current epoch and global step for the trainer. If you wish to resume training with a new learning rate, for example, you must include the --clean flag. This will load the model weights from the specified checkpoint file, but all other states will be reset.

resume user/project/model:tag --clean -- trainer.max_epochs=20

Releases

No releases published

Packages

No packages published