In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
from torch.autograd import Variable

from frontend import Frontend_mine
from backend import Backend
from data_loader import get_DataLoader
import datetime
import warnings
import librosa

from main import AssembleModel

# define here all the parameters
main_dict = {"frontend_dict":
             {"list_out_channels":[128,128,256,256,256,256], 
              "list_kernel_sizes":[(3,3),(3,3),(3,3),(3,3),(3,3),(3,3)],
              "list_pool_sizes":  [(3,2),(2,2),(2,2),(2,1),(2,1),(2,1)], 
              "list_avgpool_flags":[False,False,False,False,False,True]},
             
             "backend_dict":
             {"n_class":50,
              "bert_config":None, 
              "recurrent_units":2, 
              "bidirectional":True}, #  pass recurrent_units = None to deactivate
             
             "training_dict":
             {"dataset":'msd',
              "architecture":'without_seq2seq_5s',
              "n_epochs":1000,
              "learning_rate":1e-4},
             
             "data_loader_dict":
             {"path_to_repo":'~/dl4am/',
              "batch_size":128,
              "input_length":5, # [s]
              "spec_path":'/import/c4dm-datasets/rmri_self_att/msd',
              "audio_path":'/import/c4dm-03/Databases/songs/',
              "mode":'train', 
              "num_workers":20}}

In [2]:
def compute_melspectrogram(audio_fn):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        x, _ = librosa.core.load(audio_fn, sr=16000, res_type='kaiser_fast')
        spec = librosa.core.amplitude_to_db(librosa.feature.melspectrogram(x, 
                                                                           sr=16000, 
                                                                           n_fft=512, 
                                                                           hop_length=256, 
                                                                           n_mels=96))
    return spec

In [3]:
def load_parameters(model, filename): 
    model = torch.nn.DataParallel(model)
    S = torch.load(filename)
    model.load_state_dict(S)
    return model

def test(model, data_loader):
    model.eval()
    y_score = []
    y_true = []
    ctr = 0
    for x,y in data_loader:
        ctr+=1

        # NB: in validation mode the output of the DataLoader
        # has a shape of (1,n_chunks,F,T), where n_chunks = total time frames // input_length
        x = x.permute(1,0,2,3) 
        # by permuting it here we are treating n_chunks as the batch_size

        # forward
        out = model(x)
        out = out.detach().cpu()

        y_score.append(out.numpy().mean(axis=0))

        y_true.append(y.detach().numpy())

        if ctr % 1000 == 0:
            print("[%s] Valid Iter [%d/%d] " %
                  (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                   ctr, len(data_loader)))

    y_score = np.array(y_score).squeeze()
    y_true = np.array(y_true).squeeze().astype(int)

    roc_auc  = metrics.roc_auc_score(y_true, y_score, average='macro')
    pr_auc = metrics.average_precision_score(y_true, y_score, average='macro')
    print('roc_auc: %.4f' % roc_auc)
    print('pr_auc: %.4f' % pr_auc)

In [4]:
test_data = get_DataLoader(batch_size=1, input_length=5, mode='test', num_workers=10)
len(test_data)

28435

In [4]:
bidirectional_recurrent_self_att = load_parameters(AssembleModel(main_dict),"../models/best_model_bidirectional.pth")

In [5]:
main_dict["backend_dict"]["bidirectional"] = False
recurrent_self_att = load_parameters(AssembleModel(main_dict),"../models/best_model_sacrnn_5s.pth")

In [6]:
main_dict["backend_dict"]["recurrent_units"] = None
only_self_att = load_parameters(AssembleModel(main_dict),"../models/best_model_no_recurrent.pth")

In [8]:
test(bidirectional_recurrent_self_att,test_data)

