In [1]:
import os
import sys
import math
import json
import numpy as np
from sklearn.metrics import roc_auc_score
from tensorflow import keras
sys.path.append('C:\\Users\\Dell\\Desktop\\CV Projects\\ecg')

import load
import network
import utilities

In [2]:
base_dir = 'C:\\Users\\Dell\\Desktop\\CV Projects\\ecg'

In [3]:
model_path = os.path.join(base_dir, 'saved/ecg_experiment_01/090607/0.534-0.954-011-0.430-0.966.hdf5')
data_json = os.path.join(base_dir, 'data/val.json')
config_file = os.path.join(base_dir, 'config.json')

In [4]:
preprocessor = utilities.load(os.path.dirname(model_path))
dataset = load.load_dataset(data_json)
ecgs, labels = preprocessor.process(*dataset)

100%|████████████████████████████████████████████████████████████████████████████████| 852/852 [00:24<00:00, 34.58it/s]


In [5]:
params = json.load(open(config_file, 'r'))

In [6]:
params.update({
        'input_shape' : [None, 1],
        'num_categories' : len(preprocessor.classes)
    })

In [7]:
model = network.build_network(**params)
model.load_weights(model_path)

In [8]:
probs = model.predict(ecgs, verbose=0)
probs.shape

(852, 70, 4)

In [9]:
#probs[-1]

