In [53]:
import torch
from torch import nn
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook
from copy import deepcopy

In [5]:
from model import AudioCNN

In [125]:
checkpoint_path = 'model_checkpoints/edm_genre_epoch_22_acc_0.9480_vs_0.7909_loss_5.4965_vs_6.3733.pth'
checkpoint_path = 'model_checkpoints/edm_genre_epoch_61_acc_0.9919_vs_0.7955_loss_3.2539_vs_4.3752.pth'
checkpoint_path = 'model_checkpoints/edm_genre_epoch_59_acc_0.8197_vs_0.7273_loss_1.8854_vs_2.2722.pth'
checkpoint_path = 'model_checkpoints/edm_genre_epoch_79_acc_0.9379_vs_0.7045_loss_1.4581_vs_2.2103.pth'
checkpoint_path = 'model_checkpoints/edm_genre_epoch_16_acc_0.9232_vs_0.8455_loss_6.3483_vs_6.8122.pth'

checkpoint = torch.load(checkpoint_path)
model = AudioCNN()
model.load_state_dict(checkpoint)
model.eval();

embedder = deepcopy(model)
embedder.model.fc = nn.Sequential()
embedder.eval();

In [126]:
idxs_path = 'valid_index.json'
with open(idxs_path, 'r') as file:
    valid_idxs = file.read().split('\n')[:-1]
    valid_idxs = [int(x) for x in valid_idxs]
    
class_map_path = 'class_map.json'
with open(class_map_path, 'r') as file:
    class_map = file.read().split('\n')[:-1]

In [127]:
import torchvision.transforms as T

data = np.load('/root/konst/melspects_general.npz', mmap_mode='r', allow_pickle=True)

transform = T.Compose([
    #T.ToTensor(),
    T.Lambda(lambda x: torch.FloatTensor(x / 255.0)),
])

labels = []
preds = []
filenames = []
embeds = []
for idx in tqdm_notebook(valid_idxs):
    img = data['x'][idx]
    label = data['y'][idx]
    label = class_map.index(label)
    img = transform(img)
    img = img.unsqueeze(0)
    output = model(img)
    _, pred = torch.max(output, 1)
    labels.append(label)
    preds.append(int(pred))
    filenames.append(data['filenames'][idx])
    
    embed = embedder(img)
    embeds.append(embed)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


  0%|          | 0/220 [00:00<?, ?it/s]

In [128]:
from sklearn.metrics import confusion_matrix, accuracy_score
print(class_map)
print('Acc: ', accuracy_score(labels, preds))
confusion_matrix(labels, preds)

['House', 'Trap', 'Trance']
Acc:  0.8454545454545455


array([[135,   9,   7],
       [  9,  27,   1],
       [  7,   1,  24]])

In [129]:
df = pd.DataFrame(zip(labels, preds, filenames), columns=['label', 'pred', 'filename'])
df['label'] = df.label.apply(lambda x: class_map[x])
df['pred'] = df.pred.apply(lambda x: class_map[x])

In [130]:
(3 * 60 + 42) * 0.7

155.39999999999998

In [176]:
df[(df.label == 'House') & (df.pred == 'Trap')]

Unnamed: 0,label,pred,filename,delta
39,House,Trap,Bodybangers Feat. Victoria Ker - - Stars In M...,128.947647
68,House,Trap,Lutez_Smash (Original Mix) [DutchHM] [2017] [v...,91.678131
70,House,Trap,Gama_KillKid_Tribu_Otiginal_Mix_(vmuzice.com).mp3,154.827148
74,House,Trap,Zendex - Diablo (Original Mix).mp3,102.844803
77,House,Trap,PERCY_-HARD_PEACE__KOE_2016___https___vk.com_k...,122.446426
81,House,Trap,Zendex - Asylum (Club Mix).mp3,211.05455
92,House,Trap,Tiesto & MOTi-Blow Your Mind.mp3,146.606934
152,House,Trap,R3hab David Solano - Do It (Life In Color Anth...,84.545845
185,House,Trap,Pegboard Nerds - Melodymania VIP.mp3,123.7062


In [None]:
# ArcFace

In [134]:
embedder(torch.randn(1, 2, 10, 10)).shape

torch.Size([1, 1024])

In [121]:
from pytorch_metric_learning.losses import ArcFaceLoss

loss_func = ArcFaceLoss(num_classes=3, embedding_size=1024, margin=28.6, scale=64)
loss_func(torch.randn(2, 100), torch.LongTensor([0, 1]))
loss_func

ArcFaceLoss(
  (distance): CosineSimilarity()
  (reducer): MeanReducer()
  (cross_entropy): CrossEntropyLoss()
)

In [184]:
from scipy.spatial.distance import cosine
filename = 'Skrillex & Diplo & Snails - Holla Out (feat  Taranchyla).mp3'
#filename = 'Twoloud  -  Big Bang (Original Mix).mp3'
#filename = 'Hardwell-Encoded (Radio Edit).mp3'
filename = 'Sunny_Lax_-_So_Long (www.mp3zv.me).mp3'
#filename = 'Greg House & Golden Fingers - Weekend Bangers (Radio Edit).mp3'

def rel_metric(arr1, arr2):
    arr1 = list(arr1[0])
    arr2 = list(arr2[0])
    return sum([((a1 - a2) / max(a1, a2))**2 if max(a1, a2) > 0 else 0 for a1, a2 in zip(arr1, arr2)])

embed = embeds[filenames.index(filename)]
deltas = [cosine(embed.detach().numpy(), e.detach().numpy()) for e in embeds]
#deltas = [float((embed - e).abs().sum()) for e in embeds]
#deltas = [rel_metric(embed.detach().numpy(), e.detach().numpy()) for e in embeds]

df['delta'] = deltas
df.sort_values('delta')[:30]

Unnamed: 0,label,pred,filename,delta
207,Trance,Trance,Sunny_Lax_-_So_Long (www.mp3zv.me).mp3,0.0
203,Trance,Trance,Sunny_Lax_-_Adapt_Or_Die (www.mp3zv.me).mp3,0.004402
196,Trance,Trance,audien_-_hindsight_original_mix_(zaycev.net).mp3,0.004579
211,Trance,Trance,Cosmic_Gate_Super8_Tab_-_Noom_Album_Mix.mp3,0.004782
216,Trance,Trance,Farius_-_Echo_Chamber.mp3,0.005563
205,Trance,Trance,cressida_-_6_am_kyau_and_albert_remix_(zaycev....,0.006088
188,Trance,Trance,dns_project_feat._madelin_zero_-_another_day_o...,0.007417
190,Trance,Trance,Mysterious_Movement_-_On_The_Edge_Original_Mix...,0.008173
210,Trance,Trance,Sunny_Lax_-_Bad_Bye_Extended_Mix (www.mp3zv.me...,0.008216
195,Trance,Trance,ben-gold-omnia-the-conquest.mp3,0.008644
