In [12]:
import utils as utils
import importlib
import numpy as np
import data
import os
import pickle
import json
import mir_eval
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from scipy.signal import find_peaks

importlib.reload(utils)
importlib.reload(data)

train_dataset_path = './data/train'
test_dataset_path = './data/test'

device = 'cuda' if th.cuda.is_available() else 'cpu'

In [2]:
# Model descrbied in the paper plus droput
class OnsetDetectionCNN(nn.Module):
    def __init__(self):
        super(OnsetDetectionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 1))
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 1))
        self.fc1 = nn.Linear(20 * 7 * 8, 256)
        self.fc2 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 20 * 7 * 8)
        x = self.dropout(F.relu(self.fc1(x)))  # Apply dropout
        x = self.sigmoid(self.fc2(x))
        return x

# Initialize the model
model = OnsetDetectionCNN()

In [3]:
model.load_state_dict(th.load('best_model.pth'))
with open('mean_std.pkl', 'rb') as file:
    data = pickle.load(file)
mean = data['mean']
std = data['std']

# This is the almost the same as the prediction function in the onset detection but here we just use the onset signal 
# and no prediction
def raw_onset_signal(model, x, mean=mean, std=std, frame_size=15):
    model = model.to(device)
    model.eval()
    x = x.to(device)
    mean = mean.to(device)
    std = std.to(device)
    x = (x - mean) / std

    half_frame_size = frame_size // 2
    num_frames = x.shape[2]
    onset_predictions = []

    with th.no_grad():
        for j in range(half_frame_size, num_frames - half_frame_size):
            start = j - half_frame_size
            end = j + half_frame_size + 1
            input_frame = x[:, :, start:end].unsqueeze(0).float()
            output = model(input_frame).squeeze().cpu().item()
            onset_predictions.append(output)
    onset_predictions = np.array(onset_predictions)
    onset_signal = np.convolve(onset_predictions, np.hamming(10), mode='same')
    return onset_signal

def autocorrelate(signal, lag):
    r = np.zeros(len(signal) - lag)
    for t in range(len(signal) - lag):
        r[t] = signal[t + lag] * signal[t]
    return np.sum(r)

def to_bpm(max_r):
    return 60 * utils.SAMPLING_RATE / utils.HOP_LENGTH / (max_r + 25)

def autocorrelate_tao(signal, min_tao=25, max_tao=87):
    return np.array([autocorrelate(signal, tao) for tao in range(min_tao, max_tao)])

def get_tempo(model, x, top_n=2):
    onset_signal_res = raw_onset_signal(model, x)
    taos = autocorrelate_tao(onset_signal_res)
    peaks = find_peaks(taos)[0]
    highest_peaks = np.argsort(-taos[peaks])[:top_n]

    return list(reversed([to_bpm(r) for r in peaks[highest_peaks]]))

In [4]:
# Load the paths and then split them into train and test set (validation set in our case for now).
wav_files_paths_train, beat_files_paths_train, _, _ = utils.load_dataset_paths(train_dataset_path, is_train_dataset=True)

In [5]:
class BeatAgent:
    def __init__(self, id, start_time, tempo_hypothesis, initial_tempo, inner_window, outer_window, parent_agent=None):
        self.id = str(id)
        self.start_time = start_time
        self.initial_tempo = initial_tempo
        self.tempo_hypothesis = tempo_hypothesis
        self.beat_interval = 60.0 / tempo_hypothesis
        self.next_prediction = start_time + self.beat_interval
        self.accepted_events = [start_time]
        self.inner_window = inner_window
        self.outer_window = outer_window
        self.moved = False
        self.score = 0
    
        # this means it is a sub agebt that is created from an outer_window prediction
        if parent_agent is not None:
            self.score = parent_agent.score
            self.moved = True
            self.next_prediction = parent_agent.next_prediction + self.beat_interval
            self.accepted_events = parent_agent.accepted_events.copy()
            self.accepted_events.append(parent_agent.next_prediction)
    
    def process_event(self, event):
        event_diff = abs(event - self.next_prediction)
       
        if event == self.next_prediction: # the event predicted
            self.accepted_events.append(event)
            self.score += self.outer_window
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
        elif event_diff <= self.inner_window: # the event is inside the inner window
            self.accepted_events.append(event)
            self.__update_tempo_hypothesis(event)
            self.score += (self.outer_window - event_diff)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
        elif event_diff <= self.outer_window: # the event is inside the outer window
            new_agent = BeatAgent(self.id + "a", self.start_time, self. tempo_hypothesis, 
                                  self.initial_tempo, self.inner_window, self.outer_window, self)

            self.accepted_events.append(event)
            self.__update_tempo_hypothesis(event)
            self.score += (self.outer_window - event_diff)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return new_agent # return a new agent
        elif event > self.next_prediction + self.outer_window:
            # interpolated beat, grants no score
            self.accepted_events.append(event)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
            
    def __update_tempo_hypothesis(self, event):
        self.tempo_hypothesis = self.tempo_hypothesis + (event - self.next_prediction)
        self.beat_interval = 60.0 / self.tempo_hypothesis
    
    def __str__(self):
        return "Start time: " + str(self.start_time) + ", Tempo:" + str(self.initial_tempo)

