In [61]:
root = '/path/to/npy/file/root'
test_csv = '/path/to/filename/and/labels/csv'

#name of the column where the binary label for CKD/No CKD is stored
target='label'

In [57]:
import numpy as np
from paths import FilePaths
from tuningfunctions import get_data_loaders, get_criterion, Objective, run_trials
import models
import torch
import pandas as pd
import matplotlib.pyplot as plt
from ECG import EchoECG
from constants import MODEL_NAME, OPTIMIZERS, MAX_EPOCHS, METRIC, MIN_LR, MAX_LR, PATIENCE, SCHEDULER, STEP, GAMMA, PRUNER, NUM_TRIALS, DIRECTION
import wandb
from tqdm import tqdm
import sklearn

In [58]:
torch.cuda.empty_cache()

In [59]:
def collate_fn(batch):
    batch = list(filter(lambda x: x != None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

In [62]:
model = models.EffNet(channels = [32,16,24,40,80,112,192,320,1280,1],dilation = 2,
                                  stride = 8,
                                  reg = False, 
                                  start_channels=12)

In [63]:
model.load_state_dict(torch.load('twelve_lead_weights.pt'))

In [64]:
model.eval()

In [67]:

test_ds = EchoECG(root=root,
                  csv=test_csv,
                  model='RCRI_Net', 
                  rolling=0, 
                  downsample=1,
                  target=target, 
                  one_lead=False,
                  return_filename=False)
print(len(test_ds)) 
bs = 2000
test_dataloader=torch.utils.data.DataLoader(test_ds,
                                            batch_size=bs, 
                                            num_workers=40, 
                                            drop_last=False, 
                                            collate_fn=collate_fn)

In [68]:
all_labels = []
all_preds = []
with torch.no_grad():
    for ecg, labels in tqdm(test_dataloader):
        
        all_preds += list(model(ecg))
        all_labels += list(labels)
        

In [70]:
cm=sklearn.metrics.confusion_matrix(all_labels, [x>0.5 for x in all_preds])
sklearn.metrics.ConfusionMatrixDisplay(cm).plot()

fpr, tpr, thresholds = sklearn.metrics.roc_curve(all_labels, all_preds)
sklearn.metrics.RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
print(sklearn.metrics.auc(fpr, tpr))