In [None]:
# hide
%load_ext nb_black

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [None]:
# export
from typing import Collection

import numpy as np
from omegaconf import OmegaConf
import torch
from torch import Tensor
import torch.nn as nn
from fastai2.text.models import RNNDropout

from neuralmusic.midi import Triplet
from neuralmusic.data.preprocessing import Vocab

<IPython.core.display.Javascript object>

In [None]:
# default_exp model

<IPython.core.display.Javascript object>

# Model

> Learning melody and rhythm at the same time

We're attempting to build a model that can effectively learn two parallel signals (pitch and duration) at the same time, with a single loss function.

But before, we'll take a piece from Fastai v1 that I couldn't find in Fastai2, the Linear Decoder:

In [None]:
# export


class LinearDecoder(nn.Module):
    """
    A Linear Decoder from fastai v1.
    """

    initrange = 0.1

    def __init__(
        self,
        n_out: int,
        n_hid: int,
        output_p: float,
        tie_encoder=None,
        bias: bool = True,
    ):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias:
            self.decoder.bias.data.zero_()
        if tie_encoder:
            self.decoder.weight = tie_encoder.weight

    def forward(self, input: Tensor) -> Tensor:
        output = self.output_dp(input)
        decoded = self.decoder(
            output.view(output.size(0) * output.size(1), output.size(2))
        )
        return decoded

<IPython.core.display.Javascript object>

And finally, the model:

In [None]:
# export


class TheModel(nn.Module):
    """
    A model that learns pitch and duration through separate RNNs, merging them at
    the end to 'compare notes', and outputting separate predictions for each aspect.
    """

    def __init__(
        self,
        pitch_len,
        duration_len,
        kind,
        emb_size=1000,
        rnn_size=1200,
        rnn_layers=3,
        dropout=0.0,
    ):
        super().__init__()

        self.kind = kind  # TODO: use

        self.pitch_emb = nn.Embedding(
            num_embeddings=pitch_len, embedding_dim=emb_size, padding_idx=1
        )
        self.duration_emb = nn.Embedding(
            num_embeddings=duration_len, embedding_dim=emb_size, padding_idx=1
        )

        self.hidden_size = rnn_size
        self.layers = rnn_layers
        self.pitch_rnn = nn.GRU(
            input_size=emb_size,
            hidden_size=rnn_size,
            num_layers=rnn_layers,
            batch_first=False,
            bidirectional=False,
        )
        self.duration_rnn = nn.GRU(
            input_size=emb_size,
            hidden_size=rnn_size,
            num_layers=rnn_layers,
            batch_first=False,
            bidirectional=False,
        )

        self.pitch_dec = LinearDecoder(
            n_out=pitch_len, n_hid=rnn_size + rnn_size, output_p=dropout
        )
        self.duration_dec = LinearDecoder(
            n_out=duration_len, n_hid=rnn_size + rnn_size, output_p=dropout
        )

        self.reset()

    def forward(self, x):
        pitches, durations = x.transpose(0, 1)

        pitch_emb = self.pitch_emb(pitches).transpose(0, 1)
        duration_emb = self.duration_emb(durations).transpose(0, 1)

        if self.pitch_hid is None:
            self.pitch_hid = self.init_hidden(
                self.layers, pitches.size(0), self.hidden_size
            )
        if self.duration_hid is None:
            self.duration_hid = self.init_hidden(
                self.layers, durations.size(0), self.hidden_size
            )

        pitch, self.pitch_hid = self.pitch_rnn(pitch_emb, self.pitch_hid)
        duration, self.duration_hid = self.duration_rnn(duration_emb, self.duration_hid)

        together = torch.cat([pitch, duration], dim=2)

        pitch_decoded = self.pitch_dec(together)
        duration_decoded = self.duration_dec(together)

        self.pitch_hid.detach_()
        self.duration_hid.detach_()

        pitch_out = pitch_decoded.view(pitches.size(0), pitches.size(1), -1)
        duration_out = duration_decoded.view(durations.size(0), durations.size(1), -1)

        return pitch_out, duration_out

    def reset(self):
        self.pitch_hid = None
        self.duration_hid = None

    def init_hidden(self, layers, batch_size, hidden_size):
        weight = next(self.parameters()).data
        return weight.new(layers, batch_size, hidden_size).zero_()

<IPython.core.display.Javascript object>

