In [1]:
#| default_exp evaluation

In [15]:
#|export
from sklearn.metrics import accuracy_score, confusion_matrix
import math
import matplotlib.pyplot as plt
import seaborn as sns

In [16]:
#|export

def calculate_accuracy(y_test, y_preds):
    accuracy = accuracy_score(y_test, y_preds)
    return accuracy

In [17]:
#|export

def evaluate(y_test, y_preds, print_cnfm=False):

    # Accuracy sklearn
    accuracy_sklearn = calculate_accuracy(y_test, y_preds)
    
    # Confusion matrix
    conf_matrix = confusion_matrix(y_test, y_preds)
    
    # Accuracy custom
    accuracy_custom = sum(conf_matrix.diagonal())/sum(conf_matrix.flatten())
    
    # Class recall
    recall_crop = conf_matrix[0][0]/conf_matrix[0].flatten().sum()
    recall_weed = conf_matrix[1][1]/conf_matrix[1].flatten().sum()

    # Class precision
    precision_crop = conf_matrix[0][0]/conf_matrix[:,0].sum()
    if math.isnan(precision_crop):
        precision_crop = 0
    precision_weed = conf_matrix[1][1]/conf_matrix[:,1].sum()

    # Class f1 score
    f1_score_crop = (2*precision_crop*recall_crop)/(precision_crop+recall_crop)
    if math.isnan(f1_score_crop):
        f1_score_crop = 0
    f1_score_weed = (2*precision_weed*recall_weed)/(precision_weed+recall_weed)
    
    
    # F1 macro score
    f1_macro = (f1_score_crop + f1_score_weed) / 2
    
    if print_cnfm == True:
        plt.figure(figsize=(8, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues', cbar=False, 
            xticklabels=["crop", "weed"], yticklabels=["crop", "weed"])

        plt.title('Confusion Matrix Heatmap')
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        plt.show()
    
    
    return {
        "Accuracy sklearn": accuracy_sklearn,
        "Accuracy custom": accuracy_custom,
        "Conf Matrix": conf_matrix,
        "Recall crop": recall_crop,
        "Recall weed": recall_weed,
        "Precision crop": precision_crop,
        "Precision weed": precision_weed,
        "F1 crop": f1_score_crop,
        "F1 weed": f1_score_weed,
    }