In [None]:
!pip install numpy scipy matplotlib

In [None]:
import os
import sys
import array

import numpy as np

import matplotlib
import matplotlib.pyplot as plt


In [None]:
def read_s16_le_audio(name='violin.raw'):
    size = os.stat(name).st_size
    sample_count = size//2          # S16
    data = array.array('h')
    print('reading {} samples as {} bytes from: {}'.format(sample_count, size, name))
    with open(name,'rb') as f:
        data.fromfile(f, sample_count)
        if sys.byteorder != 'little':
            data.byteswap()
    return data


In [None]:
data = read_s16_le_audio()
ch1,ch2 = data[::2],data[1::2]
assert len(ch1) == len(ch2)
len(ch1)

In [None]:
plt.figure(figsize=(16,4))
plt.plot(ch1)

In [None]:
def acf_fft(x):
    pad_len = len(x)//2  *2
    px = np.pad(x, pad_width=(0, pad_len), mode='constant')
    dft = np.fft.fft(px)
    mag = dft*np.conj(dft)
    ift = np.fft.ifft(mag)
    ift = np.real(ift[:len(x)])    
    return ift

def crosscorr_fft(a, b):
    assert len(a) == len(b)
    pad_len = len(a)//2  *2
    pa = np.pad(a, pad_width=(0, pad_len), mode='constant')
    pb = np.pad(b, pad_width=(0, pad_len), mode='constant')
    
    a_fft = np.fft.fft(pa)
    b_fft = np.fft.fft(pb)
    mag = a_fft*np.conj(b_fft)
    ift = np.fft.ifft(mag)
    xcross = np.real(ift)[:len(a)]   
    return xcross
    
def wsnac_fft(x, w):
    wacorr = acf_fft(x*w) # crosscorr_fft(x*w, x*w) 
    x2w = x*x*w
    wm = crosscorr_fft(x2w, w) + crosscorr_fft(w,x2w)
    return 2*wacorr/(wm+1e-12)

def sin_window(length):
    return np.sin(np.pi*np.arange(length)/(length-1))


def f0_from_wsnac(xx, rate):
    slope_thresh_pos =  0.8
    slope_thresh_neg = -0.1
    
    above = False
    pos = True
    
    imax = 0
    vmax = 0
    amax = 0
    
    thresh = 0.1
    for i in range(0,len(xx)):
        lpos = pos
        x = xx[i]
        
        # hystheresis
        if pos:
            if x < slope_thresh_neg:
                pos = False
        elif x >= slope_thresh_pos:
            pos = True
            
        if i < 3:
            continue # ignore first few samples
        
        #print('{:4d} {:d} {:d} {:d} | {:.3f}@{} ({:+.3f})'.format(i, pos, lpos, above, vmax, imax, x))
        
        if pos and (not lpos): # positive slope
            above = True
            
        if above:
            if x > vmax:
                imax = i
                vmax = x
            if lpos and not pos: # negative slope
                above = False
                amax = imax
                break
    
    if (amax<4) or amax > len(xx)//3:
        return 0, 0
    
    pol = np.polyfit([amax-1,amax,amax+1],xx[amax-1:amax+2],2)
    pmax = -pol[1]/pol[0]/2
    fmax = np.poly1d(pol)(pmax)
        
    pfreq = 1./(pmax / rate)
        
    #print(amax,pmax, pfreq, fmax)
        
    return pfreq, fmax
    
    
def midi_tone_frames_wsnac(x):
    tones = []
    clars = []
    freqs = []
    for pos in range(0,(len(data)//1024)*1024,1024):

        x = np.array(data[pos:pos+1024])/(2**15-1)
        x = x-x.mean()
        wsnac=wsnac_fft(x,w)

        f0,clar = f0_from_wsnac(wsnac, samp_rate)
        if f0 > 100 and clar >0.8:
            midiTone = 69 + 12*np.log(f0/440.)/np.log(2)
        else:
            midiTone = 0

        #tone = int(round(toneA440))

        tones.append(midiTone)
        clars.append(clar)
        freqs.append(f0)
    
    return tones, np.array(clars), freqs

In [None]:
start_ms =  3000 #0
end_ms   = 12000 #12000

samp_start = int(start_ms * 48000/1000)
samp_end   = int(end_ms * 48000/1000)
samp_end = ((samp_end-samp_start)//1024)*1024 + samp_start

data = ch1[samp_start:samp_end]
print('using range [{}:{}] of {:.2f}*1024 samples'.format(samp_start, samp_end, len(data)/1024))

mpos = 332

NOTE_NAMES=['A','A#','B','C','C#','D','D#','E','F','F#','G','G#']
frame_size = 1024
samp_rate = 48000
w = sin_window(frame_size)


plt.figure(figsize=(16,4*3))

plt.subplot(311)
plt.plot(data)
plt.axvline(mpos*1024)

tones, clars, freqs = midi_tone_frames_wsnac(data)

plt.subplot(312)
plt.plot(np.linspace(0,1024*len(freqs)/48000,len(freqs)), freqs) 
plt.grid()
plt.axvline(mpos*1024/48000)

plt.subplot(313)
plt.plot(np.linspace(0,1024*len(freqs)/48000,len(clars)), (clars>0.9)*freqs) 
plt.grid()
plt.axvline(mpos*1024/48000)

#plt.subplot(313)
#plt.plot(data[mpos*1024:mpos*1024+1024])

'''
    print('{:6.1f} Hz at {:2.1f} |{} {}| {:.1f} cents'.format(f0, clar, 
                    NOTE_NAMES[(8*12+tone)%12],
                    (int)(4+int((9+tone)/12)),
                    (toneA440-tone)*100
                   ))
''';

In [None]:
x = np.array(data[mpos*1024:mpos*1024+1024])/(2**15-1)
x = x-x.mean()
wsnac=wsnac_fft(x,w)
f0,clar = f0_from_wsnac(wsnac, samp_rate)

print(f0,clar)

plt.figure(figsize=(16,12))
plt.subplot(311)
plt.plot(x)

plt.subplot(312)
plt.plot(wsnac[:])
plt.grid()
