### Test whether audio files are valid by listening to the beats

In [None]:
a, fs = librosa.load("data/wav/10CD1_-_The_Beatles/CD1_-_06_-_The_Continuing_Story_of_Bungalow_Bill.wav", sr=None)

In [None]:
beats, positions = mir_eval.io.load_time_series('data/isophonics/10CD1_-_The_Beatles/CD1_-_06_-_The_Continuing_Story_of_Bungalow_Bill.txt')

In [None]:
clicks = mir_eval.sonify.clicks(beats, fs=fs, length=a.shape[0])

In [None]:
IPython.display.Audio([a, clicks], rate=fs)

### Create whoosh index for Isophonics .wav files

In [None]:
beatles_list = []
# Iterate over all wav files
for n, wav_file in enumerate(glob.glob(os.path.join('data', 'wav', '*', '*.wav'))):
    # Reconstruct title from filename
    filename = os.path.splitext(os.path.split(wav_file)[1])[0]
    # Remove number (e.g. 01_-_) and replace underscores with spaces
    title = re.split('[0-9][0-9]_-_', filename)[-1].replace('_', ' ')
    # Construct path prefix
    path = os.path.join(os.path.split(os.path.split(wav_file)[0])[1], filename)
    # Add an entry for this file
    beatles_list.append(
        {'id': unicode(n), 'artist': u"The Beatles",
         'title': unicode(title), 'path': unicode(path)})
# Create the whoosh index
whoosh_search.create_index(
    os.path.join('data', 'index'), beatles_list)

In [None]:
FIXES = {"When I'm 64": "When I'm Sixty-Four",
 "I Want You (She's So Heavy)": "I Want You",
 "Blackbird": "Black Bird"}
def fix(s):
    if s in FIXES:
        return FIXES[s]
    else:
        return s

In [None]:
# Path to clean_midi dataset
CLEAN_MIDI_PATH = '/home/craffel/projects/midi-dataset/data/clean_midi/'
# Get list of all MIDI file metadata
index = whoosh_search.get_whoosh_index(os.path.join(CLEAN_MIDI_PATH, 'index'))
with index.searcher() as searcher:
    midi_list = list(searcher.documents())
# Load in beatles index for searching
beatles_index = whoosh_search.get_whoosh_index(
    os.path.join('data', 'index'))
with beatles_index.searcher() as searcher:
    # Get list of beatles entries; we will use this to retrieve paths.
    beatles_list = list(searcher.documents())
    for entry in midi_list:
        # Only search Beatles tracks
        if entry['artist'] == 'The Beatles':
            # Search the beatles index for this track
            matches = whoosh_search.search(searcher, beatles_index.schema, entry['artist'], fix(entry['title']), 9)
            # If we have this beatles track in the index
            if len(matches) > 0:
                # Grab the base path for this track
                path = [e['path'] for e in beatles_list if e['id'] == matches[0][0]][0]
                # Figure out how many versions of this Beatles song we already have copied
                n = 0
                while os.path.exists(os.path.join('data/mid', path + '.mid{}'.format(n))):
                    n += 1
                # Construct path to original file
                orig_path = os.path.join(CLEAN_MIDI_PATH, 'mid', entry['path'] + '.mid')
                # Try loading in this MIDI file
                try:
                    pretty_midi.PrettyMIDI(orig_path)
                # If we can't load it, don't copy
                except Exception as e:
                    print "{}".format(e)
                    continue
                # Construct output path
                output_path = os.path.join('data', 'mid', path + '.mid{}'.format(n))
                # Create output path if it doesn't exist
                if not os.path.exists(os.path.split(output_path)[0]):
                    os.makedirs(os.path.split(output_path)[0])
                # Copy the MIDI file
                shutil.copy(orig_path, output_path)

### Align all pairs

In [None]:
# The frame resolution used in align_text_matches is 1024 samples
# At 22.05 kHz this corresponds to about 46 ms, which is around the
# same temporal tolerance as beat tracking eval.  So, divide by 2
# to make the temporal resolution finer.
align_text_matches.feature_extraction.AUDIO_HOP = 512
align_text_matches.feature_extraction.MIDI_HOP = 256
# Also need to change it in feature_extraction
feature_extraction.AUDIO_HOP = 512

In [None]:
if not os.path.exists(os.path.join('data', 'mid_aligned')):
    os.makedirs(os.path.join('data', 'mid_aligned'))
