In [1]:
import os
import glob

# Found file corresponding to this W&B run with: `grep "3pzwny4n" outputs/*/args.json`
model_folder = '../outputs/crepe-20211129-122548'

model_paths = glob.glob(os.path.join(model_folder, '*'))

In [2]:
import sys
sys.path.append('/home/jxm3/research/transcription/contrastive-pitch-detection')

In [3]:
from models.bytedance import Bytedance_Regress_pedal_Notes
from models.contrastive import ContrastiveModel

min_midi = 21
max_midi = 108
def get_model():
    num_output_nodes = 256 # contrastive embedding dim
    out_activation = None
    
    model = Bytedance_Regress_pedal_Notes(
        num_output_nodes, out_activation, tiny=False
    )
    
    return ContrastiveModel(model, min_midi, max_midi, num_output_nodes)

In [4]:
import glob
import natsort
import os
import torch

model_paths = glob.glob(os.path.join(model_folder, '*'))

model_path = natsort.natsorted(model_paths)[-2]
print('loaded model from:', model_path)

model = get_model()
model.load_state_dict(torch.load(model_path)['model'])

loaded model from: ../outputs/crepe-20211129-122548/84_epochs.pth


<All keys matched successfully>

In [5]:
from generator import AudioDataGenerator
g = AudioDataGenerator(
        [], 16000, float('inf'),
        randomize_train_frame_offsets=True,
        batch_size=256,
        augmenter=None,
        normalize_audio=False,
        label_format='categorical',
        min_midi=21, max_midi=108,
        sample_rate=16000,
        batch_by_track=False,
        num_fake_nsynth_chords=1000,
    )

x, y = g[0]

Replacing 0 tracks with 1000 fake NSynth chords
--> MusicDataLoader loading dataset nsynth_keyboard_train


Resampling tracks: 100%|██████████| 51821/51821 [00:00<00:00, 1190447.13it/s]


TrackFrameSampler loaded 4000 frames


# Inference (val set) - calculating p(chord|audio)

Beam search!

In [15]:
cos_sim = torch.nn.CosineSimilarity(1)
def find_best_chord(audio, eps=0.0):
    audio_encoding = model(audio[None])
    best_labels = torch.zeros((88))
    zero_label_encoding = model.encode_note_labels(best_labels[None])
    best_overall_sim = cos_sim(audio_encoding.squeeze(), zero_label_encoding).item()
    for _ in range(6):
        new_labels = best_labels.repeat((88,1))
        new_notes = torch.eye(88)
        new_labels = torch.maximum(new_notes, new_labels) # 88 tensors, each one has a new 1 at a different position
        label_encodings = model.encode_note_labels(new_labels)
        cos_sims = cos_sim(audio_encoding, label_encodings)
        best_idx = cos_sims.argmax()
        best_sim = cos_sims[best_idx].item()
        
        if best_sim - best_overall_sim > eps:
            #print('choosing note', note)
            best_overall_sim = best_sim
            best_labels = new_labels[best_idx]
        else:
            #print(f'breaking after {_} steps')
            break
        # print('**'*40)
    return best_sim, best_labels

In [16]:
import numpy as np
import tqdm

y_pred_sim = []
y_true_sim = []
y_rand_sim = []
y_rand_sim_2 = []

y_pred = []
#
for audio, y_true in tqdm.tqdm(zip(x, y), desc='Finding best chord', total=len(x)):
    sim, label = find_best_chord(audio)
    y_pred_sim.append(sim)
    y_pred.append(label)
    #
    y_true_enc = model.encode_note_labels(y_true[None])
    audio_enc = model(audio[None]).squeeze()
    y_true_sim.append(cos_sim(y_true_enc, audio_enc).item())
    # a chord with each note a random 0 or 1
    rand_labels = torch.rand((1,88)).round()
    rand_enc = model.encode_note_labels(rand_labels)
    y_rand_sim.append(cos_sim(rand_enc, audio_enc).item())
    # a chord with 1-6 random notes
    n = np.random.choice([1,2,3,4,5,6], p=[0.5, 0.25, 0.125, 0.0625, 0.03125, 0.03125])
    rand_idxs = np.random.choice(88, size=n)
    rand_labels_2 = np.zeros(88)
    rand_labels_2[rand_idxs] = 1
    rand_labels_2 = torch.tensor(rand_labels_2[None], dtype=torch.float32)
    rand_enc_2 = model.encode_note_labels(rand_labels_2)
    y_rand_sim_2.append(cos_sim(rand_enc_2, audio_enc).item())
y_pred_sim = torch.tensor(np.array(y_pred_sim))
y_pred = torch.stack(y_pred)
y_true_sim = torch.tensor(np.array(y_true_sim))
y_rand_sim = torch.tensor(np.array(y_rand_sim))
y_rand_sim_2 = torch.tensor(np.array(y_rand_sim_2))

Finding best chord: 100%|██████████| 256/256 [00:46<00:00,  5.48it/s]


In [8]:
y_true_note_vals = [str((label.nonzero() + min_midi).flatten().tolist()) for label in y]
y_pred_note_vals = [str((label.nonzero() + min_midi).flatten().tolist()) for label in y_pred]

import pandas as pd
pred_df = pd.DataFrame({
    'y_pred': y_pred_note_vals,
    'y_pred_sim': y_pred_sim,
    'y_true': y_true_note_vals,
    'y_true_sim': y_true_sim,
    'y_rand_sim': y_rand_sim, # a chord with each note a random 0 or 1
    'y_rand_sim_2': y_rand_sim_2
})
pred_df.head(n=20)

Unnamed: 0,y_pred,y_pred_sim,y_true,y_true_sim,y_rand_sim,y_rand_sim_2
0,"[34, 44]",0.873851,"[24, 40, 43, 61, 91]",0.873187,0.928209,0.663487
1,[],0.854334,[87],0.789047,0.767654,0.822788
2,"[38, 59]",0.888909,"[55, 62, 64, 69, 82, 95]",0.892222,0.914088,0.711038
3,"[50, 65]",0.897734,[59],0.77269,0.878599,0.736955
4,"[33, 63]",0.90342,"[57, 75]",0.78332,0.915009,0.707385
5,"[45, 59]",0.886491,[41],0.774952,0.890849,0.784111
6,[68],0.895762,[68],0.830476,0.861077,0.759806
7,"[41, 46]",0.889462,"[34, 59, 98]",0.871097,0.899202,0.829218
8,"[40, 57]",0.896978,"[70, 84, 104]",0.780489,0.85971,0.737549
9,"[33, 47]",0.908943,"[53, 90]",0.772499,0.881247,0.753615


In [17]:
from metrics import (
    categorical_accuracy, pitch_number_acc, NStringChordAccuracy,
    precision, recall, f1
)

f1(y_pred, y)

tensor(0.1353)

In [18]:
recall(y_pred, y)

tensor(0.2690)

In [19]:
precision(y_pred, y)

tensor(0.0904)