In [None]:
import numpy as np
from sklearn.metrics import roc_curve, auc
import models
import torch

from tqdm import tqdm
import torch.nn.functional as F
import os
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
from pthflops import count_ops

import pandas as pd
from ECG import ECG_loader
def get_model(depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs=0):
    model = models.EffNet(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion,num_additional_features=additional_inputs)
    print('parameters: ' +str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    model.eval()
    return model

def point(y_total,yhat_total,t):
    specificity = []
    sensitivity = []
    for i in tqdm(t):
        tn, fp, fn, tp = confusion_matrix(y_total, yhat_total>i).ravel()
        specificity.append( tn / (tn+fp) )
        sensitivity.append( tp / (tp+fn) )
    return t[(np.array(specificity) + np.array(sensitivity) - 1).argmax()]
def thresholded_output_transform(yhat,y):
    y_pred, y = yhat,y
    y_pred = torch.sigmoid(y_pred)
    return y_pred, y
def produce_df(test_loader, checkpoint,folder = 'Test Results', depth = [1,2,2,3,3,3,3],channels = [32,16,24,40,80,112,192,320,1280],dilation = 1,stride = 2,expansion = 6,additional_inputs = 0,ds = None):
    
    # load model and checkpoint
    model = get_model(depth = depth,channels = channels,dilation = dilation,stride = stride,expansion = expansion, additional_inputs=additional_inputs)
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint)
    model.to('cpu')
    yhat = torch.Tensor()
    
    # Calculate FLOPs
    if not ds is None:
        dl=torch.utils.data.DataLoader(ds, batch_size=1,num_workers=10,drop_last=False)
        for x,y in dl:
            count_ops(model,x)
            break

    # produce tensors
    y_total = torch.Tensor()
    yhat_total = torch.Tensor()
    for x,y in tqdm(test_loader):
        yhat = model(x)
        yhat,y = thresholded_output_transform(yhat,y)
        y_total = torch.cat((y_total,y),0)
        yhat_total = torch.cat((yhat_total,yhat.detach().cpu()),0)
    
    # Save predictions
    df = pd.DataFrame({'labels':y_total.flatten().tolist(),'Prediction':yhat_total.flatten().tolist()})
    
    # Produce ROC and CM
    fpr,tpr, t = roc_curve(y_total,yhat_total)
    thresh = point(y_total,yhat_total,t)
    lw = 2
    cm(y_total.flatten(),yhat_total.flatten(),folder=folder,threshold = thresh)
    bootstrap(df)
    return yhat_total, y_total, fpr, tpr, t
def bootstrap(df):
    y_total,yhat_total = df['labels'],df['Prediction']
    fpr_boot = []
    tpr_boot = []
    aucs = []
    
    # bootstrap for confidence interval
    for i in tqdm(range(0,10000)):
        choices = np.random.choice(range(0,len(yhat_total)),int(len(yhat_total)/2))
        fpr,tpr, _ = roc_curve(y_total[choices],yhat_total[choices])
        fpr_boot.append(fpr)
        tpr_boot.append(tpr)
        aucs.append(auc(fpr,tpr))
    low,high = np.nanmean(aucs)-np.nanstd(aucs)*1.96,np.nanmean(aucs)+np.nanstd(aucs)*1.96
    lower_point = round(np.percentile(aucs,2.5),2)
    higher_point = round(np.percentile(aucs,97.5),2)
    mean_point = round(np.nanmean(aucs),2)
    x = plt.hist(aucs,bins = 50,label = 'mean: '+str(mean_point))

    plt.plot([np.percentile(aucs,2.5),np.percentile(aucs,2.5)],[0,max(x[0])],label = 'lower interval: '+str(lower_point))
    plt.plot([np.percentile(aucs,97.5),np.percentile(aucs,97.5)],[0,max(x[0])],label = 'higher interval: '+str(higher_point))
    plt.title("AUC Histogram")
    plt.xlabel("AUC")
    plt.legend()
    plt.show()
    
    plt.figure()
    lw = 2
    for i in range(0,1000):
        plt.plot(fpr_boot[i],tpr_boot[i], color='lightblue',
                 lw=lw)
    fpr,tpr, _ = roc_curve(y_total,yhat_total)
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % auc(fpr,tpr))
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.show()
def cm(y_total,yhat_total,Project_name = None,folder=None,threshold = 0.5):
    print(threshold)
    cm = confusion_matrix(y_total,yhat_total>threshold)
    tn, fp, fn, tp = confusion_matrix(y_total,yhat_total>threshold).ravel()
    specificity = ( tn / (tn+fp) )
    sensitivity= ( tp / (tp+fn) )
    print('Positive Predictive Value',round(tp/(tp+fp),2),'Negative Predictive Value', round(tn/(tn+fn),2), ' Specificty ', specificity, 'Sensitivity ', sensitivity)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['No Event','Adverse Event'])
    disp.plot()
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
Val_root = '/workspace/John/IntroECG-main/IntroECG-main/data/Definitely Not A Mistake/'
Val_csv = '/workspace/John/IntroECG-main/IntroECG-main/data/Test_rcri_outcome.csv'


# RCRINet on RCRI Outcome

In [None]:
bs = 2000
checkpoint = 'best_roc_model_Mortality_with_RCRI_Features.pt'
additional_inputs = ['CrGreaterThan2','is_risk','insulin','cad','chf','stroke']

test_ds = ECG_loader(root = Val_root, csv = Val_csv,sliding = False,downsample=1,additional_inputs=additional_inputs)
val_dataloader=torch.utils.data.DataLoader(test_ds, batch_size=bs,num_workers=10,drop_last=False)
x = produce_df(val_dataloader,checkpoint,folder = 'Spare',stride = 8, dilation = 2,additional_inputs=len(additional_inputs))
