In [None]:
import os
import torch
import numpy as np
import pandas as pd
from data_utils import data_loading as data_load
from data_utils import classes_labels as cl
from custom_nets.cnn import CNNAll, CNN3Classes, CNN10Classes
from custom_nets.lstm_simple_all_classes import LSTMSimpleAll
from custom_nets.resnet_all_classes import ResBlock, ResNetAll
from custom_nets.lstm_cnn_all_classes import LSTMCNNAll
from custom_nets.lstm_cnn_pqd_all_classes import LSTMCNNAllpqd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from IPython.core.pylabtools import figsize
import ntpath

In [None]:
class ConfusionMatrixAll:
    def __init__(self, batch_size, device):
        self.batch_size = batch_size
        self.device = device
        self.testloader, self.number_of_audio = data_load.load_validation_as_test_data(self.batch_size, '../data')
        if not os.path.exists('../confusion_matrices'):
            os.makedirs('../confusion_matrices')
    
    def save_confusion_matrix(self, network, best_networkstate_path):
        
        classes = np.empty((self.number_of_audio), dtype=object)
        proper_classes = np.empty((self.number_of_audio), dtype=object)
        if self.device == 'cpu':
            network.load_state_dict(torch.load(best_networkstate_path, map_location=torch.device('cpu')))
        else:
            network.load_state_dict(torch.load(best_networkstate_path))
        network.to(self.device)
        total = 0

        with torch.no_grad():
            for data in self.testloader:
                images, file_names = data[0].to(self.device), data[1]
                outputs = network(images)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.tolist()
                for i, predicted_label in enumerate(predicted):
                    classes[total] = cl.label_number_to_class[predicted_label]
                    proper_classes[total] = cl.label_number_to_class[cl.class_to_label_number[file_names[i].split('/')[0]]]
                    total += 1

        labels = list(cl.label_number_to_class2.values())
        conf_mat = confusion_matrix(proper_classes, classes, labels=labels)
        pd.DataFrame(conf_mat, index=labels, columns=labels).to_csv(f'../confusion_matrices/{ntpath.basename(best_networkstate_path[:-4])}.csv')

In [None]:
class ConfusionMatrixTwoNets:
    def __init__(self, batch_size, device):
        self.batch_size = batch_size
        self.device = device
        self.testloader, self.number_of_audio = data_load.load_validation_as_test_data(self.batch_size, '../data')
        if not os.path.exists('../confusion_matrices'):
            os.makedirs('../confusion_matrices')
    
    def save_confusion_matrix(self, network, network2, best_networkstate_path, best_networkstate_path2):
        
        classes = np.empty((self.number_of_audio), dtype=object)
        proper_classes = np.empty((self.number_of_audio), dtype=object)
        
        if self.device == 'cpu':
            network.load_state_dict(torch.load(best_networkstate_path, map_location=torch.device('cpu')))
        else:
            network.load_state_dict(torch.load(best_networkstate_path))
        network.to(self.device)
        if self.device == 'cpu':
            network2.load_state_dict(torch.load(best_networkstate_path2, map_location=torch.device('cpu')))
        else:
            network2.load_state_dict(torch.load(best_networkstate_path2))
        network2.to(self.device)
        
        total = 0

        with torch.no_grad():
            for data in self.testloader:
                images, file_names = data[0].to(self.device), data[1]
                
                outputs = network(images)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.tolist()
                
                outputs2 = network2(images)
                _, predicted2 = torch.max(outputs2.data, 1)
                predicted2 = predicted2.tolist()
                
                for i, predicted_label in enumerate(predicted):
                    class_name = cl.label_number_to_class1[predicted_label]
                    if class_name == 'valid':
                        class_name = cl.label_number_to_class2[predicted2[i] + 2]
                    classes[total] = class_name
                    proper_classes[total] = cl.label_number_to_class[cl.class_to_label_number[file_names[i].split('/')[0]]]
                    total += 1
        
        labels = list(cl.label_number_to_class2.values())
        conf_mat = confusion_matrix(proper_classes, classes, labels=labels)
        pd.DataFrame(conf_mat, index=labels, columns=labels).to_csv(f'../confusion_matrices/{ntpath.basename(best_networkstate_path[:-4])}_{ntpath.basename(best_networkstate_path2[:-4])}.csv')

In [None]:
batch_size = 32
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
class ConfusionMatrixVisualizer:
    
    @staticmethod
    def visualize(confusion_matrix_path):
        figsize(7, 7)
        if not os.path.exists('../confusion_matrices_pdfs'):
            os.makedirs('../confusion_matrices_pdfs')
        
        matrix = pd.read_csv(confusion_matrix_path, index_col='Unnamed: 0')
        
        index = list(matrix.index)
        cols = list(matrix.columns)
        conf_data = np.array(matrix)

        fig, ax = plt.subplots()
        cmap = plt.cm.get_cmap("PuBuGn", 3000)
        im = ax.imshow(conf_data, cmap=cmap, vmin=0, vmax=200)

        # Show all ticks and label them with the respective list entries.
        ax.set_xticks(np.arange(len(index)), labels=index, fontsize=12)
        ax.set_yticks(np.arange(len(cols)), labels=cols, fontsize=12)

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right',
                rotation_mode='anchor')

        # Loop over data dimensions and create text annotations.
        for i in range(len(cols)):
            for j in range(len(index)):
                col = 'white' if conf_data[i, j] > 100 else 'black'
            
                ax.text(j, i, conf_data[i, j],
                            ha='center', va='center', color=col, fontsize=12)

        ax.set_xlabel('predicted', fontsize=12)
        ax.set_ylabel('true', fontsize=12)
        fig.tight_layout()

        fig_png_path = os.path.join('../confusion_matrices_pdfs', ntpath.basename(confusion_matrix_path[:-4]) + '.pdf')
        if os.path.isfile(fig_png_path):
            os.remove(fig_png_path)
        plt.savefig(fig_png_path)

        plt.show()