In [1]:
import tensorflow as tf

In [2]:
from etl.load_dataset import DatasetProcessor

target_dir = '../test_files/EGGIMazing/Dataset'
batch_size = 16
num_epochs = 10
learning_rate = 1e-4

dp = DatasetProcessor(target_dir)
df = dp.process()
# df = df[~df.isna().any(axis=1)].reset_index(drop=True)
X, y = df['image_directory'], df['eggim_square']

In [4]:
from custom_models.cnns import simple_cnn
from etl.load_dataset import get_tf_eggim_patch_dataset

split = dp.stratified_k_splits(X, y, k=1, train_size=0.8, val_size=0.1, test_size=0.1, random_state=42)
train_idx, val_idx, test_idx = next(split)
# df_train = df.loc[train_idx]
tf_train_df = get_tf_eggim_patch_dataset(df.loc[train_idx], num_classes=3)
tf_val_df = get_tf_eggim_patch_dataset(df.loc[val_idx], num_classes=3)
tf_test_df = get_tf_eggim_patch_dataset(df.loc[test_idx], num_classes=3)

tf_train_df = tf_train_df.batch(batch_size)
tf_val_df = tf_val_df.batch(batch_size)
tf_test_df = tf_test_df.batch(batch_size)

n_classes = 3  # Replace with the number of classes you have
model = simple_cnn(input_shape=(224, 224, 3), n_classes=n_classes)

In [50]:
checkpoint_dir = '../test_scripts/test_simple_cnn_20240812-143310'
model.load_weights(f'{checkpoint_dir}/weights.h5')


In [51]:
import numpy as np
y_true = np.concatenate([y for (_, y) in tf_test_df])
y_pred = model.predict(tf_test_df)



In [52]:
y_true_ordinal = np.argmax(y_true, axis=-1) # [0 0 1] -> 2
y_pred_ordinal = np.argmax(y_pred, axis=-1)
y_pred_one_hot = np.zeros_like(y_pred)
y_pred_one_hot[np.arange(len(y_pred)), np.argmax(y_pred, axis=1)] = 1 # [0.2, 0.2, 0.6] -> [0, 0, 1]

In [53]:
from sklearn.metrics import precision_score, recall_score, roc_auc_score, confusion_matrix


def categorical_accuracy(y_true, y_pred):
    return np.mean(np.all(y_true == y_pred, axis=1))

accuracy = categorical_accuracy(y_true, y_pred_one_hot)
print(f'Categorical Accuracy: {accuracy:.4f}')


# Precision and Recall
precision = precision_score(y_true_ordinal, y_pred_ordinal, average='weighted')
recall = recall_score(y_true_ordinal, y_pred_ordinal, average='weighted')

print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')

# AUC
auc = roc_auc_score(y_true, y_pred, multi_class='ovr')
print(f'AUC: {auc:.4f}')

# Confusion Matrix
conf_matrix = confusion_matrix(y_true_ordinal, y_pred_ordinal)
print('Confusion Matrix:')
print(conf_matrix)

Categorical Accuracy: 0.8095
Precision: 0.8363
Recall: 0.8095
AUC: 0.8180
Confusion Matrix:
[[13  0  0]
 [ 2  2  1]
 [ 1  0  2]]


In [54]:
def categorical_accuracy(y_true, y_pred):
    return np.mean(np.all(y_true == y_pred, axis=1))

def specificity_per_class(conf_matrix):
    specificity_per_class = []
    # Number of classes
    num_classes = conf_matrix.shape[0]

    for i in range(num_classes):
        # True Positives for class i
        TP = conf_matrix[i, i]

        # False Positives for class i
        FP = np.sum(conf_matrix[:, i]) - TP

        # False Negatives for class i
        FN = np.sum(conf_matrix[i, :]) - TP

        # True Negatives for class i
        TN = np.sum(conf_matrix) - (TP + FP + FN)

        # Specificity for class i
        specificity = TN / (TN + FP)
        specificity_per_class.append(specificity)
    specificity_per_class = np.array(specificity_per_class)
    return specificity_per_class, np.mean(specificity_per_class)

In [55]:
def sensitivity_per_class(conf_matrix):
    # Initialize an array to store sensitivity for each class
    sensitivity_per_class = []

    # Number of classes
    num_classes = conf_matrix.shape[0]

    for i in range(num_classes):
        # True Positives for class i
        TP = conf_matrix[i, i]

        # False Negatives for class i
        FN = np.sum(conf_matrix[i, :]) - TP

        # Sensitivity for class i
        sensitivity = TP / (TP + FN)
        sensitivity_per_class.append(sensitivity)
    return np.array(sensitivity_per_class), np.mean(sensitivity_per_class)

In [56]:
specificity_per_class(conf_matrix), sensitivity_per_class(conf_matrix)

((array([0.625     , 1.        , 0.94444444]), 0.8564814814814815),
 (array([1.        , 0.4       , 0.66666667]), 0.6888888888888888))