# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import torch
import torch.nn as nn
import neurokit2 as nk
import numpy as np

from tqdm import tqdm

from hparams import BATCH_SIZE, NUM_WORKERS
from models.baseline import ResnetBaseline
from runners.train import Runner

In [3]:
from dataloaders.code_draft import CODE as DS
from dataloaders.code_draft import CODEsplit as DSsplit

# init

In [4]:
database = DS()
model = ResnetBaseline(n_classes = 6)
model = torch.load('output/code/code.pt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
runner = Runner(device = device, model = model, database = database, split = DSsplit, model_label = 'code')

using test ds, H is treated as X


# utils

In [6]:
SIGNAL_CROP_LEN = 2560
SIGNAL_NON_ZERO_START = 571

def get_inputs(batch, device = "cuda"):
    # (B, C, L)
    if batch.shape[1] > batch.shape[2]:
        batch = batch.permute(0, 2, 1)
    
    fs = 400
    for i in (range(batch.shape[0])):
        for j in range(batch.shape[1]):
            batch[i, j, :] = torch.tensor(nk.ecg_clean(batch[i, j, :], sampling_rate = fs).copy())

    transformed_data = batch.float()
    return transformed_data.to(device)

In [7]:
from sklearn.metrics import f1_score

def find_best_thresholds(predictions, true_labels_dict, thresholds):
    num_classes = len(predictions[0])
    best_thresholds = [0.5] * num_classes
    best_f1s = [0.0] * num_classes

    for class_idx in (range(num_classes)):
        for thresh in thresholds:
            f1 = f1_score(
                true_labels_dict[class_idx],
                predictions[thresh][class_idx],
                zero_division=0,
            )

            if f1 > best_f1s[class_idx]:
                best_f1s[class_idx] = f1
                best_thresholds[class_idx] = thresh
    
    return best_f1s, best_thresholds

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score

def metrics_table(all_binary_results, all_true_labels):
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []
    auc_scores = []

    num_classes = all_binary_results.shape[-1]
    for class_idx in range(num_classes):
        class_binary_results = all_binary_results[:, class_idx].cpu().numpy()
        class_true_labels = all_true_labels[:, class_idx].cpu().numpy()

        accuracy = accuracy_score(class_true_labels, class_binary_results)
        precision = precision_score(class_true_labels, class_binary_results, zero_division=0)
        recall = recall_score(class_true_labels, class_binary_results, zero_division=0)
        f1 = f1_score(class_true_labels, class_binary_results, zero_division=0)
        auc = roc_auc_score(class_true_labels, class_binary_results)

        accuracy_scores.append(accuracy)
        precision_scores.append(precision)
        recall_scores.append(recall)
        f1_scores.append(f1)
        auc_scores.append(auc)
    
    # normal
    # normal_idx = torch.sum(all_true_labels, dim = 1) == 0
    
    # class_binary_results = torch.sum(all_binary_results, axis = 1).bool()[normal_idx].cpu().numpy()
    # class_true_labels = torch.sum(all_true_labels, axis = 1).bool()[normal_idx].cpu().numpy()
    class_binary_results = (~torch.sum(all_binary_results, axis = 1).bool()).int().cpu().numpy()
    class_true_labels = (~torch.sum(all_true_labels, axis = 1).bool()).int().cpu().numpy()

    accuracy = accuracy_score(class_true_labels, class_binary_results)
    precision = precision_score(class_true_labels, class_binary_results, zero_division=0)
    recall = recall_score(class_true_labels, class_binary_results, zero_division=0)
    f1 = f1_score(class_true_labels, class_binary_results, zero_division=0)
    auc = roc_auc_score(class_true_labels, class_binary_results)
    
    accuracy_scores.append(accuracy)
    precision_scores.append(precision)
    recall_scores.append(recall)
    f1_scores.append(f1)
    auc_scores.append(auc)

    metrics_dict = {
        "Accuracy": accuracy_scores,
        # "Precision": precision_scores,
        # "Recall": recall_scores,
        "F1 Score": f1_scores,
        "AUC ROC": auc_scores,
    }

    return metrics_dict

In [9]:
def synthesis(model, device, loader, best_thresholds = None):
    if best_thresholds == None:
        num_classes = 6
        thresholds = np.arange(0, 1.01, 0.01)  # Array of thresholds from 0 to 1 with step 0.01
        predictions = {thresh: [[] for _ in range(num_classes)] for thresh in thresholds}
        true_labels_dict = [[] for _ in range(num_classes)]
    else:
        all_binary_results = []
        all_true_labels = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader):
            # raw, exam_id, label = batch
            raw = batch['X']
            label = batch['y']
            ecg = get_inputs(raw, device = device)
            label = label.to(device).float()

            logits = model(ecg)
            probs = torch.sigmoid(logits)

            if best_thresholds == None:
                for class_idx in range(num_classes):
                    for thresh in thresholds:
                        predicted_binary = (probs[:, class_idx] >= thresh).float()
                        predictions[thresh][class_idx].extend(
                            predicted_binary.cpu().numpy()
                        )
                    true_labels_dict[class_idx].extend(
                        label[:, class_idx].cpu().numpy()
                    )
            else:
                binary_result = torch.zeros_like(probs)
                for i in range(len(best_thresholds)):
                    binary_result[:, i] = (
                        probs[:, i] >= best_thresholds[i]
                    ).float()
                
                all_binary_results.append(binary_result)
                all_true_labels.append(label)

    if best_thresholds == None:
        best_f1s, best_thresholds = find_best_thresholds(predictions, true_labels_dict, thresholds)
        return best_f1s, best_thresholds
    else:
        all_binary_results = torch.cat(all_binary_results, dim=0)
        all_true_labels = torch.cat(all_true_labels, dim=0)
        return all_binary_results, all_true_labels, metrics_table(all_binary_results, all_true_labels)

