In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import json
import pickle
import sparse
import random
# import librosa
import mir_eval
import fluidsynth
import pretty_midi
import numpy as np
from time import time

import IPython.display
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.distributed import DistributedSampler

from utils.drum_utils import *
from utils.common_utils import *
from utils.train import *
from utils.test import *
from utils.layers import *
from utils.loss import *

%load_ext autoreload
%autoreload 2

In [None]:
# initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

In [None]:
device

### Load data

In [None]:
path = 'midi_data.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)

print('The number of data : %d' % len(data))

In [None]:
# play example
fs = 7
pm = drum_play(data[6].todense(), fs)
IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000)

In [None]:
# shuffle and split
num_data = len(data)
# random.shuffle(data)

num_train = int(num_data * 0.7)
num_val = int(num_data * 0.1)

train_data = data[:num_train]
val_data = data[num_train:num_train+num_val]
test_data = data[num_train+num_val:]

print('The number of train: %d' % len(train_data))
print('The number of validation: %d' % len(val_data))
print('The number of test: %d' % len(test_data))

In [None]:
# custom dataloader
class DatasetSampler(Dataset):
    def __init__(self, x):
        self.x = x

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx].todense().astype('float32')

In [None]:
# dataloader
params = {'batch_size': 512, 
          'shuffle': True,
          'pin_memory': True,
          'num_workers': 1}

train_set = DataLoader(DatasetSampler(train_data), **params)
val_set = DataLoader(DatasetSampler(val_data), **params)
test_set = DataLoader(DatasetSampler(test_data), **params)

In [None]:
# init model
enc_input_size = 512
enc_latent_dim = 256
enc_hidden_size = 512

encoder = Encoder(enc_input_size, enc_hidden_size, enc_latent_dim)
encoder = encoder.to(device)

con_input_size = enc_latent_dim
con_hidden_size = 256

conductor = Conductor(con_input_size, con_hidden_size, device)
conductor = conductor.to(device)

dec_input_size = con_hidden_size
dec_hidden_size = 256
dec_output_size = 512

decoder = Hierarchical_Decoder(dec_input_size, dec_hidden_size, dec_output_size)
decoder = decoder.to(device)

model = [encoder, conductor, decoder]

In [None]:
# optimizer
enc_optimizer = optim.Adam(encoder.parameters(), lr=1e-3)
con_optimizer = optim.Adam(conductor.parameters(), lr=1e-3)
dec_optimizer = optim.Adam(decoder.parameters(), lr=1e-3)

optimizer = [enc_optimizer, con_optimizer, dec_optimizer]

### Train

In [None]:
history = hierarchical_train(device, vae_loss, train_set, val_set, model, optimizer, bar_units=16, epochs=1000)

In [None]:
# save moodel
from time import localtime, time
tm = localtime(time())
torch.save(model[0].state_dict(), './model/encoder_{:d}_{:d}_{:d}_{:d}_{:d}'.format(enc_hidden_size, 
                                                                    tm.tm_mon, tm.tm_mday, tm.tm_hour, tm.tm_sec))
torch.save(model[1].state_dict(), './model/conductor_{:d}_{:d}_{:d}_{:d}_{:d}'.format(con_hidden_size, 
                                                                    tm.tm_mon, tm.tm_mday, tm.tm_hour, tm.tm_sec))
torch.save(model[2].state_dict(), './model/decoder_{:d}_{:d}_{:d}_{:d}_{:d}'.format(dec_hidden_size, 
                                                                    tm.tm_mon, tm.tm_mday, tm.tm_hour, tm.tm_sec))

### Test

In [None]:
history, y_true, y_pred = hierarchical_test(device, vae_loss, test_set, model, bar_units=16, options='full_sampling')

In [None]:
# input
fs = 7; idx = 10
pm = drum_play(y_true[idx], fs)
IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000)

In [None]:
# reconstruct sampled from categorical distribution
pm = drum_play(prob_soft_label(y_pred[idx]), fs)
IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000)
# pm.write('output.mid')

### Generate custom inputs and predict

In [None]:
def predict(feat, decoder, bar_units=16, seq_len=64, temp=1):
    batch_size = feat.shape[0]
    
    hidden_size = decoder.hidden_size
    output_size = decoder.output_size
    num_hidden = decoder.num_hidden
    
    inputs = torch.zeros((batch_size, 1, output_size), device=device)
    outputs = torch.zeros((batch_size, seq_len, output_size), device=device) # argmax
    
    # full sampling
    for j in range(seq_len):
        bar_idx = j // bar_units
        bar_change_idx = j % bar_units
        
        z = feat[:, bar_idx, :]
        
        if bar_change_idx == 0:
            h = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            c = z.repeat(num_hidden, 1, int(hidden_size/z.shape[1]))
            
        label, prob, h, c = decoder(inputs, h, c, z, temp=temp)
        outputs[:, j, :] = prob.squeeze()

        inputs = F.one_hot(label, num_classes=output_size)
        
    return outputs

In [None]:
# custom input
sequence = [[0, 3], [3], [3], [0, 3], [3], [3], [1, 3], [3]]

In [None]:
# generate new sample
dim = 512
hot_encoding = np.eye(dim)
hot_encoded = np.zeros((2*len(sequence), dim), dtype='float32')

for i in range(0, 2*len(sequence), 2):
    hit_idx = int(i/2)
    if sequence[hit_idx][0] == -1:
        hot_encoded[i, 0] = 1
        continue
        
    temp = np.zeros(9)
    temp[sequence[hit_idx]] = 1
    decimal = bin_to_dec(temp)
    
    hot_encoded[i, :] = hot_encoding[decimal]
    hot_encoded[i+1, 0] = 1 # rest
    
hot_encoded = np.tile(hot_encoded, (4, 1))
print('input shape :', hot_encoded.shape)

In [None]:
# play input
fs = 7
pm = drum_play(hot_encoded, fs)
IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000)

In [None]:
# MusicVAE inference
fs = 7
temp = 3

test = torch.from_numpy(hot_encoded).to(device).unsqueeze(0)

z, mu, std = encoder(test)
feat = conductor(z)
pred = np.squeeze(predict(feat, decoder, temp=temp).data.cpu().numpy())

pm = drum_play(prob_soft_label(pred), fs=fs)
IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000)