Generative modeling of MIDI files
When it comes to modeling music, there's two general approaches:
- modeling the raw audio waveforms (AudioLM, MusicGen)
- modeling the symbolic representations (Music Transformer, Multitrack Music Transformer, SymphonyNet, MusicVAE)
I'm mostly interested in modeling symbolic representations because it preserves editability post-generation. As a composer, you could go back and make small edits (e.g. shift the pitch of an individual note), repeat certain sections (e.g. copy and paste a set of notes), or easily modify the instrument for a given track (e.g. turn the piano track into a trumpet track).
Another useful aspect of symbolic representations is that they're very data-efficient. For example, if we looked at the MAESTRO dataset, the raw audio representation would take 120GB of storage whereas the symbolic representation (MIDI) only takes 81MB of storage (perhaps the better comparison would be the number of tokens needed to represent each type of sequence, but this can vary depending on your tokenization strategy).
This project is set up to encourage easy experimentation by providing a consistent training structure alongside interfaces for each of the axes which I might wish to experiment. The main interfaces are defined as:
- dataset: a PyTorch Lightning data module which encapsulates training/validation datasets and dataloaders
- network: the underlying
torch.nn.Module
architecture - model: a PyTorch Lightning module which defines training and validation steps for a given network
- tokenizer: a object which can convert
muspy.Music
objects into a dictionary of tensors and vice versa - transforms: a set of function which can modify a
muspy.Music
object to do things like cropping and transposing
These various objects are assembled in midi_lm/train.py
according to the configuration passed in by hydra
from a
command line invocation.
I've added multiple different MIDI datasets of varying complexity (from basic scales to full orchestral pieces) along with networks of varying capacity so I can do some exploration of ideas at a small scale before ramping up the compute cost to train larger models on bigger datasets.
Dataset | Description | Train (file count) | Validation (file count) |
---|---|---|---|
Eighth Notes | One measure of eighth notes for 12 different pitches | 6 | 6 |
Scales | 12 major and 12 minor scales | 20 | 4 |
JSB Chorales | A collection of 382 four-part chorales by Johann Sebastian Bach | 229 | 76 |
NES | Songs from the soundtracks of 397 NES games | 4441* | 395* |
MAESTRO | MIDI recordings from ten years of International Piano-e-Competition | 962 | 137 |
SymphonyNet | A collection of classical and contemporary symphonic compositions | 37088 | 9272 |
GiantMIDI Piano | A classical piano MIDI dataset of 2,786 different composers | 8682* | 2170* |
*Count after filtering out files from the original dataset which don't meet data quality thresholds for a minimum number of beats, tracks, etc.
You can see more information about these datasets in this Weights and Biases report.
I aim to support a mixture of reference implementations from the literature alongside various ideas that I'm exploring.
- Multitrack Music Transformer (paper, code): a standard Transformer architecture which represents MIDI files as a sequence of (event_type, beat, position, pitch, duration, instrument) tokens. This model has 6 input/output heads, one for each different token type.
- Multi-head transformer: a slightly more generic version architecture which supports arbitrary input/output heads which are merged and passed into a decoder-only Transformer model. This architecture can be used for other tokenization strategies such as the (time-shift, pitch, duration) tokenizer for single-track MIDI files.
- Structured transformer: a standard decoder-only Transformer architecture (only one input/output head) which enforces an explicit sampling structure during inference, this model expects tokens to appear in repeating sets of (pitch, velocity, duration, time_shift) tokens.
Initialize the environment
make bootstrap
Install the requirements
make install
Model training can be kicked off from the command line using Hydra. Hydra provides a lot of control
over how you compose configurations from the command line. These configurations are defined in midi_lm/config/
as
dataclass objects. These dataclass objects are given short-names in the "config store" defined in
midi_lm/config/__init__.py
.
Example local runs:
train compute=local logger=wandb-test trainer=mps \
tokenizer=mmt model=mmt network=mmt-1m \
dataset=scales transforms=crop \
trainer.max_epochs=20
train compute=local logger=wandb-test trainer=mps \
tokenizer=mmt model=mmt network=mmt-1m \
dataset=bach transforms=crop_transpose
train compute=local logger=wandb-test trainer=mps \
tokenizer=mmt model=mmt network=mmt-1m \
dataset=nes transforms=crop-transpose
Example remote run:
train compute=a10g logger=wandb trainer=gpu \
tokenizer=mmt model=mmt network=mmt-7m \
dataset=maestro transforms=crop-transpose
train compute=a10g logger=wandb-test trainer=gpu \
tokenizer=mmt model=mmt network=mmt-7m \
dataset=nes transforms=crop-transpose \
+trainer.profiler="simple" +trainer.max_steps=10
Hydra override syntax examples:
train optimizer.lr=0.0012 # override the default for a value in the config dataclass
train optimizer=adamw +optimizer.amsgrad=True # add a new value that wasn't tracked in the config dataclass
If you want to debug the configuration, add -c job
to the end of your command. It will print the Hydra config instead
of running the job.
You can see the available config options to choose from by running train --help
. The config options are also shown
below for convenience.
== Configuration groups ==
Compose your configuration from those groups (group=option)
collator: multi-seq-dict
compute: a100, a10g, cpu, h100, local
dataset: bach, giantmidi, maestro, nes, scales, symphony-net
logger: tensorboard, wandb, wandb-test
lr_scheduler: cosine, plateau
model: mmt, multihead-transformer, structured
network: mmt-1m, mmt-20m, mmt-7m, structured-1m, structured-20m, structured-7m, tpd-19m, tpd-1m, tpd-6m
optimizer: adam, adamw, sgd
tokenizer: mmt, structured, tpd
trainer: cpu, gpu, mps, smoke-test
transforms: crop, crop-transpose
You can also continue training from a checkpoint saved in Weights and Biases using the resume
command.
This command assumes that you have a checkpoint saved as a Weights and Biases artifact. It will load the
configuration from the run associated with the provided artifact, and then can optionally provide a set of
overrides passed from the command line. Anything after the --
(end of options delimiter) is treated as a
Hydra override command. You can use this to update certain configurations from the original run which you
would like to change, such as training for more epochs.
resume user/project/model:tag -- trainer.max_epochs=20 dataset.batch_size=32
By default, all of the model and training states will be loaded from the checkpoint. This includes states
such as the optimizer and learning rate schedulers, as well as the current epoch and global step for the
trainer. If you wish to resume training with a new learning rate, for example, you must include the --clean
flag. This will load the model weights from the specified checkpoint file, but all other states will be reset.
resume user/project/model:tag --clean -- trainer.max_epochs=20