# MLP Training Module

In [1]:
import os
os.chdir('genre_classification_289a/src')

In [2]:
import torch
import torch.nn as nn
import os
from model import STN, MLP
import torch.optim as optim

In [3]:
import numpy as np

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
from features import get_data_loaders, FramedFeatureDataset, FeatureDataset, DatasetSettings

### MLP Training

In [6]:
#MLP target always genre
agfs = [] #'subgenre', 'mfcc'
genre = True #False if not genre STN
    
#dataset
dataset_name = 'fma_medium'

In [7]:
settings = DatasetSettings(dataset_name, 'fma_metadata')
dataset = FramedFeatureDataset(settings,  agfs=agfs, genre=genre)
print("Num genres: ", settings.num_genres)
print(settings.genre_counts)

Num genres:  16
Rock                   6911
Electronic             6110
Experimental           2207
Hip-Hop                2109
Folk                   1477
Instrumental           1280
Pop                    1129
International          1004
Classical               598
Old-Time / Historic     510
Jazz                    380
Country                 178
Soul-RnB                154
Spoken                  118
Blues                    72
Easy Listening           21
Name: genre_top, dtype: int64


In [8]:
def get_stn_path(dataset, target):
    return '../models/DCNN_{}_{}'.format(dataset, target)

# load STNs
targets = ['subgenres']#, 'mfcc', 'genre']
stns = [torch.load(get_stn_path(dataset_name, target)).to(device) for target in targets]
stn_layer_dims = [None, 16, 32, 64, 64, 128, 256, 256]

#which layer to extract features from
layer = 7

# setup MLP on GPU
mlp_input_size = len(targets) * stn_layer_dims[layer]
mlp_output_size = settings.num_genres
mlp = MLP(mlp_input_size, mlp_output_size)
mlp.to(device)
mlp = nn.DataParallel(mlp)

## Training Parameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp.parameters(), lr=0.001)
epochs = 5
batch_size = 64
valid_split = 0.2

trainloader, validloader = get_data_loaders(dataset, batch_size, valid_split)

In [9]:
import torch.nn.functional as F
from sklearn.metrics import f1_score

def validate():
    
    with torch.no_grad():
        for stn in stns:
            stn.eval()
        mlp.eval()
                
        all_pred = []
        all_true = []
        
        for i, data in enumerate(validloader, 0):
            inputs, labels = data[0].to(device), data[1]['genre'].to(device)
            
            out_intermediate = [stn.module.forward_intermediate(inputs, layer) for stn in stns]
            input_mlp = torch.cat(out_intermediate, dim=1)
            
            out = mlp(input_mlp)
            loss = F.cross_entropy(out, labels)
            
            all_pred.append(out.argmax(dim=1))
            all_true.append(labels)
            
        all_pred = torch.cat(all_pred)
        all_true = torch.cat(all_true)
        
        curr_f1 = f1_score(all_true.cpu(), all_pred.cpu(), average='micro')
        return curr_f1

In [10]:
#Train it
%time
losses = []
accs = []
for stn in stns:
    stn.eval()


#f.write('Initial Validation F1: %.6f' % validate())

mlp.train()

for epoch in range(epochs):  # loop over the dataset multiple times
    
    print('Starting epoch %d' % (epoch+1))
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data[0].to(device), data[1]['genre'].to(device)  #data[1]['{argument for agf being trained}']
        
        input_mlp = None
        with torch.no_grad():
            out_intermediates = [stn.module.forward_intermediate(inputs, layer) for stn in stns]
            input_mlp = torch.cat(out_intermediates, dim=1)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = mlp(input_mlp)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 30 == 29:    # print every 30 mini-batches
            avg_loss = running_loss / 30
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, avg_loss))
            losses.append(avg_loss)
            running_loss = 0.0

print('Finished Training')

final_f1 = validate()
np.array(losses).tofile(f'logs/losses_MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}')
np.array(final_f1).tofile(f'logs/final_MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}')

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 5.72 µs
Starting epoch 1
[1,    30] loss: 1.709
[1,    60] loss: 1.073
[1,    90] loss: 0.868
[1,   120] loss: 0.804
[1,   150] loss: 0.779
[1,   180] loss: 0.721
[1,   210] loss: 0.712
[1,   240] loss: 0.698
[1,   270] loss: 0.644
[1,   300] loss: 0.682
[1,   330] loss: 0.656
[1,   360] loss: 0.616
[1,   390] loss: 0.593
[1,   420] loss: 0.637
[1,   450] loss: 0.648
[1,   480] loss: 0.587
[1,   510] loss: 0.595
[1,   540] loss: 0.630
[1,   570] loss: 0.592
[1,   600] loss: 0.604
[1,   630] loss: 0.664
[1,   660] loss: 0.589
[1,   690] loss: 0.653
[1,   720] loss: 0.586
[1,   750] loss: 0.545
[1,   780] loss: 0.593
[1,   810] loss: 0.648
[1,   840] loss: 0.559
[1,   870] loss: 0.555
[1,   900] loss: 0.563
[1,   930] loss: 0.583
[1,   960] loss: 0.579
[1,   990] loss: 0.591
[1,  1020] loss: 0.601
[1,  1050] loss: 0.640
[1,  1080] loss: 0.546
[1,  1110] loss: 0.522
[1,  1140] loss: 0.599
[1,  1170] loss: 0.562
[1,  1200] loss: 0.557

### Save Model

In [11]:
model_file = f'../models/MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}'
torch.save(mlp, model_file)

## Load & Eval Model

In [12]:
model_file = f'../models/MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}'
mlp = torch.load(model_file)

Load losses and final accuracy:

In [13]:
losses = np.fromfile(f'logs/losses_MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}')
final_f1 = np.fromfile(f'logs/final_MLP_{dataset_name}_stn_{"_".join(targets)}_layer_{layer}')

In [14]:
final_f1

array([0.86016313])

## MLP Performance History
* SGM layer 4: 0.67209446
* 