In [None]:
import itertools
import json
import librosa.display
import math
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.nn import ReLU
from torch.autograd import Variable
from tqdm import tqdm

# declare execution device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")

else:
    device = torch.device("cpu")
    print("Running on the CPU")
    

def load_data(data_path, test_size, validation_size, scale=False):
    """Loads training dataset from a .npy file.
        :param data_path (str): Path to .npy file containing data
    """
    
    print("Loading data...")
    
    data = np.load(data_path, allow_pickle=True).item()

    # convert lists to tensors
    X = torch.Tensor(data["spectrograms"])
    y = torch.Tensor(data["labels"])
    
    if scale:
        X, scale_min, scale_max = scale_input(X)

    # create train/test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size)
    X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=validation_size)
    
    # reshape data for an input layer
    shape = np.shape(X_train)
    X_train = X_train[..., np.newaxis].view(-1, 1, shape[1], shape[2])
    X_validation = X_validation[..., np.newaxis].view(-1, 1, shape[1], shape[2])
    X_test = X_test[..., np.newaxis].view(-1, 1, shape[1], shape[2])
    
    print("Data succesfully loaded!")
    
    if scale:
        return X_train, X_validation, X_test, y_train, y_validation, y_test, scale_min, scale_max
    else:
        return X_train, X_validation, X_test, y_train, y_validation, y_test


def scale_input(x, scale_min=None, scale_max=None):
    if scale_min is None:
        scale_min = x.min()
    x = x - scale_min
    if scale_max is None:
        scale_max = x.max()
    x = x / scale_max
    x = x - 0.5
    x = x * 2
    
    return x, scale_min, scale_max

def unscale_input(x, scale_min, scale_max):
    x = x / 2
    x = x + 0.5
    x = x * scale_max
    x = x + scale_min
    
    return x

    
