In [None]:
# hide
%load_ext nb_black

The nb_black extension is already loaded. To reload it, use:
  %reload_ext nb_black


<IPython.core.display.Javascript object>

In [None]:
# default_exp model

<IPython.core.display.Javascript object>

In [None]:
# export
import torch
from torch import Tensor
import torch.nn as nn
from fastai2.text.models import RNNDropout

<IPython.core.display.Javascript object>

# Model

We're attempting to build a model that can effectively learn two parallel signals (pitch and duration) at the same time, with a single loss function.

But before, we'll take a piece from Fastai v1 that I couldn't find in Fastai2, the Linear Decoder:

In [None]:
# export
class LinearDecoder(nn.Module):
    initrange = 0.1

    def __init__(
        self,
        n_out: int,
        n_hid: int,
        output_p: float,
        tie_encoder=None,
        bias: bool = True,
    ):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias:
            self.decoder.bias.data.zero_()
        if tie_encoder:
            self.decoder.weight = tie_encoder.weight

    def forward(self, input: Tensor) -> Tensor:
        output = self.output_dp(input)
        decoded = self.decoder(
            output.view(output.size(0) * output.size(1), output.size(2))
        )
        return decoded

<IPython.core.display.Javascript object>

And finally, the model:

In [None]:
# export
class TheModel(nn.Module):
    def __init__(
        self,
        pitch_len,
        duration_len,
        emb_size=1000,
        rnn_size=1200,
        rnn_layers=3,
        dropout=0.0,
    ):
        super().__init__()

        self.pitch_emb = nn.Embedding(
            num_embeddings=pitch_len, embedding_dim=emb_size, padding_idx=1
        )
        self.duration_emb = nn.Embedding(
            num_embeddings=duration_len, embedding_dim=emb_size, padding_idx=1
        )

        self.hidden_size = rnn_size
        self.layers = rnn_layers
        self.pitch_rnn = nn.GRU(
            input_size=emb_size,
            hidden_size=rnn_size,
            num_layers=rnn_layers,
            batch_first=False,
            bidirectional=False,
        )
        self.duration_rnn = nn.GRU(
            input_size=emb_size,
            hidden_size=rnn_size,
            num_layers=rnn_layers,
            batch_first=False,
            bidirectional=False,
        )

        self.pitch_dec = LinearDecoder(
            n_out=pitch_len, n_hid=rnn_size + rnn_size, output_p=dropout
        )
        self.duration_dec = LinearDecoder(
            n_out=duration_len, n_hid=rnn_size + rnn_size, output_p=dropout
        )

        self.reset()

    def forward(self, x):
        pitches, durations = x.transpose(0, 1)

        pitch_emb = self.pitch_emb(pitches).transpose(0, 1)
        duration_emb = self.duration_emb(durations).transpose(0, 1)

        if self.pitch_hid is None:
            self.pitch_hid = self.init_hidden(
                self.layers, pitches.size(0), self.hidden_size
            )
        if self.duration_hid is None:
            self.duration_hid = self.init_hidden(
                self.layers, durations.size(0), self.hidden_size
            )

        pitch, self.pitch_hid = self.pitch_rnn(pitch_emb, self.pitch_hid)
        duration, self.duration_hid = self.duration_rnn(duration_emb, self.duration_hid)

        together = torch.cat([pitch, duration], dim=2)

        pitch_decoded = self.pitch_dec(together)
        duration_decoded = self.duration_dec(together)

        self.pitch_hid.detach_()
        self.duration_hid.detach_()

        pitch_out = pitch_decoded.view(pitches.size(0), pitches.size(1), -1)
        duration_out = duration_decoded.view(durations.size(0), durations.size(1), -1)

        return pitch_out, duration_out

    def reset(self):
        self.pitch_hid = None
        self.duration_hid = None

    def init_hidden(self, layers, batch_size, hidden_size):
        weight = next(self.parameters()).data
        return weight.new(layers, batch_size, hidden_size).zero_()

<IPython.core.display.Javascript object>