In [None]:
# export


def triplets_to_input(
    triplets: Collection[Triplet], pitch_vocab, duration_vocab
) -> torch.Tensor:
    """
    Formats a sequence of triplets as an input to the model.
    """
    return torch.tensor(
        [
            [
                [pitch_vocab.index(p) for (p, _, _) in triplets],
                [duration_vocab.index(str(d)) for (_, d, _) in triplets],
            ]
        ]
    )

<IPython.core.display.Javascript object>

In [None]:
# test
from fastai2.text.data import make_vocab

from testing import test_eq, path

from neuralmusic.midi import parse_midi_file, row_to_triplets
from neuralmusic.data.preprocessing import preprocess

raw_df, notes = parse_midi_file(path("data/ff4-airship.mid"))
df, pitch_count, duration_count = preprocess(raw_df)

song = row_to_triplets(df, 0)

batch_size = 1
seq_len = 10
prompt = song[0:seq_len]

pitch_vocab = make_vocab(pitch_count, min_freq=1)
duration_vocab = make_vocab(duration_count, min_freq=1)

model = TheModel(
    pitch_len=len(pitch_vocab),
    duration_len=len(duration_vocab),
    kind="dual",
    emb_size=1000,
    rnn_size=1200,
    rnn_layers=2,
)

pitch_out, duration_out = model(triplets_to_input(prompt, pitch_vocab, duration_vocab))

test_eq(torch.Size([batch_size, seq_len, len(pitch_vocab)]), pitch_out.shape)
test_eq(torch.Size([batch_size, seq_len, len(duration_vocab)]), duration_out.shape)

model

TheModel(
  (pitch_emb): Embedding(56, 1000, padding_idx=1)
  (duration_emb): Embedding(24, 1000, padding_idx=1)
  (pitch_rnn): GRU(1000, 1200, num_layers=2)
  (duration_rnn): GRU(1000, 1200, num_layers=2)
  (pitch_dec): LinearDecoder(
    (decoder): Linear(in_features=2400, out_features=56, bias=True)
    (output_dp): RNNDropout()
  )
  (duration_dec): LinearDecoder(
    (decoder): Linear(in_features=2400, out_features=24, bias=True)
    (output_dp): RNNDropout()
  )
)

<IPython.core.display.Javascript object>

## Prediction

To predict notes from a prompt (a sequence of triplets to prime the model), we'll need a couple more functions.

In [None]:
# export


def choose(top_k, logits, vocab):
    """
    Chooses between the top K probabilities, and returns a single random choice.
    """
    last_logits = logits.squeeze(0)[-1]
    top_vals, top_ix = torch.topk(last_logits, k=top_k)
    choice = np.random.choice(top_ix.tolist())
    category = vocab[choice]
    return choice, category


def predict(device, model, prompt, pitch_vocab, duration_vocab, top_k=5, n_notes=4):
    """
    Predicts the next n notes given a model and a prompt.
    """
    model.reset()
    model.eval()
    notes = []
    input = triplets_to_input(prompt, pitch_vocab, duration_vocab).to(device)
    for n in range(n_notes):
        pitch_out, duration_out = model(input)
        pitch_encoded, pitch = choose(top_k, pitch_out, pitch_vocab)
        duration_encoded, duration = choose(top_k, duration_out, duration_vocab)
        input = torch.tensor([[[pitch_encoded], [duration_encoded]]]).to(device)
        notes.append((pitch, duration))

    return notes

<IPython.core.display.Javascript object>

In [None]:
# test
predicted = predict(
    torch.device("cpu"), model, prompt, pitch_vocab, duration_vocab, top_k=1, n_notes=5
)
pitch, duration = predicted[0]

pitch, duration, pitch_vocab.index(pitch), duration_vocab.index(duration)

('xxfake', 'xxfake', 52, 17)

<IPython.core.display.Javascript object>

In [None]:
# export


def get_model(cfg: OmegaConf, pitch_vocab: Vocab, duration_vocab: Vocab) -> TheModel:
    return TheModel(
        pitch_len=len(pitch_vocab),
        duration_len=len(duration_vocab),
        kind=cfg.name,
        emb_size=cfg.emb_size,
        rnn_size=cfg.rnn_size,
        rnn_layers=cfg.rnn_layers,
    )

<IPython.core.display.Javascript object>