In [2]:
import utils as utils
import importlib
import numpy as np
import data
import pickle

import torch as th
import torch.nn as nn
import torch.nn.functional as F

from scipy.signal import find_peaks

importlib.reload(utils)
importlib.reload(data)

train_dataset_path = './data/onset/train'
test_dataset_path = './data/onset/test'

In [10]:
# Model descrbied in the paper plus droput
class OnsetDetectionCNN(nn.Module):
    def __init__(self):
        super(OnsetDetectionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 1))
        self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 1))
        self.fc1 = nn.Linear(20 * 7 * 8, 256)
        self.fc2 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 20 * 7 * 8)
        x = self.dropout(F.relu(self.fc1(x)))  # Apply dropout
        x = self.sigmoid(self.fc2(x))
        return x

# Initialize the model
model = OnsetDetectionCNN()

In [6]:
# Load the paths and then split them into train and test set (validation set in our case for now).
wav_files_paths_test, _, _, _ = utils.load_dataset_paths(test_dataset_path, is_train_dataset=False)
features_test, sample_rates_test = utils.preprocess_audio(wav_files_paths_test)

100%|██████████| 50/50 [00:01<00:00, 34.75it/s]


In [7]:
device = 'cuda' if th.cuda.is_available() else 'cpu'
print(device)

cpu


In [19]:
def manual_evaluate_test(model, feature, threshold, frame_size=15, mean=None, std=None):

    if mean is None or std is None:
        raise ValueError("Mean and std must be provided for normalization.")

    mean = mean.to(device)
    std = std.to(device)

    model = model.to(device)
    model.eval()

    half_frame_size = frame_size // 2

    with th.no_grad():
        predictions = []

        # Prepare features
        num_frames = feature.shape[2]
        f = feature.to(device)
        feature = (f - mean) / std  # Normalize the feature

        # Loop through on the frames
        for j in range(half_frame_size, num_frames - half_frame_size):
            start = j - half_frame_size
            end = j + half_frame_size + 1

            input_frame = feature[:, :, start:end].unsqueeze(0).float()  # Add batch dimension
            output = model(input_frame).squeeze().item()
            predictions.append(output)

        # Smoothing the predictions the 10 hamming window is coming from trial and error.
        predictions = np.convolve(predictions, np.hamming(10))

        res = []
        for idx in find_peaks(predictions)[0]:
            if predictions[idx] >= threshold:
                res.append(idx * utils.HOP_LENGTH / utils.SAMPLING_RATE)
    return res

In [20]:
model.load_state_dict(th.load('./best_model.pth'))
model.eval()

OnsetDetectionCNN(
  (conv1): Conv2d(3, 10, kernel_size=(3, 7), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(10, 20, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1120, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (sigmoid): Sigmoid()
)

In [21]:
# Load mean and std from file
with open('./mean_std.pkl', 'rb') as f:
    data = pickle.load(f)
    mean = data['mean']
    std = data['std']

In [25]:
pred_dict = {}
for filename, feature in zip(wav_files_paths_test, features_test):
    onset_predictios = manual_evaluate_test(model, feature, threshold=0.95, mean=mean, std=std)
    f = filename.split('/')[-1].replace('.wav', '')
    pred_dict[f] = {}
    pred_dict[f]['onsets'] = onset_predictios

In [26]:
pred_dict

{'test48': {'onsets': [0.603718820861678,
   1.0100680272108844,
   1.195827664399093,
   1.822766439909297,
   2.391655328798186,
   2.995374149659864,
   3.5874829931972787,
   4.20281179138322,
   4.794920634920635,
   5.3986394557823125,
   5.990748299319728,
   6.594467120181406,
   7.186575963718821,
   7.581315192743764,
   7.650975056689342,
   7.7090249433106575,
   7.790294784580499,
   8.289523809523809,
   8.486893424036282,
   9.775600907029478,
   10.402539682539683,
   10.994648526077098,
   11.598367346938776,
   12.202086167800454,
   12.805804988662132,
   13.374693877551021,
   13.978412698412699,
   14.570521541950113,
   15.069750566893424,
   15.17424036281179,
   15.777959183673469,
   16.39328798185941,
   16.973786848072564,
   17.368526077097506,
   17.46140589569161,
   17.589115646258502,
   18.07673469387755,
   18.285714285714285,
   19.06358276643991,
   19.957551020408165,
   20.59609977324263,
   21.199818594104308,
   21.571337868480725,
   21.66421768