In [69]:
import sys
sys.path.append("../src")

In [77]:
from config import Config

cfg = Config()

In [237]:
import joblib

In [238]:
features = joblib.load("../data/feats/benzaiten_feats.pkl")
data_all = features["data"]
label_all = features["label"]

In [246]:
note_seq = data_all[:, :, :49]
chord_seq = data_all[:, :, -12:]

## Train

### Define model

In [247]:
chord = "Cm7"

In [249]:
import torch
import torch.nn as nn

In [465]:
# TODO: 
# - sample の inputを作る or 作られてるfeatsを使ってデモinputを作る
# - 入力から出力まで通るようにモデル実装
# - train.py のモデルをこれに差し替える
# - メロディ生成してみる


class LSTMEncoder(nn.Module):
    def __init__(
        self, input_dim: int, latent_dim: int, hidden_dim: int, n_layers: int = 1
    ) -> None:
        super(LSTMEncoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        # NOTE: 双方向LSTMにすると隠れ層の次元が２倍になるため、後続の入力次元を倍にするか、平均化するかが必要
        self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
        self.relu = nn.ReLU()

        self._to_mean = nn.Linear(hidden_dim, latent_dim)
        self._to_lnvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        _, (hidden, cell) = self.lstm(x)
        z = self.relu(hidden)
        z = z.transpose(0, 1)
        mean  = self._to_mean(z)
        logvar =  self._to_lnvar(z)
        return mean, logvar


class LSTMDecoder(nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, n_layers: int = 1) -> None:
        super(LSTMDecoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.lstm = nn.LSTM(
            input_dim, hidden_dim, num_layers=n_layers, batch_first=True
        )
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, latent, condition) -> torch.Tensor:
        inputs = torch.cat([latent, condition], dim=2)
        output, (hidden, cell) = self.lstm(inputs)
        x_hat = self.fc_out(output)
        return x_hat


class Chord2Melody(nn.Module):
    """CVAE model to generate melody from chord progression.
    """
    def __init__(self, input_dim: int, latent_dim: int, condition_dim: int) -> None:
        super(Chord2Melody, self).__init__()
        self.encoder = LSTMEncoder(input_dim, latent_dim, 128, 1)
        self.decoder = LSTMDecoder((latent_dim + condition_dim), input_dim, 128, 2)

    def encode(self, x):
        mean, logvar = self.encoder(x)
        latent = self.reparameterization(mean, logvar)
        return latent

    def reparameterization(
        self, mean: torch.Tensor, logvar: torch.Tensor
    ) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        latent = mean + eps * std
        return latent
    
    def decode(self, latent, condition):
        x_hat = self.decoder(latent, condition)
        return x_hat

    
    def forward(self, melody, chord_prog):
        latent = self.encode(melody)
        latent = latent.repeat(1, 64, 1)
        x_hat = self.decode(latent, condition=chord_prog)
        return x_hat

In [466]:
model = Chord2Melody(49, 128, 12)

In [467]:
batch_size = 10


melody = torch.from_numpy(note_seq[:batch_size]).to(torch.float32)
condition = torch.from_numpy(chord_seq[:batch_size]).to(torch.float32)
output = model(melody, condition)

In [468]:
output.argmax(dim=2)[:, -4:]

tensor([[32, 32, 32, 32],
        [23, 23, 23, 23],
        [16, 16, 16, 16],
        [32, 32, 32, 32],
        [32, 32, 32, 32],
        [32, 32, 32, 32],
        [16, 16, 16, 16],
        [32, 32, 32, 32],
        [13, 13, 13, 13],
        [44, 44, 44, 44]])

In [374]:
# Train model

## Inference

In [49]:
batch_size = 4
generated_melody = []

for i in range(120):
    if len(generated_melody) == 0:
        latent_dim = 64

        latent = torch.rand((batch_size, latent_dim))
        condition = torch.tile(torch.tensor([1], dtype=torch.int32), dims=(batch_size, 1))
        output = model.decode(latent, condition)
    else:
        condition = torch.tile(torch.tensor([1], dtype=torch.int32), dims=(batch_size, 1))
        output = model(generated_melody, condition)

    generated_melody.append(output)

tensor([11])

## Generate midi

In [None]:
# TODO: generated_melody to midi file