# Importing Libraries and Modules

In [None]:
import os
import glob

import pandas as pd
import numpy as np

from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score

import matplotlib.pyplot as plt

# Defining Utility Functions

This cell defines several utility functions used later in the notebook for calculating metrics (get_metrics_for_class), plotting confusion matrix (plot_confusion_matrix), plotting metrics (plot_metrics), and plotting metrics for all classes (plot_metrics_for_all_classes).

In [19]:
PATH_MODEL_LABELS = '../models_labels/'
PATH4IMAGES = '../../2_training/data/images/'

def get_metrics_for_class(y_true, y_pred, target_class):
    """
    Computes the True Positive, False Positive, False Negative, and True Negative
    for a given class in a classification task.

    Parameters:
        y_true (array-like): The true labels of the data.
        y_pred (array-like): The predicted labels of the data.
        target_class (int): The class for which the metrics are computed.

    Returns:
        tp (int): Number of True Positives.
        fp (int): Number of False Positives.
        fn (int): Number of False Negatives.
        tn (int): Number of True Negatives.

    Raises:
        ValueError: If the target_class is out of range of classes in the confusion matrix.
    """

    cm = confusion_matrix(y_true, y_pred)
    if target_class >= len(cm):
        raise ValueError("Target class is out of range.")
    
    tp = cm[target_class, target_class]
    fp = cm[:, target_class].sum() - tp
    fn = cm[target_class, :].sum() - tp
    tn = cm.sum() - fp - fn - tp

    return tp, fp, fn, tn

def plot_confusion_matrix(data, ax, title):
    """
    Plots a confusion matrix on a given axis object with specific title.
    
    Parameters:
        data (array-like): 2D array representing the confusion matrix.
        ax (matplotlib.axes._axes.Axes): Matplotlib axis object to plot the confusion matrix on.
        title (str): Title for the confusion matrix plot.

    Notes:
        - The axis ticks are turned off in this function.
        - Text colors in each cell are determined by the cell's value relative to the sum of TP and TN values.
    """

    labels = ['TP', 'FP', 'FN', 'TN']
    im = ax.imshow(data, interpolation='nearest', cmap=plt.cm.Blues)
    ax.set_title(title, fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])

    for i in range(2):
        for j in range(2):
            ax.text(j, i, f"{labels[2*i + j]}: {data[i][j]}",
                     ha="center", va="center", fontsize=16,
                     color="white" if data[i][j] > (data[0][0] + data[1][1]) / 2. else "black")

            
def plot_metrics(models, model_names, metrics, label):
    """
    Plots multiple bars representing scores of various models for different metrics.
    
    Parameters:
        models (list): List of lists representing scores for each model.
        model_names (list): List of strings representing the names of the models.
        metrics (list): List of strings representing the names of the metrics.
        label (str): Label used for saving the plot.
        
    Notes:
        - Each bar in the plot is color-coded and labeled according to the model it represents.
        - The plot is saved as an image file with specified dpi.
    """

    barWidth = 0.15
    r = list(range(len(metrics)))
    colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'grey']
    
    plt.clf()
    plt.figure(figsize=(8, 4))
    for i, model in enumerate(models):
        plt.bar([x + barWidth*i for x in r], model, width=barWidth, label=model_names[i], color=colors[i])

    plt.ylabel('Score')
    plt.xlabel('Metrics')
    plt.grid(True, ls='--', color='gray', alpha=0.4)
    plt.xticks([r + barWidth for r in range(len(models[0]))], metrics)
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
    plt.savefig(f'{PATH4IMAGES}{label}.jpg', format='jpg', dpi=350)
    

def plot_metrics_for_all_classes(y_true, y_pred, model_name, labels):
    """
    Plots precision, recall, and F1 score for all classes.
    
    Parameters:
        y_true (array-like): The true labels of the data.
        y_pred (array-like): The predicted labels of the data.
        model_name (str): The name of the model used for prediction.
        labels (list): List of unique labels in the data.

    Notes:
        - The function calculates metrics specifically for each class and plots them in separate subplots.
        - The title of each subplot includes the name and the count of the corresponding class in true labels.
        - The plot is saved as an image file with specified dpi.
    """
    metrics_names = ['Precision', 'Recall', 'F1 Score']
        
    replacement_dict_inverted = {v: k for k, v in replacement_dict.items()} 
    n_rows = 2
    n_cols = 4
    
    accuracy = accuracy_score(y_true, y_pred)
    value_counts = y_true.value_counts()

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(16, 8))
    fig.suptitle(f'{model_name}, accuracy: {accuracy:.2f}', fontsize=16)
    
    for idx, class_label in enumerate(labels):
        row = idx // n_cols
        col = idx % n_cols

        # Calculate metrics for the specific class
        
        precision = precision_score(y_true, y_pred, labels=[class_label], average='macro', zero_division=0)
        recall = recall_score(y_true, y_pred, labels=[class_label], average='macro', zero_division=0)
        f1 = f1_score(y_true, y_pred, labels=[class_label], average='macro', zero_division=0)

        metrics_scores = [precision, recall, f1]

        axs[row, col].bar(metrics_names, metrics_scores, color='b', alpha=0.7, label=f'Class {class_label}')
        axs[row, col].set_ylim([0, 1.1])
        for i, v in enumerate(metrics_scores):
            axs[row, col].text(i, v + 0.02, f"{v:.2f}", color='black', ha='center')
        axs[row, col].set_title(f'Metrics for {replacement_dict_inverted[class_label]} ({value_counts.get(class_label)})')
        axs[row, col].set_ylabel('Score')
        axs[row, col].grid(True, ls='--', color='gray', alpha=0.6)
        
    plt.tight_layout()
    plt.savefig(f'{PATH4IMAGES}{model_name}_metrics_across_classes.jpg', format='jpg', dpi=350)

