## Training Module

In [1]:
import os
os.chdir('genre_classification_289a/src')

In [2]:
import torch
import torch.nn as nn
import os
from model import STN, MLP
import torch.optim as optim

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
from features import get_data_loaders, genre_counts, FramedFeatureDataset, FeatureDataset

Pop              1000
International    1000
Instrumental     1000
Folk             1000
Rock              999
Experimental      999
Electronic        999
Hip-Hop           997
Name: genre_top, dtype: int64
{'Pop': 0, 'International': 1, 'Instrumental': 2, 'Folk': 3, 'Rock': 4, 'Experimental': 5, 'Electronic': 6, 'Hip-Hop': 7}


## DCNN Training

In [5]:
num_genres = len(genre_counts)

In [6]:
#STN targets
agfs = []
genre = True

#dataset
dataset_name = 'fma_small'

In [8]:
stn = STN(len(genre_counts))
stn.to(device)
stn = nn.DataParallel(stn)

In [9]:
## Training Parameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(stn.parameters(), lr=0.001)
epochs = 30
batch_size = 64
valid_split = 0.2

In [10]:
dataset = FramedFeatureDataset(agfs=agfs, genre=genre)

In [11]:
trainloader, validloader = get_data_loaders(dataset, batch_size, valid_split)

In [12]:
import torch.nn.functional as F
from sklearn.metrics import f1_score

def validate(stn, label_name):
    
    with torch.no_grad():
        stn.eval()
        
        all_pred = []
        all_true = []
        
        for i, data in enumerate(validloader, 0):
            inputs, labels = data[0].to(device), data[1][label_name].to(device)
            
            out = stn(inputs)
            loss = F.cross_entropy(out, labels)
            
            all_pred.append(out.argmax(dim=1))
            all_true.append(labels)
            
        all_pred = torch.cat(all_pred)
        all_true = torch.cat(all_true)
        
        curr_f1 = f1_score(all_true.cpu(), all_pred.cpu(), average='micro')
        
        print('Validation f1 score: {}'.format(curr_f1))

In [13]:
#Train it
for epoch in range(epochs):  # loop over the dataset multiple times
    
    print('Starting epoch', epoch + 1)
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data[0].to(device), data[1]['genre'].to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = stn(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        

        # print statistics
        running_loss += loss.item()
        if i % 50 == 49:    # print every 30 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 30))
            running_loss = 0.0
            
    if epoch % 5 == 4:
        validate(stn, 'genre')
        stn.train()

print('Finished Training')

Starting epoch 1
[1,    50] loss: 3.127
[1,   100] loss: 2.726
[1,   150] loss: 2.606
[1,   200] loss: 2.527
[1,   250] loss: 2.459
[1,   300] loss: 2.350
[1,   350] loss: 2.364
[1,   400] loss: 2.305
[1,   450] loss: 2.217
[1,   500] loss: 2.240
[1,   550] loss: 2.188
[1,   600] loss: 2.243
[1,   650] loss: 2.220
[1,   700] loss: 2.179
Starting epoch 2
[2,    50] loss: 2.090
[2,   100] loss: 2.071
[2,   150] loss: 2.088
[2,   200] loss: 2.020
[2,   250] loss: 2.041
[2,   300] loss: 2.024
[2,   350] loss: 1.974
[2,   400] loss: 2.039
[2,   450] loss: 1.999
[2,   500] loss: 1.987
[2,   550] loss: 1.917
[2,   600] loss: 1.976
[2,   650] loss: 1.997
[2,   700] loss: 1.866
Starting epoch 3
[3,    50] loss: 1.891
[3,   100] loss: 1.842
[3,   150] loss: 1.937
[3,   200] loss: 1.826
[3,   250] loss: 1.921
[3,   300] loss: 1.831
[3,   350] loss: 1.818
[3,   400] loss: 1.844
[3,   450] loss: 1.817
[3,   500] loss: 1.808
[3,   550] loss: 1.837
[3,   600] loss: 1.796
[3,   650] loss: 1.737
[3,   

## Load Module

In [15]:
model_file = '../models/DCNN_{}_{}'.format(dataset_name, 'genre')
torch.save(stn, model_file)

In [16]:
model_file = '../models/DCNN_{}_{}'.format(dataset_name, 'genre')
model = torch.load(model_file)
model

DataParallel(
  (module): STN(
    (layer1): Sequential(
      (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ELU(alpha=1.0)
      (2): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
    )
    (layer2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      (4): Dropout(p=0.1, inplace=False)
    )
    (layer3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ELU(alpha=1.0)
      (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (layer4): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, mome

In [17]:
validate(model, 'genre')

Validation f1 score: 0.8896434634974533
