In [None]:
from tqdm import tqdm
from collections import OrderedDict

import os, pickle
import argparse
import json
import numpy as np
import pandas as pd

import librosa
import scipy
import torch
import torch.nn as nn
from torch.utils import data

import sys

DIR = '/CSC413/RawNet/python/RawNet2/Pre-trained_model'
sys.path.append( DIR )
from models.RawNet2 import RawNet as RawNet2

from argparse import Namespace
import glob, json, argparse

MAGIC_NUMBER = 59049
np.random.seed(MAGIC_NUMBER)

In [None]:
parser = argparse.ArgumentParser()
args = Namespace(**{
    "bs": 64,
    "lr": 0.001,
    "nb_samp": MAGIC_NUMBER,
    "name": 'fma-trial2',
    "save_dir": 'DNNs/',
    "DB": '/',
    "window_size": 0,
    "wd": 0.001,
    "epoch": 60,
    "optimizer": 'Adam',
    "nb_worker": 4,
    "temp": .5,
    "seed": 12315,
    "load_model_dir": '/h/marko/CSC413/RawNet/python/RawNet2/Pre-trained_model/rawnet2_best_weights.pt',
    "m_first_conv": 251,
    "m_in_channels": 1,
    "m_filts": [128, [128,128], [128,256], [256,256]],
    "m_blocks": [2, 4],
    "m_nb_fc_att_node": [1],
    "m_nb_fc_node": 1024,
    "m_gru_node": 1024,
    "m_nb_gru_layer": 1,
    "m_nb_samp": MAGIC_NUMBER,
    "amsgrad": True,
    "make_val_trial": False,
    "debug": False,
    "comet_disable": False,
    "save_best_only": False,
    "mg": False,
    "load_model": True,
    "reproducible": True,
})
args.model = {}
for k, v in vars(args).items():
    if k[:2] == 'm_':
        # print(k, v)
        args.model[k[2:]] = v
args.model['nb_classes'] = 6112

### Train set here

In [None]:
label_map = pd.read_csv('/CSC413/labels.csv', index_col=0, squeeze=True)

In [None]:
data_dir = '/CSC413/fma/fma_npy'
all_files = np.array(glob.glob(data_dir+'/*.npy'))

In [None]:
n = len(all_files)
subset_indices = np.random.choice(n, n//10, replace=False)

In [None]:
train_set = all_files[np.setdiff1d(np.arange(n), subset_indices)]
val_set = all_files[subset_indices]

In [None]:
class FMADataset(torch.utils.data.Dataset):
    def __init__(self, audio_list):
        self.audio_list = audio_list
        self.n = len(self.audio_list)

    def __len__(self):
        return self.n * 19

    def __getitem__(self, idx):
        audio_fn = self.audio_list[idx % self.n]
        audio = np.load(audio_fn)
        offset = (idx // self.n) * ((audio.shape[0]-MAGIC_NUMBER) // 19)
        label = label_map[int(audio_fn.split('/')[-1][:6])]
        return audio[offset:offset+MAGIC_NUMBER], label
    
class FMADataset_val(torch.utils.data.Dataset):
    def __init__(self, audio_list):
        self.audio_list = audio_list
        self.n = len(self.audio_list)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        audio_fn = self.audio_list[idx]
        audio = np.load(audio_fn)
        label = label_map[int(audio_fn.split('/')[-1][:6])]
        return audio[:MAGIC_NUMBER], label

In [None]:
trainset = FMADataset(train_set)
trainset_gen = data.DataLoader(
    trainset,
    batch_size = args.bs,
    shuffle = True,
    drop_last = True,
    num_workers = args.nb_worker)

valset = FMADataset_val(val_set)
valset_gen = data.DataLoader(
    valset,
    batch_size = 20,
    shuffle = True,
    drop_last = True,
    num_workers = args.nb_worker)

In [None]:
#set save directory
save_dir = args.save_dir + args.name + '/'
os.makedirs(save_dir, exist_ok=True)
os.makedirs(save_dir+'results/', exist_ok=True)
os.makedirs(save_dir+'models/', exist_ok=True)

In [None]:
model = RawNet2(args.model, 'cuda').to('cuda')
if args.load_model: model.load_state_dict(torch.load(args.load_model_dir))
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])

In [None]:
model.fc2_gru = nn.Linear(in_features = args.model['nb_fc_node'],
    out_features = 16,
    bias = True)
model.cuda();

In [None]:
#set ojbective funtions
criterion = {}
criterion['cce'] = nn.CrossEntropyLoss()

#set optimizer
params = [
    {
        'params': [
            param for name, param in model.named_parameters()
            if 'bn' not in name
        ]
    },
    {
        'params': [
            param for name, param in model.named_parameters()
            if 'bn' in name
        ],
        'weight_decay':
        0
    },
]

In [None]:
optimizer = torch.optim.Adam(params,
    lr = args.lr,
    weight_decay = args.wd,
    amsgrad = args.amsgrad)

In [None]:
def keras_lr_decay(step, decay = 0.0001):
    return 1./(1. + decay * step)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: keras_lr_decay(step))

### Evaluate

In [None]:
def evaluate(model):
    global valset, valset_gen
    corr = 0
    n = len(valset)
    for m_batch, m_label in tqdm(valset_gen):
        m_batch, m_label = m_batch.cuda(), m_label.cuda()
        output = model(m_batch, m_label)
        _, pred = torch.max(output, 1)
        corr += (m_label == pred).sum().item()
    print(f'accuracy: {corr*100/n:.2f}')

In [None]:
model = RawNet2(args.model, 'cuda')
model.fc2_gru = nn.Linear(in_features = args.model['nb_fc_node'],
    out_features = 16,
    bias = True)
model.load_state_dict(torch.load(save_dir +  f'models/TA_26.pt'))
model.cuda();

In [None]:
evaluate(model)

### Train

In [None]:
device = 'cuda'
for epoch in range(args.epoch):
    model.train()
    corr = 0
    with tqdm(total = len(trainset_gen)+1, leave=True) as pbar:
        epoch_loss = 0
        for m_batch, m_label in trainset_gen:
            m_batch, m_label = m_batch.to(device), m_label.to(device)
            output = model(m_batch, m_label)
            cce_loss = criterion['cce'](output, m_label)
            loss = cce_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            
            epoch_loss += loss.item()
            _, pred = torch.max(output, 1)
            corr += (m_label == pred).sum().item()
            
            pbar.set_description(f'epoch: {epoch+1}, cce: {cce_loss:.3f}')
            pbar.update(1)
                    
        epoch_loss /= len(trainset_gen)
        pbar.set_description(f'epoch: {epoch+1}, avg loss: {epoch_loss:.3f}, acc: {corr/len(trainset_gen):.2f}')
        pbar.update(1)
                    
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), save_dir +  f'models/TA_{epoch+1}.pt')
        torch.save(optimizer.state_dict(), save_dir + 'models/best_opt_eval.pt')

### Hyperparameter Search

In [None]:
def evaluate(parametrization={}):
    args = parametrization
    myIndex = open_dir('index')
    myQueryParser = QueryParser("file_content", schema=myIndex.schema, group=qparser.OrGroup)
    mySearcher = myIndex.searcher(weighting=BM25F(B=args.get('B', 0.524), K1=args.get('K1', 3)))
    res = pyTrecEval(TOPIC_FILE, QRELS_FILE, myQueryParser, mySearcher)
    return acc