In [60]:
import os
import sys
import wave
import math
import glob
import time
import random
import feather
import librosa
import librosa.display
import pretty_midi
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(font_scale=5)
from matplotlib.mlab import find
from sklearn.decomposition import PCA
from scipy import signal

def parabolic(f, x):
    xv = 1/2. * (f[x-1] - f[x+1]) / (f[x-1] - 2 * f[x] + f[x+1]) + x
    yv = f[x] - 1/4. * (f[x-1] - f[x+1]) * (xv - x)
    return (xv, yv)

def freq_from_autocorr(audio_signal, sr):
    """Estimate frequency using autocorrelation."""
    # Calculate autocorrelation (same thing as convolution, but with one input
    # reversed in time), and throw away the negative lags
    audio_signal -= np.mean(audio_signal)  # Remove DC offset
    corr = signal.fftconvolve(audio_signal, audio_signal[::-1], mode='full')
    corr = corr[len(corr)/2:]
    # Find the first low point
    d = np.diff(corr)
    try:
        start = find(d > 0)[0]
        # Find the next peak after the low point (other than 0 lag).  This bit is
        # not reliable for long signals, due to the desired peak occurring between
        # samples, and other peaks appearing higher.
        i_peak = np.argmax(corr[start:]) + start
        i_interp = parabolic(corr, i_peak)[0]
        freq = sr / i_interp
    except IndexError as e:
        # index could not be found, set the pitch to frequency 0
        freq = float('nan')
    
    # The voiced speech of a typical adult male will have a fundamental frequency 
    # from 85 to 180 Hz, and that of a typical adult female from 165 to 255 Hz.
    # Lowest Bass E2 (82.41Hz) to Soprano to C6 (1046.50Hz)
    
    # This number can be set with insight from the full dataset
    if freq < 100 or freq > 400:
        freq = float('nan')
        
    return freq

def load(filename, sr=8000):
    """
    Load a wave file and return the signal, sample rate and number of channels.
    Can be any format that libsndfile supports, like .wav, .flac, etc.
    """
    signal, sample_rate = librosa.load(filename, sr=sr)
    channels = 1
    return signal, sample_rate, channels

def analyse_pair(head, midi_filename, num_windows=10, fs=8000, frame_size=256):
    """Analyse a single input example from the MIR-QBSH dataset."""
    # set parameters 
    # sampling rate of the data samples is 8kHz or 8000Hz 
    # ground truth frame size is at 256
    # can oversample for increased resolution

    ## Load data and label
    fileroot = head + '/' + midi_filename
    
    # Load the midi data as well
    midi_file = midiroot + midi_filename + '.mid'
    # print(midiroot)
    midi_data = pretty_midi.PrettyMIDI(midi_file)
    # print(midi_data)
    # think about alignment to midi_data.instruments[0].notes
    # currently alignment to true midi is NOT handled

    # load data
    audio_signal, _, _ = load(fileroot+'.wav', fs)
    # print('audio_signal', audio_signal)
    # print('length of audio_signal', len(audio_signal))

    # Load matching true labelled values
    # The .pv file contains manually labelled pitch file, 
    # with frame size = 256 and overlap = 0. 
    with open(fileroot+'.pv', 'r') as f:
        y = []
        for line in f:
            y.append(float(line))
    # length of the true pitch values should match 
    # the number of audio frames to analyse
    # print('length of true pitch values', len(y))

    size_match = len(y) == len(audio_signal)/frame_size
    # print('size match', size_match)

    num_frames = len(audio_signal) / frame_size
    # print('num_frames', num_frames)

    # extract pitches (candidates), y_hat

    # allow for window overlap to handle clean transitions
    # solve for num_frames * 4 and then average out for clean sample
    num_windows = 10
    window_size = frame_size / num_windows
    y_hat_freq = []
    for n in range(num_frames):
        window_freq = []
        for i in range(num_windows):
            window_start = 0+n*frame_size+i*window_size
            window_end = frame_size+n*frame_size+i*window_size
            window_s = audio_signal[window_start:window_end]

            # this is where the magic happens
            # define the function to extract the frequency from the windowed signal
            window_freq.append(freq_from_autocorr(window_s, fs))
        y_hat_freq.append(np.mean(window_freq))

    # downsample to the same length as the ground truth
    # Convert the frequencies to midi notes
    y_hat = librosa.hz_to_midi(y_hat_freq)

    # print('y_hat', y_hat)
    # print('length of estimated pitch values', len(y_hat))

    # compare pitches with actual labels, y
    squared_error = (y-y_hat)**2
    absolute_error = abs(y-y_hat)
    mse = np.nanmean(squared_error)
    mae = np.nanmean(absolute_error)
    # print('MSE', mse)
    
    # create a version of the frequency distribution with no nans
    y_hat_freq_no_nan = [value for value in y_hat_freq if not math.isnan(value)]
    
    return audio_signal, midi_data, y, y_hat, y_hat_freq, y_hat_freq_no_nan, squared_error, mse, mae

In [None]:
# 1. Load the data in to training samples X and labels y
# 2. Split the data into training, testing and validation
# 3. Test the naive model (autocorrelation) on the test set, this is the baseline to beat
# 3a. Evaluation metric is pitch value MSE compared to groundtruth
# 4. Train a new model, compare with naive model
# Data: Roger Jang's MIR-QBSH corpus which is comprised of 8kHz
# 4431 queries along with 48 ground-truth MIDI files. 
# All queries are from the beginning of references. 
# Manually labeled pitch for each recording is available. 
# hand labels are more important to match than ground truth
# multiclass classification example

# build the filenames incrementally
wavroot = 'MIR-QBSH-corpus/waveFile/'
midiroot = 'MIR-QBSH-corpus/midiFile/'

