In [7]:
import utils as utils
import importlib
import numpy as np
import data
import os
import pickle
import json
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from scipy.signal import find_peaks

importlib.reload(utils)
importlib.reload(data)

train_dataset_path = './data/train'
test_dataset_path = './data/test'

device = 'cuda' if th.cuda.is_available() else 'cpu'

In [8]:
# Model descrbied in the paper plus droput
class OnsetDetectionCNN(nn.Module):
    def __init__(self):
        super(OnsetDetectionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 1))
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 1))
        self.fc1 = nn.Linear(20 * 7 * 8, 256)
        self.fc2 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 20 * 7 * 8)
        x = self.dropout(F.relu(self.fc1(x)))  # Apply dropout
        x = self.sigmoid(self.fc2(x))
        return x

# Initialize the model
model = OnsetDetectionCNN()

In [9]:
model.load_state_dict(th.load('best_model.pth'))
with open('mean_std.pkl', 'rb') as file:
    data = pickle.load(file)
mean = data['mean']
std = data['std']

# This is the almost the same as the prediction function in the onset detection but here we just use the onset signal 
# and no prediction
def raw_onset_signal(model, x, mean=mean, std=std, frame_size=15):
    model = model.to(device)
    model.eval()
    x = x.to(device)
    mean = mean.to(device)
    std = std.to(device)
    x = (x - mean) / std

    half_frame_size = frame_size // 2
    num_frames = x.shape[2]
    onset_predictions = []

    with th.no_grad():
        for j in range(half_frame_size, num_frames - half_frame_size):
            start = j - half_frame_size
            end = j + half_frame_size + 1
            input_frame = x[:, :, start:end].unsqueeze(0).float()
            output = model(input_frame).squeeze().cpu().item()
            onset_predictions.append(output)
    onset_predictions = np.array(onset_predictions)
    onset_signal = np.convolve(onset_predictions, np.hamming(10), mode='same')
    return onset_signal

def autocorrelate(signal, lag):
    r = np.zeros(len(signal) - lag)
    for t in range(len(signal) - lag):
        r[t] = signal[t + lag] * signal[t]
    return np.sum(r)

def to_bpm(max_r):
    return 60 * utils.SAMPLING_RATE / utils.HOP_LENGTH / (max_r + 25)

def autocorrelate_tao(signal, min_tao=25, max_tao=87):
    return np.array([autocorrelate(signal, tao) for tao in range(min_tao, max_tao)])

def get_tempo(model, x, top_n=2):
    onset_signal_res = raw_onset_signal(model, x)
    taos = autocorrelate_tao(onset_signal_res)
    peaks = find_peaks(taos)[0]
    highest_peaks = np.argsort(-taos[peaks])[:top_n]

    return list(reversed([to_bpm(r) for r in peaks[highest_peaks]]))

In [10]:
# Load the paths and then split them into train and test set (validation set in our case for now).
wav_files_paths_train, beat_files_paths_train, _, _ = utils.load_dataset_paths(train_dataset_path, is_train_dataset=True)

In [11]:
class BeatAgent:
    def __init__(self, id, start_time, tempo_hypothesis, initial_tempo, inner_window, outer_window, parent_agent=None):
        self.id = str(id)
        self.start_time = start_time
        self.initial_tempo = initial_tempo
        self.tempo_hypothesis = tempo_hypothesis
        self.beat_interval = 60.0 / tempo_hypothesis
        self.next_prediction = start_time + self.beat_interval
        self.accepted_events = [start_time]
        self.inner_window = inner_window
        self.outer_window = outer_window
        self.moved = False
        self.score = 0
    
        # this means it is a sub agebt that is created from an outer_window prediction
        if parent_agent is not None:
            self.score = parent_agent.score
            self.moved = True
            self.next_prediction = parent_agent.next_prediction + self.beat_interval
            self.accepted_events = parent_agent.accepted_events.copy()
            self.accepted_events.append(parent_agent.next_prediction)
    
    def process_event(self, event):
        event_diff = abs(event - self.next_prediction)
       
        if event == self.next_prediction: # the event predicted
            self.accepted_events.append(event)
            self.score += self.outer_window
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
        elif event_diff <= self.inner_window: # the event is inside the inner window
            self.accepted_events.append(event)
            self.__update_tempo_hypothesis(event)
            self.score += (self.outer_window - event_diff)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
        elif event_diff <= self.outer_window: # the event is inside the outer window
            new_agent = BeatAgent(self.id + "a", self.start_time, self. tempo_hypothesis, 
                                  self.initial_tempo, self.inner_window, self.outer_window, self)

            self.accepted_events.append(event)
            self.__update_tempo_hypothesis(event)
            self.score += (self.outer_window - event_diff)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return new_agent # return a new agent
        elif event > self.next_prediction + self.outer_window:
            # interpolated beat, grants no score
            self.accepted_events.append(event)
            self.next_prediction = event + self.beat_interval
            self.moved = True
            return None
            
    def __update_tempo_hypothesis(self, event):
        self.tempo_hypothesis = self.tempo_hypothesis + (event - self.next_prediction)
        self.beat_interval = 60.0 / self.tempo_hypothesis
    
    def __str__(self):
        return "Start time: " + str(self.start_time) + ", Tempo:" + str(self.initial_tempo)