if not os.path.exists(os.path.join('data', 'diagnostics')):
    os.makedirs(os.path.join('data', 'diagnostics'))
pairs = []
# Construct pairs
for midi_filename in glob.glob('data/mid/*/*.mid*'):
    path, midi_filename_only = os.path.split(midi_filename)
    midi_path = os.path.join(os.path.split(path)[1], midi_filename_only)
    audio_filename = os.path.join(
        'data', 'wav', os.path.splitext(midi_path)[0] + '.wav')
    audio_features_filename = os.path.join(
        'data', 'wav', os.path.splitext(midi_path)[0] + '.h5')
    midi_features_filename = os.path.join(
        'data', 'mid', midi_path.replace('.mid', '.h5'))
    output_midi_filename = os.path.join(
        'data', 'mid_aligned', midi_path)
    output_diagnostics_filename = os.path.join(
        'data', 'diagnostics', midi_path.replace('.mid', '.h5'))
    pairs.append((audio_filename, midi_filename, audio_features_filename,
                  midi_features_filename, output_midi_filename,
                  output_diagnostics_filename))

# Run alignment
_ = joblib.Parallel(n_jobs=10, verbose=10)(
    joblib.delayed(align_text_matches.align_one_file)(*args)
    for args in pairs)

### Extract ground truth

In [None]:
def interpolate_times(times, old_timebase, new_timebase, labels=None,
                      shift_start=False):
    '''
    Linearly interpolate a set of times (and optionally labels) to a new
    timebase.  All returned times will fall within the range of
    ``new_timebase``, and only times which fall within ``old_timebase`` will be
    interpolated.

    Parameters
    ----------
    - times : np.ndarray
        Times of some events to be interpolated.
    - old_timebase : np.ndarray
        The original timebase of ``times``.
    - new_timebase : np.ndarray
        The new timebase to resample ``times`` to.
    - labels : list or NoneType
        Labels of the events in ``times``; if ``None``, no interpolated labels
        will be generated.
    - shift_start : bool
        Whether to create an additional interpolated event with time
        ``new_timebase[0]`` when any entry of ``times`` is before
        ``old_timebase[0]`` and ``new_timebase[0]``

    Returns
    -------
    - interpolated_times : np.ndarray
        Interpolated times.
    - interpolated_labels : list
        Interpolated labels.  Only returned when ``labels`` is not ``None``.
    '''
    # Remove all times which fall outside of the range of the original timebase
    valid_times = [time for time in times
                   if (time >= old_timebase[0]
                       and time <= old_timebase[-1])]
    # When labels are provided, also remove labels whose time falls outside of
    # the range of the original timebase
    if labels is not None:
        valid_labels = [label for (time, label) in zip(times, labels)
                        if (time >= old_timebase[0]
                            and time <= old_timebase[-1])]
    # Linearly interpolate the provided times to the new timebase
    interped_times = np.interp(valid_times, old_timebase, new_timebase)
    # If we have been told to add a time when an event falls before the
    # timebases...
    if (shift_start and np.any(times < new_timebase[0])
            and np.any(times < old_timebase[0])
            and not np.any(times == old_timebase[0])):
        # Add an event at the beginning of the new timebase
        interped_times = np.append(new_timebase[0], interped_times)
        # If labels were provided, find the label of the first event before
        # the old timebase and add it to the output labels
        if labels is not None:
            first_label = np.argmin(times < old_timebase[0]) - 1
            valid_labels = [labels[first_label]] + valid_labels
    # When labels were not provided, just return interpolated times
    if labels is None:
        return interped_times
    # When labels were provided, return interpolated times and labels
    else:
        return interped_times, valid_labels

In [None]:
feature_extraction.AUDIO_HOP = 512
errors = np.zeros(1000)
def get_error():
    fs = feature_extraction.AUDIO_FS
    s = np.zeros((fs*np.random.randint(1, 5)))
    place = np.random.randint(0, s.size)
    s[place] = 1
    gram = librosa.cqt(s, sr=fs, hop_length=feature_extraction.AUDIO_HOP,
                       fmin=librosa.midi_to_hz(feature_extraction.NOTE_START),
                       n_bins=feature_extraction.N_NOTES).T
    return place/22050. - librosa.frames_to_time(
        np.argmax(gram.sum(axis=1)),
        hop_length=feature_extraction.AUDIO_HOP,
        sr=feature_extraction.AUDIO_FS)[0]
    #return place/22050. - feature_extraction.frame_times(gram)[np.argmax(gram.sum(axis=1))]
