In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import csv

import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

from pitch_tracker.utils import files, dataset
from pitch_tracker.utils.constants import SAMPLE_RATE, PATCH_SIZE, RANDOM_STATE
from pitch_tracker.ml.measures import melody_evaluate
import medleydb


In [3]:
import librosa
import mir_eval
import numpy as np

from pitch_tracker.utils.constants import MIDI_START, MIDI_END, F_MIN
F_MAX = librosa.midi_to_hz(MIDI_END)
def rand_one_hot_encode(n, m):
    """
    Generate a one-hot encoded matrix with n samples and m classes using PyTorch.
    :param n: int, number of samples
    :param m: int, number of classes
    :return: torch tensor of shape (n, m)
    """
    one_hot = torch.zeros(n, m)
    for i in range(n):
        one_hot[i][torch.randint(0, m, (1,))] = 1
    return one_hot

def midi_to_hz(midi_value:float):
    return 440.0 * (2.0 ** ((torch.as_tensor(midi_value) - 69.0) / 12.0))


def class_to_frequency(class_inputs:torch.Tensor, midi_start=MIDI_START, n_classes=89):
    # n_classes also includes non-melody pitches
    pre_midi_start = midi_start - 1
    voiced_mask = class_inputs != 0

    midi_values = class_inputs + pre_midi_start
    output_frequencies = midi_to_hz(midi_values)
    output_frequencies = output_frequencies * voiced_mask

    return output_frequencies

In [4]:
y_true = rand_one_hot_encode(210, 89).unsqueeze(0)
y_pred = torch.softmax(torch.rand(y_true.size()),dim=1)

In [5]:
y_pred

tensor([[[0.0034, 0.0034, 0.0066,  ..., 0.0066, 0.0032, 0.0055],
         [0.0044, 0.0063, 0.0037,  ..., 0.0062, 0.0064, 0.0049],
         [0.0028, 0.0041, 0.0048,  ..., 0.0040, 0.0071, 0.0042],
         ...,
         [0.0029, 0.0029, 0.0038,  ..., 0.0054, 0.0043, 0.0041],
         [0.0041, 0.0045, 0.0066,  ..., 0.0035, 0.0054, 0.0048],
         [0.0042, 0.0038, 0.0059,  ..., 0.0051, 0.0074, 0.0055]]])

In [6]:
melody_evaluate(y_true, y_pred).keys()

odict_keys(['Voicing Recall', 'Voicing False Alarm', 'Raw Pitch Accuracy', 'Raw Chroma Accuracy', 'Overall Accuracy'])

In [7]:
xssl = torch.argmax(y_true, dim=-1)
yssl = torch.argmax(y_pred, dim=-1)
display(xssl)
display(yssl)

tensor([[80, 11, 21, 61, 40, 60, 75, 20, 62, 72, 32, 30, 49,  4, 24,  4, 88, 43,
         61, 64, 63, 27, 79, 22,  2, 15, 17,  7, 32, 46, 65, 62, 36,  5, 64, 45,
         80, 88, 59, 11,  8,  8, 32, 32, 80,  2, 50, 58, 74, 27, 64, 13, 72, 12,
         35, 52, 55, 34, 46, 40, 10, 22, 46, 27,  3, 39, 43, 59,  8, 84, 66, 82,
          9, 64, 71,  9, 53, 19, 85, 83, 59, 55, 21, 72, 73, 67, 27, 44, 79, 85,
         85, 28,  0, 69, 25, 23,  8, 73, 61, 36, 84, 35, 46, 80, 38, 39, 66, 59,
         69, 63,  9, 68, 82,  7, 82, 74, 51, 78,  4, 80, 84, 87,  6,  2, 41, 73,
         87, 60, 80, 80, 21, 25, 59, 42, 86, 23, 18, 58, 55, 53, 40, 28, 12, 39,
         25, 40, 59,  0, 69, 67, 57, 23, 33, 62, 21, 29, 56, 62, 87, 10, 38, 87,
         18,  3, 48, 32, 22, 55,  2, 82, 20, 36,  7,  1, 29, 36, 76,  3, 85, 55,
          9, 19, 54, 44,  3, 66, 84, 37, 16, 58, 75, 26, 41, 30, 61,  8, 32, 62,
         74, 34, 36, 36, 62, 38,  7, 37, 48, 53, 55,  3]])

