In [None]:
# @title **(1)** Parse MIDI piano performances into simple lists of notes

USE_PRECACHED = True  # @param{type:"boolean"}
PIANO_LOWEST_KEY_MIDI_PITCH = 21
PIANO_NUM_KEYS = 88

import gzip
import json
from collections import defaultdict

from tqdm.notebook import tqdm


def download_and_parse_maestro():
    # Install pretty_midi
    !!pip install pretty_midi
    import pretty_midi

    # Download MAESTRO dataset (Hawthorne+ 2018)
    !!wget -nc https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
    !!unzip maestro-v2.0.0-midi.zip

    # Parse MAESTRO dataset
    dataset = defaultdict(list)
    with open("maestro-v2.0.0/maestro-v2.0.0.json", "r") as f:
        for attrs in tqdm(json.load(f)):
            split = attrs["split"]
            midi = pretty_midi.PrettyMIDI("maestro-v2.0.0/" + attrs["midi_filename"])
            assert len(midi.instruments) == 1
            notes = [
                (
                    float(n.start),
                    float(n.end),
                    int(n.pitch - PIANO_LOWEST_KEY_MIDI_PITCH),
                    int(n.velocity),
                )
                for n in midi.instruments[0].notes
            ]
            notes = sorted(notes, key=lambda n: (n[0], n[2]))
            assert all(
                [
                    all(
                        [
                            # Start times should be non-negative
                            n[0] >= 0,
                            # Note duration should be positive
                            n[0] < n[1],
                            # Key index should be in range of the piano
                            0 <= n[2] and n[2] < PIANO_NUM_KEYS,
                            # Velocity should be valid
                            1 <= n[3] and n[3] < 128,
                        ]
                    )
                    for n in notes
                ]
            )
            dataset[split].append(notes)

        return dataset


if USE_PRECACHED:
    !!wget -nc https://github.com/chrisdonahue/music-cocreation-tutorial/raw/main/part-1-py-training/data/maestro-v2.0.0-simple.json.gz
    with gzip.open("maestro-v2.0.0-simple.json.gz", "rb") as f:
        DATASET = json.load(f)
else:
    DATASET = download_and_parse_maestro()
    with gzip.open("maestro-v2.0.0-simple.json.gz", "w") as f:
        f.write(json.dumps(DATASET).encode("utf-8"))

print([(s, len(DATASET[s])) for s in ["train", "validation", "test"]])

In [None]:
# @title **(2)** Define Piano Genie modules

import torch
import torch.nn as nn
import torch.nn.functional as F