# build a list of all the subjects
subjects = []
for dirpath, dirnames, filenames in os.walk(wavroot):
    if not dirnames:
        subjects.append(str(dirpath) + '/')

# build a dictionary to collect the errors
errors = {}

# build an array to visualize frequency distribtion over all samples
y_hat_freq_no_nan_all = []

# for each subject
for subject in tqdm(subjects):
    # get a list of files this subject has recorded
    files = glob.glob(subject + "*.wav")

    # analyse the given audio signal file
    for f in tqdm(files):
        errors[f] = {}
        # split the given file into filename head and tail 
        head, tail = os.path.split(f)
        # midi_filename gives midi_file reference for each nested sample wav/pv pair
        midi_filename = tail.split('.')[0]
        audio_signal, midi_data, y, y_hat, y_hat_freq, y_hat_freq_no_nan, squared_error, mse, mae = analyse_pair(head, midi_filename, 
                                        num_windows=10, fs=8000, frame_size=256)
        
        # extend the frequency list out
        y_hat_freq_no_nan_all.extend(y_hat_freq_no_nan)

        # Store error metrics in errors dictionary
        errors[f]['mse'] = mse
        errors[f]['mae'] = mae
        errors[f]['filename'] = '/'.join(f.split('/')[-3:])
        # print(f, mse)

  0%|          | 0/195 [00:00<?, ?it/s]
  0%|          | 0/21 [00:00<?, ?it/s][A
  5%|▍         | 1/21 [00:00<00:15,  1.29it/s][A
 10%|▉         | 2/21 [00:01<00:14,  1.31it/s][A
 14%|█▍        | 3/21 [00:02<00:13,  1.32it/s][A
 19%|█▉        | 4/21 [00:03<00:15,  1.09it/s][A
 24%|██▍       | 5/21 [00:04<00:16,  1.00s/it][A
 29%|██▊       | 6/21 [00:05<00:15,  1.04s/it][A
 33%|███▎      | 7/21 [00:07<00:15,  1.11s/it][A
 38%|███▊      | 8/21 [00:08<00:15,  1.21s/it][A
 43%|████▎     | 9/21 [00:09<00:14,  1.24s/it][A
 48%|████▊     | 10/21 [00:11<00:13,  1.23s/it][A
 52%|█████▏    | 11/21 [00:12<00:12,  1.22s/it][A
 57%|█████▋    | 12/21 [00:13<00:10,  1.20s/it][A
 62%|██████▏   | 13/21 [00:14<00:09,  1.18s/it][A
 67%|██████▋   | 14/21 [00:15<00:08,  1.17s/it][A
 71%|███████▏  | 15/21 [00:16<00:06,  1.14s/it][A
 76%|███████▌  | 16/21 [00:17<00:05,  1.03s/it][A
 81%|████████  | 17/21 [00:18<00:03,  1.05it/s][A
 86%|████████▌ | 18/21 [00:19<00:02,  1.12it/s][A
 90%|████

In [None]:
# get some summary statistics from Pandas
df = pd.DataFrame.from_dict(errors, orient='index')
df.describe()
# a pitch difference of 1Hz is perceptable to the human ear
# http://hyperphysics.phy-astr.gsu.edu/hbase/Sound/earsens.html
# thus, want to get the mean absolute error below 1Hz

# store the processing of the current experiment
current_milli_time = int(round(time.time() * 1000))
path = 'analysis/' + str(current_milli_time) + '.feather'
feather.write_dataframe(df, path)

In [None]:
# Visualize the errors between the estimated and ground truth 
# plot the pitches extracted against the ground truth labels
# evenly sampled time at 200ms intervals
t = np.arange(0., 250., 1)
f, (ax1, ax2) = plt.subplots(2, figsize=(20,20))

# plot actual and estimates
ax1.tick_params(axis='both', which='major')
ax1.plot(t, y, 'r--')
ax1.plot(t, y_hat, 'bs')
ax1.set_title('Actual and Estimated Values')
ax1.set(xlabel='Frame', ylabel='MIDI Pitch') 
ax1.legend(["actual", "estimated"], loc=2, bbox_to_anchor=(1.05, 1), frameon=True)

# Plot square errors
ax2.plot(squared_error)
ax2.set_title('Squared Error')
ax2.set(xlabel='Frame', ylabel='Squared Error (log scale)') 
ax2.set_yscale('log')

# tighten up and show
plt.tight_layout()
plt.show()

# Visualize the estimated frequencies, and the frequency histogram
f, (ax1, ax2) = plt.subplots(2, figsize=(20,20))

# plot estimated frequencies
ax1.tick_params(axis='both', which='major')
ax1.plot(t, y_hat_freq, 'r--')
ax1.set_title('Estimated frequency')
ax1.set(xlabel='Frame', ylabel='Frequency (Hz)') 
ax1.legend(["estimated"], loc=2, bbox_to_anchor=(1.05, 1), frameon=True)

# plot frequency histogram
ax2.set_title('Frequency distribution')
sns.distplot(y_hat_freq_no_nan, ax=ax2)

# tighten up and show
plt.tight_layout()
plt.show()

In [None]:
# test reading in the saved error dataframe
df = feather.read_dataframe(path)

# Visualize the full experimental set statistics
f, (ax1, ax2) = plt.subplots(2, figsize=(20,20))

# plot estimated frequencies
ax1.tick_params(axis='both', which='major')
df['mae'].plot(ax=ax1)
ax1.set_title('Absolute error for each sample')
ax1.set(xlabel='Sample', ylabel='Absolute Error (Hz)') 

# plot frequency histogram
ax2.set_title('Frequency distribution')
sns.distplot(y_hat_freq_no_nan_all, ax=ax2)

# tighten up and show
plt.tight_layout()
plt.show()