def prune_similiar_agents(agents): # 
    agents_tempo_map = {}
    agents_to_keep = []

    for agent in agents:
        th = round(agent.tempo_hypothesis, 3)
        if not agent.moved:
            agents_to_keep.append(agent)
        elif th not in agents_tempo_map:
            agents_tempo_map[th] = agent
        elif agent.score > agents_tempo_map[th].score:
                agents_tempo_map[th] = agent

    for th in agents_tempo_map:
        agents_to_keep.append(agents_tempo_map[th])
    
    return agents_to_keep

def multiple_agent_beat_tracking(onsets, tempo_estimations):
    id = 0
    agents = []
    for tempo in tempo_estimations:
        for idx in range(0, min(10, len(onsets))):
            agents.append(BeatAgent(id, onsets[idx], tempo, tempo, 0.05, 0.1))
            id += 1
    
    for event_time in onsets:
        next_agents = []
        for agent in agents:
            if agent.start_time < event_time:
                new_agent = agent.process_event(event_time)
                if new_agent is not None:
                    next_agents.append(new_agent)
        
        
        
        agents += next_agents
        agents = prune_similiar_agents(agents)

    best_agent: BeatAgent = max(agents, key=lambda agent: agent.score)

    return np.array(best_agent.accepted_events)

# wav_file = "data/train/Media-105810(5.0-15.0).wav"
# # wav_file = "data/test/test48.wav"
# ft_, sr_ = utils.preprocess_audio([wav_file])
# ft_ = ft_[0]
# sr_ = sr_[0]
# onset_ = raw_onset_signal(model, ft_, mean, std)

# import scipy
# peaks, _ = scipy.signal.find_peaks(onset_, height=np.max(onset_) * 0.5)

# onset_times = librosa.frames_to_time(peaks, sr=utils.SAMPLING_RATE, hop_length=utils.HOP_LENGTH)

###### used the ground truth onsets in order to not blame outside factors
# f=open("data/train/Media-105810(5.0-15.0).onsets.gt", "r")
# f=open("data/train/train20.onsets.gt", "r")
# lines=f.readlines()
# true_onsets=[]
# for x in lines:
#     true_onsets.append(float(x.split('\t')[0]))
# f.close()
# ######

# tempo_ = get_tempo(model, ft_)
# beats_ = multiple_agent_beat_tracking(true_onsets, tempo_)
# beats_

In [13]:
wav_files_paths_test = []
for filename in os.listdir(test_dataset_path):
    f = os.path.join(test_dataset_path, filename)
    wav_files_paths_test.append(f)
features_test, sample_rates_test = utils.preprocess_audio(wav_files_paths_test)

100%|██████████| 50/50 [00:01<00:00, 33.61it/s]


In [14]:
beatsT = {}
tempos = {}
for _, (ft, wav_file) in enumerate(zip(features_test, wav_files_paths_test)):
    # print(wav_file)
    filename = wav_file.split('/')[-1].replace('.wav', '')

    f=open("onsets_data.json", "r")
    j = json.load(f)
    true_onsets = j[filename]["onsets"]
    f.close()

    temp = get_tempo(model, ft)
    beatsT[filename] = multiple_agent_beat_tracking(true_onsets, temp)

In [15]:
pred = {}

for filename in wav_files_paths_test:
    filename = filename.split('/')[-1].replace('.wav', '')
    pred[filename] = {'beats': beatsT[filename].tolist()}
pred

{'test48': {'beats': [0.603718820861678,
   1.195827664399093,
   1.822766439909297,
   2.391655328798186,
   2.995374149659864,
   3.5874829931972787,
   4.20281179138322,
   4.794920634920635,
   5.3986394557823125,
   5.990748299319728,
   6.594467120181406,
   7.186575963718821,
   7.7090249433106575,
   8.289523809523809,
   9.775600907029478,
   10.402539682539683,
   10.994648526077098,
   11.598367346938776,
   12.202086167800454,
   12.805804988662132,
   13.374693877551021,
   13.978412698412699,
   14.570521541950113,
   15.14929284766366,
   15.777959183673469,
   16.39328798185941,
   16.973786848072564,
   17.46140589569161,
   18.07673469387755,
   19.06358276643991,
   19.957551020408165,
   20.59609977324263,
   21.199818594104308,
   21.803537414965987,
   22.3956462585034,
   23.010975056689343,
   23.56825396825397,
   24.18358276643991,
   24.775691609977326,
   25.402630385487527,
   25.994739229024944,
   26.610068027210886,
   27.213786848072562,
   27.790563537

In [11]:
import json
file_path = 'beats.json'

# Open the file in write mode and save the dictionary
with open(file_path, 'w') as f:
    json.dump(pred, f, indent=4)