In [1]:
import tensorflow as tf
import keras
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import train_test_split
from training.train import load_tfrecord_dataset
import os
import glob
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt

2024-12-04 15:27:46.159022: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-04 15:27:46.166873: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733300866.176159   17853 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733300866.178763   17853 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-04 15:27:46.188021: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
CHECKPOINT_PATH = "/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints"
RECORD_PATH = "/mnt/dat/prepped/apnea_sp02_pr_by_record"
DATASET1 = "/mnt/dat/databases/shhs/datasets/shhs1-dataset-0.21.0.csv"
DATASET2 = "/mnt/dat/databases/shhs/datasets/shhs2-dataset-0.21.0.csv"

In [3]:
trained_cps = glob.glob(os.path.join(CHECKPOINT_PATH, "test*.keras"))
trained_cps = [(int(path[path.rfind("/")+len("test_")+1:path.rfind(".keras")]), path) for path in trained_cps if "_last.keras" not in path]
trained_cps

[(255,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_255.keras'),
 (445,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_445.keras'),
 (411,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_411.keras'),
 (818,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_818.keras'),
 (698,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_698.keras'),
 (529,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_529.keras'),
 (442,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_442.keras'),
 (968,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_968.keras'),
 (743,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_743.keras'),
 (578,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_578.keras'),
 (311,
  '/home/phatdat/Desktop/Sleep-Apnea-Detection/model/checkpoints/test_311.keras'),
 (986,
  '

In [4]:
# stratify split
shhs1_csv = pd.read_csv(DATASET1, usecols=['nsrrid', 'pptid','ahi_a0h3a'])
shhs1_csv['nsrrid'] = "shhs1-" + shhs1_csv['nsrrid'].astype('str')
shhs2_csv = pd.read_csv(DATASET2, usecols=['nsrrid', 'pptid','ahi_a0h3a'], encoding_errors='replace')
shhs2_csv['nsrrid'] = "shhs2-" + shhs2_csv['nsrrid'].astype('str')

csv_df = pd.concat([shhs1_csv, shhs2_csv], ignore_index=True)
csv_df.rename(columns={'nsrrid': 'Record'}, inplace=True)

bins = [-float('inf'), 5, 15, 30, float('inf')]  # Define bins for ranges
labels = ['none', 'mild', 'moderate', 'severe']  # Corresponding labels

csv_df['ahi_label'] = pd.cut(csv_df['ahi_a0h3a'], bins=bins, labels=labels, right=False)

all_record = glob.glob(os.path.join(RECORD_PATH, "*.tfrecord"))
all_record = pd.DataFrame({"Record": [name[name.rfind("/")+1:name.rfind(".tfrecord")] for name in all_record],
                            "Path": all_record})
all_record = pd.merge(all_record, csv_df, how='left', on='Record')
all_record

Unnamed: 0,Record,Path,pptid,ahi_a0h3a,ahi_label
0,shhs2-203332,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs2...,3352.0,4.897959,none
1,shhs1-203561,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,3581.0,5.115207,mild
2,shhs1-202057,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,2064.0,6.486486,mild
3,shhs1-203763,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,3785.0,49.793323,severe
4,shhs1-203230,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,3249.0,3.916449,none
...,...,...,...,...,...
8070,shhs1-205797,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,5832.0,2.099596,none
8071,shhs1-203026,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,3045.0,6.241135,mild
8072,shhs1-204347,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,4381.0,10.458716,mild
8073,shhs1-205387,/mnt/dat/prepped/apnea_sp02_pr_by_record/shhs1...,5422.0,44.426559,severe


In [5]:
def get_true_pred(model, dataset, batch_size, verbose=True):
    y_pred_prob = model.predict(dataset, batch_size=batch_size, verbose=verbose)
    y_true = []
    for _, y in dataset:
        y_true.append(y.numpy())
    y_true = np.vstack(y_true)
    return y_true, y_pred_prob

In [6]:
def sen_spec(y_true, y_pred_probs, threshold=None):
    """
    Calculate the geometric mean (G-mean) for a given threshold.
    """
    # Convert predicted probabilities to binary predictions based on threshold
    if threshold is None:
        y_pred = y_pred_probs
    else:
        y_pred = (y_pred_probs >= threshold).astype(int)
        
    # Compute confusion matrix
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # Compute sensitivity and specificity
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    return sensitivity, specificity

def optimize_threshold(y_true, y_pred_probs):
    """
    Find the threshold that maximizes the G-mean for binary classification.
    """
    best_threshold = -1
    best_gmean = -1
    thresholds = np.linspace(0, 0.1, 101)

    for threshold in thresholds:
        sensitivity, specificity = sen_spec(y_true, y_pred_probs, threshold)
        gmean = np.sqrt(sensitivity * specificity)

        if gmean > best_gmean:
            best_gmean = gmean
            best_threshold = threshold

    return best_threshold, best_gmean

In [7]:
from collections import defaultdict

results = defaultdict(list)

for state, cp_path in tqdm(trained_cps, total=len(trained_cps)):
    shhs2_records = all_record[all_record['Record'].str.startswith('shhs2')]

    test1_records = all_record[all_record['Record'].str.startswith('shhs1')]
    train_records, test2_records = train_test_split(shhs2_records, test_size=0.3, 
                                                    random_state=state, 
                                                    stratify=shhs2_records['ahi_label']) # should use AHI

    # train_records, validation_records = train_test_split(train_records, test_size=0.2, 
    #                                                 random_state=state, 
    #                                                 stratify=train_records['ahi_label']) # should use 
    
    train_set, _ = load_tfrecord_dataset(train_records['Path'].tolist(), batch_size=1024, shuffle=False)
    test1_set, _ = load_tfrecord_dataset(test1_records['Path'].tolist(), batch_size=1024, shuffle=False)
    test2_set, _ = load_tfrecord_dataset(test2_records['Path'].tolist(), batch_size=1024, shuffle=False)
    
    model = keras.models.load_model(cp_path)
    
    y_train_true, y_train_pred_prob = get_true_pred(model, train_set, batch_size=1024, verbose=False)
    y_test1_true, y_test1_pred_prob = get_true_pred(model, test1_set, batch_size=1024, verbose=False)
    y_test2_true, y_test2_pred_prob = get_true_pred(model, test2_set, batch_size=1024, verbose=False)
    
    # best_threshold, best_gmean = optimize_threshold(y_train_true.flatten(), y_train_pred_prob.flatten())
    best_threshold = 0.02
    y_train_pred = (y_train_pred_prob >= best_threshold).astype(int)
    y_test1_pred = (y_test1_pred_prob >= best_threshold).astype(int)
    y_test2_pred = (y_test2_pred_prob >= best_threshold).astype(int)
    
    sen_train, spec_train = sen_spec(y_train_true, y_train_pred)
    sen1, spec1 = sen_spec(y_test1_true, y_test1_pred)
    sen2, spec2 = sen_spec(y_test2_true, y_test2_pred)
    
    metrics_train = classification_report(y_train_true, y_train_pred, output_dict=True)
    metrics1 = classification_report(y_test1_true, y_test1_pred, output_dict=True)
    metrics2 = classification_report(y_test2_true, y_test2_pred, output_dict=True)
    
    rocauc_train = roc_auc_score(y_train_true, y_train_pred)
    rocauc1 = roc_auc_score(y_test1_true, y_test1_pred)
    rocauc2 = roc_auc_score(y_test2_true, y_test2_pred)
    
    precision_train, recall_train, _ = precision_recall_curve(y_train_true, y_train_pred_prob)
    pr_auc_train = auc(recall_train, precision_train)
    precision1, recall1, _ = precision_recall_curve(y_test1_true, y_test1_pred_prob)
    pr_auc1 = auc(recall1, precision1)
    precision2, recall2, _ = precision_recall_curve(y_test2_true, y_test2_pred_prob)
    pr_auc2 = auc(recall2, precision2)
    
    results['set'].append("SHHS2 Train")
    results['state'].append(state)
    results['threshold'].append(best_threshold)
    results['acc'].append(metrics_train['accuracy'])
    results['sensitivity'].append(sen_train)
    results['specificity'].append(spec_train)
    results['f1-score'].append(metrics_train['1']['f1-score'])
    results['precision (PPV)'].append(metrics_train['1']['precision'])
    results['prauc'].append(pr_auc_train)
    results['rocauc'].append(rocauc_train)
    
    results['set'].append("SHHS2 Test")
    results['state'].append(state)
    results['threshold'].append(best_threshold)
    results['acc'].append(metrics2['accuracy'])
    results['sensitivity'].append(sen2)
    results['specificity'].append(spec2)
    results['f1-score'].append(metrics2['1']['f1-score'])
    results['precision (PPV)'].append(metrics2['1']['precision'])
    results['prauc'].append(pr_auc2)
    results['rocauc'].append(rocauc2)
    
    results['set'].append("SHHS1 Test")
    results['state'].append(state)
    results['threshold'].append(best_threshold)
    results['acc'].append(metrics1['accuracy'])
    results['sensitivity'].append(sen1)
    results['specificity'].append(spec1)
    results['f1-score'].append(metrics1['1']['f1-score'])
    results['precision (PPV)'].append(metrics1['1']['precision'])
    results['prauc'].append(pr_auc1)
    results['rocauc'].append(rocauc1)

    break # one run only

results = pd.DataFrame(results)
results

  0%|          | 0/25 [00:00<?, ?it/s]I0000 00:00:1733300867.594321   17853 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6808 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2070 SUPER, pci bus id: 0000:01:00.0, compute capability: 7.5
2024-12-04 15:27:48.087486: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:370] TFRecordDataset `buffer_size` is unspecified, default to 262144
I0000 00:00:1733300868.115458   17853 cuda_dnn.cc:529] Loaded cuDNN version 90300
2024-12-04 15:27:59.935084: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
  self.gen.throw(typ, value, traceback)
2024-12-04 15:28:11.692572: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-12-04 15:29:12.626938: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End 

Unnamed: 0,set,state,threshold,acc,sensitivity,specificity,f1-score,precision (PPV),prauc,rocauc
0,SHHS2 Train,255,0.02,0.842126,0.901001,0.840625,0.221002,0.125947,0.395291,0.870813
1,SHHS2 Test,255,0.02,0.845971,0.893352,0.844833,0.213852,0.121464,0.372727,0.869093
2,SHHS1 Test,255,0.02,0.881439,0.738227,0.885984,0.276974,0.170465,0.370917,0.812105


In [8]:
results.T

Unnamed: 0,0,1,2
set,SHHS2 Train,SHHS2 Test,SHHS1 Test
state,255,255,255
threshold,0.02,0.02,0.02
acc,0.842126,0.845971,0.881439
sensitivity,0.901001,0.893352,0.738227
specificity,0.840625,0.844833,0.885984
f1-score,0.221002,0.213852,0.276974
precision (PPV),0.125947,0.121464,0.170465
prauc,0.395291,0.372727,0.370917
rocauc,0.870813,0.869093,0.812105
