# 3D CNN Model Metrics

In [None]:
import os
import random
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath("__file__")), '..')))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

from datasets import DukeDataGenerator, DukeDataset
from models import CNN3D
from settings import Settings

%matplotlib inline


In [None]:
settings = Settings()

np.random.seed(settings.RANDOM_SEED)
random.seed(settings.RANDOM_SEED)
tf.random.set_seed(settings.RANDOM_SEED)

JOB_ID = "23187"
PHENOTYPE = 0
WEIGHTS_PATH = os.path.join(settings.BASE_DATA_DIR, "jobs", JOB_ID, settings.PHENOTYPES[PHENOTYPE], "checkpoints", "weights.h5") 

### Load dataset 

In [None]:
dataset = DukeDataset(settings.DATASET_DIR, crop_size=settings.INPUT_SIZE)
train_generator = DukeDataGenerator(
    settings.DATASET_DIR,
    dataset=dataset,
    stage="train",
    dim=settings.INPUT_SIZE,
    batch_size=settings.BATCH_SIZE,
    positive_class=PHENOTYPE,
    autoencoder=False,
)
test_generator = DukeDataGenerator(
    settings.DATASET_DIR,
    dataset=dataset,
    stage="test",
    dim=settings.INPUT_SIZE,
    batch_size=settings.BATCH_SIZE,
    positive_class=PHENOTYPE,
    autoencoder=False,
)

### Model preparation

In [None]:
input_size = settings.INPUT_SIZE
model = CNN3D(depth=input_size[0], width=input_size[1], height=input_size[2])
model.built = True
model.load_weights(WEIGHTS_PATH)

### Get model predictions on test dataset

In [None]:
y_true = []
y_pred = []

for i in range(len(test_generator)):
    x, y = test_generator[i]
    y_true.extend(y)
    y_pred.extend(model.predict(x).ravel())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

### Calculate and show metrics 

In [None]:
roc_auc = roc_auc_score(y_true, y_pred)

print(f"Phenotype: {settings.PHENOTYPES[PHENOTYPE]} ({PHENOTYPE})\n")
print(f"Accuracy: {accuracy_score(y_true, y_pred > 0.5):.4f}")
print(f"Precision: {precision_score(y_true, y_pred > 0.5):.4f}")
print(f"Recall: {recall_score(y_true, y_pred > 0.5):.4f}")
print(f"AUC: {roc_auc:.4f}")
print(f"F1 Score: {f1_score(y_true, y_pred > 0.5):.4f}")

def calculate_binary_class_weights(labels):
    """
    Calculate class weights for binary classification.

    This function computes the weights for each class in a binary classification
    problem to handle class imbalance. The weights are calculated based on the
    inverse frequency of each class in the training data.

    Parameters:
        labels (array-like): Array of shape (n_samples,) containing the class labels
                             for the training data. Must contain exactly two unique
                             classes.

    Returns:
        dict: A dictionary where keys are the class labels and values are the
              corresponding class weights.

    Raises:
        AssertionError: If the number of unique classes in labels is not equal to 2.
    """
    class_weights = {}
    total_samples = len(labels)
    unique_classes = np.unique(labels)

    assert unique_classes.shape[0] == 2, "Only binary classification is supported"

    for cls in unique_classes:
        n_x = np.sum(labels == cls)
        class_weights[cls] = (1 / n_x) * (total_samples / 2)

    return class_weights

class_weights = calculate_binary_class_weights(y_true)
sample_weights = np.array([class_weights[cls] for cls in y_true])
weighted_auc = roc_auc_score(y_true, y_pred, sample_weight=sample_weights)

print(f"Weighted AUC: {weighted_auc:.4f}")

conf_matrix = confusion_matrix(y_true, y_pred > 0.5)
conf_matrix_df = pd.DataFrame(
    conf_matrix, index=np.unique(y_true), columns=np.unique(y_true)
)

print("\nConfusion Matrix:")
print(conf_matrix_df)

### Plot ROC Curve 

In [None]:
fpr, tpr, _ = roc_curve(y_true, y_pred)
plt.figure()
plt.plot(fpr, tpr, color="blue", lw=2, label=f"ROC curve (AUC = {roc_auc:.4f})")
plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
plt.show() 