In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import IPython.display as ipd
import pyximport
pyximport.install()
%load_ext Cython
import sigkernel as ksig
from utils.data import *
from model.generators import *

In [None]:
hist_len = 10
sample_len = 30
noise_dim =2
seq_dim = 3
scale = 1.
stride = 800
dpitch_range = 24
key = 'single_key'

sigma = 1.0
kernel_type = 'truncated'
dyadic_order = 3
n_levels = 5
order = 1

batch_size = 64
activation = 'Tanh'
hidden_size = 256
conv_kernel_size = 4
conv_stride = 1
n_transformer_layers = 1
n_head = 4
n_channels = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
with open(f'./data/dataframes/{key}/df_titles_cluster_30_min_notes_pitch_range_5_24.pkl', 'rb') as f:
    df_titles_clusters = pickle.load(f)
len(df_titles_clusters)

In [None]:
cluster_labels = [item[-1] for item in df_titles_clusters]
unique_labels, counts = np.unique(cluster_labels, return_counts=True)
unique_labels.shape, counts

In [None]:
df_clusters = []
for i in range(unique_labels.shape[0]):
    df_clusters.append([item for item in df_titles_clusters if item[-1] == i])
    print(i, len(df_clusters[-1]))

In [None]:
lens = [len(item[0]) for item in df_titles_clusters]
print('Max length:', max(lens))
stride = max(lens) + 1 # ensures no sampling from middle of song

In [None]:
# gap_dur_dpitch_dfs = gap_duration_deltapitch_transform([item[0] for item in df_titles_clusters])
# clusters = [item[4] for item in df_titles_clusters]
# dataset = GapDurationDeltaPitchDataset(gap_dur_dpitch_dfs, sample_len=sample_len, scale=scale, stride=stride, clusters=clusters)

gap_dur_dpitch_dfs = gap_duration_deltapitch_transform([item[0] for item in df_clusters[1]])
dataset = GapDurationDeltaPitchDataset(gap_dur_dpitch_dfs, sample_len=sample_len, scale=scale, stride=stride)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
len(dataset), len(dataloader)

In [None]:
if kernel_type == 'truncated':
    static_kernel = ksig.static.kernels.RationalQuadraticKernel(sigma=sigma)
    kernel = ksig.kernels.SignatureKernel(n_levels=n_levels, order=order, normalization=0, static_kernel=static_kernel, device_ids=None)
elif kernel_type == 'pde':
    static_kernel = ksig.sigkernelpde.RationalQuadraticKernel(sigma=sigma, alpha=1.0)
    kernel = ksig.sigkernelpde.SigKernelPDE(static_kernel, dyadic_order)

In [None]:
generator = TransInc(noise_dim, seq_dim, sample_len, hist_len, dpitch_range, # data related
                             kernel_size=conv_kernel_size, stride=conv_stride, n_channels=n_channels, # conv layers
                             n_head=n_head, n_transformer_layers=n_transformer_layers, # transformer layers
                             hidden_size=hidden_size, activation=activation)
generator = generator.cuda()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, verbose=True)

In [None]:
for epoch in range(10):
    losses = []
    for batch_num, items in enumerate(tqdm(dataloader)):
        # X, title, cluster = items
        # cluster = cluster.to(device).unsqueeze(-1)
        X, title = items

        X = X.to(device)
        X_rect = batch_rectilinear_with_gap_transform(X[:, hist_len:, :])

        noise = torch.randn(X.shape[0], X.shape[1]-1, noise_dim).to(device)
        Y = generator(X, noise)

        Y_rect = batch_rectilinear_with_gap_transform(Y[:, hist_len:, :])
        # print(output.shape, X.shape, X[:, :hist_len, :].shape)

        # compute loss
        optimizer.zero_grad()
        loss = ksig.tests.mmd_loss_no_compile(X_rect, Y_rect, kernel)
        losses.append(loss.item())

        # backpropagate and update weights
        loss.backward()
        optimizer.step()

    # log epoch loss and plot generated samples
    epoch_loss = np.average(losses) # average batch mmd for epoch
    scheduler.step(epoch_loss)
    print(f'Epoch {epoch+1}, loss: {epoch_loss}')