def prune_similiar_agents(agents): # 
    agents_tempo_map = {}
    agents_to_keep = []

    for agent in agents:
        th = round(agent.tempo_hypothesis, 3)
        if not agent.moved:
            agents_to_keep.append(agent)
        elif th not in agents_tempo_map:
            agents_tempo_map[th] = agent
        elif agent.score > agents_tempo_map[th].score:
                agents_tempo_map[th] = agent

    for th in agents_tempo_map:
        agents_to_keep.append(agents_tempo_map[th])
    
    return agents_to_keep

def multiple_agent_beat_tracking(onsets, tempo_estimations):
    id = 0
    agents = []
    for tempo in tempo_estimations:
        for idx in range(0, min(10, len(onsets))):
            agents.append(BeatAgent(id, onsets[idx], tempo, tempo, 0.05, 0.1))
            id += 1
    
    for event_time in onsets:
        next_agents = []
        for agent in agents:
            if agent.start_time < event_time:
                new_agent = agent.process_event(event_time)
                if new_agent is not None:
                    next_agents.append(new_agent)
        
        
        
        agents += next_agents
        agents = prune_similiar_agents(agents)

    best_agent: BeatAgent = max(agents, key=lambda agent: agent.score)

    return np.array(best_agent.accepted_events)

# wav_file = "data/train/Media-105810(5.0-15.0).wav"
# # wav_file = "data/test/test48.wav"
# ft_, sr_ = utils.preprocess_audio([wav_file])
# ft_ = ft_[0]
# sr_ = sr_[0]
# onset_ = raw_onset_signal(model, ft_, mean, std)

# import scipy
# peaks, _ = scipy.signal.find_peaks(onset_, height=np.max(onset_) * 0.5)

# onset_times = librosa.frames_to_time(peaks, sr=utils.SAMPLING_RATE, hop_length=utils.HOP_LENGTH)

###### used the ground truth onsets in order to not blame outside factors
# f=open("data/train/Media-105810(5.0-15.0).onsets.gt", "r")
# f=open("data/train/train20.onsets.gt", "r")
# lines=f.readlines()
# true_onsets=[]
# for x in lines:
#     true_onsets.append(float(x.split('\t')[0]))
# f.close()
# ######

# tempo_ = get_tempo(model, ft_)
# beats_ = multiple_agent_beat_tracking(true_onsets, tempo_)
# beats_

In [6]:
features_train, sample_rates_train = utils.preprocess_audio(wav_files_paths_train)

100%|██████████| 127/127 [00:04<00:00, 26.10it/s]


In [7]:
beatsT = {}
for _, (ft, wav_file) in enumerate(zip(features_train, wav_files_paths_train)):
    # print(wav_file)
    filename = wav_file.split('/')[-1].replace('.wav', '')
    onset_file_name = wav_file.replace(".wav", ".onsets.gt")

    true_onsets = []

    with open(onset_file_name) as topo_file:
        for line in topo_file:
            true_onsets.append(float(line))

    temp = get_tempo(model, ft)
    beatsT[filename] = multiple_agent_beat_tracking(true_onsets, temp)

In [8]:
pred = {}

for filename in wav_files_paths_train:
    filename = filename.split('/')[-1].replace('.wav', '')
    pred[filename] = {'beats': beatsT[filename].tolist()}
pred

{'Media-105810(5.0-15.0)': {'beats': [0.208979591,
   0.544217687,
   0.891065759,
   1.21324263,
   1.573151927,
   1.922448979,
   2.265306122,
   2.763265306,
   3.004081632,
   3.302040816,
   3.663265306,
   4.006122448,
   4.334875283,
   4.680272108,
   5.002040816,
   5.371065759,
   5.702040816,
   6.069387755,
   6.402902494,
   6.640816326,
   6.905034013,
   7.203061224,
   7.462244897,
   7.804081632,
   8.112471655,
   8.480612244,
   8.822131519,
   9.159401278913123,
   9.519387755,
   9.865306122]},
 'ff123_DaFunk': {'beats': [0.11,
   0.63,
   1.17,
   1.716825396,
   2.25,
   2.79,
   3.33,
   3.88,
   4.407437641,
   4.948752834,
   5.49,
   6.04,
   6.57,
   7.11,
   7.65,
   8.189387755,
   8.724897959,
   9.263310657,
   9.8,
   10.351746031,
   10.89,
   11.424217687,
   11.966984126,
   12.51,
   13.040907029,
   13.583673469,
   14.126439909,
   14.666303854,
   15.194557823,
   15.74,
   16.274285714,
   16.822857142,
   17.354013605,
   17.895328798,
   18.4

In [10]:
target = {}

for filename in wav_files_paths_train:
    wav_f = filename.split('/')[-1].replace('.wav', '')
    beat_filename = filename.replace(".wav", ".beats.gt")
    f=open(beat_filename, "r")
    lines=f.readlines()
    result=[]
    for x in lines:
        result.append(float(x.split('\t')[0]))
    f.close()
    target[wav_f] = {'beats': result}

In [13]:
def evaluate_loop(submission, target):
    sum_f = 0.
    for target_key, target_value in target.items():
        if target_key in submission:
            reference = target_value['beats']
            estimated = submission[target_key]['beats']
            f = mir_eval.beat.f_measure(
                np.array(reference),
                np.array(estimated),
                f_measure_threshold=0.07  # 70 [ms]
            )
        else:
            f = 0.

        sum_f += f
    return sum_f / len(target)


print(evaluate_loop(pred, target))

0.48642307012113245
