In [79]:
import os,sys,time
import numpy as np
from scipy.io import wavfile
import sklearn

from IPython.display import Audio

import musiclib, database

%load_ext cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


# DTW

In [80]:
%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport sqrt

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef align(float[:,:] sig1,float[:,:] sig2):
    cdef int d = sig1.shape[1]
    cdef int len1 = sig1.shape[0]
    cdef int len2 = sig2.shape[0]
    cdef np.ndarray[np.float32_t, ndim=2] npL = np.empty((len1,len2), dtype=np.float32)
    cdef np.ndarray[np.float32_t, ndim=2] npP = np.empty((len1,len2), dtype=np.float32)
    
    cdef float[:,:] L = npL
    cdef float[:,:] P = npP
    
    cdef float cost,tmp
    cdef int j,k,i
    for j in range(0,len1):
        for k in range(0,len2):
            cost = 0
            for i in range(d):
                tmp = sig1[j,i] - sig2[k,i]
                cost += tmp * tmp
            cost = sqrt(cost)
            
            if j == 0 and k == 0:
                L[j,k] = cost
                P[j,k] = 3
            elif k == 0:
                L[j,k] = cost + L[j-1,k]
                P[j,k] = 2
            elif j == 0:
                L[j,k] = cost + L[j,k-1]
                P[j,k] = 1
            else: # j, k > 0
                if L[j-1,k] < L[j,k-1] and L[j-1,k] < L[j-1,k-1]: # insertion (up)
                    P[j,k] = 1
                    L[j,k] = cost + L[j-1,k]
                elif L[j,k-1] < L[j-1,k-1]: # deletion (left)
                    P[j,k] = 2
                    L[j,k] = cost + L[j,k-1]
                else: # match (up left)
                    P[j,k] = 3
                    L[j,k] = cost + L[j-1,k-1]
    
    return npL,npP

In [81]:
%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport sqrt

def traceback_loss(float[:,:] sig1,float[:,:] sig2, float[:,:] L):
    sig12 = np.zeros(sig2.shape) # align 1 onto 2
    cdef int j = sig1.shape[0]-1
    cdef int k = sig2.shape[0]-1
    A = []
    C = []
    cdef float cost,tmp
    while True:
        if j == 0 and k == 0:
            A.append((0,0))
            C.append(L[0,0])
            break # got back to the beginning
        
        cost = 0
        for i in range(sig1.shape[1]):
            tmp = sig1[j,i] - sig2[k,i]
            cost += tmp * tmp
        cost = sqrt(cost)
        
        if j>0 and k>0 and L[j,k] == L[j-1,k-1] + cost: # progress
            A.append((j,k))
            C.append(L[j,k])
            j -= 1
            k -= 1
        elif k>0 and L[j,k] == L[j,k-1] + cost: # stay sig2
            A.append((j,k))
            C.append(L[j,k])
            k -= 1
        elif j>0 and L[j,k] == L[j-1,k] + cost: # stay sig1
            A.append((j,k))
            C.append(L[j,k])
            j -= 1
        else: 
#             print 'j',j
#             print 'k',k
#             print 'cost',cost
#             print 'L[j,k]',L[j,k]
#             print 'L[j,k]-cost',L[j,k]-cost
#             print 'L[j,k-1]',L[j,k-1]
#             print 'L[j-1,k]',L[j-1,k]
#             print 'L[j-1,k-1]',L[j-1,k-1]
            assert False
    
    return list(reversed(A)),list(reversed(C))

In [82]:
left_pad = 1*fs
right_pad = 0*fs
window_size=2048
stride=512
cutoff=int(50*(window_size/2048.))

In [83]:
record = 'MIDI-Unprocessed_R1_D2-13-20_mid--AUDIO-from_mp3_16_R1_2015_wav--1.wav'
synth = 'wtc1p19.wav'

fs, data1 = wavfile.read(record)
data1 = data1[0:int(9.75*fs)]
data1 = np.concatenate((np.zeros((left_pad,2)),data1),axis=0)
frep1 = database.featurize(data1,fs,musiclib.feature,window_size,stride=stride,normalize=False)

fs, data2 = wavfile.read(synth)
data2 = data2[0:fs*10]
data2 = np.concatenate((np.zeros((left_pad,2)),data2),axis=0)
frep2 = database.featurize(data2,fs,musiclib.feature,window_size,stride=stride,normalize=False)



In [84]:
print(data1.shape)
print(data2.shape)

(474075, 2)
(485100, 2)


In [85]:
Audio(data1[:,0],rate=fs)

In [86]:
Audio(data2[:,0],rate=fs)

In [87]:
# Align signals
start = time.time()
L,P = align(frep1[0:cutoff].T.astype(np.float32),frep2[0:cutoff].T.astype(np.float32))
end = time.time()
print('Elapsed time: ' + str(end - start))