array([[1.7419802e-02, 2.4686202e-02, 9.5464623e-01, 3.2478312e-03],
       [2.0965284e-02, 3.8706653e-02, 9.3614364e-01, 4.1843918e-03],
       [3.1738374e-02, 3.7591409e-02, 9.2480069e-01, 5.8695506e-03],
       [3.5588272e-02, 3.9788108e-02, 9.1911137e-01, 5.5123335e-03],
       [2.2712620e-02, 3.8828801e-02, 9.3385828e-01, 4.6002222e-03],
       [5.7084411e-02, 4.9791113e-02, 8.8885355e-01, 4.2709075e-03],
       [4.0995199e-02, 6.9723375e-02, 8.8680220e-01, 2.4792536e-03],
       [8.6958311e-02, 5.5776220e-02, 8.5447103e-01, 2.7944162e-03],
       [1.5502419e-02, 3.3474285e-02, 9.5022768e-01, 7.9564919e-04],
       [3.2465514e-02, 4.6434678e-02, 9.1966367e-01, 1.4361503e-03],
       [4.1637812e-02, 3.5995558e-02, 9.2202687e-01, 3.3973585e-04],
       [9.7986721e-03, 2.8186454e-02, 9.6184003e-01, 1.7487805e-04],
       [1.0411953e-02, 2.3781091e-02, 9.6549082e-01, 3.1617726e-04],
       [1.4473261e-02, 3.3052765e-02, 9.5227742e-01, 1.9654511e-04],
       [1.7856143e-02, 1.0482879e-

In [58]:
def class_full_name(cname):
    if cname == 'A':
        return 'Atrial Fibrillation'
    elif cname == 'N':
        return 'Normal Sinus Rythym'
    elif cname == 'O':
        return 'Other Rythym'
    elif cname == '~':
        return 'Noise'


def stats(ground_truth, preds):
    labels = range(ground_truth.shape[2])
    g = np.argmax(ground_truth, axis=2).ravel()
    p = np.argmax(preds, axis=2).ravel()
    stat_dict = {}
    for i in labels:
        tp = np.sum(g[g==i] == p[g==i])
        fp = np.sum(g[p==i] != p[p==i])
        fn = np.sum(g==i) - tp
        tn = np.sum(g!=i) - fp
        stat_dict[i] = (tp, fp, fn, tn)
    return stat_dict


def to_set(preds):
    idxs = np.argmax(preds, axis=2)
    return [list(set(r)) for r in idxs]


def set_stats(ground_truth, preds):
    labels = range(ground_truth.shape[2])
    ground_truth = to_set(ground_truth)
    preds = to_set(preds)
    stat_dict = {}
    for x in labels:
        tp = 0; fp = 0; fn = 0; tn = 0;
        for g, p in zip(ground_truth, preds):
            if x in g and x in p:
                tp += 1
            elif x not in g and x in p:
                fp +=1
            elif x in g and x not in p:
                fn += 1
            elif x not in g and x not in p:
                tn += 1
        stat_dict[x] = (tp, fp, fn, tn)
    return stat_dict


def compute_f1(tp, fp, fn, tn):
    precision = tp / float(tp + fp)
    recall = tp / float(tp + fn)
    specificity = tn / float(tn + fp)
    npv = tn / float(tn + fn)
    f1 = 2 * precision * recall / (precision + recall)
    return f1, tp + fn


def print_results(seq_stat, set_stat):
    print('\t\t Seq F1    Set F1')
    seq_tf1 = 0; seq_tot = 0
    set_tf1 = 0; set_tot = 0
    
    for k, v in seq_stat.items():
        set_f1, n = compute_f1(*set_stat[k])
        set_tf1 += n * set_f1
        set_tot += n
        seq_f1, n = compute_f1(*v)
        seq_tf1 += n * seq_f1
        seq_tot += n
        print ('{:<16} {:10.3f} {:10.3f}'.format(
            class_full_name(preprocessor.classes[k]), seq_f1, set_f1))
    print ('{:<16} {:10.3f} {:10.3f}'.format(
        "Average", seq_tf1 / float(seq_tot), set_tf1 / float(set_tot)))
    
    
def c_statistic_with_95p_confidence_interval(cstat, num_positives, num_negatives, z_alpha_2=1.96):
    """
    Calculates the confidence interval of an ROC curve (c-statistic), using the method described
    under "Confidence Interval for AUC" here:
      https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/PASS/Confidence_Intervals_for_the_Area_Under_an_ROC_Curve.pdf
    Args:
        cstat: the c-statistic (equivalent to area under the ROC curve)
        num_positives: number of positive examples in the set.
        num_negatives: number of negative examples in the set.
        z_alpha_2 (optional): the critical value for an N% confidence interval, e.g., 1.96 for 95%,
            2.326 for 98%, 2.576 for 99%, etc.
    Returns:
        The 95% confidence interval half-width, e.g., the Y in X ± Y.
    """
    q1 = cstat / (2 - cstat)
    q2 = 2 * cstat**2 / (1 + cstat)
    numerator = cstat * (1 - cstat) \
        + (num_positives - 1) * (q1 - cstat**2) \
        + (num_negatives - 1) * (q2 - cstat**2)
    standard_error_auc = math.sqrt(numerator / (num_positives * num_negatives))
    return z_alpha_2 * standard_error_auc


def roc_auc(ground_truth, probs, index):
    gts = np.argmax(ground_truth, axis=2)
    
    n_gts = np.zeros_like(gts)
    n_gts[gts==index] = 1
    num_pos = np.sum(n_gts == 1)
    num_neg = n_gts.size - num_pos
    
    n_probs = probs[..., index].squeeze()
    
    n_gts, n_probs = n_gts.ravel(), n_probs.ravel()
    
    return num_pos, num_neg, roc_auc_score(n_gts, n_probs)
    
    
def roc_auc_set(ground_truth, probs, index):
    gts = np.argmax(ground_truth, axis=2)
    max_probs = np.max(probs[..., index], axis=1)
    max_gts = np.any(gts==index, axis=1)
    pos = np.sum(max_gts)
    neg = max_gts.size - pos
    return pos, neg, roc_auc_score(max_gts, max_probs)


def print_aucs(ground_truth, probs):
    seq_tauc = 0.0; seq_tot = 0.0
    set_tauc = 0.0; set_tot = 0.0
    print ('\t\t\t        AUC')
    for idx, cname in preprocessor.int_to_class.items():
        cname = class_full_name(cname)
        pos, neg, seq_auc = roc_auc(ground_truth, probs, idx)
        seq_tot += pos
        seq_tauc += pos * seq_auc
        seq_conf = c_statistic_with_95p_confidence_interval(seq_auc, pos, neg)
        pos, neg, set_auc = roc_auc_set(ground_truth, probs, idx)
        set_tot += pos
        set_tauc += pos * set_auc
        set_conf = c_statistic_with_95p_confidence_interval(set_auc, pos, neg)
        print ("{: <16}\t{:.3f} ({:.3f}-{:.3f})\t{:.3f} ({:.3f}-{:.3f})".format(
            cname, seq_auc, seq_auc-seq_conf,seq_auc+seq_conf,
            set_auc, set_auc-set_conf, set_auc+set_conf))
    print ('Average\t\t\t{:.3f}\t\t\t{:.3f}'.format(seq_tauc/seq_tot, set_tauc/set_tot))

In [59]:
print_results(stats(labels, probs), set_stats(labels, probs))
print('\n')
print_aucs(labels, probs)

		 Seq F1    Set F1
Atrial Fibrillation      0.741      0.764
Normal Sinus Rythym      0.789      0.853
Other Rythym          0.600      0.515
Noise                 0.863      0.994
Average               0.795      0.869


			        AUC
Atrial Fibrillation	0.986 (0.983-0.989)	0.981 (0.959-1.003
Normal Sinus Rythym	0.928 (0.925-0.930)	0.891 (0.869-0.912
Other Rythym    	0.864 (0.859-0.869)	0.875 (0.844-0.905
Noise           	0.986 (0.985-0.987)	0.999 (0.998-1.001
Average			0.949			0.945
