In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import IPython.display as ipd
import pyximport
pyximport.install()
%load_ext Cython

import sigkernel as ksig
from utils.midi import *
from utils.data import *
from model.generators import *

In [2]:
hist_len = 50
sample_len = 100
seq_dim = 4
scale = 1.

In [3]:
path = './data/maestro-v3.0.0_midi/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi'
midi_data = pretty_midi.PrettyMIDI(path)
df = midi_to_df(midi_data)
dataset = DFDataset(df, sample_len, scale, stride=20, col_idx=list(range(4)))
dataloader = DataLoader(dataset, batch_size=50, shuffle=True, num_workers=0)

In [4]:
for x in dataloader:
    print(x)
    break

tensor([[[0.0000e+00, 1.4169e-01, 5.0000e+01, 7.0000e+01],
         [1.8646e-01, 2.1875e-01, 5.7000e+01, 7.0000e+01],
         [1.9583e-01, 3.5834e-01, 7.9000e+01, 7.1000e+01],
         ...,
         [9.0844e+00, 9.2167e+00, 4.7000e+01, 6.7000e+01],
         [9.2365e+00, 9.4115e+00, 4.0000e+01, 6.6000e+01],
         [9.2510e+00, 9.2906e+00, 7.1000e+01, 7.0000e+01]],

        [[0.0000e+00, 1.6250e-01, 7.6000e+01, 7.9000e+01],
         [2.0104e-01, 4.6458e-01, 7.9000e+01, 7.6000e+01],
         [4.1979e-01, 4.8021e-01, 7.2000e+01, 7.0000e+01],
         ...,
         [1.2309e+01, 1.2561e+01, 7.5000e+01, 6.9000e+01],
         [1.2516e+01, 1.3089e+01, 5.2000e+01, 6.1000e+01],
         [1.2522e+01, 1.4464e+01, 7.6000e+01, 6.5000e+01]],

        [[0.0000e+00, 1.3541e-01, 7.9000e+01, 8.2000e+01],
         [2.0752e-03, 1.5938e-01, 5.2000e+01, 7.6000e+01],
         [1.2604e-01, 2.1875e-01, 8.1000e+01, 8.0000e+01],
         ...,
         [6.7469e+00, 6.8292e+00, 7.8000e+01, 6.2000e+01],
         [

In [5]:
generator = TransformerMusic(seq_dim, sample_len, hist_len, scale=scale,
                             kernel_size=5, stride=1, n_channels=16, n_head=4, n_transformer_layers=1, hidden_size=128, activation='GELU')
generator = generator.cuda()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5, verbose=True)

In [6]:
static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=0.1)
kernel = ksig.kernels.SignatureKernel(n_levels=5, order=5, normalization=0, static_kernel=static_kernel, device_ids=None)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(100):
    losses = [] # due to legacy code, losses is actually the mmd
    for batch_num, X in enumerate(tqdm(dataloader)):
        X = X.to(device)

        output = generator(X)
        X_wo_hist = X[:, hist_len:, :]

        # compute loss
        optimizer.zero_grad()
        loss = ksig.tests.mmd_loss_no_compile(X_wo_hist, output, kernel)
        losses.append(loss.item())

        # backpropagate and update weights
        loss.backward()
        optimizer.step()

    # log epoch loss and plot generated samples
    epoch_loss = np.average(losses) # average batch mmd for epoch
    scheduler.step(epoch_loss)
    print(f'Epoch {epoch}, loss: {epoch_loss}')

100%|██████████| 8/8 [00:12<00:00,  1.59s/it]


Epoch 0, loss: 0.00034290552139282227


100%|██████████| 8/8 [00:12<00:00,  1.53s/it]


Epoch 1, loss: 0.0002301335334777832


100%|██████████| 8/8 [00:12<00:00,  1.52s/it]


Epoch 2, loss: 0.00027191638946533203


100%|██████████| 8/8 [00:12<00:00,  1.52s/it]


Epoch 3, loss: 0.0002333521842956543


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 4, loss: 0.0003839433193206787


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 5, loss: 0.0003235340118408203


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 6, loss: 0.0005273222923278809


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 7, loss: 0.0003798156976699829


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 8, loss: 0.0003165900707244873


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 9, loss: 0.00035956501960754395


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 10, loss: 0.00029271841049194336


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 11, loss: 0.0007114112377166748


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 00013: reducing learning rate of group 0 to 2.5000e-04.
Epoch 12, loss: 0.00038629770278930664


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 13, loss: 0.00026279687881469727


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 14, loss: 0.000368267297744751


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 15, loss: 0.0005691349506378174


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 16, loss: 0.0001999884843826294


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 17, loss: 0.0004608035087585449


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 18, loss: 0.00015558302402496338


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 19, loss: 0.00038686394691467285


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 20, loss: 0.0003968477249145508


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 21, loss: 0.0005238354206085205


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 22, loss: 0.00024968385696411133


100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


Epoch 23, loss: 0.000435560941696167


100%|██████████| 8/8 [00:12<00:00,  1.53s/it]


Epoch 24, loss: 0.00025266408920288086


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 25, loss: 0.00040775537490844727


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 26, loss: 0.00021901726722717285


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 27, loss: 0.0003939718008041382


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 28, loss: 0.0003119558095932007


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 00030: reducing learning rate of group 0 to 1.2500e-04.
Epoch 29, loss: 0.00040653347969055176


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 30, loss: 0.00042259693145751953


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 31, loss: 0.000431939959526062


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 32, loss: 0.00029665231704711914


100%|██████████| 8/8 [00:12<00:00,  1.57s/it]


Epoch 33, loss: 0.000544816255569458


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 34, loss: 0.0003367960453033447


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 35, loss: 0.00028520822525024414


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 36, loss: 0.00044989585876464844


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 37, loss: 0.0003837496042251587


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 38, loss: 0.0002539306879043579


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 39, loss: 0.000138893723487854


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 40, loss: 0.0003598034381866455


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 41, loss: 0.00043645501136779785


100%|██████████| 8/8 [00:12<00:00,  1.52s/it]


Epoch 42, loss: 0.00036288797855377197


100%|██████████| 8/8 [00:12<00:00,  1.57s/it]


Epoch 43, loss: 0.0005569905042648315


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 44, loss: 0.0004912614822387695


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 45, loss: 0.0003471076488494873


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 46, loss: 0.00036209821701049805


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 47, loss: 0.00025963783264160156


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 48, loss: 0.00038358569145202637


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 49, loss: 0.0002868473529815674


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 00051: reducing learning rate of group 0 to 6.2500e-05.
Epoch 50, loss: 0.0005068480968475342


100%|██████████| 8/8 [00:12<00:00,  1.57s/it]


Epoch 51, loss: 0.0005899369716644287


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 52, loss: 0.0005483031272888184


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 53, loss: 0.00035144388675689697


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 54, loss: 0.00036522746086120605


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 55, loss: 0.00029599666595458984


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 56, loss: 0.0005617588758468628


100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


Epoch 57, loss: 0.0005692839622497559


100%|██████████| 8/8 [00:12<00:00,  1.50s/it]


Epoch 58, loss: 0.0003655552864074707


100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


Epoch 59, loss: 0.0005972981452941895


100%|██████████| 8/8 [00:12<00:00,  1.53s/it]


Epoch 60, loss: 0.0003979504108428955


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 00062: reducing learning rate of group 0 to 3.1250e-05.
Epoch 61, loss: 0.0002792775630950928


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 62, loss: 0.0005373656749725342


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 63, loss: 0.0002422630786895752


100%|██████████| 8/8 [00:12<00:00,  1.54s/it]


Epoch 64, loss: 0.0002784132957458496


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 65, loss: 0.00029721856117248535


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Epoch 66, loss: 0.0002823173999786377


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 67, loss: 0.0005319267511367798


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 68, loss: 0.0005060732364654541


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 69, loss: 0.0003104954957962036


100%|██████████| 8/8 [00:12<00:00,  1.55s/it]


Epoch 70, loss: 0.0006821900606155396


100%|██████████| 8/8 [00:12<00:00,  1.53s/it]


Epoch 71, loss: 0.0004452913999557495


100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


Epoch 00073: reducing learning rate of group 0 to 1.5625e-05.
Epoch 72, loss: 0.0002434849739074707


 50%|█████     | 4/8 [00:07<00:07,  1.95s/it]


KeyboardInterrupt: 

In [8]:
for x in dataloader:
    x = x.to(device)
    output = generator(x)
    break

In [9]:
i = 0
x_np = x[i].cpu().detach().numpy()
x_np[:, 2:] = np.round(x_np[:, 2:] * scale)
output_np = output[i].cpu().detach().numpy()
output_np[:, 2:] = np.round(output_np[:, 2:] * scale)
df_x = pd.DataFrame(x_np, columns=['Start', 'End', 'Pitch', 'Velocity'])
df_output = pd.DataFrame(output_np, columns=['Start', 'End', 'Pitch', 'Velocity'])

In [10]:
df_input = df_x.iloc[hist_len:]
df_input.iloc[:,:2] = df_input.iloc[:,:2] - df_input.iloc[0,0]
df_input

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_input.iloc[:,:2] = df_input.iloc[:,:2] - df_input.iloc[0,0]


Unnamed: 0,Start,End,Pitch,Velocity
50,0.0,0.088562,62.0,78.0
51,0.020874,0.153137,79.0,64.0
52,0.037537,0.09375,77.0,70.0
53,0.136475,0.188538,67.0,74.0
54,0.146912,0.185425,84.0,61.0
55,0.1521,0.185425,83.0,67.0
56,0.276062,0.427124,60.0,76.0
57,0.2771,0.369812,76.0,83.0
58,0.414612,0.491699,64.0,66.0
59,0.43335,0.485413,79.0,57.0


In [11]:
df_output.iloc[:,:2] = df_output.iloc[:,:2] - df_output.iloc[0,0]
df_output

Unnamed: 0,Start,End,Pitch,Velocity
0,0.0,0.785002,0.0,1.0
1,0.016623,0.172844,61.0,36.0
2,0.060726,1.117059,113.0,115.0
3,0.073077,2.074643,30.0,22.0
4,0.084586,1.184121,73.0,75.0
5,0.086188,0.33914,42.0,121.0
6,0.095097,2.512566,61.0,4.0
7,0.11956,0.799353,84.0,17.0
8,0.144665,1.263191,89.0,105.0
9,0.184938,0.58349,77.0,23.0


In [12]:
input_midi = df_to_midi(df_input)
output_midi = df_to_midi(df_output)

In [15]:
Fs = 22050
audio_data = input_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)

In [16]:
audio_data = output_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)