In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import IPython.display as ipd
import pyximport
pyximport.install()
%load_ext Cython
import sigkernel as ksig
from utils.data import *
from model.generators import *

# Setup

In [None]:
hist_len = 16
sample_len = 32 #NOTE it includes the hist_len
noise_dim = 1
seq_dim = 3 # (gap, duration, pitch)
scale = 1.
dpitch_range = 12
stride = 800
folder = 'theorytab'

sigma = 1.0
kernel_type = 'truncated'
dyadic_order = 3
n_levels = 5
order = 1

batch_size = 16
activation = 'Tanh'
hidden_size = 64
n_layers = 1

epochs = 50
patience = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# with open(f'./data/dataframes/{folder}/melodies_beats_min_5_unique_max_range_24.pkl', 'rb') as f:
with open(f'./data/dataframes/{folder}/melodies_beats_min_5_unique_max_range_24_spec_cluster_12.pkl', 'rb') as f:
# with open(f'./data/dataframes/{folder}/all_melodies_within_key_beats_aligned_min_5_unique_max_range_21.pkl', 'rb') as f:
    songs = pickle.load(f)
len(songs)

In [None]:
songs[0][1:]

In [None]:
cluster_labels = [item[-1] for item in songs]
unique_labels, counts = np.unique(cluster_labels, return_counts=True)
unique_labels.shape, counts

In [None]:
# separate the dataframes by cluster
df_clusters = []
for i in range(unique_labels.shape[0]):
    df_clusters.append([item for item in songs if item[-1] == i])
    print(i, len(df_clusters[-1]))

In [None]:
lens = [len(item[0]) for item in songs]
print('Max length:', max(lens))
stride = max(lens) + 1 # ensures no sampling from middle of song

In [None]:
# gap_dur_dpitch_dfs = gap_duration_deltapitch_transform([item[0] for item in songs])
# clusters = [item[4] for item in songs]
# dataset = GapDurationDeltaPitchDataset(gap_dur_dpitch_dfs, sample_len=sample_len, scale=scale, stride=stride, clusters=clusters)

# cluster_idx = 0
# gap_dur_dpitch_dfs = gap_duration_deltapitch_transform([item[0] for item in df_clusters[cluster_idx]])

gap_dur_dpitch_dfs = gap_duration_deltapitch_transform([item[0] for item in songs])

dataset = GapDurationDeltaPitchDataset(gap_dur_dpitch_dfs, sample_len=sample_len, scale=scale, stride=stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
len(dataset), len(dataloader)

In [None]:
if kernel_type == 'truncated':
    static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=sigma)
    # static_kernel = ksig.static.kernels.LinearKernel()
    kernel = ksig.kernels.SignatureKernel(n_levels=n_levels, order=order, normalization=0, static_kernel=static_kernel, device_ids=None)
elif kernel_type == 'pde':
    static_kernel = ksig.sigkernelpde.RationalQuadraticKernel(sigma=sigma, alpha=1.0)
    kernel = ksig.sigkernelpde.SigKernelPDE(static_kernel, dyadic_order)

In [None]:
# generator = LSTMgate(noise_dim, seq_dim, sample_len, hidden_size, n_layers, activation)
# generator = LSTMinc(noise_dim, seq_dim, sample_len, dpitch_range, scale, hidden_size, n_layers, activation)
generator = LSTMinc_v2(noise_dim, seq_dim, sample_len, dpitch_range, scale, hidden_size, n_layers, activation)
generator = generator.cuda()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.5, verbose=True)

# Training

In [None]:
for epoch in range(epochs):
    losses = []
    for batch_num, items in enumerate(tqdm(dataloader)):
        # X, title, cluster = items
        # cluster = cluster.to(device).unsqueeze(-1)
        X, title = items

        X = X.to(device)
        X_rect = batch_rectilinear_with_gap_transform(X[:, hist_len:, :])

        # For LSTMgate
        # noise = torch.randn(X.shape[0], X.shape[1]-1, noise_dim).to(device)
        # Y = generator(noise, cluster, X[:, :hist_len, :], X[:, hist_len:, :2])

        # For LSTMinc
        noise = torch.randn(X.shape[0], X.shape[1]-1, noise_dim).to(device)
        Y = generator(noise, X[:, :hist_len, :], X[:, hist_len:, :2])

        Y_rect = batch_rectilinear_with_gap_transform(Y[:, hist_len:, :])
        # print(output.shape, X.shape, X[:, :hist_len, :].shape)

        # compute loss
        optimizer.zero_grad()
        loss = ksig.tests.mmd_loss_no_compile(X_rect, Y_rect, kernel)
        losses.append(loss.item())

        # backpropagate and update weights
        loss.backward()
        optimizer.step()

    # log epoch loss and plot generated samples
    epoch_loss = np.average(losses) # average batch mmd for epoch
    scheduler.step(epoch_loss)
    print(f'Epoch {epoch+1}, loss: {epoch_loss}')
    # if epoch_loss < 0.:
    #     sigma = sigma * 0.7
    #     static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=sigma)
    #     kernel = ksig.kernels.SignatureKernel(n_levels=n_levels, order=order, normalization=0, static_kernel=static_kernel, device_ids=None)
    #     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, verbose=True)
    #     print(f'New sigma: {sigma}')

In [None]:
# torch.save(generator.state_dict(), f'./data/weights/gapdurdpitch_{noise_dim}z_{sample_len}l_{hist_len}h_{key}_{n_levels}m_{order}o_{hidden_size}u_{n_layers}lstm.pt')

