In [None]:
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_model_summary import summary
from sklearn.model_selection import train_test_split
from tqdm import tqdm


def load_data(data_path, test_size, validation_size):
    """Loads training dataset from json file.
        :param data_path (str): Path to json file containing data
    """
    
    print("Loading data...")
    
    with open(data_path, "r") as fp:
        data = json.load(fp)

    # convert lists to tensors
    X = torch.Tensor(data["mfcc"])
    y = torch.Tensor(data["labels"]).argmax(axis=1).type(torch.LongTensor)

    # create train/test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
    X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=validation_size)
    
    print("Data succesfully loaded!")
        
    return X_train, X_validation, X_test, y_train, y_validation, y_test

    
class Model(nn.Module):
    
    def __init__(self):
        
        super(Model, self).__init__()
        self.lstm1 = nn.LSTM(input_size=13, hidden_size=64, num_layers=2, batch_first=True)
        self.fc1 = nn.Linear(130 * 64, 64)
        self.fc2 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.2)

        
    def forward(self, x):
                
        h1 = torch.zeros(2, x.size(0), 64).to(device)
        c1 = torch.zeros(2, x.size(0), 64).to(device)
        
        # lstm
        h1, c1 = self.lstm1(x, (h1, c1))

        # 1st linear
        x = F.relu( self.fc1(torch.flatten(h1, 1, -1)) )
        x = self.dropout(x)
        
        # 2nd linear = output layer
        output = F.softmax( self.fc2(x) , dim=1)
        
        return output
    
    
    def compile(self, optimizer, loss_function):
        
        self.model_name = f"model-{int(time.time())}"
        self.optimizer = optimizer
        self.loss_function = loss_function
        print(self)
    
    
    def fit(self, X_train, y_train, validation_data, epochs, batch_size, log=True):
        
        # declare history with default values before training
        history = {'acc': [0.1], 'loss': [2.3026], 'val_acc': [0.1], 'val_loss': [2.3026]}
        num_batches = math.ceil(len(X_train) / batch_size)
        
        # iteration over epochs
        for epoch in tqdm(range(epochs)):
            # reset metrics after each epoch
            if log:
                running_loss = 0.0
                running_acc = 0.0
                self.train()
            
            # iteration over batches
            for i in range(0, len(X_train), batch_size):
                # batch data and load to device
                X_batch = X_train[i:i+batch_size].to(device)
                y_batch = y_train[i:i+batch_size].to(device)

                # forward and backward pass
                self.zero_grad()
                outputs = self(X_batch)
                loss = self.loss_function(outputs, y_batch)
                loss.backward()
                self.optimizer.step()

                # calculate running statistics
                if log:
                    matches = [torch.argmax(j)==k for j, k in zip(outputs, y_batch)]
                    running_acc += matches.count(True)/len(matches)
                    running_loss += loss.item()

            # add stats to history
            if log:
                running_acc /= num_batches
                running_loss /= num_batches
                running_val_acc, running_val_loss = self.test(validation_data[0], validation_data[1])
                history['acc'].append(round(float(running_acc),4))
                history['loss'].append(round(float(running_loss),4))
                history['val_acc'].append(round(float(running_val_acc),4))
                history['val_loss'].append(round(float(running_val_loss),4))            
            
        return history

        
    def test(self, X, y, out=False):
        
        self.eval()
        
        with torch.no_grad():
            outputs = self(X.to(device))
            matches  = [torch.argmax(i)==j for i, j in zip(outputs, y)]
            acc = round(float(matches.count(True)/len(matches)),4)
            loss = round(float(self.loss_function(outputs, y.to(device))),4)
            
        if out:
            print(f"Test acc: {acc} Test loss: {loss}")
            
        return acc, loss 


def plot_history(history):
    """Plots accuracy/loss for training/validation set as a function of the epochs
        :param history: Training history of model
        :return:
    """

    fig, axs = plt.subplots(2)

    # create accuracy sublpot
    axs[0].plot(history["acc"], label="train accuracy")
    axs[0].plot(history["val_acc"], label="val accuracy")
    axs[0].set_ylabel("Accuracy")
    axs[0].legend(loc="lower right")
    axs[0].set_title("")

    # create error sublpot
    axs[1].plot(history["loss"], label="train error")
    axs[1].plot(history["val_loss"], label="val error")
    axs[1].set_ylabel("Error")
    axs[1].set_xlabel("Epoch")
    axs[1].legend(loc="upper right")
    axs[1].set_title("")

    plt.show()
    
    
    
if __name__ == "__main__":
    # path to json file that stores MFCCs and genre labels for each processed segment
    DATA_PATH = "data/data_20.json"
    
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print("Running on the GPU")
        
    else:
        device = torch.device("cpu")
        print("Running on the CPU")
        
    try:
        X_train
        
    except NameError:
        X_train, X_validation, X_test, y_train, y_validation, y_test = load_data(
            DATA_PATH, 
            test_size=0.10, 
            validation_size=0.10)

    model = Model().to(device)
    model.compile(optim.Adam([
                        {'params': model.lstm1.parameters()},
                        {'params': model.lstm2.parameters()},
                        {'params': model.fc1.parameters()},
                        {'params': model.fc2.parameters(), 'weight_decay': 1e-3}
                    ], lr=1e-4), 
                    nn.CrossEntropyLoss())
    
    print(summary(model, torch.rand(5, 130, 13).to(device)), sep='\n')

In [None]:
history = model.fit(
    X_train, y_train,
    validation_data=(X_validation, y_validation), 
    epochs=50, 
    batch_size=64,
    log=True)
    
model.test(X_test, y_test, out=True)
plot_history(history)
print("Max_val_acc:", max(history['val_acc']), "  Max_train_acc:", max(history['acc']))
print("Min_val_loss:", min(history['val_loss']), "  Min_train_loss:", min(history['loss']))