In [None]:
import numpy as np
import librosa
import matplotlib.pyplot as plt
import pretty_midi
import midi
import matplotlib.patches
import matplotlib
import djitw
import scipy.spatial
import librosa
import sys
import os
import csv
import glob
import tabulate
import ujson as json
import collections
%matplotlib inline
import seaborn as sns
matplotlib.rc('font',**{'size':16, 'family':'Open Sans'})

In [None]:
# Testing colors

BLUE = '#1a6396'
GREEN = '#59dd97'
ORANGE = '#E8B71A'
GREY = '#DFDFDF'
RED = '#DB3340'
TAN = '#F7EAC8'
FIGSIZE = (9, 6)
FIGSIZE_FLAT = (9, 2)
plt.figure(figsize=(8, 3))
plt.gca().add_patch(plt.Rectangle((-1, -1), 8, 1.5, fc=TAN, lw=0))
plt.gca().add_patch(plt.Rectangle((0, 0), 1, 1, fc=BLUE, lw=0))
plt.gca().add_patch(plt.Rectangle((1, 0), 1, 1, fc=GREEN, lw=0))
plt.gca().add_patch(plt.Rectangle((2, 0), 1, 1, fc=GREY, lw=0))
plt.gca().add_patch(plt.Rectangle((3, 0), 1, 1, fc=RED, lw=0))
plt.gca().add_patch(plt.Rectangle((4, 0), 1, 1, fc=ORANGE, lw=0))
plt.gca().add_patch(plt.Rectangle((5, 0), 1, 1, fc=TAN, lw=0))
plt.gca().add_patch(plt.Rectangle((5, 0), 1, 1, fc=TAN, lw=0))
plt.xlim([-1, 7])
plt.ylim([-1, 2])
plt.axis('off')

# Chapter 1

In [None]:
plt.figure(figsize=(FIGSIZE[0], FIGSIZE[1]*2))
ax = plt.gca()
t = np.linspace(.1, .9, 9)
plt.vlines(t, 0, 1., linestyles='dashed', alpha=.3, zorder=-1)

vc = .9
words = 'The quick brown fox jumps over the lazy dog'.split(' ')
for x, word in zip(t, words):
    ax.text(x, vc, word, {'family': 'monospace', 'size': 12}, ha='center', va='center')

vc = .7
signal = .08*np.sin(2.4*t*np.pi) + vc
plt.plot(t, signal, 'k.', ms=10)
plt.vlines(t, [s if s > vc else vc for s in signal],
           [s if s < vc else vc for s in signal], lw=1.2) 

vc = .5
a, _ = librosa.load('data/1_A.wav')
N = a.shape[0]
for x in t:
    frame = a[(x - .1)*N:(x + .1)*N]
    spectrum = np.abs(np.fft.rfft(frame))
    spectrum = spectrum[:spectrum.shape[0]/3]
    spectrum = spectrum/3000.
    plt.plot(x + spectrum, np.linspace(vc - .08, vc + .08, spectrum.shape[0]), 'k')

axis = plt.axis()
vc = .3
w = .02
h = .08
dna_names = ['T', 'A', 'C', 'G']
for x in t:
    dna = np.zeros((4, 1))
    n = np.random.randint(0, 4)
    dna[n] = 1
    plt.imshow(dna, interpolation='nearest', extent=(x - w, x + w, vc - h/2, vc + h), cmap=plt.cm.gray)
    plt.plot((x - w, x - w, x + w, x + w, x - w), (vc - h/2, vc + h, vc + h, vc - h/2, vc - h/2), 'k')
    ax.text(x, vc - h/2 - .01, dna_names[n], {'family': 'monospace', 'size': 16}, ha='center', va='top')
plt.axis(axis)

vc = .1
for n, x in enumerate(t):
    im = plt.imread('data/1_video/{}.png'.format(n + 1))
    plt.imshow(im, interpolation='nearest', extent=(x - .03, x + .03, vc - .08, vc + .08))

plt.xlim([0.05, 0.95])
plt.ylim([0, 1])
plt.yticks([])
plt.axis('off')

vc = 0.
words = 'The quick brown fox jumps over the lazy dog'.split(' ')
for n, x in enumerate(t):
    ax.text(x, vc, '$t_{}$'.format(n), {'family': 'monospace', 'size': 16}, ha='center', va='top')

plt.savefig('1-example_sequences.pdf', transparent=True, bbox_inches='tight', pad_inches=0.)

