In [87]:
import numpy as np

from paths import FilePaths
from tuningfunctions import get_data_loaders, get_criterion, Objective, run_trials
import models
import torch
import pandas as pd
import matplotlib.pyplot as plt
from ECG import EchoECG
from constants import MODEL_NAME, OPTIMIZERS, MAX_EPOCHS, METRIC, MIN_LR, MAX_LR, PATIENCE, SCHEDULER, STEP, GAMMA, PRUNER, NUM_TRIALS, DIRECTION
import wandb
from tqdm import tqdm
import sklearn

In [88]:
torch.cuda.empty_cache()

In [89]:
def collate_fn(batch):
    batch = list(filter(lambda x: x != None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [90]:
    #'1 heartbeat': '/workspace/data/drives/Internal_SSD/sdd/data/Limited Datasets/2008-2020 EKG QRS Numpy',
    #'P Include': '/workspace/data/drives/Internal_SSD/sdd/data/Limited Datasets/P_include',
    #'P Exclude':'/workspace/data/drives/Internal_SSD/sdd/data/Limited Datasets/P_exclude',
    #'QRS Include':'/workspace/data/drives/Internal_SSD/sdd/data/Limited Datasets/QRS_include',
    #'QRS Exclude':'/workspace/data/drives/Internal_SSD/sdd/data/Limited Datasets/QRS_exclude',
    #'T Include':'/workspace/data/drives/Internal_SSD/sdb/data/Limited Datasets/T_include',
    #'T Exclude':'/workspace/data/drives/Internal_SSD/sdb/data/Limited Datasets/T_exclude'

In [91]:
root = '/workspace/data/drives/Local_SSD/sdd/data/Remade with New Coefficents'
model_origin = 'RCRI_Net'
target='label'

In [92]:
root_dir = '/workspace/data/drives/Local_SSD/sdc/kidney_disease/DefinitiveAllStagesData/'
extensions = ['Mild_test_young.csv',
              'Mild_test_old.csv',
              'Moderate_test_young.csv',
              'Moderate_test_old.csv',
              'ESRD_test_young.csv',
              'ESRD_test_old.csv',
              'all_stages_test.csv',
              'under_60_years_old_subset_test.csv']

In [93]:
model = models.EffNet(channels = [32,16,24,40,80,112,192,320,1280,1],dilation = 2,
                                  stride = 8,
                                  reg = False, 
                                  start_channels=1)
                                 #,num_additional_features=2)

In [94]:
model.load_state_dict(torch.load('/workspace/kai/phecode/Training/CLASSIFIER_DEFINITIVE_ONE_LEAD/ONE_LEAD/best_roc_model_55_val_roc=0.7121.pt'))

In [95]:
model.eval()

test_ = pd.read_csv(test_csv)

stages = [expand_icd_codes(['585.1','585.2','N18.1','N18.2']),
          expand_icd_codes(['585.3','585.4','585.5','N18.3','N18.4','N18.5',]), 
          expand_icd_codes(['585.6','N18.6'])]
names = ['mild', 'moderate', 'ESRD']

for i, stage in enumerate(stages):
    temp_df = test_[test_.ICD9_CD_LIST.astype(str).isin([*stage, 'nan']) 
                    |test_.ICD10_CD_LIST.astype(str).isin([*stage, 'nan'])].copy()
    temp_df.to_csv(f'/workspace/data/drives/Local_SSD/sdc/kidney_disease/ckd_data_negative_sensitivity/only_{names[i]}_test.csv')

In [96]:

test_ds = EchoECG(root=root,
                  csv=root_dir + 'all_stages_test.csv',
                  model=model_origin, 
                  rolling=0, 
                  downsample=1,
                  target=target, 
                  one_lead=True,
                  return_filename=True)
print(len(test_ds)) 
bs = 2000
test_dataloader=torch.utils.data.DataLoader(test_ds,
                                            batch_size=bs, 
                                            num_workers=40, 
                                            drop_last=False, 
                                            collate_fn=collate_fn)

In [97]:
all_labels = []
all_preds = []
all_fnames = []
with torch.no_grad():
    for ecg, labels, fnames in tqdm(test_dataloader):
        
        all_preds += list(model(ecg))
        all_labels += list(labels)
        all_fnames += list(fnames)
        

In [98]:
print(len(all_labels))
print(len(all_preds))
print(len(all_fnames))

In [99]:
cm=sklearn.metrics.confusion_matrix(all_labels, [x>0.5 for x in all_preds])
sklearn.metrics.ConfusionMatrixDisplay(cm).plot()

fpr, tpr, thresholds = sklearn.metrics.roc_curve(all_labels, all_preds)
sklearn.metrics.RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
print(sklearn.metrics.auc(fpr, tpr))

In [107]:
all_predictions_test = pd.DataFrame({'Filename':all_fnames,'one_lead_prediction':torch.cat(all_preds),'has_ckd': torch.cat(all_labels)})

In [108]:
twelve_lead_preds = pd.read_csv('/workspace/data/drives/Local_SSD/sdc/kidney_disease/DefinitiveAllStagesData/definitive_12_lead_eval_test.csv')

In [109]:
twelve_lead_preds

In [110]:
all_predictions_test

In [111]:
twelve_lead_preds['one_lead_prediction'] = all_predictions_test.one_lead_prediction

In [112]:
twelve_lead_preds

In [114]:
twelve_lead_preds.to_csv('/workspace/data/drives/Local_SSD/sdc/kidney_disease/DefinitiveAllStagesData/definitive_model_predictions.csv')