# Evaluation

In [None]:
# generator.load_state_dict(torch.load(f'./data/weights/gapdurdpitch_{noise_dim}z_{sample_len}l_{hist_len}h_{key}_{n_levels}m_{order}o_{hidden_size}u_{n_layers}lstm.pt'))

In [None]:
# for items in dataloader:
#     X, title, cluster = items
#     X = X.to(device)
#     cluster = cluster.to(device).unsqueeze(-1)
#     X_rect = batch_rectilinear_with_gap_transform(X[:, hist_len:, :])

#     noise = torch.randn(X.shape[0], X.shape[1]-1, noise_dim).to(device)
#     Y = generator(noise, cluster, X[:, :hist_len, :], X[:, hist_len:, :2])

#     Y_rect = batch_rectilinear_with_gap_transform(Y[:, hist_len:, :])
#     break

### Sample and play MIDI

In [None]:
X_titles = []
for arg in title:
    X_titles.append((df_titles_clusters[arg.item()][1], df_titles_clusters[arg.item()][2], df_titles_clusters[arg.item()][3]))
start_pitch = 60
X_dfs = batch_gap_duration_pitch_to_df(X, start_pitch=start_pitch)
Y_dfs = batch_gap_duration_pitch_to_df(Y, start_pitch=start_pitch)

In [None]:
sample_idx = 0
print(X_titles[sample_idx])

In [None]:
input_midi = df_to_midi(X_dfs[sample_idx])
output_midi = df_to_midi(Y_dfs[sample_idx])

In [None]:
pd.merge(X_dfs[sample_idx][['Start', 'End', 'Pitch']], Y_dfs[sample_idx][['Pitch']], left_index=True, right_index=True)

In [None]:
fs=44100
audio_data = input_midi.fluidsynth(fs=fs, sf2_path='./data/soundfonts/Steinway_Grand_Piano_1.2.sf2')
ipd.Audio(audio_data, rate=fs)

In [None]:
fs=44100
audio_data = output_midi.fluidsynth(fs=fs, sf2_path='./data/soundfonts/Steinway_Grand_Piano_1.2.sf2')
ipd.Audio(audio_data, rate=fs)

### Evaluate note within key percentages

In [None]:
key_to_note = {0: 'C', 1: 'C#', 2: 'D', 3: 'D#', 4: 'E', 5: 'F', 6: 'F#', 7: 'G', 8: 'G#', 9: 'A', 10: 'A#', 11: 'B'}
key_notes = []
for i in range(12):
    key_notes.append(get_notes_from_major_scale(i))

In [None]:
# check percentage of notes in C major scale
X_key_percentages = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: [], 10: [], 11: []}
Y_key_percentages = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: [], 10: [], 11: []}
for df in X_dfs:
    for i in range(len(key_notes)):
        percentage = len([note for note in df['Pitch'].values if note in key_notes[i]]) / len(df['Pitch'].values)
        X_key_percentages[i].append(percentage)
for df in Y_dfs:
    for i in range(len(key_notes)):
        percentage = len([note for note in df['Pitch'].values if note in key_notes[i]]) / len(df['Pitch'].values)
        Y_key_percentages[i].append(percentage)

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(15, 5))
print(f'Key: {key_to_note[start_pitch % 12]}')
for i in range(len(key_notes)):
    print(f'Mean percentage of notes in key {key_to_note[i]}: {np.mean(X_key_percentages[i])}')
    ax[i//3, i%3].hist(X_key_percentages[i], bins=100)
    ax[i//3, i%3].set_title(f'{key_to_note[i]} major scale')
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(4, 3, figsize=(15, 5))
print(f'Key: {key_to_note[start_pitch % 12]}')
for i in range(len(key_notes)):
    print(f'Mean percentage of notes in key {key_to_note[i]}: {np.mean(Y_key_percentages[i])}')
    ax[i//3, i%3].hist(Y_key_percentages[i], bins=100)
    ax[i//3, i%3].set_title(f'{key_to_note[i]} major scale')
plt.tight_layout()