tensor([[79, 35, 66, 15, 57, 73, 81, 69, 82, 75, 64, 84, 86, 13, 72, 72, 10, 35,
         71, 78, 80, 24, 35, 38, 14, 17, 86, 60, 16, 27, 22, 46, 48, 85, 36, 18,
         83,  4, 57, 73, 54, 14, 70, 27, 67, 51, 40, 83, 14, 13, 14, 66, 33, 35,
         81, 12, 53, 84,  5, 39,  3, 65, 27, 56, 80, 64, 81, 60, 73,  7, 71, 41,
         77,  6, 71, 38, 35, 81, 29,  4, 82, 52, 17, 82, 11, 63, 44, 60, 72, 12,
         80,  1, 41, 29, 48, 65, 82, 79, 75, 38, 75, 73, 73, 81, 48, 83, 75, 73,
         86, 81, 22,  7, 59, 29, 13, 66, 81, 27,  3,  4,  3, 29, 88, 78, 78, 44,
         63, 54, 16, 71, 70, 64, 74, 46, 85, 57, 31, 83, 13, 82, 17, 64,  2, 40,
         79, 33, 39, 69, 53, 28, 21, 79, 35, 79, 63, 82,  9, 10, 75, 40, 73, 63,
          1, 54, 80, 13, 65,  7, 35, 48, 40, 46, 38, 81,  7, 35, 54, 54, 83,  0,
          5, 14,  2, 42, 60, 64, 16, 61, 67, 56,  6, 36, 84,  2, 46, 59, 56, 77,
         57, 17, 44, 73, 54, 38, 83, 75, 33, 22, 77, 66]])

In [8]:
yssl.flatten()

tensor([79, 35, 66, 15, 57, 73, 81, 69, 82, 75, 64, 84, 86, 13, 72, 72, 10, 35,
        71, 78, 80, 24, 35, 38, 14, 17, 86, 60, 16, 27, 22, 46, 48, 85, 36, 18,
        83,  4, 57, 73, 54, 14, 70, 27, 67, 51, 40, 83, 14, 13, 14, 66, 33, 35,
        81, 12, 53, 84,  5, 39,  3, 65, 27, 56, 80, 64, 81, 60, 73,  7, 71, 41,
        77,  6, 71, 38, 35, 81, 29,  4, 82, 52, 17, 82, 11, 63, 44, 60, 72, 12,
        80,  1, 41, 29, 48, 65, 82, 79, 75, 38, 75, 73, 73, 81, 48, 83, 75, 73,
        86, 81, 22,  7, 59, 29, 13, 66, 81, 27,  3,  4,  3, 29, 88, 78, 78, 44,
        63, 54, 16, 71, 70, 64, 74, 46, 85, 57, 31, 83, 13, 82, 17, 64,  2, 40,
        79, 33, 39, 69, 53, 28, 21, 79, 35, 79, 63, 82,  9, 10, 75, 40, 73, 63,
         1, 54, 80, 13, 65,  7, 35, 48, 40, 46, 38, 81,  7, 35, 54, 54, 83,  0,
         5, 14,  2, 42, 60, 64, 16, 61, 67, 56,  6, 36, 84,  2, 46, 59, 56, 77,
        57, 17, 44, 73, 54, 38, 83, 75, 33, 22, 77, 66])

In [9]:
y_true_frequencies = class_to_frequency(xssl)
y_pred_frequencies = class_to_frequency(yssl)

In [10]:
y_true_frequencies

