In [129]:
import utils as utils
import importlib
import numpy as np
import mir_eval
import data
import pickle
import librosa

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from scipy.signal import find_peaks

importlib.reload(utils)
importlib.reload(data)

train_dataset_path = './data/train'
test_dataset_path = './data/test'

device = 'cuda' if th.cuda.is_available() else 'cpu'

In [130]:
# Model descrbied in the paper plus droput
class OnsetDetectionCNN(nn.Module):
    def __init__(self):
        super(OnsetDetectionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 1))
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 1))
        self.fc1 = nn.Linear(20 * 7 * 8, 256)
        self.fc2 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 20 * 7 * 8)
        x = self.dropout(F.relu(self.fc1(x)))  # Apply dropout
        x = self.sigmoid(self.fc2(x))
        return x

# Initialize the model
model = OnsetDetectionCNN()

In [138]:
model.load_state_dict(th.load('best_model.pth'))
with open('mean_std.pkl', 'rb') as file:
    data = pickle.load(file)
mean = data['mean']
std = data['std']

# This is the almost the same as the prediction function in the onset detection but here we just use the onset signal 
# and no prediction
def raw_onset_signal(model, x, mean=mean, std=std, frame_size=15):
    model = model.to(device)
    model.eval()
    x = x.to(device)
    mean = mean.to(device)
    std = std.to(device)
    x = (x - mean) / std

    half_frame_size = frame_size // 2
    num_frames = x.shape[2]
    onset_predictions = []

    with th.no_grad():
        for j in range(half_frame_size, num_frames - half_frame_size):
            start = j - half_frame_size
            end = j + half_frame_size + 1
            input_frame = x[:, :, start:end].unsqueeze(0).float()
            output = model(input_frame).squeeze().cpu().item()
            onset_predictions.append(output)
    onset_predictions = np.array(onset_predictions)
    onset_signal = np.convolve(onset_predictions, np.hamming(10), mode='same')
    return onset_signal

def autocorrelate(signal, lag):
    r = np.zeros(len(signal) - lag)
    for t in range(len(signal) - lag):
        r[t] = signal[t + lag] * signal[t]
    return np.sum(r)

def to_bpm(max_r):
    return 60 * utils.SAMPLING_RATE / utils.HOP_LENGTH / (max_r + 25)

def autocorrelate_tao(signal, min_tao=25, max_tao=87):
    return np.array([autocorrelate(signal, tao) for tao in range(min_tao, max_tao)])

def get_tempo(model, x, top_n=2):
    onset_signal_res = raw_onset_signal(model, x)
    taos = autocorrelate_tao(onset_signal_res)
    peaks = find_peaks(taos)[0]
    highest_peaks = np.argsort(-taos[peaks])[:top_n]

    return list(reversed([to_bpm(r) for r in peaks[highest_peaks]]))

