# Music Co-creation Tutorial Part 1: Training a generative model of music
### [Chris Donahue](https://chrisdonahue.com), [Anna Huang](https://research.google/people/105787/), [Jon Gillick](https://www.jongillick.com/)

This is the first part of a two-part tutorial entitled [*Interactive music co-creation with PyTorch and TensorFlow.js*](https://github.com/chrisdonahue/music-cocreation-tutorial/), prepared as part of the ISMIR 2021 tutorial *Designing generative models for interactive co-creation*. This part of the tutorial will demonstrate how to **train a generative model of music in PyTorch**, and **port its weights to TensorFlow.js** format for interaction. The [final result is here](https://chrisdonahue.com/music-cocreation-tutorial)—see our [GitHub repo](https://github.com/chrisdonahue/music-cocreation-tutorial/) for part 2.

## Primer on Piano Genie

The generative model we will train is called [Piano Genie](https://magenta.tensorflow.org/pianogenie) (Donahue et al. 2019). Piano Genie is a system which maps amateur improvisations on a miniature 8-button keyboard ([video](https://www.youtube.com/watch?v=YRb0XAnUpIk), [demo](https://piano-genie.glitch.me)) into realistic performances on a full 88-key piano.

To achieve this, Piano Genie adopts an _autoencoder_ approach. First, an _encoder_ maps professional piano performances into this 8-button space. Then, a _decoder_ attempts to reconstruct the original piano performance from the 8-button version. The entire system is trained end-to-end to minimize the decoder's reconstruction error. At performance time, we replace the encoder with a user improvising on an 8-button controller, and use the pre-trained decoder to generate a corresponding piano performance.

<center><img src="https://i.imgur.com/pmYajEg.png" width=600px/></center>

At a low-level, both the encoder and the decoder for Piano Genie are lightweight recurrent neural networks, which are suitable for real-time performance even on mobile CPUs. The discrete bottleneck is achieved using a technique called _integer-quantized autoencoding_ (IQAE), which was also proposed in the Piano Genie paper.

In [None]:
#@title **(Step 1)** Parse MIDI piano performances into simple lists of notes

# @markdown *Note*: Check this box to rebuild the dataset from scratch.
REBUILD_DATASET = False  # @param{type:"boolean"}

# @markdown To train Piano Genie, we will use a dataset of professional piano performances called [MAESTRO](https://magenta.tensorflow.org/datasets/maestro) (Hawthorne et al. 2019).
# @markdown Each performance in this dataset was captured by a Disklavier, a computerized piano which can record human performances in MIDI format, i.e., as timestamped sequences of notes.

PIANO_LOWEST_KEY_MIDI_PITCH = 21
PIANO_NUM_KEYS = 88

import gzip
import json
from collections import defaultdict

from tqdm.notebook import tqdm


def download_and_parse_maestro():
    # Install pretty_midi
    !!pip install pretty_midi
    import pretty_midi

    # Download MAESTRO dataset (Hawthorne+ 2018)
    !!wget -nc https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
    !!unzip maestro-v2.0.0-midi.zip

    # Parse MAESTRO dataset
    dataset = defaultdict(list)
    with open("maestro-v2.0.0/maestro-v2.0.0.json", "r") as f:
        for attrs in tqdm(json.load(f)):
            split = attrs["split"]
            midi = pretty_midi.PrettyMIDI("maestro-v2.0.0/" + attrs["midi_filename"])
            assert len(midi.instruments) == 1
            # @markdown Formally, a piano performance is a sequence of notes: $\mathbf{x} = (x_1, \ldots, x_N)$, where each $x_i = (t_i, x^e_i, k_i, x^v_i)$, signifying:
            notes = [
                (
                    # @markdown 1. (When the key was pressed) An _onset_ time $t_i \in \mathbb{T}$, where $\mathbb{T} = \{ t \in \mathbb{R} \mid 0 \leq t \leq T \}$ 
                    float(n.start),
                    # @markdown 2. (When the key was released) An _offset_ time $x^e_i \in \mathbb{T}$
                    float(n.end),
                    # @markdown 3. (Which key was pressed) A _key_ index $k_i \in \mathbb{K}$, where $\mathbb{K} = \{\text{A0}, \ldots, \text{C8}\}$ and $|\mathbb{K}| = 88$
                    int(n.pitch - PIANO_LOWEST_KEY_MIDI_PITCH),
                    # @markdown 4. (How hard the key was pressed) A _velocity_ $x^v_i \in \mathbb{V}$, where $\mathbb{V} = \{1, \ldots, 127\}$
                    int(n.velocity),
                )
                for n in midi.instruments[0].notes
            ]

            # This list is in sorted order of onset time, i.e., $x_{i-1}^s \leq x_i^s ~\forall~i \in \{2, \ldots, N\}$.
            notes = sorted(notes, key=lambda n: (n[0], n[2]))
            assert all(
                [
                    all(
                        [
                            # Start times should be non-negative
                            n[0] >= 0,
                            # Note durations should be strictly positive, i.e., $x_i^s < x_i^e$
                            n[0] < n[1],
                            # Key index should be in range of the piano
                            0 <= n[2] and n[2] < PIANO_NUM_KEYS,
                            # Velocity should be valid
                            1 <= n[3] and n[3] < 128,
                        ]
                    )
                    for n in notes
                ]
            )
            dataset[split].append(notes)

        return dataset


if REBUILD_DATASET:
    DATASET = download_and_parse_maestro()
    with gzip.open("maestro-v2.0.0-simple.json.gz", "w") as f:
        f.write(json.dumps(DATASET).encode("utf-8"))
else:
    !!wget -nc https://github.com/chrisdonahue/music-cocreation-tutorial/raw/main/part-1-py-training/data/maestro-v2.0.0-simple.json.gz
    with gzip.open("maestro-v2.0.0-simple.json.gz", "rb") as f:
        DATASET = json.load(f)

print([(s, len(DATASET[s])) for s in ["train", "validation", "test"]])

In [None]:
# @title **(Step 2)** Define Piano Genie autoencoder

# @markdown Our intended interaction for Piano Genie is to have users perform on a miniature 8-button keyboard.
# @markdown Similarly to how we formalized professional piano performances, we will represent these "button performances" as sequences of "notes", where we replace piano keys $k_i$ with buttons $b_i$, and we remove velocity since our controllers are not velocity-sensitive. So, to summarize, piano performances $\mathbf{x}$ and button performances $\mathbf{c}$ are defined as follows:

# @markdown - $\mathbf{x} = (x_1, \ldots, x_N)$, where $x_i = (t_i \in \mathbb{T}, x^e_i \in \mathbb{T}, k_i \in \mathbb{K}, x^v_i \in \mathbb{V})$, i.e., (onsets, offsets, keys, velocities)

# @markdown - $\mathbf{c} = (c_1, \ldots, c_M)$, where $c_i = (c^s_i \in \mathbb{T}, c^e_i \in \mathbb{T}, b_i \in \mathbb{B})$, i.e., (onsets, offsets, buttons), and $\mathbb{B} = \{ \color{#EE2B29}\blacksquare, \color{#ff9800}\blacksquare, \color{#ffff00}\blacksquare, \color{#c6ff00}\blacksquare, \color{#00e5ff}\blacksquare, \color{#2979ff}\blacksquare, \color{#651fff}\blacksquare, \color{#d500f9}\blacksquare \}$

# @markdown To map button performances into piano performances, we will train a generative model $P(\mathbf{x} \mid \mathbf{c})$.
# @markdown In practice, we will factorize this joint distribution over note sequences $\mathbf{x}$ into the product of conditional probabilities of individual notes: $P(\mathbf{x} \mid \mathbf{c}) = \prod_{i=1}^{N} P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$. 

# @markdown Hence, our **overall goal is to learn** $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$, 
# @markdown which we will **approximate by modeling**:

# @markdown <center>$P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$.</center>

# @markdown We arrived at this approximation by working through constraints imposed by the interaction (details at the end).

# @markdown #### **Decoder**

# @markdown <center><img src="https://i.imgur.com/K7ks74K.png" width=600px/></center>
# @markdown <center><b>Piano Genie decoder processing $N=4$ notes</b></center>

# @markdown The approximation $P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$ constitutes the decoder of Piano Genie, which we will parameterize using an RNN.
# @markdown This is the portion of the model that users will interact with. 
# @markdown To achieve our intended real-time interaction, we will compute and sample from this RNN at the instant the user presses a button, passing as input the key from the previous timestep, the current time, the button the user pressed, and a vector which summarizes the ongoing history.

import torch
import torch.nn as nn
import torch.nn.functional as F

SOS = PIANO_NUM_KEYS

class PianoGenieDecoder(nn.Module):
    def __init__(
        self,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(PIANO_NUM_KEYS + 3, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=False,
        )
        self.output = nn.Linear(rnn_dim, 88)

    # @markdown Formally, the decoder is a function:
    # @markdown $D_{\theta}: k_{i-1}, t_i, b_i, \mathbf{h}_{i-1} \mapsto \mathbf{\hat{k}}_i, \mathbf{h}_i$, where:
    def forward(self, k, t, b):
        # @markdown - $k_0$ is a special start-of-sequence token $<\text{S}>$
        k_m1 = torch.cat([torch.full_like(k[:, :1], SOS), k[:, :-1]], dim=1)

        # @markdown - $\mathbf{h}_i$ is a vector summarizing timesteps $1, \ldots, i$

        # @markdown - $\mathbf{h}_0$ is some initial value (zeros) for that vector

        inputs = [
            F.one_hot(k_m1, PIANO_NUM_KEYS + 1),
            t.unsqueeze(dim=2),
            b.unsqueeze(dim=2),
        ]
        x = self.input(torch.cat(inputs, dim=2))
        # NOTE: PyTorch uses zeros automatically if h is None
        x, _ = self.lstm(x, None)

        # @markdown - $\mathbf{\hat{k}}_i \in \mathbb{R}^{88}$ are the output logits for timestep $i$
        hat_k = self.output(x)

        return hat_k


# @markdown #### **Encoder**

# @markdown <center><img src="https://i.imgur.com/ikNuSEH.png" width=600px/></center>
# @markdown <center><b>Piano Genie encoder processing $N=4$ notes</b></center>

# @markdown Because we lack examples of human button performances, we use an encoder to automatically learn to map piano performances into synthetic button performances.
# @markdown Our encoder is also an RNN, though it is bidirectional unlike the decoder. 
# @markdown This allows it to observe the entire piano performance before compressing it into buttons.

class PianoGenieEncoder(nn.Module):
    def __init__(
        self,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(PIANO_NUM_KEYS + 1, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=True,
        )
        self.output = nn.Linear(rnn_dim * 2, 1)

    # @markdown Formally, the encoder is a function: $E_{\varphi} : k_i, t_i, \mathbf{h^f}_{i-1}, \mathbf{h^b}_{N-i} \mapsto b_i, \mathbf{h^f}_i, \mathbf{h^b}_{N - i + 1}$, where $\mathbf{h^f}$ / $\mathbf{h^b}$ are summary vectors in the forwards and backwards directions respectively, and $\mathbf{h^f}_0$ / $\mathbf{h^b}_0$ are their respective initial values.
    def forward(self, k, t):
        inputs = [
            F.one_hot(k, PIANO_NUM_KEYS),
            t.unsqueeze(dim=2),
        ]
        x = self.input(torch.cat(inputs, dim=2))
        # NOTE: PyTorch uses zeros automatically if h is None
        x, _ = self.lstm(x, None)
        x = self.output(x)
        return x[:, :, 0]


# @markdown #### **Quantizing encoder output to discrete buttons**

# @markdown <center><img src="https://i.imgur.com/pWuVHhf.png" width=600px/></center>
# @markdown <center><b>Quantizing continuous encoder output (grey line) to eight discrete values (colorful line segments)</b></center>

# @markdown You may have noticed that the encoder outputs a real-valued scalar (let's call it $e_i \in \mathbb{R}$) at each timestep, but our goal is to output one of eight discrete buttons, i.e., $b_i \in \mathbb{B}$. 
# @markdown To achieve this, we will quantize this real-valued scalar as the centroid of the nearest of eight bins between $[-1, 1]$ (see figure above):

# @markdown <center>$b_i = 2 \cdot \frac{\tilde{b}_i - 1}{B - 1} - 1$, where $\tilde{b}_i = \text{round} \left( 1 + (B - 1) \cdot \min \left( \max \left( \frac{e_i  + 1}{2}, 0 \right), 1 \right) \right)$</center>

class IntegerQuantizer(nn.Module):
    def __init__(self, num_bins):
        super().__init__()
        self.num_bins = num_bins

    def real_to_discrete(self, x, eps=1e-6):
        x = (x + 1) / 2
        x = torch.clamp(x, 0, 1)
        x *= self.num_bins - 1
        x = (torch.round(x) + eps).long()
        return x

    def discrete_to_real(self, x):
        x = x.float()
        x /= self.num_bins - 1
        x = (x * 2) - 1
        return x

    def forward(self, x):
        # Quantize and compute delta (used for straight-through estimator)
        with torch.no_grad():
            x_disc = self.real_to_discrete(x)
            x_quant = self.discrete_to_real(x_disc)
            x_quant_delta = x_quant - x

        # @markdown In the backwards pass, we will use the straight-through estimator (Bengio et al. 2013), i.e., pretend that this discretization did not happen when computing gradients.
        # Quantize w/ straight-through estimator
        x = x + x_quant_delta

        return x


# @markdown #### **Defining the autoencoder**

# @markdown Finally, the Piano Genie autoencoder is simply the composition of the encoder, quantizer, and decoder.

class PianoGenieAutoencoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.enc = PianoGenieEncoder(
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )
        self.quant = IntegerQuantizer(cfg["num_buttons"])
        self.dec = PianoGenieDecoder(
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )

    def forward(self, k, t):
        e = self.enc(k, t)
        b = self.quant(e)
        hat_k = self.dec(k, t, b)
        return hat_k, e


# @markdown #### **Approximating $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$**

# @markdown This section walks through how we designed an approximation to $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$ which would be appropriate for our intended interaction. You probably don't need to understand this, but some may find it helpful as an illustration of how to design a generative model around constraints imposed by interaction.

# @markdown First, we expand the terms:

# @markdown <center>$P(x_i \mid \mathbf{x}_{< i}, \mathbf{c}) = P(t_i, x^e_i, k_i, x^v_i \mid \mathbf{t}_{<i}, \mathbf{x^e}_{<i}, \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{c^s}, \mathbf{c^e}, \mathbf{b})$</center>

# @markdown We think it might be intuitive for the miniature piano to behave like a real piano: pressing a button causes a note to sound, which is held until released. Hence, $N = M$, $t_i = c^s_i$, and $x^e_i = c^e_i$, so we can remove some redundant terms:

# @markdown <center>$= P(k_i, x^v_i \mid \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{t}, \mathbf{c^e}, \mathbf{b})$</center>

# @markdown Because we want this interaction to be real-time, we must remove any term that might not be available at $t_i$, which includes future onsets $\mathbf{t}_{>i}$, future buttons $\mathbf{b}_{>i}$, and all offsets $\mathbf{c^e}$, since notes can be held indefinitely:

# @markdown <center>$\approx P(k_i, x^v_i \mid \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

# @markdown Finally, we anticipate that it will be frustrating for users if the model predicts dynamics on their behalf, so we remove velocity terms $\mathbf{x^v}$:

# @markdown <center>$\approx P(k_i, \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

In [None]:
# @title **(Step 3)** Train Piano Genie

# @markdown *Note*: Check this box to log training curves to [Weights & Biases](https://wandb.ai/) (which will prompt you to log in).
USE_WANDB = False  # @param{type:"boolean"}

# @markdown Now that we've defined the autoencoder, we need to train it.
# @markdown We will train the entire autoencoder end-to-end to minimize the reconstruction loss of the decoder.

# @markdown <center>$\mathcal{L}_{\text{recons}} = \frac{1}{N} \sum_{i=1}^{N} \text{CrossEntropy}(\text{Softmax}(\mathbf{\hat{k}}_i), k_i)$</center>

# @markdown This loss alone does not encourage the encoder to produce button sequences with any particular structure, so the behavior of the decoder will likely be fairly unpredictable at interaction time.
# @markdown We think it might be intuitive to users if the decoder respected the _contour_ of their performance, i.e., if higher buttons produced higher notes and lower buttons produced lower notes.
# @markdown Hence, we include a loss term which encourages the encoder to produces button sequences which align with the contour of the piano key sequences.

# @markdown <center>$\mathcal{L}_{\text{contour}} = \frac{1}{N - 1} \sum_{i=2}^{N} \max (0, 1 - (k_i - k_{i-1}) \cdot (e_i - e_{i-1}))^2$</center>

# @markdown Finally, we find empirically that the encoder often outputs values outside of the $[-1, 1]$ range used for discretization. 
# @markdown Hence, we add a loss term which explicitly encourages this behavior

# @markdown <center>$\mathcal{L}_{\text{margin}} = \frac{1}{N} \sum_{i=1}^{N} \max(0, |e_i| - 1)^2$</center>

# @markdown Thus, our final loss function is:
# @markdown <center>$\mathcal{L} = \mathcal{L}_{\text{recons}} + \mathcal{L}_{\text{contour}} + \mathcal{L}_{\text{margin}}$</center>


CFG = {
    "seed": 0,
    # Number of buttons in interface
    "num_buttons": 8,
    # Onset delta times will be clipped to this maximum
    "data_delta_time_max": 1.0,
    # Max time stretch for data augmentation (+- 5%)
    "data_augment_time_stretch_max": 0.05,
    # Max transposition for data augmentation (+- tritone)
    "data_augment_transpose_max": 6,
    # RNN dimensionality
    "model_rnn_dim": 128,
    # RNN num layers
    "model_rnn_num_layers": 2,
    # Training hyperparameters
    "batch_size": 32,
    "seq_len": 128,
    "lr": 3e-4,
    "loss_margin_multiplier": 1.0,
    "loss_contour_multiplier": 1.0,
    "summarize_frequency": 128,
    "eval_frequency": 128,
    "max_num_steps": 50000
}

import pathlib
import random

import numpy as np

if USE_WANDB:
    try:
        import wandb
    except ModuleNotFoundError:
        !!pip install wandb
        import wandb

# Init
run_dir = pathlib.Path("piano_genie")
run_dir.mkdir(exist_ok=True)
with open(pathlib.Path(run_dir, "cfg.json"), "w") as f:
    f.write(json.dumps(CFG, indent=2))
if USE_WANDB:
    wandb.init(project="music-cocreation-tutorial", config=CFG, reinit=True)

# Set seed
if CFG["seed"] is not None:
    random.seed(CFG["seed"])
    np.random.seed(CFG["seed"])
    torch.manual_seed(CFG["seed"])
    torch.cuda.manual_seed_all(CFG["seed"])

# Create model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PianoGenieAutoencoder(CFG)
model.train()
model.to(device)
print("-" * 80)
for n, p in model.named_parameters():
    print(f"{n}, {p.shape}")

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])

# Subsamples performances to create a minibatch
def performances_to_batch(performances, device, train=True):
    batch_k = []
    batch_t = []
    for p in performances:
        # Subsample seq_len notes from performance
        assert len(p) >= CFG["seq_len"]
        if train:
            subsample_offset = random.randrange(0, len(p) - CFG["seq_len"])
        else:
            subsample_offset = 0
        subsample = p[subsample_offset : subsample_offset + CFG["seq_len"]]
        assert len(subsample) == CFG["seq_len"]

        # Data augmentation
        if train:
            stretch_factor = random.random() * CFG["data_augment_time_stretch_max"] * 2
            stretch_factor += 1 - CFG["data_augment_time_stretch_max"]
            transposition_factor = random.randint(
                -CFG["data_augment_transpose_max"], CFG["data_augment_transpose_max"]
            )
            subsample = [
                (
                    n[0] * stretch_factor,
                    n[1] * stretch_factor,
                    max(0, min(n[2] + transposition_factor, PIANO_NUM_KEYS - 1)),
                    n[3],
                )
                for n in subsample
            ]
        
        # Key features
        batch_k.append([n[2] for n in subsample])

        # Onset features
        # NOTE: For stability, we pass delta time to Piano Genie instead of time.
        t = np.diff([n[0] for n in subsample])
        t = np.concatenate([[1e8], t])
        t = np.clip(t, 0, CFG["data_delta_time_max"])
        batch_t.append(t)

    return (torch.tensor(batch_k).long(), torch.tensor(batch_t).float())


# Train
step = 0
best_eval_loss = float("inf")
while CFG["max_num_steps"] is None or step < CFG["max_num_steps"]:
    if step % CFG["eval_frequency"] == 0:
        model.eval()

        with torch.no_grad():
            eval_losses = []
            for i in range(0, len(DATASET["validation"]), CFG["batch_size"]):
                eval_batch = performances_to_batch(
                    DATASET["validation"][i : i + CFG["batch_size"]],
                    device,
                    train=False,
                )
                eval_k, eval_t = tuple(t.to(device) for t in eval_batch)
                eval_hat_k, _ = model(eval_k, eval_t)
                eval_loss = F.cross_entropy(
                    eval_hat_k.view(-1, PIANO_NUM_KEYS),
                    eval_k.view(-1),
                    reduction="none",
                )
                eval_losses.extend(eval_loss.cpu().numpy().tolist())

            eval_loss = np.mean(eval_losses)
            if eval_loss < best_eval_loss:
                torch.save(model.state_dict(), pathlib.Path(run_dir, "model.pt"))
                best_eval_loss = eval_loss

        eval_metrics = {"eval_loss_recons": eval_loss}
        if USE_WANDB:
            wandb.log(eval_metrics, step=step)
        print(step, "eval", eval_metrics)

        model.train()

    # Create minibatch
    batch = performances_to_batch(
        random.sample(DATASET["train"], CFG["batch_size"]), device, train=True
    )
    k, t = tuple(t.to(device) for t in batch)

    # Run model
    optimizer.zero_grad()
    k_hat, e = model(k, t)

    # Compute losses and update params
    loss_recons = F.cross_entropy(k_hat.view(-1, PIANO_NUM_KEYS), k.view(-1))
    loss_margin = torch.square(
        torch.maximum(torch.abs(e) - 1, torch.zeros_like(e))
    ).mean()
    loss_contour = torch.square(
        torch.maximum(
            1 - torch.diff(k, dim=1) * torch.diff(e, dim=1),
            torch.zeros_like(e[:, 1:]),
        )
    ).mean()
    loss = loss_recons
    if CFG["loss_margin_multiplier"] > 0:
        loss += CFG["loss_margin_multiplier"] * loss_margin
    if CFG["loss_contour_multiplier"] > 0:
        loss += CFG["loss_contour_multiplier"] * loss_contour
    loss.backward()
    optimizer.step()
    step += 1

    if step % CFG["summarize_frequency"] == 0:
        metrics = {
            "loss_recons": loss_recons.item(),
            "loss_margin": loss_margin.item(),
            "loss_contour": loss_contour.item(),
            "loss": loss.item(),
        }
        if USE_WANDB:
            wandb.log(metrics, step=step)
        print(step, "train", metrics)

# Download the trained model so we don't lose it!
from google.colab import files

files.download('piano_genie/model.pt')
files.download('piano_genie/cfg.json')

In [None]:
# Download the trained model so we don't lose it!
from google.colab import files

files.download('piano_genie/model.pt')
files.download('piano_genie/cfg.json')

In [None]:
# @title **(Step 4)** Port trained decoder parameters to Tensorflow.js format

# @markdown In this step, we will use the TensorFlow.js Python library to export our model's parameters in a binary format, to be loaded later by the JavaScript client.

!!pip install tensorflowjs

from tensorflowjs.write_weights import write_weights

# Load saved model dict
d = torch.load("piano_genie/model.pt", map_location=torch.device("cpu"))
d = {k: v.numpy() for k, v in d.items()}

# Convert to tensorflow-js format
pathlib.Path("piano_genie/dec_tfjs").mkdir(exist_ok=True)
write_weights(
    [[{"name": k, "data": v} for k, v in d.items() if k.startswith("dec")]],
    "piano_genie/dec_tfjs",
)

In [None]:
# @title **(Step 5)** Create test fixtures check correctness of JavaScript port

# @markdown Finally, we will serialize a sequence of inputs to and outputs from our trained model to create a test case for our JavaScript reimplementation.
# @markdown This is critically important—I have ported many models from Python to JavaScript and have yet to get it right on the first try.
# @markdown Porting models from PyTorch to TensorFlow.js is additionally tricky because parameters of the same shape are often used differently by the two APIs.

# Restore model from saved checkpoint
device = torch.device("cpu")
with open("piano_genie/cfg.json", "r") as f:
    cfg = json.load(f)
model = PianoGenieAutoencoder(cfg)
model.load_state_dict(torch.load("piano_genie/model.pt", map_location=device))
model.eval()
model.to(device)

# Serialize a batch of inputs/outputs as JSON
with torch.no_grad():
    input_dts, ground_truth_keys = performances_to_batch(
        [DATASET["validation"][0]], device, train=False
    )
    output_logits, _, input_buttons = model(input_dts, ground_truth_keys)

    input_dts = input_dts[0].cpu().numpy().tolist()
    ground_truth_keys = ground_truth_keys[0].cpu().numpy().tolist()
    input_keys = [PIANO_NUM_KEYS] + ground_truth_keys[:-1]
    input_buttons = input_buttons[0].cpu().numpy().tolist()
    output_logits = output_logits[0].cpu().numpy().tolist()

    fixtures = {
        n: eval(n)
        for n in ["input_dts", "input_keys", "input_buttons", "output_logits"]
    }
    with open(pathlib.Path("piano_genie", "fixtures.json"), "w") as f:
        f.write(json.dumps(fixtures))