## Import

In [1]:
import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F

## musicVAE Model

In [3]:
class MusicVAE(nn.Module):
    
    def __init__(self, input_size, encoder_hidden_size, decoder_hidden_size):
        super(MusicVAE, self).__init__()
    
        self.encoder = nn.LSTM(
            batch_first=True,
            input_size=input_size,
            hidden_size=encoder_hidden_size,
            num_layers=2,
            bidirectional=True
        )
        self.mu_linear = nn.Linear(encoder_hidden_size*2, 512)
        self.sigma_linear = nn.Linear(encoder_hidden_size*2, 512)
        self.z_latent_linear = nn.Sequential(
            nn.Linear(512, 512),
            nn.Tanh(),
        )
        
        self.conductor = nn.LSTM(
            batch_first=True,
            input_size=512,
            hidden_size=decoder_hidden_size,
            proj_size=512,
            num_layers=2,
            bidirectional=False
        )
        self.c_linear = nn.Sequential(
            nn.Linear(512, 1024),
            nn.Tanh(),
        )
            
        self.bottom_level_decoder = nn.LSTM(
            batch_first=True,
            input_size=decoder_hidden_size,
            hidden_size=decoder_hidden_size,
            num_layers=2,
            bidirectional=False
        )
        self.classifier = nn.Linear(1024, INPUT_SIZE)
        
    def forward(self, x):
        z = self.encode(x)
        z = self.z_latent_linear(z)
        self.decode(z)
        
        
    def encode(self, x):
        x, _ = self.encoder(x)
        mu = self.mu_linear(x)
        sigma = F.softplus(self.sigma_linear(x))
        
        sigma = torch.exp(0.5 * sigma)
        eps = torch.randn_like(sigma)
        return mu + sigma * eps
    
    def decode(self, x):
        c_vector, _ = self.conductor(x)
        c_vector = self.c_linear(c_vector)
        
        h, c = torch.zeros(c_vector.size(0), c_vector.size(2)), torch.zeros(c_vector.size(0), c_vector.size(2))
        for i in range(16):
            x, (h, c) = self.bottom_level_decoder(c_vector[:, i, :], (h, c))
            x = self.classifier(x)
            x = F.softmax(x, dim=-1)

        
        
# TEST
INPUT_SIZE = 7
BATCH_SIZE = 2
ENCODER_HIDDEN_SIZE = 2048
DECODER_HIDDEN_SIZE = 1024

model = MusicVAE(input_size=INPUT_SIZE, encoder_hidden_size=ENCODER_HIDDEN_SIZE, decoder_hidden_size=DECODER_HIDDEN_SIZE)
print(model)
model(torch.randn(BATCH_SIZE, 16, INPUT_SIZE))

MusicVAE(
  (encoder): LSTM(7, 2048, num_layers=2, batch_first=True, bidirectional=True)
  (mu_linear): Linear(in_features=4096, out_features=512, bias=True)
  (sigma_linear): Linear(in_features=4096, out_features=512, bias=True)
  (z_latent_linear): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): Tanh()
  )
  (conductor): LSTM(512, 1024, proj_size=512, num_layers=2, batch_first=True)
  (c_linear): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): Tanh()
  )
  (bottom_level_decoder): LSTM(1024, 1024, num_layers=2, batch_first=True)
  (classifier): Linear(in_features=1024, out_features=7, bias=True)
)
z: torch.Size([2, 16, 512])
c torch.Size([2, 16, 512])