# Evaluation

In [None]:
# generator.load_state_dict(torch.load(f'./data/weights/gapdurdpitch_{noise_dim}z_{sample_len}l_{hist_len}h_{key}_{n_levels}m_{order}o_{hidden_size}u_{n_layers}lstm.pt'))

### Sample and play MIDI

In [None]:
X_titles = []
for arg in title:
    X_titles.append((songs[arg.item()][1], songs[arg.item()][2], songs[arg.item()][3]))
start_pitch = 60
X[:,:,-1] /= scale
Y[:,:,-1] /= scale
X_dfs = batch_gap_duration_pitch_to_df(X, start_pitch=start_pitch)
Y_dfs = batch_gap_duration_pitch_to_df(Y, start_pitch=start_pitch)

In [None]:
np_X_rect = batch_rectilinear_with_gap_transform(X).cpu().numpy()
np_Y_rect = batch_rectilinear_with_gap_transform(Y).detach().cpu().numpy()
fig, ax = plt.subplots(batch_size//4, 4, figsize=(16, batch_size//2))
for i in range(batch_size//4):
    for j in range(4):
        if batch_size//4 == 1:
            ax[j].plot(np_X_rect[j,:,0], np_X_rect[j,:,1]/scale)
            ax[j].plot(np_Y_rect[j,:,0], np_Y_rect[j,:,1]/scale)
            ax[j].set_title((f'{j} {X_titles[j][0]} {X_titles[j][1]}')[:50])
        else:
            ax[i, j].plot(np_X_rect[i*4+j,:,0], np_X_rect[i*4+j,:,1]/scale)
            ax[i, j].plot(np_Y_rect[i*4+j,:,0], np_Y_rect[i*4+j,:,1]/scale)
            ax[i, j].set_title((f'{i*4+j} {X_titles[i*4+j][0]} {X_titles[i*4+j][1]}')[:50])
plt.tight_layout()

In [None]:
sample_idx = 0
print(X_titles[sample_idx])

In [None]:
input_midi = df_to_midi(X_dfs[sample_idx])
output_midi = df_to_midi(Y_dfs[sample_idx])

In [None]:
pd.merge(X_dfs[sample_idx][['Start', 'End', 'Pitch']], Y_dfs[sample_idx][['Pitch']], left_index=True, right_index=True)

In [None]:
fs=44100
audio_data = input_midi.fluidsynth(fs=fs, sf2_path='./data/soundfonts/Steinway_Grand_Piano_1.2.sf2')
ipd.Audio(audio_data, rate=fs)

In [None]:
fs=44100
audio_data = output_midi.fluidsynth(fs=fs, sf2_path='./data/soundfonts/Steinway_Grand_Piano_1.2.sf2')
ipd.Audio(audio_data, rate=fs)

### Evaluate note within key percentages

In [None]:
Xs = []
Ys = []
for item in dataloader:
    X, title = item
    Xs.append(X)
    X = X.to(device)
    X_rect = batch_rectilinear_with_gap_transform(X[:, hist_len:, :])

    noise = torch.randn(X.shape[0], X.shape[1]-1, noise_dim).to(device)
    Y = generator(noise, X[:, :hist_len, :], X[:, hist_len:, :2])
    Ys.append(Y.detach().cpu())
    Y_rect = batch_rectilinear_with_gap_transform(Y[:, hist_len:, :])

Xs = torch.cat(Xs, dim=0)
Ys = torch.cat(Ys, dim=0)
print(Xs.shape, Ys.shape)
all_X_dfs = batch_gap_duration_pitch_to_df(Xs, start_pitch=start_pitch)
all_Y_dfs = batch_gap_duration_pitch_to_df(Ys, start_pitch=start_pitch)

In [None]:
key_to_note = {0: 'C', 1: 'C#', 2: 'D', 3: 'D#', 4: 'E', 5: 'F', 6: 'F#', 7: 'G', 8: 'G#', 9: 'A', 10: 'A#', 11: 'B'}
key_notes = []
for i in range(12):
    key_notes.append(get_notes_from_major_scale(i))

In [None]:
# check percentage of notes in C major scale
X_key_percentages = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: [], 10: [], 11: []}
Y_key_percentages = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: [], 10: [], 11: []}
for df in all_X_dfs:
    for i in range(len(key_notes)):
        percentage = len([note for note in df['Pitch'].values if note in key_notes[i]]) / len(df['Pitch'].values)
        X_key_percentages[i].append(percentage)
for df in all_Y_dfs:
    for i in range(len(key_notes)):
        percentage = len([note for note in df['Pitch'].values if note in key_notes[i]]) / len(df['Pitch'].values)
        Y_key_percentages[i].append(percentage)

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(15, 5))
print(f'Key: {key_to_note[start_pitch % 12]}')
for i in range(len(key_notes)):
    print(f'Mean percentage of notes in key {key_to_note[i]}: {np.mean(X_key_percentages[i])}')
    ax[i//3, i%3].hist(X_key_percentages[i], bins=100)
    ax[i//3, i%3].set_title(f'{key_to_note[i]} major scale')
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(15, 5))
print(f'Key: {key_to_note[start_pitch % 12]}')
for i in range(len(key_notes)):
    print(f'Mean percentage of notes in key {key_to_note[i]}: {np.mean(Y_key_percentages[i])}')
    ax[i//3, i%3].hist(Y_key_percentages[i], bins=100)
    ax[i//3, i%3].set_title(f'{key_to_note[i]} major scale')
plt.tight_layout()