# MLP Testing Module (on the XXSMALL set of 166 tracks not in medium)

In [1]:
import os
os.chdir('genre_classification_289a/src')

In [2]:
import torch
import torch.nn as nn
import os
from model import STN, MLP
import torch.optim as optim

In [3]:
import numpy as np

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
from features import get_data_loaders, FramedFeatureDataset, FeatureDataset, DatasetSettings

### MLP Testing

In [9]:
# set dataset as the test set
train_dataset_name = 'fma_medium'

def get_stn_path(dataset, target):
    return '../models/DCNN_{}_{}'.format(dataset, target)

targets = ['subgenres', 'mfcc', 'genre'] # STNs that this MLP was trained with
layer = 7

# load STNs
stns = [torch.load(get_stn_path(train_dataset_name, target)).to(device) for target in targets]
stn_layer_dims = [None, 16, 32, 64, 64, 128, 256, 256]

# load the MLP
model_file = f'../models/MLP_{train_dataset_name}_stn_{"_".join(targets)}_layer_{layer}'
mlp = torch.load(model_file)

test_dataset_name = 'fma_medium_testset'

In [33]:
settings = DatasetSettings(test_dataset_name, 'fma_metadata')
# extract data for genre targets, no AGF targets (those only used for training AGF STNs)
dataset = FramedFeatureDataset(settings,  agfs=[], genre=True)
print("Num genres: ", settings.num_genres)
print(settings.genre_counts)
print("genre coding: ", settings.coded_genres)

# Num genres:  16
# 0->7 Rock                   6911
# 1->6 Electronic             6110
# 2->5 Experimental           2207
# 3->3 Hip-Hop                2109
# 4->1 Folk                   1477
# 5->2 Instrumental           1280
# 6->0 Pop                    1129
# 7->4 International          1004
# 8 Classical               598
# 9 Old-Time / Historic     510
# 10 Jazz                    380
# 11 Country                 178
# 12 Soul-RnB                154
# 13 Spoken                  118
# 14 Blues                    72
# 15 Easy Listening           21

Num genres:  8
Pop              45
Folk             27
Instrumental     26
Hip-Hop          25
International    14
Experimental     12
Electronic        9
Rock              8
Name: genre_top, dtype: int64
genre coding:  {'Pop': 0, 'Folk': 1, 'Instrumental': 2, 'Hip-Hop': 3, 'International': 4, 'Experimental': 5, 'Electronic': 6, 'Rock': 7}


In [19]:
## Training Parameters (@TODO remove these)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp.parameters(), lr=0.001)
epochs = 5
batch_size = 64
valid_split = .98 # virtually all validation data, just doing this hack so other code doesn't need refactoring

trainloader, validloader = get_data_loaders(dataset, batch_size, valid_split)

In [40]:
import torch.nn.functional as F
from sklearn.metrics import f1_score

def genre_16_encoding_to_genre_8(predictions):
    # map encoding of 16genre-trained model to encoding of 8genre dataset
    # '8' all corresponds to misclassification (no genres beyond 7 are in the dataset)
    pred_map = [7,6,5,3,1,2,0,4,8,8,8,8,8,8,8,8]
    return [pred_map[pred] for pred in predictions]

def validate():
    
    with torch.no_grad():
        for stn in stns:
            stn.eval()
        mlp.eval()
                
        all_pred = []
        all_true = []
        
        for i, data in enumerate(validloader, 0):
            inputs, labels = data[0].to(device), data[1]['genre'].to(device)
            
            out_intermediate = [stn.module.forward_intermediate(inputs, layer) for stn in stns]
            input_mlp = torch.cat(out_intermediate, dim=1)
            
            out = mlp(input_mlp)
            loss = F.cross_entropy(out, labels)
            
            all_pred.append(out.argmax(dim=1))
            all_true.append(labels)
            
        all_pred_converted = genre_16_encoding_to_genre_8(torch.cat(all_pred).cpu())
        all_pred = torch.cat(all_pred).cpu()
        all_true = torch.cat(all_true).cpu()
        
        a = [{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]},{'pred':[],'true':[]}]
        
        # accumulate pred/label by true class
        for i in range(len(list(all_true))):
            a[all_true[i]]['true'].append(all_true[i])
            a[all_true[i]]['pred'].append(all_pred_converted[i])
        
        f1s = [0,0,0,0,0,0,0,0]
        # calculate accuracy by class
        for i in range(len(a)):
            for j in range(len(a[i]['true'])):
                f1s[i] += int(a[i]['true'][j] == a[i]['pred'][j])
            f1s[i] /= len(a[i]['true'])
            #f1s.append(f1_score(a[i]['true'], a[i]['pred']))
        print("Genre lookup:",settings.coded_genres)
        print("F1s by genre:",f1s)
        
        for i in range(len(all_pred)):
            print(f"pred: {all_pred[i]}->{all_pred_converted[i]}, true: {all_true[i]}")
        return f1_score(all_true, all_pred_converted, average='micro')

print('Validation f1 score: {}'.format(validate())) 

Genre lookup: {'Pop': 0, 'Folk': 1, 'Instrumental': 2, 'Hip-Hop': 3, 'International': 4, 'Experimental': 5, 'Electronic': 6, 'Rock': 7}
F1s by genre: [0.11326860841423948, 0.4891304347826087, 0.2388888888888889, 0.672514619883041, 0.2708333333333333, 0.5185185185185185, 0.6190476190476191, 0.48148148148148145]
pred: 1->6, true: 0
pred: 1->6, true: 2
pred: 3->3, true: 2
pred: 0->7, true: 0
pred: 1->6, true: 3
pred: 1->6, true: 3
pred: 6->0, true: 0
pred: 0->7, true: 0
pred: 0->7, true: 6
pred: 3->3, true: 3
pred: 6->0, true: 0
pred: 1->6, true: 4
pred: 1->6, true: 4
pred: 3->3, true: 3
pred: 1->6, true: 7
pred: 0->7, true: 0
pred: 3->3, true: 3
pred: 4->1, true: 1
pred: 7->4, true: 4
pred: 0->7, true: 2
pred: 6->0, true: 0
pred: 3->3, true: 0
pred: 14->8, true: 0
pred: 7->4, true: 4
pred: 5->2, true: 1
pred: 1->6, true: 2
pred: 0->7, true: 0
pred: 1->6, true: 1
pred: 0->7, true: 0
pred: 1->6, true: 2
pred: 7->4, true: 4
pred: 6->0, true: 0
pred: 7->4, true: 0
pred: 0->7, true: 5
pred: 2