## Arousal style transfer code demo
This notebook provides demonstration on performing arousal style transfer
The pre-trained model used here is trained on single bar segments (3-6 seconds).

In [1]:
import json
import torch
from gmm_model import *
import os
from sklearn.model_selection import train_test_split
from ptb_v2 import *
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pretty_midi
from IPython.display import Audio
from tqdm import tqdm
from polyphonic_event_based_v2 import *
from collections import Counter
import matplotlib.pyplot as plt
from polyphonic_event_based_v2 import parse_pretty_midi

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


### Load dataset and models

In [2]:
# initialization
with open('gmm_model_config.json') as f:
    args = json.load(f)
if not os.path.isdir('log'):
    os.mkdir('log')
if not os.path.isdir('params'):
    os.mkdir('params')


from datetime import datetime
timestamp = str(datetime.now())
save_path_timing = 'params/{}.pt'.format(args['name'] + "_" + timestamp)

# model dimensions
EVENT_DIMS = 342
RHYTHM_DIMS = 3
NOTE_DIMS = 16
CHROMA_DIMS = 24

model = MusicAttrRegGMVAE(roll_dims=EVENT_DIMS, rhythm_dims=RHYTHM_DIMS, note_dims=NOTE_DIMS, 
                        chroma_dims=CHROMA_DIMS,
                        hidden_dims=args['hidden_dim'], z_dims=args['z_dim'], 
                        n_step=args['time_step'],
                        n_component=2)  
model.load_state_dict(torch.load("params/music_attr_vae_reg_gmm.pt"))
print("Loading params/music_attr_vae_reg_gmm.pt...")
    

if torch.cuda.is_available():
    print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device()))
    model.cuda()
else:
    print('CPU mode')

step, pre_epoch = 0, 0
batch_size = args["batch_size"]
print(batch_size)
# model.train()

# dataloaders
print("Loading Yamaha...")
is_shuffle = False
data_lst, rhythm_lst, note_density_lst, chroma_lst = get_classic_piano()
tlen, vlen = int(0.8 * len(data_lst)), int(0.9 * len(data_lst))
train_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, mode="train")
train_dl_dist = DataLoader(train_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
val_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, mode="val")
val_dl_dist = DataLoader(val_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
test_ds_dist = YamahaDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, mode="test")
test_dl_dist = DataLoader(test_ds_dist, batch_size=batch_size, shuffle=is_shuffle, num_workers=0)
dl = train_dl_dist
print(len(train_ds_dist), len(val_ds_dist), len(test_ds_dist))

# vgmidi dataloaders
print("Loading VGMIDI...")
data_lst, rhythm_lst, note_density_lst, chroma_lst, arousal_lst, valence_lst = get_vgmidi()
# data_lst, rhythm_lst, note_density_lst, arousal_lst, valence_lst, chroma_lst = get_vgmidi()
vgm_train_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="train")
vgm_train_dl_dist = DataLoader(vgm_train_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)
vgm_val_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="val")
vgm_val_dl_dist = DataLoader(vgm_val_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)
vgm_test_ds_dist = VGMIDIDataset(data_lst, rhythm_lst, note_density_lst, 
                                chroma_lst, arousal_lst, valence_lst, mode="test")
vgm_test_dl_dist = DataLoader(vgm_test_ds_dist, batch_size=32, shuffle=is_shuffle, num_workers=0)
print(len(vgm_train_ds_dist), len(vgm_val_ds_dist), len(vgm_test_ds_dist))
print()

Loading params/music_attr_vae_reg_gmm.pt...
Using:  Tesla V100-DGXS-32GB


  9%|▉         | 9782/103998 [00:00<00:00, 97813.56it/s]

128
Loading Yamaha...
Dataset length: 1703


100%|██████████| 103998/103998 [00:01<00:00, 80319.81it/s]


Shapes for: Data, Rhythm Density, Note Density, Chroma
(103934, 100) (103934, 16) (103934, 16) (103934, 24)
83147 10393 10394
Loading VGMIDI...
Shapes for: Data, Rhythm Density, Note Density, Chroma
(1013,) (1013,) (1013,) (1013, 24)
Shapes for: Arousal, Valence
(1013,) (1013,)
911 51 51



In [3]:
def convert_to_one_hot(input, dims):
    if len(input.shape) > 1:
        input_oh = torch.zeros((input.shape[0], input.shape[1], dims)).cuda()
        input_oh = input_oh.scatter_(-1, input.unsqueeze(-1), 1.)
    else:
        input_oh = torch.zeros((input.shape[0], dims)).cuda()
        input_oh = input_oh.scatter_(-1, input.unsqueeze(-1), 1.)
    return input_oh