errors = _ = joblib.Parallel(n_jobs=10, verbose=0)(
    joblib.delayed(get_error)() for _ in range(1000))
_ = plt.hist(errors, bins=20)
print np.mean(errors), np.median(errors), np.max(errors) - np.min(errors)

In [None]:
if not os.path.exists(os.path.join('data', 'extracted')):
    os.makedirs(os.path.join('data', 'extracted'))

def process_one_file(diagnostics_file):
    diagnostics = deepdish.io.load(diagnostics_file)
    # Load the extracted features
    midi_features = deepdish.io.load(diagnostics['midi_features_filename'])
    audio_features = deepdish.io.load(
        diagnostics['audio_features_filename'])
    # Load in the original MIDI file
    midi_object = pretty_midi.PrettyMIDI(str(diagnostics['midi_filename']))
    # Compute the times of the frames (will be used for interpolation)
    midi_frame_times = feature_extraction.frame_times(
        midi_features['gram'])[diagnostics['aligned_midi_indices']]
    audio_frame_times = feature_extraction.frame_times(
        audio_features['gram'])[diagnostics['aligned_audio_indices']]
    adjusted_beats = interpolate_times(
        midi_object.get_beats(), midi_frame_times, audio_frame_times)
    output_file = diagnostics_file.replace('diagnostics', 'extracted').replace('.h5', '.txt')
    if not os.path.exists(os.path.split(output_file)[0]):
        os.makedirs(os.path.split(output_file)[0])
    np.savetxt(output_file, adjusted_beats)

_ = joblib.Parallel(n_jobs=10, verbose=10)(
    joblib.delayed(process_one_file)(diagnostics_file)
    for diagnostics_file in glob.glob(
        os.path.join('data', 'diagnostics', '*', '*.h5*')))

### Evaluate extracted ground truth

In [None]:
len(glob.glob('data/isophonics/*/*.txt'))

In [None]:
# Test a single example
ref_beats, ref_labels = mir_eval.io.load_time_series('data/isophonics/01_-_Please_Please_Me/01_-_I_Saw_Her_Standing_There.txt')
est_beats = mir_eval.io.load_events("data/extracted/01_-_Please_Please_Me/01_-_I_Saw_Her_Standing_There.txt0")
mir_eval.beat.evaluate(ref_beats, est_beats)

In [None]:
def get_reference_beat_variations(reference_beats):
    # Create annotations at twice the metric level
    interpolated_indices = np.arange(0, reference_beats.shape[0]-.5, .5)
    original_indices = np.arange(0, reference_beats.shape[0])
    double_reference_beats = np.interp(interpolated_indices,
                                       original_indices,
                                       reference_beats)
    interpolated_indices = np.arange(0, reference_beats.shape[0]-.5, 1./3)
    original_indices = np.arange(0, reference_beats.shape[0])
    triple_reference_beats = np.interp(interpolated_indices,
                                       original_indices,
                                       reference_beats)
    # Return metric variations:
    # True, off-beat, double tempo, half tempo odd, and half tempo even
    return (reference_beats,
            double_reference_beats[1::2],
            double_reference_beats,
            reference_beats[::2],
            reference_beats[1::2],
            triple_reference_beats,
            reference_beats[::3])

In [None]:
def get_scores(estimated_beats_file):
    est_beats = mir_eval.io.load_events(estimated_beats_file)
    if est_beats.size == 0:
        return None, None
    ground_truth_beats_file = os.path.splitext(estimated_beats_file.replace('extracted', 'isophonics'))[0] + '.txt'
    ref_beats = mir_eval.io.load_labeled_events(ground_truth_beats_file)[0]
    d = deepdish.io.load(estimated_beats_file.replace('extracted', 'diagnostics').replace('.txt', '.h5'))
    confidence = d['score']
    return confidence, mir_eval.beat.evaluate(ref_beats, est_beats)
confidence_and_evaluation_scores = joblib.Parallel(n_jobs=10, verbose=10)(
    joblib.delayed(get_scores)(estimated_beats_file)
    for estimated_beats_file in glob.glob(os.path.join('data', 'extracted', '*', '*.txt*')))
confidence_scores = [e[0] for e in confidence_and_evaluation_scores if e[0] is not None]
evaluation_scores = [e[1] for e in confidence_and_evaluation_scores if e[1] is not None]

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_style('darkgrid')
#import mpld3
#import mpld3.plugins
#mpld3.enable_notebook()

