**This notebook is important :) **

Does a simple argmax over the (Melody2) CNN outputs to produce melody2 estimates.
Scores estimates against the ground truth.
One of the paper plots is at the bottom of a qualitative example.

In [None]:
import motif
import motif.plot
import numpy as np
import mir_eval
import os
import medleydb as mdb
import seaborn
import csv
import glob
import json
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
with open("../outputs/data_splits.json", 'r') as fhandle:
    dat_dict = json.load(fhandle)

In [None]:
cd ../deepsalience/

In [None]:
import compute_training_data as C
import evaluate

In [None]:
thresh_vals = np.arange(0, 1, 0.1)
mel_accuracy = {v: [] for v in thresh_vals}

for trackid in dat_dict['validate']:

    mtrack = mdb.MultiTrack(trackid)
    if mtrack.dataset_version != 'V1':
        continue
    
    print(trackid)
    pred_path = "../experiment_output_submission/Model_11b_mel2_outputs/{}_prediction.npy".format(trackid)
    pitch_activations = np.load(pred_path)
    mel2 = mtrack.melody2_annotation
    mel2 = np.array(mel2).T
    ref_times, ref_freqs = (mel2[0], mel2[1])

    for thresh in thresh_vals:
        est_times, est_freqs = evaluate.pitch_activations_to_singlef0(pitch_activations, thresh)
        mel_scores = mir_eval.melody.evaluate(ref_times, ref_freqs, est_times, est_freqs)
        mel_accuracy[thresh].append(mel_scores['Overall Accuracy'])

In [None]:
accuracy_vals = [np.mean(mel_accuracy[thresh]) for thresh in thresh_vals]
best_thresh_idx = np.argmax(accuracy_vals)
best_thresh = thresh_vals[best_thresh_idx]

print("Best threshold is {} with an OA of {}".format(
    best_thresh, accuracy_vals[best_thresh_idx])
)

In [None]:
best_thresh = 0.3

In [None]:
def save_mel_prediction(est_times, est_freqs, fpath):
    with open(fpath, 'w') as fhandle:
        writer = csv.writer(fhandle, delimiter=',')
        for t, f in zip(est_times, est_freqs):
            writer.writerow([t, f])


In [None]:
all_mel_scores = []
for trackid in dat_dict['test']:
    print(trackid)
    mtrack = mdb.MultiTrack(trackid)
    
    pred_path = "../experiment_output_submission/Model_11b_mel2_outputs/{}_prediction.npy".format(trackid)
    if not os.path.exists(pred_path) or not os.path.exists(mtrack.melody2_fpath):
        print(trackid)
        continue

    pitch_activations = np.load(pred_path)
    est_times, est_freqs = evaluate.pitch_activations_to_singlef0(pitch_activations, thresh)
    save_mel_prediction(
        est_times, est_freqs,
        "../experiment_output_submission/Model_11b_mel2_outputs/{}_mel2_prediction.txt".format(trackid))
    mel2 = mtrack.melody2_annotation
    mel2 = np.array(mel2).T
    ref_times, ref_freqs = (mel2[0], mel2[1])
    
    plt.figure(figsize=(15, 7))
    plt.title(trackid)
    plt.plot(ref_times, ref_freqs, '.k', markersize=8)
    plt.plot(est_times, est_freqs, '.r', markersize=3)
    plt.show()

    mel_scores = mir_eval.melody.evaluate(ref_times, ref_freqs, est_times, est_freqs)
    all_mel_scores.append(mel_scores)

mel_scores_df_partial = pd.DataFrame(all_mel_scores)

In [None]:
mel_scores_df_partial.describe()

In [None]:
mel_scores_df_partial.to_csv("../outputs/CNNmel2_argmax_scores.csv")

In [None]:
sns.set(font_scale=1.5)
sns.set_style('white')

trackid = 'MusicDelta_SwingJazz'
mtrack = mdb.MultiTrack(trackid)

pred_path = "../experiment_output_submission/Model_11b_mel2_outputs/{}_prediction.npy".format(trackid)
Y = np.load(pred_path)

est_times, est_freqs = get_mel_prediction(pred_path, best_thresh)
mel2 = mtrack.melody2_annotation
mel2 = np.array(mel2).T
ref_times, ref_freqs = (mel2[0], mel2[1])

est_freqs[est_freqs < 0] == 0

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.semilogy(ref_times, ref_freqs, '.k', basey=2, markersize=8)
plt.semilogy(est_times, est_freqs, '.', color="#CF6766", basey=2, markersize=4)

plt.ylim([2**7, 32.7*6])
plt.xlim([0, 25])
plt.yticks([2**i for i in range(7, 12)], [2**i for i in range(7, 12)])
plt.ylabel('Frequency (Hz)')
plt.xlabel('Time (sec)')

plt.subplot(1, 2, 2)
plt.imshow(Y[120:, :2150], origin='lower', cmap='RdBu_r', vmin=-1, vmax=1)
plt.yticks([])
plt.xticks(np.linspace(0, 2150, 6), [0, 5, 10, 15, 20, 25])
plt.xlabel('Time (sec)')
plt.axis('auto')

plt.tight_layout()
plt.savefig('../paper-figs/mel_qualatative.pdf', format='pdf', bbox_inches='tight')