def clean_output(out):
    recon = np.trim_zeros(torch.argmax(out, dim=-1).cpu().detach().numpy().squeeze())
    if 1 in recon:
        last_idx = np.argwhere(recon == 1)[0][0]
        recon[recon == 1] = 0
        recon = recon[:last_idx]
    return recon

def repar(mu, stddev, sigma=1):
    eps = Normal(0, sigma).sample(sample_shape=stddev.size()).cuda()
    z = mu + stddev * eps  # reparameterization trick
    return z
    

### Obtain "shifting vectors"

In [4]:
mu_r_lst = []
var_r_lst = []
mu_n_lst = []
var_n_lst = []
for k_i in torch.arange(0, 2):
    mu_k = model.mu_r_lookup(k_i.cuda())
    mu_r_lst.append(mu_k.cpu().detach())
    
    var_k = model.logvar_r_lookup(k_i.cuda()).exp_()
    var_r_lst.append(var_k.cpu().detach())
    
    mu_k = model.mu_n_lookup(k_i.cuda())
    mu_n_lst.append(mu_k.cpu().detach())
    
    var_k = model.logvar_n_lookup(k_i.cuda()).exp_()
    var_n_lst.append(var_k.cpu().detach())

r_low_to_high = mu_r_lst[1] - mu_r_lst[0]
r_high_to_low = mu_r_lst[0] - mu_r_lst[1]
n_low_to_high = mu_n_lst[1] - mu_n_lst[0]
n_high_to_low = mu_n_lst[0] - mu_n_lst[1]

### Example on low arousal -> high arousal

In [5]:
d, r, n, c, r_density, n_density = test_ds_dist[404]
c = torch.Tensor(c).cuda().unsqueeze(0)
d[d == 1] = 0
d_oh = convert_to_one_hot(torch.Tensor(d).cuda().long(), EVENT_DIMS)
pm = magenta_decode_midi(np.trim_zeros(d))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

In [7]:
model.eval()
dis_r, dis_n = model.encode(d_oh.unsqueeze(0))
z_r = dis_r.rsample()
z_n = dis_n.rsample()

lmbda = 1
z_r_new = z_r + lmbda*torch.Tensor(r_low_to_high).cuda()
z_n_new = z_n + lmbda*torch.Tensor(n_low_to_high).cuda()

z = torch.cat([z_r_new, z_n_new, c], dim=1)        
out = model.global_decoder(z, steps=300)
print("Tokens:", clean_output(out))

pm = magenta_decode_midi(clean_output(out))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

Tokens: [185 301  49 181 295  33 182 290  37 190 121 188 125 179 305  51 193 294
  33 188 137 180 300  37 181 121 187 294  42 188 125 187 298  30 180 130
 193  33 187 302  27 182 281  48 181 118 180 139 183 299  37 185 121 187
 115 306  26 193 121 188 114 296  52 193 136 179 125 179 301  32 297  49
 183 140 120 137]



### Example on high arousal -> low arousal

In [9]:
d, r, n, c, r_density, n_density = test_ds_dist[100]
c = torch.Tensor(c).cuda().unsqueeze(0)
d[d == 1] = 0
d_oh = convert_to_one_hot(torch.Tensor(d).cuda().long(), EVENT_DIMS)
pm = magenta_decode_midi(np.trim_zeros(d))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

In [11]:
model.eval()
dis_r, dis_n = model.encode(d_oh.unsqueeze(0))
z_r = dis_r.rsample()
z_n = dis_n.rsample()

lmbda = 1
z_r_new = z_r + lmbda*torch.Tensor(r_high_to_low).cuda()
z_n_new = z_n + lmbda*torch.Tensor(n_high_to_low).cuda()

z = torch.cat([z_r_new, z_n_new, c], dim=1)        
out = model.global_decoder(z, steps=300)
print("Tokens:", clean_output(out))

pm = magenta_decode_midi(clean_output(out))
a_1 = pm.fluidsynth()
Audio(a_1, rate=44100)

Tokens: [312  44 179 306  35 182 132 188 309  51 178 314  63 182 123 200  44 183
 298  48 207 136 181 292  36 186 312  58 182 301  53 187 132 139 203 288
  46 212 124 191 134 191 146 141 151]
