In [5]:
import tensorflow as tf
import keras
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split
from training.train import load_tfrecord_dataset
import os
import glob
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

In [6]:
CHECKPOINT_PATH = "/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/bootstrap"
DATA_PATH = "/mnt/dat/prepped/apnea_sp02_pr_bootstrap"

In [9]:
checkpoints = glob.glob(os.path.join(CHECKPOINT_PATH, "*.keras"))
bootstrap_part = glob.glob(os.path.join(DATA_PATH, "train*.tfrecord"))

len(bootstrap_part), len(checkpoints)

(41, 41)

In [7]:
def get_true_pred(model, dataset, batch_size, verbose=True):
    y_pred_prob = model.predict(dataset, batch_size=batch_size, verbose=verbose)
    y_true = []
    for _, y in dataset:
        y_true.append(y.numpy())
    y_true = np.vstack(y_true)
    return y_true, y_pred_prob

In [8]:
def sen_spec(y_true, y_pred_probs, threshold=None):
    """
    Calculate the geometric mean (G-mean) for a given threshold.
    """
    # Convert predicted probabilities to binary predictions based on threshold
    if threshold is None:
        y_pred = y_pred_probs
    else:
        y_pred = (y_pred_probs >= threshold).astype(int)
        
    # Compute confusion matrix
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # Compute sensitivity and specificity
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    return sensitivity, specificity

def optimize_threshold(y_true, y_pred_probs):
    """
    Find the threshold that maximizes the G-mean for binary classification.
    """
    best_threshold = -1
    best_gmean = -1
    thresholds = np.linspace(0, 0.1, 101)

    for threshold in thresholds:
        sensitivity, specificity = sen_spec(y_true, y_pred_probs, threshold)
        gmean = np.sqrt(sensitivity * specificity)

        if gmean > best_gmean:
            best_gmean = gmean
            best_threshold = threshold

    return best_threshold, best_gmean

In [16]:
test1_path = os.path.join(DATA_PATH, "test_1.tfrecord")
test2_path = os.path.join(DATA_PATH, "test_2.tfrecord")


test1_set, _ = load_tfrecord_dataset(test1_path, batch_size=1024, shuffle=False)
test2_set, _ = load_tfrecord_dataset(test2_path, batch_size=1024, shuffle=False)

In [17]:
shhs1_res = []
shhs2_res = []

for cp_path in tqdm(checkpoints, total=len(checkpoints)):
    model = keras.models.load_model(cp_path)
    
    y_test1_true, y_test1_pred_prob = get_true_pred(model, test1_set, batch_size=1024, verbose=False)
    y_test2_true, y_test2_pred_prob = get_true_pred(model, test2_set, batch_size=1024, verbose=False)
    
    shhs1_res.append((y_test1_true, y_test1_pred_prob))
    shhs2_res.append((y_test2_true, y_test2_pred_prob))

shhs1_res = np.array(shhs1_res)
shhs2_res = np.array(shhs2_res)

shhs1_res.shape, shhs2_res.shape

  0%|          | 0/41 [00:00<?, ?it/s]2024-12-06 09:44:31.871433: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144
I0000 00:00:1733453071.936681    7968 cuda_dnn.cc:529] Loaded cuDNN version 90300
2024-12-06 09:45:03.723041: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
  self.gen.throw(typ, value, traceback)
2024-12-06 09:45:35.565177: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-12-06 09:45:46.060462: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
  2%|▏         | 1/41 [01:14<49:32, 74.30s/it]2024-12-06 09:46:58.488345: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
  7%|▋         | 3/41 [03:39<4

((41, 2, 5344888, 1), (41, 2, 875086, 1))

In [None]:
shhs1_res = np.mean(shhs1_res, axis=0)
shhs2_res = np.mean(shhs2_res, axis=0)


In [20]:
shhs1_res = np.squeeze(shhs1_res)
shhs2_res = np.squeeze(shhs2_res)
shhs1_res.shape, shhs2_res.shape

((2, 5344888), (2, 875086))

In [32]:
y_test1_true = shhs1_res[0].astype(int)
y_test1_pred_prob = shhs1_res[1]
y_test1_pred = np.round(y_test1_pred_prob).astype(int)

y_test2_true = shhs2_res[0].astype(int)
y_test2_pred_prob = shhs2_res[1]
y_test2_pred = np.round(y_test2_pred_prob).astype(int)

y_test1_true.shape, y_test1_pred.shape, y_test2_true.shape, y_test2_pred.shape

((5344888,), (5344888,), (875086,), (875086,))

In [33]:
from collections import defaultdict
results = defaultdict(list)

sen1, spec1 = sen_spec(y_test1_true, y_test1_pred)
sen2, spec2 = sen_spec(y_test2_true, y_test2_pred)

metrics1 = classification_report(y_test1_true, y_test1_pred, output_dict=True)
metrics2 = classification_report(y_test2_true, y_test2_pred, output_dict=True)

rocauc1 = roc_auc_score(y_test1_true, y_test1_pred)
rocauc2 = roc_auc_score(y_test2_true, y_test2_pred)

precision1, recall1, _ = precision_recall_curve(y_test1_true, y_test1_pred_prob)
pr_auc1 = auc(recall1, precision1)
precision2, recall2, _ = precision_recall_curve(y_test2_true, y_test2_pred_prob)
pr_auc2 = auc(recall2, precision2)


results['set'].append("SHHS2 Test")
results['acc'].append(metrics2['accuracy'])
results['sensitivity'].append(sen2)
results['specificity'].append(spec2)
results['f1-score'].append(metrics2['1']['f1-score'])
results['precision (PPV)'].append(metrics2['1']['precision'])
results['prauc'].append(pr_auc2)
results['rocauc'].append(rocauc2)

results['set'].append("SHHS1 Test")
results['acc'].append(metrics1['accuracy'])
results['sensitivity'].append(sen1)
results['specificity'].append(spec1)
results['f1-score'].append(metrics1['1']['f1-score'])
results['precision (PPV)'].append(metrics1['1']['precision'])
results['prauc'].append(pr_auc1)
results['rocauc'].append(rocauc1)

results = pd.DataFrame(results)
results.T

Unnamed: 0,0,1
set,SHHS2 Test,SHHS1 Test
acc,0.882886,0.910864
sensitivity,0.848186,0.685779
specificity,0.883814,0.918008
f1-score,0.273913,0.321267
precision (PPV),0.163329,0.209769
prauc,0.398854,0.365969
rocauc,0.866,0.801893