tensor([[2637.0203,   48.9994,   87.3071,  880.0000,  261.6255,  830.6094,
         1975.5334,   82.4069,  932.3276, 1661.2188,  164.8138,  146.8324,
          440.0000,   32.7032,  103.8262,   32.7032, 4186.0088,  311.1270,
          880.0000, 1046.5022,  987.7666,  123.4708, 2489.0159,   92.4986,
           29.1352,   61.7354,   69.2957,   38.8909,  164.8138,  369.9944,
         1108.7306,  932.3276,  207.6523,   34.6478, 1046.5022,  349.2282,
         2637.0203, 4186.0088,  783.9908,   48.9994,   41.2034,   41.2034,
          164.8138,  164.8138, 2637.0203,   29.1352,  466.1638,  739.9888,
         1864.6549,  123.4708, 1046.5022,   55.0000, 1661.2188,   51.9131,
          195.9977,  523.2511,  622.2540,  184.9972,  369.9944,  261.6255,
           46.2493,   92.4986,  369.9944,  123.4708,   30.8677,  246.9417,
          311.1270,  783.9908,   41.2034, 3322.4377, 1174.6591, 2959.9553,
           43.6535, 1046.5022, 1567.9818,   43.6535,  554.3653,   77.7817,
         3520.0000, 3135.

In [11]:
import librosa

In [12]:
librosa.note_to_hz('C10')

16744.036179238312

In [13]:
y = [
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,0],
    [0,0,1,0],
    [0,0,0,1],
    [0,0,0,1],
    ]
y = torch.Tensor(y)
y.shape

torch.Size([6, 4])

In [14]:
# x = rand_one_hot_encode(10, 4).unsqueeze(0)

ys = y.unsqueeze(0)
xs = torch.softmax(torch.rand(ys.size()),dim=1)

xss = torch.cat([xs]*2,0)
yss = torch.cat([ys]*2,0)

xssl = torch.argmax(xss, dim=-1).flatten()
yssl = torch.argmax(yss, dim=-1).flatten()

est_freq = class_to_frequency(xssl)
ref_freq = class_to_frequency(yssl)
time_1d = torch.arange(0, ref_freq.numel(),1)


In [15]:
xss.shape

torch.Size([2, 6, 4])

In [16]:
scores = mir_eval.melody.evaluate(
        ref_time=time_1d.numpy(),
        ref_freq=ref_freq.numpy(),
        est_time=time_1d.numpy(),
        est_freq=est_freq.numpy(),
        cent_tolerance=51)
scores

OrderedDict([('Voicing Recall', 0.8),
             ('Voicing False Alarm', 0.0),
             ('Raw Pitch Accuracy', 0.4),
             ('Raw Chroma Accuracy', 0.4),
             ('Overall Accuracy', 0.5)])

In [17]:
scores = melody_evaluate(yssl,xssl)
scores

OrderedDict([('Voicing Recall', 0.8),
             ('Voicing False Alarm', 0.0),
             ('Raw Pitch Accuracy', 0.4),
             ('Raw Chroma Accuracy', 0.4),
             ('Overall Accuracy', 0.5)])

In [18]:
xssl, yssl

(tensor([0, 0, 2, 3, 1, 3, 0, 0, 2, 3, 1, 3]),
 tensor([0, 1, 2, 2, 3, 3, 0, 1, 2, 2, 3, 3]))

In [19]:
len(yssl.shape)

1

In [20]:
ref_freq, est_freq

(tensor([ 0.0000, 27.5000, 29.1352, 29.1352, 30.8677, 30.8677,  0.0000, 27.5000,
         29.1352, 29.1352, 30.8677, 30.8677]),
 tensor([ 0.0000,  0.0000, 29.1352, 30.8677, 27.5000, 30.8677,  0.0000,  0.0000,
         29.1352, 30.8677, 27.5000, 30.8677]))

In [21]:
xs.shape, ys.shape, xss.shape, yss.shape

(torch.Size([1, 6, 4]),
 torch.Size([1, 6, 4]),
 torch.Size([2, 6, 4]),
 torch.Size([2, 6, 4]))

In [22]:
import torch.nn.functional as F
def my_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss

criterion = torch.nn.CrossEntropyLoss()

loss_reference = criterion(xs, ys)
loss = my_cross_entropy(xs, ys)

# print(loss_reference - loss)


RuntimeError: gather(): Expected dtype int64 for index

In [134]:
xss.shape, yss.shape

(torch.Size([2, 6, 4]), torch.Size([2, 6, 4]))

In [99]:
y

tensor([0, 5, 8, 4, 9])

In [83]:
sample_weight

tensor([0.3762, 0.8364, 0.8474, 0.0816, 0.1550, 0.7574, 0.7805, 0.3203, 0.3081,
        0.1468])

In [86]:
criterion.reduction

'none'