In [None]:
fig, ax = plt.subplots(3, 3, sharex="col", sharey="row", figsize=(12, 12))
for n, metric in enumerate([key for key in evaluation_scores[0].keys() if key != 'Goto']):
    points = ax[n / 3, n % 3].scatter(
        confidence_scores,
        np.array([s[metric] for s in evaluation_scores]),
        c='#3778bf',
        alpha=.3)
    dbn_mean = np.mean([s[metric] for s in dbn_scores])
    ax[n / 3, n % 3].plot([-1, 2.], [dbn_mean, dbn_mean], 'k:', lw=4)
    dbn_std = np.std([s[metric] for s in dbn_scores])
    ax[n / 3, n % 3].set_xlim([-.05, 1.05])
    ax[n / 3, n % 3].set_ylim([-.05, 1.05])
    ax[n / 3, n % 3].set_title(metric)
#mpld3.plugins.connect(fig, mpld3.plugins.LinkedBrush(points))

In [None]:
fig, ax = plt.subplots(1, 3, sharex="col", sharey="row", figsize=(12, 3))
for n, metric in enumerate(['F-measure', 'Any Metric Level Total', 'Information gain']):
    points = ax[n].scatter(
        confidence_scores,
        np.array([s[metric] for s in evaluation_scores]),
        c='#3778bf',
        alpha=.3)
    dbn_mean = np.mean([s[metric] for s in dbn_scores])
    ax[n].plot([-1, 2.], [dbn_mean, dbn_mean], 'k:', lw=4)
    dbn_std = np.std([s[metric] for s in dbn_scores])
    ax[n].set_xlim([-.05, 1.05])
    ax[n].set_ylim([-.05, 1.05])
    ax[n].set_title(metric)
    if n == 0:
        ax[n].set_ylabel('Metric score')
    if n == 1:
        ax[n].set_xlabel('Confidence score')
plt.savefig('beat_scores.pdf', bbox_inches='tight', pad_inches=0.1)

[here](beat_scores.pdf)

### Compare to beat tracker

