In [None]:
import numpy as np
import aubio
import matplotlib.pyplot as plt
import pandas as pd
import glob

from IPython import display
from pathlib import PurePath

In [None]:
DATASET_PREFIX = '../data/beatboxset1/'

class Track:
    """
    Abstracts out the audio track. Only mono tracks are supported currently.
    """
    def __init__(self, source):
        if type(source) == str:
            source = aubio.source(source)
        if source.channels > 1:
            raise Exception('File {} has {} channels instead of 1'.format(source.uri, source.channels))
        
        self.n_samples = source.duration
        self.rate = source.samplerate
        self.duration = self.n_samples / self.rate
        self.hop_size = source.hop_size
        
        self.wave = aubio.fvec(self.n_samples)
        total_read = 0
        for sample in source:
            m = sample.shape[0]
            self.wave[total_read:total_read+m] = sample
            total_read += m
            
    def segment(self, start, duration):
        return self.wave[int(start*self.rate):int((start+duration)*self.rate)]
    
def read_annotation(path):
    return pd.read_csv(path, names=['time', 'class'])
    
def load_track_with_onsets(path):
    track = Track(DATASET_PREFIX + path + '.wav')
    onsets_DR = read_annotation(DATASET_PREFIX + 'Annotations_DR/' + path + '.csv')
    onsets_HT = read_annotation(DATASET_PREFIX + 'Annotations_HT/' + path + '.csv')
    return track, onsets_DR, onsets_HT
    
def detect_onsets(track, method):
    onset_detector = aubio.onset(method=method)
    hs = track.hop_size
    N = track.n_samples // hs
    onsets_sec = []
    onsets = []
    for i in range(N):
        chunk = track.wave[i*hs:(i+1)*hs]
        if onset_detector(chunk):
            onsets_sec.append(onset_detector.get_last_s())
            onsets.append(onset_detector.get_last())

    classes = ['x']*len(onsets_sec)
    onsets_detected = pd.DataFrame.from_dict({
        'time': np.array(onsets_sec),
        'frame': np.array(onsets),
        'class': classes
    })
    return onsets_detected

def onsets_F1_score(pred, target, ms_threshold=10, prec_rec=False):
    thr = ms_threshold / 1000
    i = 0
    j = 0
    tp = 0
    fp = 0
    fn = 0
    while i < len(pred) or j < len(target):
        if i >= len(pred):
            fn += 1
            j += 1
        elif j >= len(target):
            fp += 1
            i += 1
        elif pred[i] >= target[j] - thr and pred[i] <= target[j] + thr:
            tp += 1
            i += 1
            j += 1
        elif pred[i] < target[j] - thr:
            fp += 1
            i += 1
        else:
            fn += 1
            j += 1
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    
    if prec_rec:
        return f1, precision, recall
    else:
        return f1

plt.rcParams.update({'font.size': 16})    

def plot_track(track, onsets=None, event_type=None, title=None, return_events=False):
    fig = plt.figure(figsize=(20, 5))
    plt.plot(np.linspace(0, track.duration, track.n_samples), track.wave)
    
    events = []
    if onsets is not None:
        for (idx, row) in onsets.iterrows():
            if event_type is None or row['class'] == event_type:
                events.append(row['time'])
                plt.axvline(x=row['time'], color='r')
                
    plt.ylim((-1.5, 1.5))
    plt.xlabel('Time, seconds')
    plt.ylabel('Signal amplitude')
    if title is not None:
        plt.title(title)
    plt.show()
    
    if return_events:
        return events

In [None]:
track_sh, onsets_DR, onsets_HT = load_track_with_onsets('snare_hex')

In [None]:
kick_events = plot_track(track_sh, onsets=onsets_DR, event_type='k', title='Ground truth (variant #1)', return_events=True)

In [None]:
segm = track_sh.segment(kick_events[2], 2)

In [None]:
fig = plt.figure(figsize=(5,2))
plt.plot(np.arange(len(segm)), segm)

In [None]:
_ = plt.specgram(segm, Fs=2)

In [None]:
fig = plt.figure(figsize=(17,4))
_ = plt.specgram(track_sh.wave, Fs=2)

In [None]:
plot_track(track_sh, onsets=onsets_HT, title='Ground truth (variant #2)')

In [None]:
onsets_detected_hfc = detect_onsets(track_sh, method='hfc')
plot_track(track_sh, onsets=onsets_detected_hfc, title='Detected onsets (method=HFC)')

In [None]:
onsets_detected_cp = detect_onsets(track_sh, method='complex')
plot_track(track_sh, onsets=onsets_detected_cp, title='Detected onsets (method=Complex)')

