In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import IPython.display as ipd
import pyximport
pyximport.install()
%load_ext Cython
import sigkernel as ksig
from utils.midi import *
from utils.data import *
from model.generators import *

In [None]:
hist_len = 10
sample_len = 30
seq_dim = 2
scale = 1.
stride = 10
max_pitch = 32
pitch_offset = 47

sigma = 1.0
kernel_type = 'truncated'
dyadic_order = 3
n_levels = 5

batch_size = 64
rectilinear = True
activation = 'GELU'
hidden_size = 256
n_transformer_layers = 1
n_head = 4
n_channels = 32

In [None]:
# with open('./data/dataframes/min_note_50_min_gap_0/dfs_note_dur_offset_47.pkl', 'rb') as f:
with open('./data/dataframes/pop/dfs_note_dur_offset_47.pkl', 'rb') as f:
    dfs = pickle.load(f)

In [None]:
pitch_range(dfs)

In [None]:
dataset = NoteDurationDataset(dfs, sample_len=sample_len, scale=scale, stride=stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

In [None]:
generator = TransformerMusic(seq_dim, sample_len, max_pitch, hist_len, scale, # data related
                             kernel_size=5, stride=1, n_channels=n_channels, # conv layers
                             n_head=n_head, n_transformer_layers=n_transformer_layers, # transformer layers
                             hidden_size=hidden_size, activation=activation)
generator = generator.cuda()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True)

In [None]:
if kernel_type == 'truncated':
    static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=sigma)
    kernel = ksig.kernels.SignatureKernel(n_levels=n_levels, order=n_levels, normalization=0, static_kernel=static_kernel, device_ids=None)
elif kernel_type == 'pde':
    static_kernel = ksig.sigkernelpde.RationalQuadraticKernel(sigma=sigma, alpha=1.0)
    kernel = ksig.sigkernelpde.SigKernelPDE(static_kernel, dyadic_order)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(30):
    losses = [] # due to legacy code, losses is actually the mmd
    for batch_num, X in enumerate(tqdm(dataloader)):
        X = X.to(device)

        output = generator(X)
        X_wo_hist = X[:, hist_len:, :]

        # compute loss
        optimizer.zero_grad()
        loss = ksig.tests.mmd_loss_no_compile(X_wo_hist, output, kernel)
        losses.append(loss.item())

        # backpropagate and update weights
        loss.backward()
        optimizer.step()

    # log epoch loss and plot generated samples
    epoch_loss = np.average(losses) # average batch mmd for epoch
    scheduler.step(epoch_loss)
    print(f'Epoch {epoch+1}, loss: {epoch_loss}')

In [None]:
for x in dataloader:
    x = x.to(device)
    output = generator(x)
    break

In [None]:
sample_idx = np.random.randint(0, batch_size)

In [None]:
out_dfs = tensor_to_df(output, pitch_offset)
out_dfs[sample_idx]

In [None]:
in_dfs = tensor_to_df(x, pitch_offset)
in_dfs[sample_idx]

In [None]:
input_midi = df_to_midi(in_dfs[sample_idx])
output_midi = df_to_midi(out_dfs[sample_idx])

In [None]:
Fs = 22050
audio_data = input_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)

In [None]:
audio_data = output_midi.synthesize(fs=Fs)
ipd.Audio(audio_data, rate=Fs)