class Model(nn.Module):
    
    def __init__(self):
        
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=0)
        self.norm1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), padding=0)
        self.norm2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(2, 2), padding=0)
        self.norm3 = nn.BatchNorm2d(32)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(2, 2), padding=0)
        self.norm4 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 32 * 8, 128) # 32 * 8 -> shape after last convolution layer
        self.fc2 = nn.Linear(128, 10)
        self.max_pool_1 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.max_pool_2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=1)
        self.act_func = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

        
    def forward(self, x):
        
        # 1st convolutional
        x = self.act_func( self.conv1(x) )
        x = self.max_pool_1(x)
        x = self.norm1(x)

        # 2nd convolutional
        x = self.act_func( self.conv2(x) )
        x = self.max_pool_1(x)
        x = self.norm2(x)

        # 3rd convolutional
        x = self.act_func( self.conv3(x) )
        x = self.max_pool_2(x)
        x = self.norm3(x)
        
        # 4th convolutional
        x = self.act_func( self.conv4(x) )
        x = self.max_pool_2(x)
        x = self.norm4(x)
        
        # 1st linear
        x = self.act_func( self.fc1(torch.flatten(x, 1, -1)) )
        x = self.dropout(x)
        
        # 2nd linear = output layer
        output = F.softmax( self.fc2(x) , dim=1)
        
        return output
    
    
    def compile(self, model_name = None, model_state = None, optimizer_state = None):
                    
        self.to(device)
        
        if (model_name):
            self.model_name = model_name
        else:
            self.model_name = f"model-{int(time.time())}"
        
        if (model_state):
            self.load_state_dict(model_state)
        
        self.optimizer = optim.Adam([
                    {'params': self.conv1.parameters(), 'weight_decay': 1e-4},
                    {'params': self.norm1.parameters()},
                    {'params': self.conv2.parameters(), 'weight_decay': 1e-4},
                    {'params': self.norm2.parameters()},
                    {'params': self.conv3.parameters(), 'weight_decay': 1e-3},
                    {'params': self.norm3.parameters()},
                    {'params': self.conv4.parameters(), 'weight_decay': 1e-3},
                    {'params': self.norm4.parameters()},
                    {'params': self.fc1.parameters(), 'weight_decay': 1e-4},
                    {'params': self.fc2.parameters(), 'weight_decay': 1e-3}
                ], lr=1e-3)
        
        if (optimizer_state):
            self.optimizer.load_state_dict(optimizer_state)
        
        self.loss_function = nn.CrossEntropyLoss()
    
    
    def fit(self, X_train, y_train, validation_data, epochs, batch_size, log=True, history=None):
        
        # declare history with default values before training
        if history is None:
            history = {'acc': [0.1], 'loss': [2.3026], 'val_acc': [0.1], 'val_loss': [2.3026]}
        num_batches = math.ceil(len(X_train) / batch_size)
        
        self.train()
        
        # iteration over epochs
        for epoch in range(epochs):
            # reset metrics after each epoch
            if log:
                running_loss = 0.0
                running_acc = 0.0
                self.train()
            
            # iteration over batches
            for i in tqdm(range(0, len(X_train), batch_size)):
                # batch data and load to device
                X_batch = X_train[i:i+batch_size].to(device)
                y_batch = y_train[i:i+batch_size].type(torch.LongTensor).to(device)

                # forward and backward pass
                self.zero_grad()
                outputs = self(X_batch)
                loss = self.loss_function(outputs, y_batch)
                loss.backward()
                self.optimizer.step()

                # calculate running statistics
                if log:
                    matches = [torch.argmax(j)==k for j, k in zip(outputs, y_batch)]
                    running_acc += matches.count(True)/len(matches)
                    running_loss += loss.item()

            # add stats to history
            if log:
                running_acc /= num_batches
                running_loss /= num_batches
                running_val_acc, running_val_loss = self.test(validation_data[0], validation_data[1])
                history['acc'].append(round(float(running_acc),4))
                history['loss'].append(round(float(running_loss),4))
                history['val_acc'].append(round(float(running_val_acc),4))
                history['val_loss'].append(round(float(running_val_loss),4))            
        
        print("Training finished...")
        
        return history

        
    def test(self, X, y, batch_size=32, out=False):
        
        num_batches = math.ceil(len(X) / batch_size)
        running_loss = 0.0
        running_acc = 0.0
        self.eval()
        
        with torch.no_grad():
            
            for i in range(0, len(X), batch_size):
            
                X_batch = X[i:i+batch_size].to(device)
                y_batch = y[i:i+batch_size].type(torch.LongTensor).to(device)
                
                outputs = self(X_batch)
                loss = self.loss_function(outputs, y_batch)
                
                matches = [torch.argmax(i)==j for i, j in zip(outputs, y_batch)]
                running_acc += matches.count(True)/len(matches)
                running_loss += loss.item()
                                
            t_acc = round(float(running_acc / num_batches),4)
            t_loss = round(float(running_loss / num_batches),4)
            
        if out:
            print(f"Test acc: {t_acc} Test loss: {t_loss}")
            
        return t_acc, t_loss
    
    
    def get_predictions(self, X, batch_size=32):
    
        all_outputs = torch.Tensor().to(device)
        num_batches = math.ceil(len(X) / batch_size)
        self.eval()

        with torch.no_grad():

            for i in range(0, len(X), batch_size):

                X_batch = X[i:i+batch_size].to(device)
                outputs = self(X_batch)
                all_outputs = torch.cat((all_outputs, outputs), dim=0)

        return all_outputs.cpu()


def save_model(model, path):
    
    path += "/" + model.model_name + ".pth"
    torch.save({
            'name': model.model_name,
            'model': model.state_dict(),
            'optimizer': model.optimizer.state_dict(),
            }, path)
    
    
def load_model(path):
    
    model = Model()
    checkpoint = torch.load(path)
    model.compile(checkpoint['name'], checkpoint['model'], checkpoint['optimizer'])

    return model


def new_model():
    
    model = Model()
    model.compile()
    
    return model