To produce beat annotations using [madmom](https://github.com/CPJKU/madmom)'s "DBNBeatTracker" (general-purpose algorithm, SOTA in 2014, not trained on the Beatles) run the following from the data folder:

`mkdir dbn_annotations; DBNBeatTracker batch wav/*/*.wav -o dbn_annotations/`

In [None]:
# DBNBeatTracker flattens directory structure and names things ".beats.txt", so let's fix that
for f in glob.glob('data/isophonics/*/*.txt'):
    new_filename = f.replace('isophonics', 'dbn_annotations')
    old_filename = os.path.join(
        'data', 'dbn_annotations', os.path.splitext(os.path.split(f)[1])[0] + '.beats.txt')
    if not os.path.exists(os.path.split(new_filename)[0]):
        os.makedirs(os.path.split(new_filename)[0])
    shutil.move(old_filename, new_filename)

In [None]:
# Extract beats with librosa
for f in glob.glob('data/wav/*/*.wav'):
    audio_data, _ = librosa.load(f)
    beats = librosa.frames_to_time(librosa.beat.beat_track(audio_data)[1])
    output_beats_file = f.replace('.wav', '.txt').replace('wav', 'librosa_annotations')
    if not os.path.exists(os.path.split(output_beats_file)[0]):
        os.makedirs(os.path.split(output_beats_file)[0])
    np.savetxt(output_beats_file, beats)

In [None]:
compare_dir = 'dbn_annotations'

In [None]:
dbn_scores = []
for estimated_beats_file in glob.glob(os.path.join('data', compare_dir, '*', '*.txt')):
    ground_truth_beats_file = estimated_beats_file.replace(compare_dir, 'isophonics')
    if not os.path.exists(ground_truth_beats_file):
        continue
    ref_beats, _ = mir_eval.io.load_labeled_events(ground_truth_beats_file)
    est_beats = mir_eval.io.load_events(estimated_beats_file)
    dbn_scores.append(mir_eval.beat.evaluate(ref_beats, est_beats))

In [None]:
thresholds = np.linspace(0, .95, 95, endpoint=False)
fig, ax = plt.subplots(3, 3, sharex="col", figsize=(12, 12))
for n, metric in enumerate([key for key in evaluation_scores[0].keys() if key != 'Goto']):
    metric_scores = np.array([s[metric] for s in evaluation_scores])
    mean_acc = np.array([np.median(metric_scores[confidence_scores > t]) for t in thresholds])
    p25_acc = np.array([np.percentile(metric_scores[confidence_scores > t], 25) for t in thresholds])
    p75_acc = np.array([np.percentile(metric_scores[confidence_scores > t], 75) for t in thresholds])
    ax[n / 3, n % 3].plot(thresholds, mean_acc, c='#3778bf', lw=2)
    ax[n / 3, n % 3].fill_between(thresholds, p25_acc, p75_acc, facecolor='#35ad6b', alpha=.2)
    ax[n / 3, n % 3].set_title(metric)
    dbn_median = np.median([s[metric] for s in dbn_scores])
    ax[n / 3, n % 3].plot([0, 1.], [dbn_median, dbn_median], 'k:', lw=4)
    if metric == 'Information gain':
        ax[n / 3, n % 3].set_ylim(0, .65)
    else:
        ax[n / 3, n % 3].set_ylim(.5, 1)
# Set common labels
fig.text(0.5, 0.08, 'Confidence score', ha='center')
fig.text(0.08, 0.5, 'Metric score', va='center', rotation='vertical')

In [None]:
plt.violinplot?

In [None]:
spacing = np.linspace(0, 1, 21)
fig, ax = plt.subplots(3, 3, sharex=True, figsize=(12, 12))
for n, metric in enumerate([key for key in evaluation_scores[0].keys() if key != 'Goto']):
    scores = []
    my_spacing = []
    for start, end in zip(spacing[:-1], spacing[1:]):
        scores_in_range = [s for m, s in zip(evaluation_scores, confidence_scores)
                           if m[metric] >= start and m[metric] < end]
        if len(scores_in_range) > 1:
            scores.append(scores_in_range)
            my_spacing.append(start)
    plt.sca(ax[n / 3, n % 3])
    plt.violinplot(scores, my_spacing, widths=np.diff(spacing)[0], showextrema=False)
    #dbn_median = np.median([s[metric] for s in dbn_scores])
    #plt.plot([0, 1], [dbn_median, dbn_median], 'k:', lw=2)
    #dbn_median = np.median([s[metric] for s in dbn_scores])
    #plt.plot([dbn_median, dbn_median], [0, 2*max(n)], 'k:', lw=2)
    plt.title(metric)

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 12))
for n, metric in enumerate([key for key in evaluation_scores[0].keys() if key != 'Goto']):
    metric_scores = np.array([s[metric] for s in evaluation_scores])
    plt.sca(ax[n / 3, n % 3])
    n, _, _ = plt.hist(metric_scores, np.linspace(0, 1, 21), facecolor='#35ad6b', alpha=.5, orientation='horizontal')
    dbn_median = np.median([s[metric] for s in dbn_scores])
    plt.plot([0, 2*max(n)], [dbn_median, dbn_median], 'k:', lw=2)
    plt.title(metric)
    plt.xlim([0, max(n)*1.1])

### Study particularly bad alignments

In [None]:
estimated_beats_file = 'data/extracted/07_-_Revolver/02_-_Eleanor_Rigby.txt4'
est_beats = mir_eval.io.load_events(estimated_beats_file)
ground_truth_beats_file = os.path.splitext(estimated_beats_file.replace('extracted', 'isophonics'))[0] + '.txt'
ref_beats = mir_eval.io.load_labeled_events(ground_truth_beats_file)[0]
d = deepdish.io.load(estimated_beats_file.replace('extracted', 'diagnostics').replace('.txt', '.h5'))
print mir_eval.beat.evaluate(ref_beats, est_beats)
print d['score']

#midi_object = pretty_midi.PrettyMIDI(str(d['output_midi_filename']))
#m = midi_object.fluidsynth(22050)
#a, fs = librosa.load(str(d['audio_filename']))
#IPython.display.Audio([a, mir_eval.sonify.clicks(est_beats, 22050, length=a.shape[0])], rate=22050)
plt.figure()
plt.plot(np.diff(d['aligned_midi_indices']))
plt.ylim([-.1, 1.1])
plt.figure()
plt.plot(np.diff(d['aligned_audio_indices']))
plt.ylim([-.1, 1.1])

