In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import IPython.display as ipd
import pyximport
pyximport.install()
%load_ext Cython
import sigkernel as ksig
from utils.midi import *
from utils.data import *
from model.generators import *

In [None]:
hist_len = 5
sample_len = 35
seq_dim = 2
scale = 1.
stride = 40
min_notes = sample_len #NOTE: length of tensor might be longer than min_notes due to rectilinear transformation
min_gap = 0.
max_pitch = 38
pitch_offset = 47

batch_size = 50
rectilinear = True
activation = 'GELU'
hidden_size = 128
n_layers = 1
n_head = 4
n_channels = 16

In [None]:
with open('./data/dataframes/min_note_50_min_gap_0/dfs_note_dur_offset_47.pkl', 'rb') as f:
    dfs = pickle.load(f)

In [None]:
pitch_range(dfs)

In [None]:
dataset = NoteDurationDataset(dfs, sample_len=sample_len, scale=scale, stride=stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

In [None]:
generator = LSTMusic(seq_dim, sample_len, max_pitch, hidden_size, n_layers, activation)
generator = generator.cuda()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True)

In [None]:
static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=0.1)
kernel = ksig.kernels.SignatureKernel(n_levels=5, order=5, normalization=0, static_kernel=static_kernel, device_ids=None)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(30):
    losses = []
    for batch_num, X in enumerate(tqdm(dataloader)):
        X = X.to(device)
        sample = batch_rectilinear_transform(X[:, hist_len:, :])

        output = generator(X[:, :hist_len, :])
        output = batch_rectilinear_transform(output)

        # compute loss
        optimizer.zero_grad()
        loss = ksig.tests.mmd_loss_no_compile(sample, output, kernel)
        losses.append(loss.item())

        # backpropagate and update weights
        loss.backward()
        optimizer.step()

    # log epoch loss and plot generated samples
    epoch_loss = np.average(losses) # average batch mmd for epoch
    scheduler.step(epoch_loss)
    print(f'Epoch {epoch}, loss: {epoch_loss}')

In [None]:
for x in dataloader:
    x = x.to(device)
    output = generator(x[:, :hist_len, :])
    output = torch.cat((x[:, :hist_len, :], output), dim=1)
    break

In [None]:
out_dfs = tensor_to_df(output, pitch_offset)

In [None]:
in_dfs = tensor_to_df(x, pitch_offset)

In [None]:
sample_idx = 1
input_midi = df_to_midi(in_dfs[sample_idx])
output_midi = df_to_midi(out_dfs[sample_idx])

In [None]:
Fs = 22050
audio_data = input_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)

In [None]:
audio_data = output_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)