def create_confusion_matrix(X, y):
    
    stacked = torch.stack((y.type(torch.LongTensor), X.argmax(dim=1).type(torch.LongTensor)), dim=1)
    matrix = torch.zeros(10,10).type(torch.LongTensor)
    
    for p in stacked:
        target, predicted = p.tolist()
        matrix[target, predicted] += 1
        
    return matrix


def plot_confusion_matrix(matrix, title='Confusion matrix', cmap=plt.cm.Blues):

    classes = ["Blues", "Classical", "Country", "Disco", "Hip Hop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
    
    plt.imshow(matrix, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    thresh = matrix.max() / 2.
    for i, j in itertools.product(range(matrix.shape[0]), range(matrix.shape[1])):
        plt.text(j, i, format(matrix[i, j], 'd'), ha="center", va="center", color="white" if matrix[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    
def plot_history(history):
    """Plots accuracy/loss for training/validation set as a function of the epochs
        :param history: Training history of model
        :return:
    """

    fig, axs = plt.subplots(2)

    # create accuracy sublpot
    axs[0].plot(history["acc"], label="train accuracy")
    axs[0].plot(history["val_acc"], label="val accuracy")
    axs[0].set_ylabel("Accuracy")
    axs[0].legend(loc="lower right")
    axs[0].set_title("")

    # create error sublpot
    axs[1].plot(history["loss"], label="train error")
    axs[1].plot(history["val_loss"], label="val error")
    axs[1].set_ylabel("Error")
    axs[1].set_xlabel("Epoch")
    axs[1].legend(loc="upper right")
    axs[1].set_title("")

    plt.show()
    

def plot_spectrogram(Y, sr, hop_length, y_axis="linear", size=None):
    if size:
        plt.figure(figsize=size)
    librosa.display.specshow(Y, 
                             sr=sr, 
                             hop_length=hop_length, 
                             x_axis="time", 
                             y_axis=y_axis)
    plt.colorbar(format="%+2.f")

    
class GuidedBackprop():
    
    def __init__(self, model):
        
        self.model = model
        # variables for input gradients and forward activations
        self.gradients = None
        self.forward_relu_outputs = []
        # update layers
        self.update_relus()
        self.hook_layers()
    
    
    def hook_layers(self):
        # setup hook to store final gradient on backward pass
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]
        
        first_layer = list(self.model._modules.items())[0][1]
        first_layer.register_backward_hook(hook_function)
    
    
    def update_relus(self):
        # set gradient to zero if negative or if corresponding ReLU activation is zero
        def relu_backward_hook_function(module, grad_in, grad_out):
            corresponding_forward_output = self.forward_relu_outputs[-1]
            corresponding_forward_output[corresponding_forward_output > 0] = 1
            modified_grad_out = corresponding_forward_output * torch.clamp(grad_in[0], min=0.0)
            del self.forward_relu_outputs[-1]
            
            return (modified_grad_out,)

        # only store activations on forward pass
        def relu_forward_hook_function(module, ten_in, ten_out):
            self.forward_relu_outputs.append(ten_out)

        # loop through layers, hook up ReLUs
        for pos, module in self.model._modules.items():
            
            if isinstance(module, ReLU):
                module.register_backward_hook(relu_backward_hook_function)
                module.register_forward_hook(relu_forward_hook_function)

    
    def generate_gradients(self, input_sample, target_class):
        
        # compute gradients on input
        input_sample.requires_grad = True
        self.model.eval()
        
        # forward pass
        model_output = self.model(input_sample)

        # zero gradients
        self.model.zero_grad()
        
        # create target gradient for backpropagation
        one_hot_output = torch.FloatTensor(1, model_output.size()[-1]).zero_()
        one_hot_output[0][target_class] = 1

        # backward pass
        model_output.backward(gradient=one_hot_output)
        
        input_sample.requires_grad = False
        gradients_as_arr = self.gradients.data.numpy()[0].squeeze()
        
        return gradients_as_arr
    
    
def get_positive_negative_saliency(gradient):
    
    pos_saliency = (np.maximum(0, gradient) / gradient.max())
    neg_saliency = (np.maximum(0, -gradient) / -gradient.min())
    
    return pos_saliency, neg_saliency