[2021-03-14 11:14:43] Valid Iter [1000/28435] 
[2021-03-14 11:14:59] Valid Iter [2000/28435] 
[2021-03-14 11:15:15] Valid Iter [3000/28435] 
[2021-03-14 11:15:30] Valid Iter [4000/28435] 
[2021-03-14 11:15:46] Valid Iter [5000/28435] 
[2021-03-14 11:16:02] Valid Iter [6000/28435] 
[2021-03-14 11:16:17] Valid Iter [7000/28435] 
[2021-03-14 11:16:33] Valid Iter [8000/28435] 
[2021-03-14 11:16:48] Valid Iter [9000/28435] 
[2021-03-14 11:17:04] Valid Iter [10000/28435] 
[2021-03-14 11:17:20] Valid Iter [11000/28435] 
[2021-03-14 11:17:35] Valid Iter [12000/28435] 
[2021-03-14 11:17:51] Valid Iter [13000/28435] 
[2021-03-14 11:18:07] Valid Iter [14000/28435] 
[2021-03-14 11:18:22] Valid Iter [15000/28435] 
[2021-03-14 11:18:37] Valid Iter [16000/28435] 
[2021-03-14 11:18:52] Valid Iter [17000/28435] 
[2021-03-14 11:19:07] Valid Iter [18000/28435] 
[2021-03-14 11:19:22] Valid Iter [19000/28435] 
[2021-03-14 11:19:38] Valid Iter [20000/28435] 
[2021-03-14 11:19:53] Valid Iter [21000/28435] 
[

In [5]:
test(recurrent_self_att,test_data)

[2021-03-12 16:18:31] Valid Iter [1000/28435] 
[2021-03-12 16:18:45] Valid Iter [2000/28435] 
[2021-03-12 16:18:59] Valid Iter [3000/28435] 
[2021-03-12 16:19:13] Valid Iter [4000/28435] 
[2021-03-12 16:19:28] Valid Iter [5000/28435] 
[2021-03-12 16:19:42] Valid Iter [6000/28435] 
[2021-03-12 16:19:56] Valid Iter [7000/28435] 
[2021-03-12 16:20:10] Valid Iter [8000/28435] 
[2021-03-12 16:20:25] Valid Iter [9000/28435] 
[2021-03-12 16:20:39] Valid Iter [10000/28435] 
[2021-03-12 16:20:53] Valid Iter [11000/28435] 
[2021-03-12 16:21:08] Valid Iter [12000/28435] 
[2021-03-12 16:21:22] Valid Iter [13000/28435] 
[2021-03-12 16:21:36] Valid Iter [14000/28435] 
[2021-03-12 16:21:50] Valid Iter [15000/28435] 
[2021-03-12 16:22:05] Valid Iter [16000/28435] 
[2021-03-12 16:22:19] Valid Iter [17000/28435] 
[2021-03-12 16:22:33] Valid Iter [18000/28435] 
[2021-03-12 16:22:47] Valid Iter [19000/28435] 
[2021-03-12 16:23:02] Valid Iter [20000/28435] 
[2021-03-12 16:23:16] Valid Iter [21000/28435] 
[

In [6]:
test(only_self_att,test_data)

[2021-03-12 16:25:18] Valid Iter [1000/28435] 
[2021-03-12 16:25:31] Valid Iter [2000/28435] 
[2021-03-12 16:25:45] Valid Iter [3000/28435] 
[2021-03-12 16:25:59] Valid Iter [4000/28435] 
[2021-03-12 16:26:13] Valid Iter [5000/28435] 
[2021-03-12 16:26:26] Valid Iter [6000/28435] 
[2021-03-12 16:26:40] Valid Iter [7000/28435] 
[2021-03-12 16:26:54] Valid Iter [8000/28435] 
[2021-03-12 16:27:08] Valid Iter [9000/28435] 
[2021-03-12 16:27:22] Valid Iter [10000/28435] 
[2021-03-12 16:27:35] Valid Iter [11000/28435] 
[2021-03-12 16:27:49] Valid Iter [12000/28435] 
[2021-03-12 16:28:02] Valid Iter [13000/28435] 
[2021-03-12 16:28:16] Valid Iter [14000/28435] 
[2021-03-12 16:28:29] Valid Iter [15000/28435] 
[2021-03-12 16:28:43] Valid Iter [16000/28435] 
[2021-03-12 16:28:56] Valid Iter [17000/28435] 
[2021-03-12 16:29:09] Valid Iter [18000/28435] 
[2021-03-12 16:29:23] Valid Iter [19000/28435] 
[2021-03-12 16:29:36] Valid Iter [20000/28435] 
[2021-03-12 16:29:49] Valid Iter [21000/28435] 
[

In [8]:
print(sum(p.numel() for p in recurrent_self_att.parameters()))
print(sum(p.numel() for p in only_self_att.parameters()))

5715124
4925620


In [21]:
with open('msd_metadata/50tagList.txt') as f:
    tagList = f.readlines()
    
tagList = [line.replace('\n','') for line in tagList]

In [15]:
def first_10_tags(filename):

    input_length = 5*16000//256

    whole_spec = compute_melspectrogram(filename)

    n_chunks = whole_spec.shape[1] // input_length
    spec = np.zeros((n_chunks,whole_spec.shape[0],input_length)) # stack of chunks
    for i in range(n_chunks):
        spec[i]=whole_spec[:,i*input_length:(i+1)*input_length]

    spec = spec[:,np.newaxis,:,:]
    
    out = only_self_att(torch.Tensor(spec))
    
    out = out.detach().cpu().numpy()
    
    y_pred = out.mean(axis=0)
    
    print(np.array(tagList)[np.argsort(y_pred)[::-1]][:10].tolist())
    print(np.sort(y_pred)[::-1][:10].tolist())

In [16]:
first_10_tags("track_rock.mp3")

['rock', 'indie', 'alternative', 'hard rock', 'pop', 'indie rock', '80s', 'alternative rock', 'classic rock', 'electronic']
[0.37707647681236267, 0.1710204780101776, 0.12224441766738892, 0.10062480717897415, 0.09020650386810303, 0.0784156396985054, 0.07574423402547836, 0.0631411075592041, 0.057250697165727615, 0.056344885379076004]


In [20]:
first_10_tags("track_pop.mp3") #https://www.youtube.com/watch?v=EmSC6ZsxH10

['pop', 'female vocalists', 'electronic', 'rock', 'indie', 'alternative', 'dance', 'soul', 'jazz', 'chillout']
[0.21844559907913208, 0.2004319131374359, 0.16902250051498413, 0.13392417132854462, 0.10215485095977783, 0.09331976622343063, 0.07454238831996918, 0.05603437125682831, 0.052841540426015854, 0.05281108617782593]


In [22]:
tagList

['rock',
 'pop',
 'alternative',
 'indie',
 'electronic',
 'female vocalists',
 'dance',
 '00s',
 'alternative rock',
 'jazz',
 'beautiful',
 'metal',
 'chillout',
 'male vocalists',
 'classic rock',
 'soul',
 'indie rock',
 'Mellow',
 'electronica',
 '80s',
 'folk',
 '90s',
 'chill',
 'instrumental',
 'punk',
 'oldies',
 'blues',
 'hard rock',
 'ambient',
 'acoustic',
 'experimental',
 'female vocalist',
 'guitar',
 'Hip-Hop',
 '70s',
 'party',
 'country',
 'easy listening',
 'sexy',
 'catchy',
 'funk',
 'electro',
 'heavy metal',
 'Progressive rock',
 '60s',
 'rnb',
 'indie pop',
 'sad',
 'House',
 'happy']