In [None]:
for d, metrics, ref_b, est_b in zip(diagnostics, evaluation_scores, ref_beats, est_beats):
    if d['score'] > .6 and metrics['Information gain'] < .05:
        audio_filename = d['audio_filename']
        audio, fs = librosa.load(d['audio_filename'])
        ref_clicks = mir_eval.sonify.clicks(
            ref_b, fs, length=audio.shape[0])
        est_clicks = mir_eval.sonify.clicks(
            est_b, fs, length=audio.shape[0])
        IPython.display.display(IPython.display.Audio([ref_clicks + audio, est_clicks + audio], rate=fs))
        fig = plt.figure(figsize=(15, .5))
        dist = np.min(np.abs(np.subtract.outer(est_b, ref_b)), axis=1)
        plt.scatter(est_b, np.zeros(len(est_b)), c=dist, vmin=0, vmax=.25, alpha=.5, lw=0, cmap=plt.cm.jet)
        break

In [None]:
import alignment_analysis
midi_object = pretty_midi.PrettyMIDI(str(d['output_midi_filename']))
IPython.display.Audio(alignment_analysis.synthesize_aligned_midi(audio, fs, midi_object), rate=fs)

In [None]:
import djitw

In [None]:
audio_gram = deepdish.io.load(str(d['audio_features_filename']))['gram']
midi_gram = deepdish.io.load(str(d['midi_features_filenameusigram']
D = 1 - np.dot(midi_gram, audio_gram.T)
p, q, score = djitw.dtw(D, .96, np.median(D), inplace=False)
mask = np.zeros_like(D)
djitw.band_mask(.1, mask)
print score/len(p)
print score/len(p)/(np.sum(D*mask)/mask.sum())

# Key detection experiments

In [None]:
import mir_eval
import glob
import pretty_midi
import os
import numpy as np
import collections
import librosa
import vamp
import deepdish

### Key loading/computing functions

In [None]:
def load_midi_key(filename):
    ''' Load in key labels from a MIDI file '''
    # Load in MIDI object and grab key change events
    pm = pretty_midi.PrettyMIDI(filename)
    key_changes = pm.key_signature_changes
    # Convert each key change's number to a string (like 'C Major')
    # Also convert it to lowercase, for mir_eval's sake
    return [pretty_midi.key_number_to_key_name(k.key_number).lower()
            for k in key_changes]

In [None]:
def load_isophonics_key(filename):
    ''' Read in key labels from an isophonics lab file '''
    # Isophonics key lab files have three columns:
    # start time, end time, and label
    start, end, labels = mir_eval.io.load_delimited(
        filename, [float, float, str])
    # Extract key labels, which in lab files are formatted as
    # 'key\tC' or 'key\tC:minor'
    keys = [l.split('\t')[1] for l in labels if 'Key' in l]
    # Convert from 'C' and 'C:minor' to 'c major' and 'c minor'
    for n, key in enumerate(keys):
        if 'minor' in key:
            keys[n] = key.replace(':', ' ').lower()
        else:
            keys[n] = key.lower() + ' major'
        # Validate the key early
        mir_eval.key.validate_key(keys[n])
    return keys

In [None]:
def load_vamp_key(filename):
    ''' Estimate the key from an audio file using QM key detector '''
    # Load in audio data at its native sampling rate
    audio_data, fs = librosa.load(filename, sr=None)
    # Create a vamp processor that will generate key labels
    key_generator = vamp.process_audio_multiple_outputs(
        audio_data, fs, 'qm-vamp-plugins:qm-keydetector', ['key'])
    # Grab the key labels produced by the vampplugin
    vamp_output = [out['key'] for out in key_generator]
    keys = [l['label'] for l in vamp_output]
    # Compute the durations of each key in the song
    starts = [float(l['timestamp']) for l in vamp_output]
    starts.append(librosa.get_duration(audio_data, fs))
    durations = np.diff(starts)
    unique_keys = list(set(keys))
    key_durations = [sum(d for k, d in zip(keys, durations) if k == key)
                     for key in unique_keys]
    # Retrieve the key which spanned the most of the song
    most_common_key = unique_keys[np.argmax(key_durations)]
    # Sometimes vamp produces keys like
    # 'Eb / D# minor'
    # so here, we are just retrieving the last part ('D# minor')
    if ' / ' in most_common_key:
        most_common_key = most_common_key.split(' / ')[1]
    return most_common_key

### Get the accuracy of different key estimations/annotations compared to isophonics

In [None]:
# Keep track of the number of files skipped for different reasons
n_skipped = collections.defaultdict(int)
# Keep track of the weighted accuracy for each file for each source
scores = collections.defaultdict(list)
# Keep track of whether each MIDI estimated key is C major
c_majors = []
for lab_filename in glob.glob('data/isophonics_key/*/*.lab'):
    # Load Isophonics key from .lab file
    try:
        isophonics_keys = load_isophonics_key(lab_filename)
    except Exception as e:
        # Keep track of how many isophonics files which have invalid keys
        print 'Error for {}: {}'.format(lab_filename, e)
        n_skipped['isophonics_invalid'] += 1
        continue
    # If there are more than 1 Isophonics keys, skip
    if len(isophonics_keys) > 1:
        n_skipped['>1_isophonics_keys'] += 1
        continue
    isophonics_key = isophonics_keys[0]
    
    # Loop over all possible MIDI files for this key
    midi_glob = lab_filename.replace('isophonics_key', 'mid').replace('.lab', '.mid*')
    for midi_filename in glob.glob(midi_glob):
        # Get keys from MIDI file
        try:
            midi_keys = load_midi_key(midi_filename)
        except Exception as e:
            print 'Error for {}: {}'.format(midi_filename, e)
            n_skipped['midi_exceptions'] += 1
            continue
        # If there's no key change event, skip
        if len(midi_keys) == 0:
            n_skipped['no_midi_keys'] += 1
            continue
        # If there's multiple key change events, skip
        if len(midi_keys) > 1:
            n_skipped['>1_midi_keys'] += len(midi_keys) > 1
            continue
        midi_key = midi_keys[0]
        # Keep track of whether the estimated key was a C major
        c_majors.append(midi_keys[0] == 'c major')
        # Compute and store score for this MIDI file
        scores['midi'].append(mir_eval.key.weighted_score(isophonics_key, midi_key))

    # Construct .wav filename from .lab filename
    audio_filename = lab_filename.replace('isophonics_key', 'wav').replace('.lab', '.wav')
    # Estimate the key using vamp QM key detector plugin
    try:
        vamp_key = load_vamp_key(audio_filename)
    except Exception as e:
        print 'Error for {}: {}'.format(audio_filename, e)
        n_skipped['audio_exceptions'] += 1
        continue
    scores['vamp'].append(mir_eval.key.weighted_score(isophonics_key, vamp_key))

    # Construct whatkeyisitin text filename from .lab filename
    whatkeyisitin_filename = lab_filename.replace('isophonics_key', 'whatkeyisitin_key').replace('.lab', '.txt')
    if not os.path.exists(whatkeyisitin_filename):
        # Keep track of how many are skipped due to missing wkiii annotation
        n_skipped['no_wkiii_file'] += 1
        continue
    with open(whatkeyisitin_filename) as f:
        whatkeyisitin_key = f.read()
    scores['wkiii'].append(mir_eval.key.weighted_score(isophonics_key, whatkeyisitin_key))

In [None]:
import tabulate

In [None]:
# Print statistics about the MIDI key accuracy
for key, value in n_skipped.items():
    print '{} skipped because {}'.format(value, key)
print 'Total isophonics .lab files: {}'.format(len(glob.glob('data/isophonics_key/*/*.lab')))
print
mean_scores = collections.OrderedDict([
    ('MIDI, all keys', np.mean(scores['midi'])),
    ('MIDI, C major only', np.mean([s for c, s in zip(c_majors, scores['midi']) if c])),
    ('MIDI, non-C major', np.mean([s for c, s in zip(c_majors, scores['midi']) if not c])),
    ('QM Key Detector', np.mean(scores['vamp'])),
    ('whatkeyisitin.com', np.mean(scores['wkiii']))])
n_comparisons = collections.OrderedDict([
    ('MIDI, all keys', len(scores['midi'])),
    ('MIDI, C major only', sum(c_majors)),
    ('MIDI, non-C major', len([c for c in c_majors if not c])),
    ('QM Key Detector', len(scores['vamp'])),
    ('whatkeyisitin.com', len(scores['wkiii']))])
print tabulate.tabulate(
    [(name, score, num) for (name, score, num) in 
     zip(mean_scores.keys(), mean_scores.values(), n_comparisons.values())],
    ['Source', 'Mean score', '# of comparisons'])

In [None]:
sum(len(s['key_numbers']) for s in statistics)

In [None]:
len([k for k in sum([s['key_numbers'] for s in statistics], []) if k == 0])

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
ts = np.linspace(0, 1, 101)
a = [np.mean([s for c, s, a in zip(c_majors, scores, alignment_scores) if not c and a > t]) for t in ts]
plt.plot(ts, a)

### Get the accuracy of vamp keys compared to isophonics

In [None]:
# Keep track of the number of files skipped for different reasons
n_skipped = collections.defaultdict(int)
# Keep track of the weighted accuracy for each file
scores = []
for audio_filename in glob.glob('data/wav/*/*.wav'):

    
    # Construct .lab file path from .wav file path
    base_path, filename = os.path.split(audio_filename)
    lab_filename = os.path.join(base_path.replace('wav', 'isophonics_key'),
                                os.path.splitext(filename)[0] + '.lab')
    if not os.path.exists(lab_filename):
        n_skipped['no_lab_file'] += 1
        continue
    # Load in Isophonics keys from .lab file
    try:
        isophonics_keys = load_isophonics_key(lab_filename)
    except Exception as e:
        print 'Error for {}: {}'.format(lab_filename, e)
        n_skipped['isophonics_exceptions'] += 1
        continue
    # If there are more than 1 Isophonics keys, skip
    if len(isophonics_keys) > 1:
        n_skipped['>1_isophonics_keys'] += 1
        continue
    # Compute and store score for this wav file
    scores.append(mir_eval.key.weighted_score(isophonics_keys[0], vamp_key))

In [None]:
print 'Total possible: {}'.format(len(glob.glob('data/wav/*/*.wav')))
print '# of audio exceptions: {}'.format(n_skipped['audio_exceptions'])
print '# of Isophonics exceptions: {}'.format(n_skipped['isophonics_exceptions'])
print '# of Isophonics with >1 keys: {}'.format(n_skipped['>1_isophonics_keys'])
print '# of valid comparisons: {}'.format(len(scores))
print 'Accuracy all: {:.3f}'.format(np.mean(scores))

In [None]:
with open('data/whatkeyisitin.txt') as f:
    whatkeyisitin_list = [l.strip().split('\t') for l in f]
# Load in beatles index for searching
beatles_index = whoosh_search.get_whoosh_index(
    os.path.join('data', 'index'))
with beatles_index.searcher() as searcher:
    # Get list of beatles entries; we will use this to retrieve paths.
    beatles_list = list(searcher.documents())
    for (title, key) in whatkeyisitin_list:
        # Search the beatles index for this track
        matches = whoosh_search.search(searcher, beatles_index.schema, 'The Beatles', fix(title), 9)
        # If we have this beatles track in the index
        if len(matches) > 0:
            # Grab the base path for this track
            path = [e['path'] for e in beatles_list if e['id'] == matches[0][0]][0]
            output_filename = os.path.join('data', 'whatkeyisitin_key', path + '.txt')
            if not os.path.exists(os.path.split(output_filename)[0]):
                os.makedirs(os.path.split(output_filename)[0])
            with open(output_filename, 'wb') as f:
                f.write(key)

In [None]:
# Keep track of the number of files skipped for different reasons
n_skipped = collections.defaultdict(int)
# Keep track of the weighted accuracy for each file
scores = []
for whatkeyisitin_file in glob.glob('data/whatkeyisitin_key/*/*.txt'):
    with open(whatkeyisitin_file) as f:
        whatkeyisitin_key = f.read()
    isophonics_file = whatkeyisitin_file.replace('whatkeyisitin_key', 'isophonics_key').replace('.txt', '.lab')
    try:
        isophonics_keys = load_isophonics_key(isophonics_file)
    except Exception as e:
        print 'Error for {}: {}'.format(lab_filename, e)
        n_skipped['isophonics_exceptions'] += 1
        continue
    # If there are more than 1 Isophonics keys, skip
    if len(isophonics_keys) > 1:
        n_skipped['>1_isophonics_keys'] += 1
        continue
    scores.append(mir_eval.key.weighted_score(isophonics_keys[0], whatkeyisitin_key))

In [None]:
print 'Total possible: {}'.format(len(glob.glob('data/whatkeyisitin_key/*/*.txt')))
print '# of Isophonics exceptions: {}'.format(n_skipped['isophonics_exceptions'])
print '# of Isophonics with >1 keys: {}'.format(n_skipped['>1_isophonics_keys'])
print '# of valid comparisons: {}'.format(len(scores))
print 'Accuracy all: {:.3f}'.format(np.mean(scores))