In [1]:
import os
import glob

# Found file corresponding to this W&B run with: `grep "3pzwny4n" outputs/*/args.json`
model_folder = '../outputs/bytedance-20211130-181619/'

model_paths = glob.glob(os.path.join(model_folder, '*'))

In [2]:
import sys
sys.path.append('/home/jxm3/research/transcription/contrastive-pitch-detection')

In [3]:
from models.bytedance import Bytedance_Regress_pedal_Notes
from models.contrastive import ContrastiveModel

min_midi = 21
max_midi = 108
def get_model():
    num_output_nodes = 256 # contrastive embedding dim
    out_activation = None
    
    model = Bytedance_Regress_pedal_Notes(
        num_output_nodes, out_activation, tiny=False
    )
    
    return ContrastiveModel(model, min_midi, max_midi, num_output_nodes)

In [4]:
import glob
import natsort
import os
import torch

model_paths = glob.glob(os.path.join(model_folder, '*'))

model_path = natsort.natsorted(model_paths)[-2]
print('loaded model from:', model_path)

model = get_model()
model.load_state_dict(torch.load(model_path)['model'])

loaded model from: ../outputs/bytedance-20211130-181619/59_epochs.pth


<All keys matched successfully>

In [5]:
# from dataloader.nsynth import load_nsynth
# dataset = load_nsynth('test', 'keyboard')

from dataloader.nsynth_chords import load_nsynth_chords
dataset = load_nsynth_chords('test')

print('loaded', len(dataset), 'tracks')

import random
random.shuffle(dataset)

import numpy as np
from utils.misc import midi_vals_to_categorical, hz_to_midi_v

batch_size = 256

min_midi = 21
max_midi = 108

x = []
y = []
all_midis = []

for i in range(batch_size):
    track = dataset[i]
    start_idx = 0
    end_idx = 16_000
    #
    audio = torch.tensor(track.waveform[start_idx : end_idx], dtype=torch.float32)
    x.append(audio)
    #
    frequencies = track.get_frequencies_from_offset(start_idx, end_idx)
    midis = np.rint(hz_to_midi_v(frequencies))
    all_midis.append(list(midis))
    categorical = midi_vals_to_categorical(midis, min_midi, max_midi)
    y.append(torch.tensor(categorical, dtype=torch.float32))
x = torch.stack(x)
y = torch.stack(y)
print('loaded audio batch of shape:', x.shape, 'with labels', y.shape)

loaded 993 tracks
loaded audio batch of shape: torch.Size([256, 16000]) with labels torch.Size([256, 88])


# Inference (val set) - calculating p(chord|audio)

Beam search!

In [6]:
cos_sim = torch.nn.CosineSimilarity(1)
def find_best_chord(audio, eps=0.05):
    audio_encoding = model(audio[None])
    best_labels = torch.zeros((88))
    zero_label_encoding = model.encode_note_labels(best_labels[None])
    best_overall_sim = cos_sim(audio_encoding.squeeze(), zero_label_encoding).item()
    for _ in range(6):
        new_labels = best_labels.repeat((88,1))
        new_notes = torch.eye(88)
        new_labels = torch.maximum(new_notes, new_labels) # 88 tensors, each one has a new 1 at a different position
        label_encodings = model.encode_note_labels(new_labels)
        cos_sims = cos_sim(audio_encoding, label_encodings)
        best_idx = cos_sims.argmax()
        best_sim = cos_sims[best_idx].item()
        
        if best_sim - best_overall_sim > eps:
            #print('choosing note', note)
            best_overall_sim = best_sim
            best_labels = new_labels[best_idx]
        else:
            #print(f'breaking after {_} steps')
            break
        # print('**'*40)
    return best_sim, best_labels

In [7]:
import tqdm

y_pred_sim = []
y_true_sim = []
y_pred = []
#
for audio, y_true in tqdm.tqdm(zip(x, y), desc='Finding best chord', total=len(x)):
    sim, label = find_best_chord(audio)
    y_pred_sim.append(sim)
    y_pred.append(label)
    #
    y_true_enc = model.encode_note_labels(y_true[None])
    audio_enc = model(audio[None]).squeeze()
    y_true_sim.append(cos_sim(y_true_enc, audio_enc).item())
y_pred_sim = torch.tensor(np.array(y_pred_sim))
y_pred = torch.stack(y_pred)
y_true_sim = torch.tensor(np.array(y_true_sim))


Finding best chord: 100%|██████████| 256/256 [00:39<00:00,  6.48it/s]


In [10]:
from metrics import (
    categorical_accuracy, pitch_number_acc, NStringChordAccuracy,
    precision, recall, f1
)

categorical_accuracy(y_pred, y)

tensor(0.0195)

In [11]:
f1(y_pred, y)

tensor(0.1325)

In [13]:
precision(y_pred, y)

tensor(0.1116)

In [12]:
recall(y_pred, y)

tensor(0.1633)

In [9]:
y_true_note_vals = [str((label.nonzero() + min_midi).flatten().tolist()) for label in y]
y_pred_note_vals = [str((label.nonzero() + min_midi).flatten().tolist()) for label in y_pred]

import pandas as pd
pred_df = pd.DataFrame({
    'y_pred': y_pred_note_vals,
    'y_pred_sim': y_pred_sim,
    'y_true': y_true_note_vals,
    'y_true_sim': y_true_sim,
})
pred_df.head()

Unnamed: 0,y_pred,y_pred_sim,y_true,y_true_sim
0,[65],0.915799,[65],0.860138
1,"[25, 50]",0.923529,[62],0.723862
2,"[23, 46]",0.940088,[37],0.724195
3,"[63, 77]",0.938952,[63],0.803312
4,"[23, 58]",0.930739,[57],0.83598