# eval

In [10]:
old_f1s = [0.5957446808510638,
 0.8368794326241135,
 0.8463611859838275,
 0.7044025157232704,
 0.8435754189944135,
 0.7948717948717948]
old_thresholds = [0.15, 0.47000000000000003, 0.6900000000000001, 0.39, 0.44, 0.42]

In [11]:
model = model.to(device)
val_dl = torch.utils.data.DataLoader(runner.val_ds, batch_size = BATCH_SIZE, 
                                        shuffle = False, num_workers = NUM_WORKERS)
tst_dl = torch.utils.data.DataLoader(runner.tst_ds, batch_size = BATCH_SIZE, 
                                        shuffle = False, num_workers = NUM_WORKERS)
# best_f1s, best_thresholds = runner._synthesis(val_dl, best_thresholds = None)
# all_binary_results, all_true_labels, metrics_dict = runner._synthesis(tst_dl, best_thresholds)
# best_f1s, best_thresholds = synthesis(model, device, val_dl, best_thresholds = None)
all_binary_results, all_true_labels, metrics_dict = synthesis(model, device, tst_dl, old_thresholds)

100%|██████████| 7/7 [00:17<00:00,  2.56s/it]


In [12]:
metrics_dict

{'Accuracy': [0.9939540507859734,
  0.9939540507859734,
  0.9987908101571947,
  0.992744860943168,
  0.9951632406287787,
  0.992744860943168,
  0.9758162031438936],
 'F1 Score': [0.9090909090909091,
  0.9275362318840579,
  0.983050847457627,
  0.8421052631578948,
  0.8333333333333333,
  0.9166666666666667,
  0.9853157121879589],
 'AUC ROC': [0.945177006973002,
  0.9686966842222386,
  0.9833333333333334,
  0.9963008631319359,
  0.8840011340011339,
  0.9446801231611358,
  0.9584112807515136]}