Elapsed time: 0.05181884765625


In [88]:
# Find optimal path
path,costs = traceback_loss(frep1[0:cutoff].T.astype(np.float32),frep2[0:cutoff].T.astype(np.float32),L)

In [89]:
path1 = np.array([x[0] for x in path])
path2 = np.array([x[1] for x in path])

In [108]:
# Find corresponding onsets on the performance
onsets,notes = musiclib.load_midi('wtc1p19.mid')
onsets_sig2 = (onsets*fs + left_pad)/stride
onsets_sig1 = []
for onset in onsets_sig2:
    if onset > path2[-1] - right_pad/stride: # if we reached the end of the clip
        break

    sig2_index = np.argmax(path2>=onset)
    onsets_sig1.append(path1[sig2_index])
onsets_sig1 = np.array(onsets_sig1)

length of midi file69.2797896166662


In [109]:
# Find the correct onsets from the maestro midi file
onsets_correct, notes_correct = musiclib.load_midi('MIDI-Unprocessed_R1_D2-13-20_mid--AUDIO-from_mp3_16_R1_2015_wav--1.midi')
onsets_correct = (onsets_correct*fs + left_pad)/stride
onsets_correct_temp = []
for onset in onsets_correct:
    if onset > path2[-1] - right_pad/stride: # if we reached the end of the clip
        break
    onsets_correct_temp.append(onset)
onsets_correct = np.array(onsets_correct_temp)

length of midi file212.3843750000017


In [92]:
# Check our alignment
out2 = musiclib.mark_notes(data1[:,0], onsets_sig1*stride, notes)
wavfile.write('test.wav',fs,(.3*out2 + .7*data1[:,0]).astype(np.int16))
Audio(.4*out2 + .6*data1[:,0],rate=fs)

In [93]:
# Check perfect alignment from maestro midi
out2 = musiclib.mark_notes(data1[:,0], onsets_correct*stride, notes_correct)
wavfile.write('test.wav',fs,(.3*out2 + .7*data1[:,0]).astype(np.int16))
Audio(.4*out2 + .6*data1[:,0],rate=fs)

In [94]:
onsets,notes = musiclib.load_midi('wtc1p19.mid')

length of midi file69.2797896166662


In [95]:
# Note how there is some junk data in the maestro midi file, explaining
# why the maestro midi file has more onsets. 
print(np.shape(onsets_correct))
print(np.shape(onsets_sig1))
print(onsets_correct)
print(onsets_sig1)

(84,)
(53,)
[173.52172852 174.68811035 185.72387695 198.73352051 210.66650391
 224.66308594 225.2911377  253.1048584  282.08496094 294.73571777
 308.64257812 336.63574219 350.99121094 364.44946289 364.71862793
 377.45910645 391.72485352 391.90429688 405.99060059 418.91052246
 431.83044434 445.82702637 446.63452148 458.4777832  472.29492187
 473.10241699 485.12512207 498.94226074 499.74975586 512.4005127
 526.48681641 537.70202637 552.95471191 553.67248535 565.78491211
 578.70483398 579.24316406 592.88085937 607.23632812 608.40270996
 621.0534668  635.22949219 636.0369873  646.98303223 648.4185791
 659.09545898 660.88989258 661.96655273 665.1965332  673.1817627
 674.34814453 687.7166748  699.11132812 700.90576172 714.54345703
 714.90234375 716.15844727 727.9119873  741.28051758 741.6394043
 752.94433594 767.29980469 767.65869141 779.14306641 793.85742187
 806.41845703 820.86364746 821.40197754 833.15551758 834.32189941
 847.33154297 847.95959473 860.25146484 873.88916016 875.32470703
 8

In [96]:
512*499/44100

5.793378684807256

In [97]:
50/1000 * (44100/512)

4.306640625

50 millis in 4.3 strides

In [98]:
5 * 512 /44100

0.058049886621315196

In [106]:
# onsets_correct : The onsets from the maestro midi file
# onsets_predicted : The onsets generated from alignment
def evaluate_alignment(onsets_correct, onsets_predicted):
    score = 0
    for correct_onset in onsets_correct:
        found = False
        for predicted_onset in onsets_predicted:
            diff = abs(predicted_onset - correct_onset)
            if diff < 4:
                score = score + 1
                found = True
                break
        if found == False:
            print('incorrect onset : ' + str(correct_onset))
    return score

In [107]:
evaluate_alignment(onsets_correct, onsets_sig1)

incorrect onset : 253.1048583984375
incorrect onset : 928.7091064453118
incorrect onset : 929.3371582031243
incorrect onset : 929.8754882812493
incorrect onset : 941.9879150390619
incorrect onset : 942.3468017578119


78