In [None]:
onsets_F1_score(onsets_DR['time'].values, onsets_detected_hfc['time'], ms_threshold=10)

In [None]:
bbs_files = [PurePath(path).stem for path in glob.glob(DATASET_PREFIX + '*.wav')]
bbs_files.remove('putfile_dbztenkaichi')
bbs_files.remove('callout_Pneumatic')
bbs_files.remove('putfile_vonny')
bbs_files.remove('putfile_pepouni')
bbs_files

In [None]:
scores = pd.DataFrame(columns=[
    'track',
    'DR_HFC_F1', 'DR_HFC_prec', 'DR_HFC_rec',
    'DR_Complex_F1', 'DR_Complex_prec', 'DR_Complex_rec',
    'HT_HFC_F1', 'HT_HFC_prec', 'HT_HFC_rec',
    'HT_Complex_F1', 'HT_Complex_prec', 'HT_Complex_rec'
])
for (i, trackname) in enumerate(bbs_files):
    track, onsets_DR, onsets_HT = load_track_with_onsets(trackname)
    onsets_pred_hfc = detect_onsets(track, method='hfc')
    onsets_pred_cp = detect_onsets(track, method='complex')
    
    f1_dr_hfc, prec_dr_hfc, rec_dr_hfc = \
        onsets_F1_score(onsets_pred_hfc['time'].values, onsets_DR['time'].values, prec_rec=True)
    f1_dr_cp, prec_dr_cp, rec_dr_cp = \
        onsets_F1_score(onsets_pred_cp['time'].values, onsets_DR['time'].values, prec_rec=True)
    f1_ht_hfc, prec_ht_hfc, rec_hr_hfc = \
        onsets_F1_score(onsets_pred_hfc['time'].values, onsets_HT['time'].values, prec_rec=True)
    f1_ht_cp, prec_ht_cp, rec_ht_cp = \
        onsets_F1_score(onsets_pred_cp['time'].values, onsets_HT['time'].values, prec_rec=True)
    scores.loc[i] = [
        trackname,
        f1_dr_hfc, prec_dr_hfc, rec_dr_hfc,
        f1_dr_cp, prec_dr_cp, rec_dr_cp,
        f1_ht_hfc, prec_ht_hfc, rec_hr_hfc,
        f1_ht_cp, prec_ht_cp, rec_ht_cp
    ]

In [None]:
scores

In [None]:
scores[['track', 'DR_HFC_F1', 'DR_Complex_F1', 'HT_HFC_F1', 'HT_Complex_F1']].to_clipboard()

In [None]:
scores.mean()

In [None]:
scores.std()

In [None]:
labels = ['HFC/DR', 'Complex/DR', 'HFC/HT', 'Complex/HT']
mean_scores = scores.mean()
precisions = mean_scores[['DR_HFC_prec', 'DR_Complex_prec', 'HT_HFC_prec', 'HT_Complex_prec']]
recalls = mean_scores[['DR_HFC_rec', 'DR_Complex_rec', 'HT_HFC_rec', 'HT_Complex_rec']]

x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
rects1 = ax.bar(x - width/2, precisions, width, label='Precision')
rects2 = ax.bar(x + width/2, recalls, width, label='Recall')

ax.set_ylabel('Scores')
ax.set_title('Precision and recall on beatboxset1')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()

fig.tight_layout()
plt.show()

In [None]:
avp_files_all = [PurePath(path) for path in glob.glob('../data/AVP_Dataset/*/*/*.wav')]
len(avp_files_all)

In [None]:
avp_scores = pd.DataFrame(columns=[
    'track',
    'HFC_F1', 'HFC_prec', 'HFC_rec',
    'Complex_F1', 'Complex_prec', 'Complex_rec'
])

for i, filepath in enumerate(avp_files_all):
    track = Track(str(filepath))
    annotation = read_annotation(str(filepath.with_suffix('.csv')))
    onsets_pred_hfc = detect_onsets(track, method='hfc')
    onsets_pred_cp = detect_onsets(track, method='complex')
    
    f1_hfc, prec_hfc, rec_hfc = \
        onsets_F1_score(onsets_pred_hfc['time'].values, annotation['time'].values, prec_rec=True)
    
    f1_cp, prec_cp, rec_cp = \
        onsets_F1_score(onsets_pred_cp['time'].values, annotation['time'].values, prec_rec=True)
    
    avp_scores.loc[i] = [
        filepath.stem,
        f1_hfc, prec_hfc, rec_hfc,
        f1_cp, prec_cp, rec_cp
    ]