In [132]:
def estimate_beats(onset_signal, tempo_bpm):
    tempo_period = 60 / tempo_bpm * utils.SAMPLING_RATE / utils.HOP_LENGTH  # in frames
    peak_indices = find_peaks(onset_signal, distance=tempo_period//2)[0]

    # Find closest peaks to estimated tempo
    beat_positions = []
    current_position = peak_indices[0]
    for peak in peak_indices:
        if peak >= current_position:
            beat_positions.append(peak)
            current_position = peak + tempo_period

    return np.array(beat_positions) * utils.HOP_LENGTH / utils.SAMPLING_RATE  # Convert to seconds


In [77]:
def custom_beat_tracking(spec, sample_rate, tempo_estimations):    
    num_frames = spec.shape[-1]
    duration = num_frames * utils.HOP_LENGTH / sample_rate

    onset_signal = raw_onset_signal(model, spec)

    # Calculate the average tempo if multiple tempo estimations are provided
    if isinstance(tempo_estimations, (list, np.ndarray)):
        tempo = np.mean(tempo_estimations)
    else:
        tempo = tempo_estimations

    taos = autocorrelate_tao(onset_signal)
    peaks = find_peaks(taos)[0]

    # Convert peak indices to time
    peak_times = librosa.frames_to_time(peaks, sr=sample_rate, hop_length=utils.HOP_LENGTH)
        
    # Initialize beat times list
    beat_times = []

    beat_interval = 60.0 / tempo

    # Align detected peaks with expected beat intervals
    for peak_time in peak_times:
        current_beat_time = peak_time
        while current_beat_time < duration:
            beat_times.append(current_beat_time)
            current_beat_time += beat_interval
    
    return np.array(beat_times)


In [134]:
# Load the paths and then split them into train and test set (validation set in our case for now).
wav_files_paths_train, _, _, tempo_files_paths_train = utils.load_dataset_paths(train_dataset_path, is_train_dataset=True)

In [135]:
# Prepare train data
features_train, sample_rates_train = utils.preprocess_audio(wav_files_paths_train)

# tempo_train = utils.load_tempo_annotations_from_files(y_train_paths)

100%|██████████| 127/127 [00:04<00:00, 29.19it/s]


In [76]:
# len(beat0)

648

In [136]:
onsets = [raw_onset_signal(model, x, mean, std) for x in features_train]

In [None]:
# Prepare test data (validation data in our case for now)
# features_test, sample_rates_test = utils.preprocess_audio(X_test_paths)
# tempo_test = utils.load_tempo_annotations_from_files(y_train_paths)

In [139]:
tempos = [get_tempo(model, x) for x in features_train]

In [159]:
wav_files_paths_test, _, _, _ = utils.load_dataset_paths(test_dataset_path, is_train_dataset=False)
features_test, sample_rates_test = utils.preprocess_audio(wav_files_paths_test)
onsets_test = [raw_onset_signal(model, x, mean, std) for x in features_test]
tempos_tests = [get_tempo(model, x) for x in features_test]
# beats = [estimate_beats(o, max(t, default=0), sr) for _, (o, t, sr) in enumerate(zip(onsets, tempos, sample_rates_train))]
beats = []
for _, (o, t) in enumerate(zip(onsets_test, tempos_tests)):
    if len(t) != 0:
        beat = estimate_beats(o, t[-1])
        beats.append(beat)


100%|██████████| 50/50 [00:01<00:00, 35.13it/s]


In [160]:
pred = {}
# wav_files_paths_train.remove("./data/train/ff123_bloodline.wav")
for beat, filename in zip(beats, wav_files_paths_test):
    filename = filename.split('/')[-1].replace('.wav', '')
    pred[filename] = {'beats': beat.tolist()}


In [161]:
import json
file_path = 'beats.json'

# Open the file in write mode and save the dictionary
with open(file_path, 'w') as f:
    json.dump(pred, f, indent=4)

In [190]:
def evaluate_tempo(predictions, targets):
    sum_p_score = 0.
    for target_key, target_value in targets.items():
        if target_key in predictions:
            annotations = target_value['tempo']
            if len(annotations) == 1:
                tempo = annotations[0]
                reference_tempi = np.array([tempo / 2., tempo], dtype=float)
                reference_weight = 0
            elif len(annotations) == 3:
                reference_tempi = np.array(annotations[0:2], dtype=float)
                reference_weight = float(annotations[2])
            else:
                raise RuntimeError(f'tempo annotations are weird "{annotations}"')

            # Ignore whatever comes after the first two estimated values
            estimations = predictions[target_key]['tempo'][0:2]
            if len(estimations) == 2:
                # All fine
                estimated_tempi = np.array(estimations, dtype=float)
            elif len(estimations) == 1:
                # If there's only one estimated tempo, prepend its half
                tempo = estimations[0]
                estimated_tempi = np.array([tempo / 2., tempo], dtype=float)
            else:
                raise RuntimeError(f'tempo estimations are weird "{estimations}"')

            p_score, _, _ = mir_eval.tempo.detection(
                reference_tempi,
                reference_weight,
                estimated_tempi,
                tol=0.08
            )
        else:
            p_score = 0.

        sum_p_score += p_score

    return sum_p_score / len(targets)