# Path Initialization and File Retrieval

This cell initializes paths to the directory of model labels and images and retrieves a sorted list of specific CSV files located in the model labels directory.

In [None]:
list_files = sorted(glob.glob(f'{PATH_MODEL_LABELS}*.csv'))
list_files

# Loading Actual Labels

This cell loads the actual labels of the test dataset from a .npy file located at a specified path.

In [None]:
path = '../../2_training/data/datasets/test_dataset_labels.npy'
labels = np.load(path)

# DataFrame Creation for Actual Labels

This cell creates a pandas DataFrame from the loaded labels, naming the column as actual_label.

In [None]:
actual_labels = pd.DataFrame(labels, columns=['actual_label'])

# Label Replacement Dictionary Initialization

This cell initializes a dictionary for replacing string class labels with corresponding numerical identifiers.

In [None]:
replacement_dict = {
    'None': 0,
    'Pulse': 1,
    'BBRFI': 2,
    'NBRFI': 3,
    'Pulse+BBRFI': 4,
    'Pulse+NBRFI': 5,
    'NBRFI+BBRFI': 6,
    'Pulse+NBRFI+BBRFI': 7
}

# Replacing String Labels with Numerical Identifiers

This cell replaces the string labels in the actual_label column of the DataFrame with the numerical identifiers using the previously initialized dictionary.

In [None]:
actual_labels['actual_label'] = actual_labels['actual_label'].replace(replacement_dict)

# Model Evaluation and Visualization

This cell evaluates each model using actual and predicted labels, calculates several metrics, and visualizes the confusion matrix for each class. It saves the metrics and the plots in specified paths. 

In [None]:
title_list = [
    ['None', 'Pulse', 'BBRFI', 'NBRFI'],
    ['Pulse+BBRFI', 'Pulse+NBRFI', 'BBRFI+NBRFI', 'Pulse+BBRFI+NBRFI']
]
    

key_metrics_across_models = {}

for file in list_files:
    model_lables = pd.read_csv(file)[['models_label']]
    model_name = os.path.basename(file).split('-')[0]
    model_name = model_name[model_name.index('_p')+1:]
    
    table_for_comparison = actual_labels.merge(model_lables, left_index=True, right_index=True)
    
    data = []
    
    for i in range(len(replacement_dict)):
        y_true = table_for_comparison.actual_label
        y_pred = table_for_comparison.models_label

        tp, fp, fn, tn = get_metrics_for_class(y_true, y_pred, i)

        data.append([
            [tp, fp],
            [fn, tn]
        ])
    
    accuracy = accuracy_score(y_true, y_pred)

    precision_micro = precision_score(y_true, y_pred, average='micro', zero_division=0)
    recall_micro = recall_score(y_true, y_pred, average='micro', zero_division=0)
    F1_score_micro = f1_score(y_true, y_pred, average='micro', zero_division=0)

    precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
    F1_score_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)

    key_metrics_across_models[model_name] = [
        [accuracy, precision_micro, recall_micro, F1_score_micro],
        [precision_macro, recall_macro, F1_score_macro]
    ]
    
    
            
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))
    fig.suptitle(model_name, fontsize=24, y=1.)
    for i in range(2):
        for j in range(4):
            plot_confusion_matrix(data[i*4 + j], axs[i, j], title_list[i][j])

    plt.tight_layout()
    plt.savefig(f'{PATH4IMAGES}{model_name}_confusions_metrics.jpg', format='jpg', dpi=350)

# Micro-average Metrics Visualization

This cell visualizes and saves the bar plot of micro-average metrics like accuracy, recall, precision, and F1 score for each model using the plot_metrics function.



In [None]:
metrics = ['Accuracy', 'Recall', 'Precision', 'F1 score']

plot_metrics([i[0] for i in key_metrics_across_models.values()], list(key_metrics_across_models.keys()), metrics, 'key_metrics_micro')

# Macro-average Metrics Visualization

This cell visualizes and saves the bar plot of macro-average metrics like recall, precision, and F1 score for each model using the plot_metrics function.

In [None]:
metrics = ['Recall', 'Precision', 'F1 score']

plot_metrics([i[1] for i in key_metrics_across_models.values()], list(key_metrics_across_models.keys()), metrics, 'key_metrics_macro')

# Metrics Calculation and Visualization for all Classes

This cell calculates and plots precision, recall, and F1 score for all classes for each model, utilizing the plot_metrics_for_all_classes function. The plots are saved with specified filenames.

In [None]:
for file in list_files:
    model_lables = pd.read_csv(file)[['models_label']]
    model_name = os.path.basename(file).split('-')[0]
    model_name = model_name[model_name.index('prot'):]
    
    table_for_comparison = actual_labels.merge(model_lables, left_index=True, right_index=True)

    y_true = table_for_comparison.actual_label
    y_pred = table_for_comparison.models_label

    plot_metrics_for_all_classes(y_true, y_pred, model_name, range(len(replacement_dict)))