In [None]:
np.random.seed(7)
match_length = 100
crop = match_length/5
def random_walk(N):
    return np.cumsum(np.random.random_integers(-1, 1, N))/np.log(N)
def random_sine(N):
    return np.sin(np.linspace(0, 5*np.pi, N)*np.random.uniform(.9, 1.1) + np.random.uniform(0, 2*np.pi))

match = random_sine(match_length + match_length/10)
query = np.interp(np.arange(match_length),
                  np.arange(match_length + match_length/10),
                  match + .5*random_walk(match_length + match_length/10))
match = match[match_length/10:] + .5*random_walk(match_length)
match = (match - match.mean())/match.std()
query = (query - query.mean())/query.std()

D = np.subtract.outer(match, query[crop/2:-crop/2])**2
p, q, score = djitw.dtw(D, inplace=False)

In [None]:
ds = 3

plt.figure(figsize=FIGSIZE)
plt.plot(match - match.min(), GREEN, lw=2)
plt.plot(query - query.max(), BLUE, lw=2)

for n in range(0, match_length, ds) + [match_length - 1]:
    plt.plot([n, n], [match[n] - match.min(), query[n] - query.max()], 'k:', lw=2)
    
plt.xlim(-1, plt.axis()[1])
plt.axis('off')
plt.savefig('1-example_distance_unwarped.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

plt.figure(figsize=FIGSIZE)
plt.plot(match - match.min(), GREEN, lw=2)
plt.plot(np.arange(crop/2, match_length - crop/2),
         query[crop/2:-crop/2] - query[crop/2:-crop/2].max(), BLUE, lw=2)

for p_n, q_n in zip(p[::ds], q[::ds]):
    
    plt.plot([p_n, q_n + crop/2],
             [match[p_n] - match.min(), query[crop/2:-crop/2][q_n] - query[crop/2:-crop/2].max()],
             'k:', lw=2)

plt.plot([p[-1], q[-1] + crop/2],
         [match[p[-1]] - match.min(), query[crop/2:-crop/2][q[-1]] - query[crop/2:-crop/2].max()],
         'k:', lw=2)

plt.xlim(-1, plt.axis()[1])
plt.axis('off')
plt.savefig('1-example_distance_warped.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

# Chapter 2

In [None]:
signal1 = np.array([509, 113, -229, 253, -96, -195, 180, -303, -361, 17,
                    -13, 242, 14, -230, 300, 89, -112, -236, -298])
signal2 = np.array([543, 401, 122, -288, 62, 259, 180, -72, -336, 10,
                    223, 263, 35, -345, 68, 400, 38, -109, -301])
q = [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8, 9, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 18]
p = [0, 0, 1, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 12, 13, 14, 14, 15, 16, 17, 18]
plt.figure(figsize=FIGSIZE)
plt.plot(signal1, lw=2, c=BLUE)
plt.plot(signal2 + 1000, lw=2, c=GREEN)
for p_n, q_n in zip(p, q):
    plt.plot([p_n, q_n],  [signal1[p_n], signal2[q_n] + 1000], 'k:', lw=2)
ax = plt.gca()
ax.get_yaxis().set_visible(False)
sns.despine(left=True)
plt.xticks(range(0, 19, 3), range(1, 20, 3))
plt.xlim([-.1, 18.1])
plt.savefig('2-example_dtw_sequences.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

plt.figure(figsize=FIGSIZE)
dist = scipy.spatial.distance.cdist(signal1.reshape(-1, 1), signal2.reshape(-1, 1))
plt.imshow(dist, cmap=plt.cm.hot, interpolation='nearest')
axis = plt.axis()
for x, y in zip(q, p):
    plt.plot([x - .5, x - .5, x + .5, x + .5, x - .5], [y - .5, y + .5, y + .5, y - .5, y - .5], 'w')
plt.axis(axis)
plt.xticks(range(0, 19, 3), range(1, 20, 3))
plt.yticks(range(0, 19, 3), range(1, 20, 3))
plt.savefig('2-example_dtw_matrix.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

# Chapter 3

In [None]:
ALIGNMENT_SEARCH_PATH = '/Users/craffel/Documents/projects/alignment-search/'
sys.path.append(ALIGNMENT_SEARCH_PATH)
import corrupt_midi
import create_data
import find_best_aligners
import db_utils

In [None]:
# Some utility functions
def compute_cqt(audio_data):
    """ Compute the log-magnitude L2 normalized CQT """
    cqt, times = create_data.extract_cqt(audio_data)
    cqt = librosa.logamplitude(cqt, ref_power=cqt.max())
    return librosa.util.normalize(cqt, 2).T, times
def display_cqt(cqt):
    """ Plot a CQT with sane defaults """
    plt.imshow(cqt.T, aspect='auto', interpolation='nearest',
               origin='lower', cmap=plt.cm.hot,
               vmin=np.percentile(cqt, 1), vmax=np.percentile(cqt, 99))
    plt.yticks(range(0, 48, 12), [librosa.midi_to_note(n) for n in range(36, 36 + 48, 12)])

In [None]:
np.random.seed(2)

# Grab a MIDI file from the clean MIDIs we used in this experiment
midi_file = os.path.join(ALIGNMENT_SEARCH_PATH, 'data/mid/Come as You Are.mid')
# Parse the MIDI file with pretty_midi
midi_object = pretty_midi.PrettyMIDI(midi_file)

# For illustration, we'll plot a CQT of the MIDI object
# before and after corruptions.
plt.figure(figsize=FIGSIZE_FLAT)
original_cqt, original_times = compute_cqt(midi_object.fluidsynth(22050))
display_cqt(original_cqt)
#plt.title('Original MIDI CQT')
plt.savefig('3-original_cqt.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

# This is the wrapper function to apply all of the corruptions 
# defined in corrupt_midi
adjusted_times,  diagnostics = corrupt_midi.corrupt_midi(
    midi_object, original_times,
    # This defines the extent to which time will be warped
    warp_std=20,
    # These define how likely we are to crop out sections
    # We'll set them to 1 and 0 here for illustration; in the 
    # paper they are adjusted according to the desired corruption level
    start_crop_prob=0., end_crop_prob=0., middle_crop_prob=1.,
    # The likelihood that each instrument is removed
    remove_inst_prob=.5,
    # The likelihood that an instrument's program number is changed
    change_inst_prob=1.,
    # The standard deviation of velocity adjustment
    velocity_std=1.)

# Now, we can plot the CQT after corruptions.
plt.figure(figsize=FIGSIZE_FLAT)
corrupted_cqt, corrupted_times = compute_cqt(midi_object.fluidsynth(22050))
display_cqt(corrupted_cqt)
plt.xlabel('Frame')
#plt.title('After corruption')
plt.savefig('3-corrupted_cqt.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

# We can also plot the timing offset, which we will try to reverse
plt.figure(figsize=FIGSIZE)
plt.plot(original_times, original_times - adjusted_times, BLUE, lw=2)
plt.xlim([0, original_times.max()])
plt.xlabel('Original time')
plt.ylabel('Offset from original time')
plt.savefig('3-warping.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

In [None]:
# Compute a pairwise distance matrix of the original and corrupted CQTs
distance_matrix = scipy.spatial.distance.cdist(
    original_cqt, corrupted_cqt, 'sqeuclidean')
# Compute the lowest-cost path via DTW with "golden standard" parameters
p, q, score = djitw.dtw(
    distance_matrix, .96, np.median(distance_matrix), inplace=0)

# Plot the aligned corrupted times and ground-truth times
plt.figure(figsize=FIGSIZE)
plt.plot(original_times[p], original_times[p] - adjusted_times[p], BLUE, lw=2)
plt.plot(original_times[p], original_times[p] - corrupted_times[q], GREEN, lw=2)
plt.xlim([0, original_times.max()])
plt.legend(['Ground-truth offset', 'Fixed corrupted offset'], loc='upper left')
plt.xlabel('Original time')
plt.ylabel('Offset from original time')
plt.savefig('3-warping_corrected.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

# Compute the absolute error, clipped to within .5 seconds
plt.figure(figsize=FIGSIZE)
error = np.abs(np.clip(
    corrupted_times[q] - adjusted_times[p], -.5, .5))
plt.plot(original_times[p], error, BLUE, lw=2)
plt.xlabel('Time')
plt.ylabel('Correction error')
plt.xlim([0, original_times.max()])
plt.savefig('3-correction_error.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

In [None]:
# Load in the results from the parameter search experiment
params, objectives = db_utils.get_experiment_results(
    os.path.join(ALIGNMENT_SEARCH_PATH, 'results/parameter_experiment_gp/*.json'))
# Truncate to the top 20 results
good = np.argsort(objectives)[:10]
params = [params[n] for n in good]
objectives = [objectives[n] for n in good]
# Pretty-print using tabulate
for param, objective in zip(params, objectives):
    param['objective'] = objective
header_names = collections.OrderedDict([
    ('add_pen', '$\phi$ Median Scale'),
    ('standardize', 'Standardize?'),
    ('gully', 'Gully $g$'),
    ('objective', 'Mean Error')])
def yes_no(x):
    if isinstance(x, bool):
        if x:
            return 'Yes'
        else:
            return 'No'
    else:
        return x
print tabulate.tabulate([collections.OrderedDict([(k, yes_no(p[k])) for k in header_names]) for p in params],
                        headers=header_names, tablefmt='latex_booktabs')

In [None]:
# Load in all confidence reporting experiment trials
trials = []
for trial_file in glob.glob(os.path.join(ALIGNMENT_SEARCH_PATH, 'results/confidence_experiment/*.json')):
    with open(trial_file) as f:
        trials.append(json.load(f))
# Retrieve the lowest-achieved mean absolute error
best_easy_error = objectives[0]
# Retrieve the confidence reporting trial for this system
best_trial = [t for t in trials
               if np.allclose(np.mean(t['results']['easy_errors']),
                              best_easy_error)][0]
# Retrieve the results from this trial
best_result = best_trial['results']

# Plot a scatter plot of mean alignment error vs. confidence score
errors = np.array(best_result['hard_errors'] + best_result['easy_errors'])
scores = np.array(best_result['hard_penalty_len_norm_mean_norm_scores'] +
                  best_result['easy_penalty_len_norm_mean_norm_scores'])
plt.figure(figsize=FIGSIZE)
plt.scatter(errors, scores, marker='+', c='black', alpha=.3, s=40)
plt.gca().set_xscale('log')
plt.ylim(0., 1.1)
plt.xlim(.9*np.min(errors), np.max(errors)*1.1)
plt.xlabel('Alignment error')
plt.ylabel('Normalized DTW distance')
plt.savefig('3-correlation.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)

In [None]:
with open(os.path.join(ALIGNMENT_SEARCH_PATH, 'results/alignment_ratings.csv')) as f:
    reader = csv.reader(f)
    # Cast each entry in each row to the correct type
    ratings = [[int(alignment_id), int(rating), np.clip(2*(1 - float(score)), 0, 1), note]
               for alignment_id, rating, score, note in reader]

# We made notes about each alignment, too.
# Here are all the alignments where a transcription was matched to a remix
remixes = [r[1:] for r in ratings if ('remix' in r[-1].lower())]
remixes.sort(key = lambda x: -x[1])
print tabulate.tabulate(remixes, headers=['Rating', 'Confidence Score', 'Note'], tablefmt='latex_booktabs')

In [None]:
# Plot a histogram for each rating
plt.figure(figsize=FIGSIZE)
data = [np.array([r[2] for r in ratings if r[1] == n]) for n in [1, 2, 3, 4, 5]]
violins = plt.violinplot(
    data, showextrema=False, showmeans=False,
    widths=[float(len(d))/max(len(d) for d in data) for d in data])
patches = plt.boxplot(data, showmeans=False, showcaps=False, showfliers=False, 
                      patch_artist=True, widths=.1)
for line in patches['whiskers']:
    line.set_visible(False)
for box in patches['boxes']:
    box.set_facecolor('None')
    box.set_alpha(.5)
    box.set_joinstyle('round')
    box.set_facecolor('w')
    box.set_edgecolor('k')
for line in patches['medians']:
    line.set_color('black')
for body in violins['bodies']:
    body.set_alpha(.8)
for n in [0, 1]:
    violins['bodies'][n].set_facecolor(BLUE)
for n in [2, 3, 4]:
    violins['bodies'][n].set_facecolor(GREEN)
plt.xticks(
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5])
    #['Wrong song', 'Bad alignment', 'Sloppy', 'Embellishments', 'Perfect'],
    #rotation=20)
plt.xlim([.5, 5.5])
plt.xlabel('Rating')
plt.ylabel('Confidence score')
plt.legend(handles=[matplotlib.patches.Patch(color=BLUE, label='Incorrect'),
                    matplotlib.patches.Patch(color=GREEN, label='Correct')],
           loc='upper left')
plt.ylim(-.03, 1.03)
plt.savefig('3-violin.pdf', transparent=True, bbox_inches='tight', pad_inches=0.1)