class PianoGenieEncoder(nn.Module):
    """
    PianoGenieEncoder maps each performance onset (D features) to a single scalar.
    """
    def __init__(
        self,
        input_dim,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(input_dim, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=True,
        )
        self.output = nn.Linear(rnn_dim * 2, 1)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        x = self.input(x)
        h = (
            torch.zeros(
                2 * self.rnn_num_layers, batch_size, self.rnn_dim, device=x.device
            ),
            torch.zeros(
                2 * self.rnn_num_layers, batch_size, self.rnn_dim, device=x.device
            ),
        )
        x, h = self.lstm(x, h)
        x = self.output(x)
        return x[:, :, 0]


class IntegerQuantizer(nn.Module):
    """
    IntegerQuantizer independently quantizes scalar values to K values between [-1, 1].
    """
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size

    def real_to_discrete(self, x, eps=1e-6):
        x = (x + 1) / 2
        x = torch.clamp(x, 0, 1)
        x *= self.vocab_size - 1
        x = (torch.round(x) + eps).long()
        return x

    def discrete_to_real(self, x):
        x = x.float()
        x /= self.vocab_size - 1
        x = (x * 2) - 1
        return x

    def forward(self, x, output_discrete=False):
        # Quantize and compute delta (used for straight-through estimator)
        with torch.no_grad():
            x_disc = self.real_to_discrete(x)
            x_quant = self.discrete_to_real(x_disc)
            x_quant_delta = x_quant - x

        # Quantize w/ straight-through estimator
        x = x + x_quant_delta

        result = x
        if output_discrete:
            result = (x, x_disc)
        return result


class PianoGenieDecoder(nn.Module):
    """
    PianoGenieDecoder maps buttons (quantized scalars) and onset features to key logits.
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(input_dim, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=False,
        )
        self.output = nn.Linear(rnn_dim, output_dim)

    def init_hidden(self, batch_size, device=None):
        h = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_dim)
        c = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_dim)
        if device is not None:
            h = h.to(device)
            c = c.to(device)
        return (h, c)

    def forward(self, x, h=None):
        x = self.input(x)
        x, h = self.lstm(x, h)
        x = self.output(x)
        return x, h


class PianoGenieAutoencoder(nn.Module):
    """
    PianoGenieAutoencoder composes encoder, quantizer, decoder.
    """
    def __init__(self, cfg):
        super().__init__()
        self.enc = None
        if cfg["model_enc"]:
            self.enc = PianoGenieEncoder(
                # Delta time + key one-hot
                1 + PIANO_NUM_KEYS,
                rnn_dim=cfg["model_rnn_dim"],
                rnn_num_layers=cfg["model_rnn_num_layers"],
            )
            self.quant = IntegerQuantizer(cfg["num_buttons"])
        self.dec = PianoGenieDecoder(
            # Delta time + key one-hot + <SOS> + Button
            1 + PIANO_NUM_KEYS + 1 + int(cfg["model_enc"]),
            PIANO_NUM_KEYS,
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )

    def forward(self, onset_dts, onset_keys):
        if self.enc is None:
            enc = None
            enc_quant_disc = None
        else:
            # Run encoder
            enc = self.enc(
                torch.cat(
                    [
                        # Feature: onset delta time
                        onset_dts.unsqueeze(dim=2),
                        # Feature: onset key index
                        F.one_hot(onset_keys, PIANO_NUM_KEYS),
                    ],
                    dim=2,
                )
            )

            # Quantize
            enc_quant, enc_quant_disc = self.quant(enc, output_discrete=True)

        # Run decoder
        dec_inputs = [
            # Feature: onset delta time
            onset_dts.unsqueeze(dim=2),
            # Feature: *last* onset key index (prepended with <SOS>)
            F.one_hot(
                torch.cat(
                    [
                        torch.full_like(onset_keys[:, :1], PIANO_NUM_KEYS),
                        onset_keys[:, :-1],
                    ],
                    dim=1,
                ),
                PIANO_NUM_KEYS + 1,
            ),
        ]
        if self.enc is not None:
            # Feature: button from encoder
            dec_inputs.append(enc_quant.unsqueeze(dim=2))
        onset_key_logits, _ = self.dec(torch.cat(dec_inputs, dim=2))

        return onset_key_logits, enc, enc_quant_disc

In [None]:
# @title **(3)** Train Piano Genie

USE_WANDB = False  # @param{type:"boolean"}

CFG = {
    "seed": 0,
    # Number of buttons in interface
    "num_buttons": 8,
    # Onset delta times will be clipped to this maximum
    "data_delta_time_max": 1.0,
    # Max time stretch for data augmentation (+- 5%)
    "data_augment_time_stretch_max": 0.05,
    # Max transposition for data augmentation (+- tritone)
    "data_augment_transpose_max": 6,
    # Enables encoder
    "model_enc": True,
    # RNN dimensionality
    "model_rnn_dim": 128,
    # RNN num layers
    "model_rnn_num_layers": 2,
    # Training hyperparameters
    "batch_size": 32,
    "seq_len": 128,
    "lr": 3e-4,
    "loss_margin_multiplier": 1.0,
    "loss_contour_multiplier": 1.0,
    "summarize_frequency": 128,
    "eval_frequency": 128,
}

import pathlib
import random

import numpy as np

if USE_WANDB:
    try:
        import wandb
    except ModuleNotFoundError:
        !!pip install wandb
        import wandb

# Init
run_dir = pathlib.Path("piano_genie")
run_dir.mkdir(exist_ok=True)
with open(pathlib.Path(run_dir, "cfg.json"), "w") as f:
    f.write(json.dumps(CFG, indent=2))
if USE_WANDB:
    wandb.init(project="piano-genie", name="tutorial", config=CFG, reinit=True)

# Set seed
if CFG["seed"] is not None:
    random.seed(CFG["seed"])
    np.random.seed(CFG["seed"])
    torch.manual_seed(CFG["seed"])
    torch.cuda.manual_seed_all(CFG["seed"])

# Create models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PianoGenieAutoencoder(CFG)
model.train()
model.to(device)
print("-" * 80)
for n, p in model.named_parameters():
    print(f"{n}, {p.shape}")

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])

# Subsamples performances to create a minibatch
def performances_to_batch(performances, device, train=True):
    batch_onset_dt = []
    batch_onset_keys = []
    for p in performances:
        # Subsample seq_len notes from performance
        assert len(p) >= CFG["seq_len"]
        if train:
            subsample_offset = random.randrange(0, len(p) - CFG["seq_len"])
        else:
            subsample_offset = 0
        subsample = p[subsample_offset : subsample_offset + CFG["seq_len"]]
        assert len(subsample) == CFG["seq_len"]

        # Data augmentation
        if train:
            stretch_factor = (
                1
                + (random.random() * CFG["data_augment_time_stretch_max"] * 2)
                - CFG["data_augment_time_stretch_max"]
            )
            transposition_factor = random.randint(
                -CFG["data_augment_transpose_max"], CFG["data_augment_transpose_max"]
            )
            subsample = [
                (
                    n[0] * stretch_factor,
                    n[1] * stretch_factor,
                    max(0, min(n[2] + transposition_factor, PIANO_NUM_KEYS - 1)),
                    n[3],
                )
                for n in subsample
            ]

        # Compute onset features
        onset_dt = np.diff([n[0] for n in subsample])
        onset_dt = np.concatenate([[1e8], onset_dt])
        onset_dt = np.clip(onset_dt, 0, CFG["data_delta_time_max"])
        batch_onset_dt.append(onset_dt)

        # Compute onset keys
        batch_onset_keys.append([n[2] for n in subsample])

    return (
        torch.tensor(np.array(batch_onset_dt)).float(),
        torch.tensor(np.array(batch_onset_keys)).long(),
    )


# Train
step = 0
best_eval_loss_recons = float("inf")
while True:
    if step % CFG["eval_frequency"] == 0:
        model.eval()

        with torch.no_grad():
            eval_metrics = defaultdict(list)
            for i in range(0, len(DATASET["validation"]), CFG["batch_size"]):
                eval_batch = performances_to_batch(
                    DATASET["validation"][i : i + CFG["batch_size"]],
                    device,
                    train=False,
                )
                eval_onset_dts, eval_onset_keys = tuple(
                    tensor.to(device) for tensor in eval_batch
                )
                eval_onset_key_logits, _, _ = model(eval_onset_dts, eval_onset_keys)
                eval_loss_recons = F.cross_entropy(
                    eval_onset_key_logits.view(-1, PIANO_NUM_KEYS),
                    eval_onset_keys.view(-1),
                    reduction="none",
                )
                eval_metrics["loss_recons"].extend(
                    eval_loss_recons.cpu().numpy().tolist()
                )

            eval_loss_recons = np.mean(eval_metrics["loss_recons"])
            if eval_loss_recons < best_eval_loss_recons:
                torch.save(model.state_dict(), pathlib.Path(run_dir, "model.pt"))
                best_eval_loss_recons = eval_loss_recons

        eval_metrics = {"eval_loss_recons": eval_loss_recons}
        if USE_WANDB:
            wandb.log(eval_metrics, step=step)
        print(step, "eval", eval_metrics)

        model.train()

    # Create minibatch
    batch = performances_to_batch(
        random.sample(DATASET["train"], CFG["batch_size"]), device, train=True
    )
    onset_dts, onset_keys = tuple(tensor.to(device) for tensor in batch)

    # Run model
    optimizer.zero_grad()
    onset_key_logits, onset_enc, _ = model(onset_dts, onset_keys)

    # Compute losses and update params
    loss_recons = F.cross_entropy(
        onset_key_logits.view(-1, PIANO_NUM_KEYS), onset_keys.view(-1)
    )
    loss_margin = torch.square(
        torch.maximum(torch.abs(onset_enc) - 1, torch.zeros_like(onset_enc))
    ).mean()
    loss_contour = torch.square(
        torch.maximum(
            1 - torch.diff(onset_keys, dim=1) * torch.diff(onset_enc, dim=1),
            torch.zeros_like(onset_enc[:, 1:]),
        )
    ).mean()
    loss = loss_recons
    if CFG["loss_margin_multiplier"] > 0:
        loss += CFG["loss_margin_multiplier"] * loss_margin
    if CFG["loss_contour_multiplier"] > 0:
        loss += CFG["loss_contour_multiplier"] * loss_contour
    loss.backward()
    optimizer.step()
    step += 1

    if step % CFG["summarize_frequency"] == 0:
        metrics = {
            "loss_recons": loss_recons.item(),
            "loss_margin": loss_margin.item(),
            "loss_contour": loss_contour.item(),
            "loss": loss.item(),
        }
        if USE_WANDB:
            wandb.log(metrics, step=step)
        print(step, "train", metrics)

In [None]:
# @title **(3)** Port trained decoder parameters to Tensorflow.js format

!!pip install tensorflowjs

from tensorflowjs.write_weights import write_weights

# Load saved model dict
d = torch.load("piano_genie/model.pt", map_location=torch.device("cpu"))
d = {k: v.numpy() for k, v in d.items()}

# Convert to tensorflow-js format
pathlib.Path("piano_genie/dec_tfjs").mkdir(exist_ok=True)
write_weights(
    [[{"name": k, "data": v} for k, v in d.items() if k.startswith("dec")]],
    "piano_genie/dec_tfjs",
)

In [None]:
# @title **(4)** Create test fixtures check correctness of JavaScript port

# Restore model from saved checkpoint
device = torch.device("cpu")
model = PianoGenieAutoencoder(CFG)
model.load_state_dict(torch.load("piano_genie/model.pt", map_location=device))
model.eval()
model.to(device)

# Serialize a batch of inputs/outputs as JSON
with torch.no_grad():
    input_dts, ground_truth_keys = performances_to_batch(
        [DATASET["validation"][0]], device, train=False
    )
    output_logits, _, input_buttons = model(input_dts, ground_truth_keys)

    input_dts = input_dts[0].cpu().numpy().tolist()
    ground_truth_keys = ground_truth_keys[0].cpu().numpy().tolist()
    input_keys = [PIANO_NUM_KEYS] + ground_truth_keys[:-1]
    input_buttons = input_buttons[0].cpu().numpy().tolist()
    output_logits = output_logits[0].cpu().numpy().tolist()

    fixtures = {
        n: eval(n)
        for n in ["input_dts", "input_keys", "input_buttons", "output_logits"]
    }
    with open(pathlib.Path("piano_genie", "fixtures.json"), "w") as f:
        f.write(json.dumps(fixtures))