# Music Co-creation Tutorial Part 1: Training a generative model of music
### [Chris Donahue](https://chrisdonahue.com), [Anna Huang](https://research.google/people/105787/), [Jon Gillick](https://www.jongillick.com/)

This is the first part of a two-part tutorial entitled [*Interactive music co-creation with PyTorch and TensorFlow.js*](https://github.com/chrisdonahue/music-cocreation-tutorial/), prepared as part of the ISMIR 2021 tutorial *Designing generative models for interactive co-creation*. This part of the tutorial will demonstrate how to **train a generative model of music in PyTorch**, and **port its weights to TensorFlow.js** format for interaction. [See our GitHub repo](https://github.com/chrisdonahue/music-cocreation-tutorial/) for part 2.

## Primer on Piano Genie

The generative model we will train is called [Piano Genie](https://magenta.tensorflow.org/pianogenie) (Donahue et al. 2019). Piano Genie is a system which maps amateur improvisations on a miniature 8-button keyboard ([video](https://www.youtube.com/watch?v=YRb0XAnUpIk), [demo](https://piano-genie.glitch.me)) into realistic performances on a full 88-key piano.

To achieve this, Piano Genie adopts an _autoencoder_ approach. First, an _encoder_ maps professional piano performances into this 8-button space. Then, a _decoder_ attempts to reconstruct the original piano performance from the 8-button version. The entire system is trained end-to-end to minimize the decoder's reconstruction error. At performance time, we replace the encoder with a user improvising on an 8-button controller, and use the pre-trained decoder to generate a corresponding piano performance.

<center><img src="https://i.imgur.com/pmYajEg.png" width=600px/></center>

At a low-level, both the encoder and the decoder for Piano Genie are lightweight recurrent neural networks, which are suitable for real-time performance even on mobile CPUs. The discrete bottleneck is achieved using a technique called _integer-quantized autoencoding_ (IQAE), which was also proposed in the Piano Genie paper.

In [None]:
#@title **(Step 1)** Parse MIDI piano performances into simple lists of notes

USE_PRECACHED = True  # @param{type:"boolean"}

# @markdown To train Piano Genie, we will use a dataset of professional piano performances called [MAESTRO](https://magenta.tensorflow.org/datasets/maestro) (Hawthorne et al. 2019).
# @markdown Each performance in this dataset was captured by a Disklavier, a computerized piano which can record human performances in MIDI format, i.e., as timestamped sequences of notes.

PIANO_LOWEST_KEY_MIDI_PITCH = 21
PIANO_NUM_KEYS = 88

import gzip
import json
from collections import defaultdict

from tqdm.notebook import tqdm


def download_and_parse_maestro():
    # Install pretty_midi
    !!pip install pretty_midi
    import pretty_midi

    # Download MAESTRO dataset (Hawthorne+ 2018)
    !!wget -nc https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
    !!unzip maestro-v2.0.0-midi.zip

    # Parse MAESTRO dataset
    dataset = defaultdict(list)
    with open("maestro-v2.0.0/maestro-v2.0.0.json", "r") as f:
        for attrs in tqdm(json.load(f)):
            split = attrs["split"]
            midi = pretty_midi.PrettyMIDI("maestro-v2.0.0/" + attrs["midi_filename"])
            assert len(midi.instruments) == 1
            # @markdown Formally, a piano performance is a sequence of notes: $\mathbf{x} = (x_1, \ldots, x_N)$, where each $x_i = (t_i, x^e_i, k_i, x^v_i)$, signifying:
            notes = [
                (
                    # @markdown 1. (When the key was pressed) An _onset_ time $t_i \in \mathbb{T}$, where $\mathbb{T} = \{ t \in \mathbb{R} \mid 0 \leq t \leq T \}$ 
                    float(n.start),
                    # @markdown 2. (When the key was released) An _offset_ time $x^e_i \in \mathbb{T}$
                    float(n.end),
                    # @markdown 3. (Which key was pressed) A _key_ index $k_i \in \mathbb{K}$, where $\mathbb{K} = \{\text{A0}, \ldots, \text{C8}\}$ and $|\mathbb{K}| = 88$
                    int(n.pitch - PIANO_LOWEST_KEY_MIDI_PITCH),
                    # @markdown 4. (How hard the key was pressed) A _velocity_ $x^v_i \in \mathbb{V}$, where $\mathbb{V} = \{1, \ldots, 127\}$
                    int(n.velocity),
                )
                for n in midi.instruments[0].notes
            ]

            # This list is in sorted order of onset time, i.e., $x_{i-1}^s \leq x_i^s ~\forall~i \in \{2, \ldots, N\}$.
            notes = sorted(notes, key=lambda n: (n[0], n[2]))
            assert all(
                [
                    all(
                        [
                            # Start times should be non-negative
                            n[0] >= 0,
                            # Note durations should be strictly positive, i.e., $x_i^s < x_i^e$
                            n[0] < n[1],
                            # Key index should be in range of the piano
                            0 <= n[2] and n[2] < PIANO_NUM_KEYS,
                            # Velocity should be valid
                            1 <= n[3] and n[3] < 128,
                        ]
                    )
                    for n in notes
                ]
            )
            dataset[split].append(notes)

        return dataset


if USE_PRECACHED:
    !!wget -nc https://github.com/chrisdonahue/music-cocreation-tutorial/raw/main/part-1-py-training/data/maestro-v2.0.0-simple.json.gz
    with gzip.open("maestro-v2.0.0-simple.json.gz", "rb") as f:
        DATASET = json.load(f)
else:
    DATASET = download_and_parse_maestro()
    with gzip.open("maestro-v2.0.0-simple.json.gz", "w") as f:
        f.write(json.dumps(DATASET).encode("utf-8"))

print([(s, len(DATASET[s])) for s in ["train", "validation", "test"]])

In [None]:
# @title **(Step 2)** Define Piano Genie autoencoder

# @markdown Our intended interaction for Piano Genie is to have users perform on a miniature 8-button keyboard.
# @markdown Similarly to how we formalized professional piano performances, we will represent these "button performances" as sequences of "notes", where we replace piano keys $k_i$ with buttons $b_i$, and we remove velocity since our controllers are not velocity-sensitive. So, to summarize, piano performances $\mathbf{x}$ and button performances $\mathbf{c}$ are defined as follows:

# @markdown - $\mathbf{x} = (x_1, \ldots, x_N)$, where $x_i = (t_i \in \mathbb{T}, x^e_i \in \mathbb{T}, k_i \in \mathbb{K}, x^v_i \in \mathbb{V})$, i.e., (onsets, offsets, keys, velocities)

# @markdown - $\mathbf{c} = (c_1, \ldots, c_M)$, where $c_i = (c^s_i \in \mathbb{T}, c^e_i \in \mathbb{T}, b_i \in \mathbb{B})$, i.e., (onsets, offsets, buttons), and $\mathbb{B} = \{ \color{#EE2B29}\blacksquare, \color{#ff9800}\blacksquare, \color{#ffff00}\blacksquare, \color{#c6ff00}\blacksquare, \color{#00e5ff}\blacksquare, \color{#2979ff}\blacksquare, \color{#651fff}\blacksquare, \color{#d500f9}\blacksquare \}$

# @markdown To map button performances into piano performances, we will train a generative model $P(\mathbf{x} \mid \mathbf{c})$.
# @markdown In practice, we will factorize this joint distribution over note sequences $\mathbf{x}$ into the product of conditional probabilities of individual notes: $P(\mathbf{x} \mid \mathbf{c}) = \prod_{i=1}^{N} P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$. 

# @markdown Hence, our **overall goal is to learn** $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$, 
# @markdown which we will **approximate by modeling**:

# @markdown <center>$P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$.</center>

# @markdown We arrived at this approximation by working through constraints imposed by the interaction (details at the end).

In [None]:
# @markdown #### Decoder

# @markdown <center><img src="https://i.imgur.com/phEiaJZ.png" width=600px/></center>

# @markdown The approximation $P(k_i \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$ constitutes the decoder of Piano Genie, which we will parameterize using an RNN.
# @markdown To achieve our real-time interaction, we will compute and sample from this RNN at the instant the user presses a button, passing as input the key from the previous timestep, the current time, the button the user pressed, and a vector which summarizes the ongoing history.

At each timestep, the RNN receives the key from the previous timestep, the current onset time, and the 

import torch
import torch.nn as nn
import torch.nn.functional as F

SOS = PIANO_NUM_KEYS

class PianoGenieDecoder(nn.Module):
    def __init__(
        self,
        delta_time_max=1.,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.delta_time_max = delta_time_max
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(PIANO_NUM_KEYS + 3, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=False,
        )
        self.output = nn.Linear(rnn_dim, 88)

    # @markdown Formally, the decoder is a function:
    # @markdown $D_{\theta}: k_{i-1}, t_i, b_i, \mathbf{h}_{i-1} \mapsto \mathbf{\hat{k}}_i, \mathbf{h}_i$, where:
    def forward(self, k_i, t_i, b_i, h_im1=None):
        # Convert time into delta time for stability
        dt = torch.diff(t_i, dim=1)
        dt_i = torch.cat([torch.full_like(dt[:, :1], 1e6), dt], dim=1)
        dt_i = torch.minimum(dt_i, self.delta_time_max)

        # @markdown - $\mathbf{h}_i$ is a vector summarizing timesteps $1, \ldots, i$

        # @markdown - $\mathbf{h}_0$ is some initial value (zeros) for that vector
        if h_im1 is None:
            # NOTE: PyTorch uses zeros automatically if h is None
            pass

        # @markdown - $k_0$ is a special start-of-sequence token $<\text{S}>$
        k_im1 = torch.cat([torch.full_like(k_i[:, :1], SOS), k_i[:, :-1]], dim=1)

        inputs = [
            # k_im1
            F.one_hot(k_im1, PIANO_NUM_KEYS + 1),
            # t_i
            dt_i.unsqueeze(dim=2),
            # b_i
            b_i.unsqueeze(dim=2),
        ]
        x = self.input(torch.cat(inputs, dim=2))
        x, h_i = self.lstm(x, h_im1)
        # @markdown - $\mathbf{\hat{k}}_i \in \mathbb{R}^{88}$ are the output logits for timestep $i$
        k_hat_i = self.output(x)

        return k_hat_i, h_i

In [None]:
# @markdown #### Encoder

# @markdown Because we lack examples of human button performances, we use an encoder to automatically learn to map piano performances into synthetic button performances.
# @markdown Our encoder is also an RNN, though it is bidirectional unlike the decoder. 
# @markdown This allows it to observe the entire piano performance before compressing it into buttons.

# @markdown Formally, the encoder is a function: $E_{\varphi} : k_i, t_i, \mathbf{h^f}_{i-1}, \mathbf{h^b}_{i-1} \mapsto b_i$, where $\mathbf{h^f}_i$ and $\mathbf{h^b}_i$ are summary vectors in the forwards and backwards directions respectively.

class PianoGenieEncoder(nn.Module):
    def __init__(
        self,
        input_dim,
        rnn_dim=128,
        rnn_num_layers=2,
    ):
        super().__init__()
        self.rnn_dim = rnn_dim
        self.rnn_num_layers = rnn_num_layers
        self.input = nn.Linear(input_dim, rnn_dim)
        self.lstm = nn.LSTM(
            rnn_dim,
            rnn_dim,
            rnn_num_layers,
            bias=True,
            batch_first=True,
            bidirectional=True,
        )
        self.output = nn.Linear(rnn_dim * 2, 1)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        x = self.input(x)
        h = (
            torch.zeros(
                2 * self.rnn_num_layers, batch_size, self.rnn_dim, device=x.device
            ),
            torch.zeros(
                2 * self.rnn_num_layers, batch_size, self.rnn_dim, device=x.device
            ),
        )
        x, h = self.lstm(x, h)
        x = self.output(x)
        return x[:, :, 0]


class IntegerQuantizer(nn.Module):
    """
    IntegerQuantizer independently quantizes scalar values to K values between [-1, 1].
    """
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size

    def real_to_discrete(self, x, eps=1e-6):
        x = (x + 1) / 2
        x = torch.clamp(x, 0, 1)
        x *= self.vocab_size - 1
        x = (torch.round(x) + eps).long()
        return x

    def discrete_to_real(self, x):
        x = x.float()
        x /= self.vocab_size - 1
        x = (x * 2) - 1
        return x

    def forward(self, x, output_discrete=False):
        # Quantize and compute delta (used for straight-through estimator)
        with torch.no_grad():
            x_disc = self.real_to_discrete(x)
            x_quant = self.discrete_to_real(x_disc)
            x_quant_delta = x_quant - x

        # Quantize w/ straight-through estimator
        x = x + x_quant_delta

        result = x
        if output_discrete:
            result = (x, x_disc)
        return result


class PianoGenieAutoencoder(nn.Module):
    """
    PianoGenieAutoencoder composes encoder, quantizer, decoder.
    """
    def __init__(self, cfg):
        super().__init__()
        self.enc = None
        if cfg["model_enc"]:
            self.enc = PianoGenieEncoder(
                # Delta time + key one-hot
                1 + PIANO_NUM_KEYS,
                rnn_dim=cfg["model_rnn_dim"],
                rnn_num_layers=cfg["model_rnn_num_layers"],
            )
            self.quant = IntegerQuantizer(cfg["num_buttons"])
        self.dec = PianoGenieDecoder(
            # Delta time + key one-hot + <SOS> + Button
            1 + PIANO_NUM_KEYS + 1 + int(cfg["model_enc"]),
            PIANO_NUM_KEYS,
            rnn_dim=cfg["model_rnn_dim"],
            rnn_num_layers=cfg["model_rnn_num_layers"],
        )

    def forward(self, onset_dts, onset_keys):
        if self.enc is None:
            enc = None
            enc_quant_disc = None
        else:
            # Run encoder
            enc = self.enc(
                torch.cat(
                    [
                        # Feature: onset delta time
                        onset_dts.unsqueeze(dim=2),
                        # Feature: onset key index
                        F.one_hot(onset_keys, PIANO_NUM_KEYS),
                    ],
                    dim=2,
                )
            )

            # Quantize
            enc_quant, enc_quant_disc = self.quant(enc, output_discrete=True)

        # Run decoder
        dec_inputs = [
            # Feature: onset delta time
            onset_dts.unsqueeze(dim=2),
            # Feature: *last* onset key index (prepended with <SOS>)
            F.one_hot(
                torch.cat(
                    [
                        torch.full_like(onset_keys[:, :1], PIANO_NUM_KEYS),
                        onset_keys[:, :-1],
                    ],
                    dim=1,
                ),
                PIANO_NUM_KEYS + 1,
            ),
        ]
        if self.enc is not None:
            # Feature: button from encoder
            dec_inputs.append(enc_quant.unsqueeze(dim=2))
        onset_key_logits, _ = self.dec(torch.cat(dec_inputs, dim=2))

        return onset_key_logits, enc, enc_quant_disc




In [None]:
# @markdown #### Approximating $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$

# @markdown This section walks through how we designed an approximation to $P(x_i \mid \mathbf{x}_{< i}, \mathbf{c})$ which would be appropriate for our intended interaction. You probably don't need to understand this, but some may find it helpful as an illustration of how to design a generative model around constraints imposed by interaction.

# @markdown First, we expand the terms:

# @markdown <center>$P(x_i \mid \mathbf{x}_{< i}, \mathbf{c}) = P(t_i, x^e_i, k_i, x^v_i \mid \mathbf{t}_{<i}, \mathbf{x^e}_{<i}, \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{c^s}, \mathbf{c^e}, \mathbf{b})$</center>

# @markdown We think it might be intuitive for the miniature piano to behave like a real piano: pressing a button causes a note to sound, which is held until released. Hence, $N = M$, $t_i = c^s_i$, and $x^e_i = c^e_i$, so we can remove some redundant terms:

# @markdown <center>$= P(k_i, x^v_i \mid \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{t}, \mathbf{c^e}, \mathbf{b})$</center>

# @markdown Beacuse we want this interaction to be real-time, we must remove any term that might not be available at $t_i$, which includes future onsets $\mathbf{t}_{>i}$, future buttons $\mathbf{b}_{>i}$, and all offsets $\mathbf{c^e}$, since notes can be held indefinitely:

# @markdown <center>$\approx P(k_i, x^v_i \mid \mathbf{k}_{<i}, \mathbf{x^v}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

# @markdown Finally, we anticipate that it will be frustrating for users if the model predicts dynamics on their behalf, so we remove velocity terms $\mathbf{x^v}$:

# @markdown <center>$\approx P(k_i, \mid \mathbf{k}_{<i}, \mathbf{t}_{\leq i}, \mathbf{b}_{\leq i})$</center>

In [None]:
# @title **(Step 3)** Train Piano Genie

USE_WANDB = False  # @param{type:"boolean"}

CFG = {
    "seed": 0,
    # Number of buttons in interface
    "num_buttons": 8,
    # Onset delta times will be clipped to this maximum
    "data_delta_time_max": 1.0,
    # Max time stretch for data augmentation (+- 5%)
    "data_augment_time_stretch_max": 0.05,
    # Max transposition for data augmentation (+- tritone)
    "data_augment_transpose_max": 6,
    # Enables encoder
    "model_enc": True,
    # RNN dimensionality
    "model_rnn_dim": 128,
    # RNN num layers
    "model_rnn_num_layers": 2,
    # Training hyperparameters
    "batch_size": 32,
    "seq_len": 128,
    "lr": 3e-4,
    "loss_margin_multiplier": 1.0,
    "loss_contour_multiplier": 1.0,
    "summarize_frequency": 128,
    "eval_frequency": 128,
}

import pathlib
import random

import numpy as np

if USE_WANDB:
    try:
        import wandb
    except ModuleNotFoundError:
        !!pip install wandb
        import wandb

# Init
run_dir = pathlib.Path("piano_genie")
run_dir.mkdir(exist_ok=True)
with open(pathlib.Path(run_dir, "cfg.json"), "w") as f:
    f.write(json.dumps(CFG, indent=2))
if USE_WANDB:
    wandb.init(project="piano-genie", name="tutorial", config=CFG, reinit=True)

# Set seed
if CFG["seed"] is not None:
    random.seed(CFG["seed"])
    np.random.seed(CFG["seed"])
    torch.manual_seed(CFG["seed"])
    torch.cuda.manual_seed_all(CFG["seed"])

# Create models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PianoGenieAutoencoder(CFG)
model.train()
model.to(device)
print("-" * 80)
for n, p in model.named_parameters():
    print(f"{n}, {p.shape}")

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["lr"])

# Subsamples performances to create a minibatch
def performances_to_batch(performances, device, train=True):
    batch_onset_dt = []
    batch_onset_keys = []
    for p in performances:
        # Subsample seq_len notes from performance
        assert len(p) >= CFG["seq_len"]
        if train:
            subsample_offset = random.randrange(0, len(p) - CFG["seq_len"])
        else:
            subsample_offset = 0
        subsample = p[subsample_offset : subsample_offset + CFG["seq_len"]]
        assert len(subsample) == CFG["seq_len"]

        # Data augmentation
        if train:
            stretch_factor = (
                1
                + (random.random() * CFG["data_augment_time_stretch_max"] * 2)
                - CFG["data_augment_time_stretch_max"]
            )
            transposition_factor = random.randint(
                -CFG["data_augment_transpose_max"], CFG["data_augment_transpose_max"]
            )
            subsample = [
                (
                    n[0] * stretch_factor,
                    n[1] * stretch_factor,
                    max(0, min(n[2] + transposition_factor, PIANO_NUM_KEYS - 1)),
                    n[3],
                )
                for n in subsample
            ]

        # Compute onset features
        onset_dt = np.diff([n[0] for n in subsample])
        onset_dt = np.concatenate([[1e8], onset_dt])
        onset_dt = np.clip(onset_dt, 0, CFG["data_delta_time_max"])
        batch_onset_dt.append(onset_dt)

        # Compute onset keys
        batch_onset_keys.append([n[2] for n in subsample])

    return (
        torch.tensor(np.array(batch_onset_dt)).float(),
        torch.tensor(np.array(batch_onset_keys)).long(),
    )


# Train
step = 0
best_eval_loss_recons = float("inf")
while True:
    if step % CFG["eval_frequency"] == 0:
        model.eval()

        with torch.no_grad():
            eval_metrics = defaultdict(list)
            for i in range(0, len(DATASET["validation"]), CFG["batch_size"]):
                eval_batch = performances_to_batch(
                    DATASET["validation"][i : i + CFG["batch_size"]],
                    device,
                    train=False,
                )
                eval_onset_dts, eval_onset_keys = tuple(
                    tensor.to(device) for tensor in eval_batch
                )
                eval_onset_key_logits, _, _ = model(eval_onset_dts, eval_onset_keys)
                eval_loss_recons = F.cross_entropy(
                    eval_onset_key_logits.view(-1, PIANO_NUM_KEYS),
                    eval_onset_keys.view(-1),
                    reduction="none",
                )
                eval_metrics["loss_recons"].extend(
                    eval_loss_recons.cpu().numpy().tolist()
                )

            eval_loss_recons = np.mean(eval_metrics["loss_recons"])
            if eval_loss_recons < best_eval_loss_recons:
                torch.save(model.state_dict(), pathlib.Path(run_dir, "model.pt"))
                best_eval_loss_recons = eval_loss_recons

        eval_metrics = {"eval_loss_recons": eval_loss_recons}
        if USE_WANDB:
            wandb.log(eval_metrics, step=step)
        print(step, "eval", eval_metrics)

        model.train()

    # Create minibatch
    batch = performances_to_batch(
        random.sample(DATASET["train"], CFG["batch_size"]), device, train=True
    )
    onset_dts, onset_keys = tuple(tensor.to(device) for tensor in batch)

    # Run model
    optimizer.zero_grad()
    onset_key_logits, onset_enc, _ = model(onset_dts, onset_keys)

    # Compute losses and update params
    loss_recons = F.cross_entropy(
        onset_key_logits.view(-1, PIANO_NUM_KEYS), onset_keys.view(-1)
    )
    loss_margin = torch.square(
        torch.maximum(torch.abs(onset_enc) - 1, torch.zeros_like(onset_enc))
    ).mean()
    loss_contour = torch.square(
        torch.maximum(
            1 - torch.diff(onset_keys, dim=1) * torch.diff(onset_enc, dim=1),
            torch.zeros_like(onset_enc[:, 1:]),
        )
    ).mean()
    loss = loss_recons
    if CFG["loss_margin_multiplier"] > 0:
        loss += CFG["loss_margin_multiplier"] * loss_margin
    if CFG["loss_contour_multiplier"] > 0:
        loss += CFG["loss_contour_multiplier"] * loss_contour
    loss.backward()
    optimizer.step()
    step += 1

    if step % CFG["summarize_frequency"] == 0:
        metrics = {
            "loss_recons": loss_recons.item(),
            "loss_margin": loss_margin.item(),
            "loss_contour": loss_contour.item(),
            "loss": loss.item(),
        }
        if USE_WANDB:
            wandb.log(metrics, step=step)
        print(step, "train", metrics)

In [None]:
# @title **(Step 4)** Port trained decoder parameters to Tensorflow.js format

!!pip install tensorflowjs

from tensorflowjs.write_weights import write_weights

# Load saved model dict
d = torch.load("piano_genie/model.pt", map_location=torch.device("cpu"))
d = {k: v.numpy() for k, v in d.items()}

# Convert to tensorflow-js format
pathlib.Path("piano_genie/dec_tfjs").mkdir(exist_ok=True)
write_weights(
    [[{"name": k, "data": v} for k, v in d.items() if k.startswith("dec")]],
    "piano_genie/dec_tfjs",
)

In [None]:
# @title **(Step 5)** Create test fixtures check correctness of JavaScript port

# Restore model from saved checkpoint
device = torch.device("cpu")
model = PianoGenieAutoencoder(CFG)
model.load_state_dict(torch.load("piano_genie/model.pt", map_location=device))
model.eval()
model.to(device)

# Serialize a batch of inputs/outputs as JSON
with torch.no_grad():
    input_dts, ground_truth_keys = performances_to_batch(
        [DATASET["validation"][0]], device, train=False
    )
    output_logits, _, input_buttons = model(input_dts, ground_truth_keys)

    input_dts = input_dts[0].cpu().numpy().tolist()
    ground_truth_keys = ground_truth_keys[0].cpu().numpy().tolist()
    input_keys = [PIANO_NUM_KEYS] + ground_truth_keys[:-1]
    input_buttons = input_buttons[0].cpu().numpy().tolist()
    output_logits = output_logits[0].cpu().numpy().tolist()

    fixtures = {
        n: eval(n)
        for n in ["input_dts", "input_keys", "input_buttons", "output_logits"]
    }
    with open(pathlib.Path("piano_genie", "fixtures.json"), "w") as f:
        f.write(json.dumps(fixtures))