In [1]:
import numpy as np
import pandas as pd
import wfdb
import ast
import glob
from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
import os

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import math
import tensorflow as tf
from tensorflow.keras import initializers, regularizers, constraints
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize, MinMaxScaler
from scipy.signal import spectrogram, resample
from matplotlib.collections import LineCollection

import os
import wandb
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score
from wandb.keras import WandbCallback
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.parameter import Parameter

import random
from tqdm import tqdm
import wandb
from sklearn.metrics import cohen_kappa_score, accuracy_score,f1_score
from sklearn.utils import shuffle

In [3]:
### Preprocessing 
#   Using the super classes, multi label classification, excluding samples with no labels and considering atleast one label

path = 'C:/ptb/'
Y = pd.read_csv(path+ 'ptbxl_database.csv', index_col = 'ecg_id')



data = np.array([wfdb.rdsamp(path+f)[0] for f in Y.filename_lr])
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
    
agg_df = pd.read_csv(path+ 'scp_statements.csv', index_col = 0)

agg_df = agg_df[agg_df.diagnostic == 1]

def agg(y_dic):
    temp =[]
    
    for key in y_dic.keys():
        if key in agg_df.index:
            c = agg_df.loc[key].diagnostic_class
            if str(c) != 'nan':
                temp.append(c)
    return list(set(temp))

Y['diagnostic_superclass'] = Y.scp_codes.apply(agg)
Y['superdiagnostic_len'] = Y['diagnostic_superclass'].apply(lambda x: len(x))


#########

counts = pd.Series(np.concatenate(Y.diagnostic_superclass.values)).value_counts()

Y['diagnostic_superclass'] = Y['diagnostic_superclass'].apply(lambda x: list(set(x).intersection(set(counts.index.values))))

X_data = data[Y['superdiagnostic_len'] >= 1]
Y_data = Y[Y['superdiagnostic_len'] >= 1]

mlb = MultiLabelBinarizer()
mlb.fit(Y_data['diagnostic_superclass'])
y = mlb.transform(Y_data['diagnostic_superclass'].values)

########

## Stratify split

X_train = X_data[Y_data.strat_fold < 9]
y_train = y[Y_data.strat_fold < 9]

X_val = X_data[Y_data.strat_fold == 9]
y_val = y[Y_data.strat_fold == 9]

X_test = X_data[Y_data.strat_fold == 10]
y_test = y[Y_data.strat_fold == 10]

del X_data, Y_data, y

In [4]:
# Standardizing

def apply_scaler(X, scaler):
    X_tmp = []
    for x in X:
        x_shape = x.shape
        X_tmp.append(scaler.transform(x.flatten()[:,np.newaxis]).reshape(x_shape))
    X_tmp = np.array(X_tmp)
    return X_tmp


scaler = StandardScaler()

scaler.fit(np.vstack(X_train).flatten()[:,np.newaxis].astype(float))

X_train_scale = apply_scaler(X_train, scaler)
X_test_scale = apply_scaler(X_test, scaler)
X_val_scale = apply_scaler(X_val, scaler)

del X_train, X_test, X_val


In [5]:
class DataGen(Dataset):
    def __init__(self, X, y,batch_size = 16):
        self.batch_size = batch_size
        self.X = X
        self.y = y
        
    def __len__(self):
        return math.ceil(len(self.X) / self.batch_size)
    
    def __getitem__(self,idx):
        
        batch_x = self.X[idx * self.batch_size:(idx + 1) *self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *self.batch_size]
        
        return torch.tensor(batch_x, dtype = torch.float32), torch.tensor(batch_y, dtype = torch.float32)
    
## Params

batch_size = 16
    
train_gen = DataGen(X_train_scale, y_train, batch_size = batch_size)
test_gen = DataGen(X_test_scale, y_test, batch_size = batch_size)

In [6]:
class ResBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsample=None):
        super(ResBlock, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.downsample = downsample




    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        if self.downsample is not None:
            out = self.maxpool(out)
            identity = self.downsample(x)

        out += identity
        # print(out.shape)

        return out




class ECGNet(nn.Module):

    def __init__(self, struct=[15, 17, 19, 21], in_channels=12, fixed_kernel_size=17, num_classes=5):
        super(ECGNet, self).__init__()
        self.struct = struct
        self.planes = 16
        self.parallel_conv = nn.ModuleList()

        for i, kernel_size in enumerate(struct):
            sep_conv = nn.Conv1d(in_channels=in_channels, out_channels=self.planes, kernel_size=kernel_size,
                               stride=1, padding=0, bias=False)
            self.parallel_conv.append(sep_conv)
        # self.parallel_conv.append(nn.Sequential(
        #     nn.MaxPool1d(kernel_size=2, stride=2, padding=0),
        #     nn.Conv1d(in_channels=1, out_channels=self.planes, kernel_size=1,
        #                        stride=1, padding=0, bias=False)
        # ))

        self.bn1 = nn.BatchNorm1d(num_features=self.planes)
        self.relu = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv1d(in_channels=self.planes, out_channels=self.planes, kernel_size=fixed_kernel_size,
                               stride=2, padding=2, bias=False)
        self.block = self._make_layer(kernel_size=fixed_kernel_size, stride=1, padding=8)
        self.bn2 = nn.BatchNorm1d(num_features=self.planes)
        self.avgpool = nn.AvgPool1d(kernel_size=8, stride=8, padding=2)
        self.rnn = nn.LSTM(input_size=12, hidden_size=40, num_layers=1, bidirectional=False)
        self.fc = nn.Linear(in_features=168, out_features=num_classes)


    def _make_layer(self, kernel_size, stride, blocks=15, padding=0):
        layers = []
        downsample = None
        base_width = self.planes

        for i in range(blocks):
            if (i + 1) % 4 == 0:
                downsample = nn.Sequential(
                    nn.Conv1d(in_channels=self.planes, out_channels=self.planes + base_width, kernel_size=1,
                               stride=1, padding=0, bias=False),
                    nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
                )
                layers.append(ResBlock(in_channels=self.planes, out_channels=self.planes + base_width, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsample=downsample))
                self.planes += base_width
            elif (i + 1) % 2 == 0:
                downsample = nn.Sequential(
                    nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
                )
                layers.append(ResBlock(in_channels=self.planes, out_channels=self.planes, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsample=downsample))
            else:
                downsample = None
                layers.append(ResBlock(in_channels=self.planes, out_channels=self.planes, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsample=downsample))

        return nn.Sequential(*layers)



    def forward(self, x):
        out_sep = []

        for i in range(len(self.struct)):
            sep = self.parallel_conv[i](x)
            out_sep.append(sep)

        out = torch.cat(out_sep, dim=2)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv1(out)  # out => [b, 16, 9960]

        out = self.block(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)  # out => [b, 64, 10]
        out = out.reshape(out.shape[0], -1)  # out => [b, 640]

        rnn_out, (rnn_h, rnn_c) = self.rnn(x.permute(2, 0, 1))
        new_rnn_h = rnn_h[-1, :, :]  # rnn_h => [b, 40]

        new_out = torch.cat([out, new_rnn_h], dim=1)  # out => [b, 680]
        
        result = self.fc(new_out)  # out => [b, 20]

        # print(out.shape)

        return result

model = ECGNet()

In [6]:
if not os.path.exists('ECGNet_saves'):
    os.mkdir('ECGNet_saves')

In [7]:
def metrics(y_true, y_scores):
    y_pred = y_scores >= 0.5
    acc = np.zeros(y_pred.shape[-1])
    
    roc_auc = roc_auc_score(y_true, y_scores, average = 'macro')
    
    for i in range(y_pred.shape[-1]):
        acc[i] = accuracy_score(y_true[:,i], y_pred[:,i])
    return acc, np.mean(acc), roc_auc

wandb.init(project = 'BaseECG', name = 'ECGNet')

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33mlikith012[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.22 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [8]:
def train_model(model, optimizer, loss_func, dataset, epoch):

    model.train()
    
    pred_all = []
    loss_all = []
    gt_all = []
    
    for batch_step in tqdm(range(len(dataset)) , desc="train"):
        batch_x, batch_y = dataset[batch_step]    
        batch_x = batch_x.cuda()
        batch_x = batch_x.permute(0,2,1)
        batch_y = batch_y.cuda()

        pred = model(batch_x)
        pred_all.append(pred.cpu().detach().numpy())
        
        loss = loss_func(pred, batch_y)
        loss_all.append(loss.cpu().detach().item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        gt_all.extend(batch_y.cpu().detach().numpy())

    print('epoch {0} '.format(epoch))
    print('train_loss ', np.mean(loss_all))
    pred_all = np.concatenate(pred_all, axis=0)

    _, mean_acc, roc_score = metrics(np.array(gt_all), pred_all )
    wandb.log({'train_mean_accuracy' : mean_acc, 'epoch':epoch})
    wandb.log({'train_roc_score' : roc_score, 'epoch':epoch})
    wandb.log({'train_loss' : np.mean(loss_all) , 'epoch':epoch})

    return np.mean(loss_all)



def test_model(model, loss_func, dataset):

    model.eval()
    
    pred_all = []
    loss_all = []
    gt_all = []
    
    for batch_step in tqdm(range(len(dataset)) , desc="test"):
        batch_x, batch_y = dataset[batch_step]
        batch_x = batch_x.cuda()
        batch_x = batch_x.permute(0,2,1)
        batch_y = batch_y.cuda()
        
        pred = model(batch_x)
        pred_all.append(pred.cpu().detach().numpy())
       
        loss = loss_func(pred, batch_y)
        loss_all.append(loss.cpu().detach().numpy())
        gt_all.extend(batch_y.cpu().detach().numpy())

    print('test_loss ', np.mean(loss_all))
    pred_all = np.concatenate(pred_all, axis=0)

    _, mean_acc, roc_score = metrics(np.array(gt_all), pred_all )
    wandb.log({'test_mean_accuracy' : mean_acc, 'epoch':epoch})
    wandb.log({'test_roc_score' : roc_score, 'epoch':epoch})
    wandb.log({'test_loss' : np.mean(loss_all) , 'epoch':epoch})

    return np.mean(loss_all), mean_acc, roc_score

In [None]:
lr = 0.001
epochs = 60

model.cuda()
# wandb.watch(model)

optimizer = optim.Adam(model.parameters(), lr = lr)
loss_func = torch.nn.BCEWithLogitsLoss()


for epoch in range(epochs):
    train_step = train_model(model, optimizer, loss_func, train_gen, epoch)
    test_step = test_model(model, loss_func, test_gen)

    if epoch > 5 :
        torch.save(model.state_dict(), f'ECGNet_saves/{epoch}__{test_step[2]:.4f}.pt')


train: 100%|██████████| 1070/1070 [02:02<00:00,  8.72it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 37.35it/s]

epoch 0 
train_loss  0.36884335230722604


test: 100%|██████████| 136/136 [00:03<00:00, 40.85it/s]
train:   0%|          | 1/1070 [00:00<01:50,  9.66it/s]

test_loss  0.359931


train: 100%|██████████| 1070/1070 [02:03<00:00,  8.65it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 27.92it/s]

epoch 1 
train_loss  0.31217912286520005


test: 100%|██████████| 136/136 [00:03<00:00, 34.86it/s]
train:   0%|          | 1/1070 [00:00<01:53,  9.39it/s]

test_loss  0.33083642


train: 100%|██████████| 1070/1070 [02:04<00:00,  8.62it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 29.25it/s]

epoch 2 
train_loss  0.29330473386934985


test: 100%|██████████| 136/136 [00:03<00:00, 40.61it/s]
train:   0%|          | 1/1070 [00:00<01:50,  9.70it/s]

test_loss  0.31379893


train: 100%|██████████| 1070/1070 [02:03<00:00,  8.68it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 29.88it/s]

epoch 3 
train_loss  0.28176224891827484


test: 100%|██████████| 136/136 [00:03<00:00, 37.76it/s]
train:   0%|          | 1/1070 [00:00<01:49,  9.73it/s]

test_loss  0.31501555


train: 100%|██████████| 1070/1070 [01:53<00:00,  9.40it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.95it/s]

epoch 4 
train_loss  0.27314382198954296


test: 100%|██████████| 136/136 [00:03<00:00, 41.87it/s]
train:   0%|          | 2/1070 [00:00<01:21, 13.15it/s]

test_loss  0.31343472


train: 100%|██████████| 1070/1070 [01:20<00:00, 13.37it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 40.69it/s]

epoch 5 
train_loss  0.2647550921330107


test: 100%|██████████| 136/136 [00:03<00:00, 40.75it/s]
train:   0%|          | 2/1070 [00:00<01:19, 13.46it/s]

test_loss  0.3132583


train: 100%|██████████| 1070/1070 [01:23<00:00, 12.75it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.81it/s]

epoch 6 
train_loss  0.2579409536566133


test: 100%|██████████| 136/136 [00:03<00:00, 41.98it/s]


test_loss  0.30885902


train: 100%|██████████| 1070/1070 [01:23<00:00, 12.89it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 40.84it/s]

epoch 7 
train_loss  0.2504917144287969


test: 100%|██████████| 136/136 [00:03<00:00, 41.74it/s]


test_loss  0.31293997


train: 100%|██████████| 1070/1070 [01:35<00:00, 11.23it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 38.99it/s]

epoch 8 
train_loss  0.2443331797615947


test: 100%|██████████| 136/136 [00:03<00:00, 41.85it/s]


test_loss  0.31405863


train: 100%|██████████| 1070/1070 [01:27<00:00, 12.19it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.99it/s]

epoch 9 
train_loss  0.23834283063275236


test: 100%|██████████| 136/136 [00:03<00:00, 41.10it/s]


test_loss  0.31960666


train: 100%|██████████| 1070/1070 [01:33<00:00, 11.41it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.98it/s]

epoch 10 
train_loss  0.23136993729135144


test: 100%|██████████| 136/136 [00:03<00:00, 41.68it/s]


test_loss  0.3308096


train: 100%|██████████| 1070/1070 [01:31<00:00, 11.69it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 38.90it/s]

epoch 11 
train_loss  0.2244844285719027


test: 100%|██████████| 136/136 [00:03<00:00, 40.97it/s]


test_loss  0.33628672


train: 100%|██████████| 1070/1070 [01:32<00:00, 11.53it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.62it/s]

epoch 12 
train_loss  0.21883703653013037


test: 100%|██████████| 136/136 [00:03<00:00, 40.76it/s]


test_loss  0.347952


train: 100%|██████████| 1070/1070 [01:26<00:00, 12.43it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.22it/s]

epoch 13 
train_loss  0.21207367255587445


test: 100%|██████████| 136/136 [00:03<00:00, 41.76it/s]


test_loss  0.3545371


train: 100%|██████████| 1070/1070 [01:26<00:00, 12.33it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 39.50it/s]

epoch 14 
train_loss  0.2033620270056145


test: 100%|██████████| 136/136 [00:03<00:00, 40.81it/s]


test_loss  0.368325


train: 100%|██████████| 1070/1070 [01:26<00:00, 12.31it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 26.98it/s]

epoch 15 
train_loss  0.19422555699178548


test: 100%|██████████| 136/136 [00:03<00:00, 35.84it/s]


test_loss  0.3670307


train: 100%|██████████| 1070/1070 [01:27<00:00, 12.22it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 39.26it/s]

epoch 16 
train_loss  0.18699470627008477


test: 100%|██████████| 136/136 [00:03<00:00, 35.98it/s]


test_loss  0.3813329


train: 100%|██████████| 1070/1070 [01:35<00:00, 11.21it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.34it/s]

epoch 17 
train_loss  0.17815066495579537


test: 100%|██████████| 136/136 [00:03<00:00, 41.83it/s]


test_loss  0.39931965


train: 100%|██████████| 1070/1070 [01:29<00:00, 11.91it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 29.04it/s]

epoch 18 
train_loss  0.17024776575582048


test: 100%|██████████| 136/136 [00:03<00:00, 39.59it/s]


test_loss  0.41593003


train: 100%|██████████| 1070/1070 [01:31<00:00, 11.71it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.98it/s]

epoch 19 
train_loss  0.16174233258815013


test: 100%|██████████| 136/136 [00:03<00:00, 41.06it/s]


test_loss  0.4332475


train: 100%|██████████| 1070/1070 [01:25<00:00, 12.52it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 39.13it/s]

epoch 20 
train_loss  0.15276617591974334


test: 100%|██████████| 136/136 [00:03<00:00, 41.11it/s]


test_loss  0.4364877


train: 100%|██████████| 1070/1070 [01:32<00:00, 11.60it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 42.01it/s]

epoch 21 
train_loss  0.1465741006013389


test: 100%|██████████| 136/136 [00:03<00:00, 41.76it/s]


test_loss  0.45080858


train: 100%|██████████| 1070/1070 [01:19<00:00, 13.38it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.51it/s]

epoch 22 
train_loss  0.13748424958402866


test: 100%|██████████| 136/136 [00:03<00:00, 41.66it/s]


test_loss  0.5201003


train: 100%|██████████| 1070/1070 [01:32<00:00, 11.54it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.94it/s]

epoch 23 
train_loss  0.1327014625995098


test: 100%|██████████| 136/136 [00:03<00:00, 41.99it/s]


test_loss  0.5026201


train: 100%|██████████| 1070/1070 [01:24<00:00, 12.68it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.49it/s]

epoch 24 
train_loss  0.1259953607517843


test: 100%|██████████| 136/136 [00:03<00:00, 41.29it/s]


test_loss  0.50263786


train: 100%|██████████| 1070/1070 [01:30<00:00, 11.84it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 39.13it/s]

epoch 25 
train_loss  0.1195528215058496


test: 100%|██████████| 136/136 [00:03<00:00, 41.76it/s]


test_loss  0.54139996


train: 100%|██████████| 1070/1070 [01:25<00:00, 12.56it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.62it/s]

epoch 26 
train_loss  0.11388346465113007


test: 100%|██████████| 136/136 [00:03<00:00, 41.90it/s]


test_loss  0.5564948


train: 100%|██████████| 1070/1070 [01:21<00:00, 13.21it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.31it/s]

epoch 27 
train_loss  0.11041517015372482


test: 100%|██████████| 136/136 [00:03<00:00, 41.86it/s]


test_loss  0.5459528


train: 100%|██████████| 1070/1070 [01:30<00:00, 11.82it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 38.57it/s]

epoch 28 
train_loss  0.1024366324140786


test: 100%|██████████| 136/136 [00:03<00:00, 41.62it/s]


test_loss  0.5629619


train: 100%|██████████| 1070/1070 [01:30<00:00, 11.88it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.95it/s]

epoch 29 
train_loss  0.0983938837524016


test: 100%|██████████| 136/136 [00:03<00:00, 41.70it/s]


test_loss  0.60391575


train: 100%|██████████| 1070/1070 [01:29<00:00, 12.01it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.80it/s]

epoch 30 
train_loss  0.09616881861421062


test: 100%|██████████| 136/136 [00:03<00:00, 39.28it/s]


test_loss  0.5599215


train: 100%|██████████| 1070/1070 [01:28<00:00, 12.14it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.57it/s]

epoch 31 
train_loss  0.08973731704700355


test: 100%|██████████| 136/136 [00:03<00:00, 40.91it/s]


test_loss  0.6520616


train: 100%|██████████| 1070/1070 [01:29<00:00, 12.01it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.79it/s]

epoch 32 
train_loss  0.08482296666147832


test: 100%|██████████| 136/136 [00:03<00:00, 41.94it/s]


test_loss  0.62545764


train: 100%|██████████| 1070/1070 [01:34<00:00, 11.27it/s]
test:   2%|▏         | 3/136 [00:00<00:04, 27.42it/s]

epoch 33 
train_loss  0.08203077191338988


test: 100%|██████████| 136/136 [00:03<00:00, 40.90it/s]


test_loss  0.650081


train: 100%|██████████| 1070/1070 [01:21<00:00, 13.06it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.20it/s]

epoch 34 
train_loss  0.07539276468054017


test: 100%|██████████| 136/136 [00:03<00:00, 40.84it/s]


test_loss  0.6485173


train: 100%|██████████| 1070/1070 [01:28<00:00, 12.08it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 42.00it/s]

epoch 35 
train_loss  0.07368728988934482


test: 100%|██████████| 136/136 [00:03<00:00, 41.97it/s]


test_loss  0.6730396


train: 100%|██████████| 1070/1070 [01:23<00:00, 12.83it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.41it/s]

epoch 36 
train_loss  0.07017894171865953


test: 100%|██████████| 136/136 [00:03<00:00, 41.16it/s]


test_loss  0.7095008


train: 100%|██████████| 1070/1070 [01:26<00:00, 12.31it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.26it/s]

epoch 37 
train_loss  0.06741760024043653


test: 100%|██████████| 136/136 [00:03<00:00, 41.29it/s]


test_loss  0.66834205


train: 100%|██████████| 1070/1070 [01:24<00:00, 12.71it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.66it/s]

epoch 38 
train_loss  0.06698919117938135


test: 100%|██████████| 136/136 [00:03<00:00, 41.74it/s]


test_loss  0.702291


train: 100%|██████████| 1070/1070 [01:30<00:00, 11.86it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.97it/s]

epoch 39 
train_loss  0.059254360105407154


test: 100%|██████████| 136/136 [00:03<00:00, 41.43it/s]


test_loss  0.6618781


train: 100%|██████████| 1070/1070 [01:28<00:00, 12.07it/s]
test:   4%|▎         | 5/136 [00:00<00:03, 41.54it/s]

epoch 40 
train_loss  0.0597133028340552


test: 100%|██████████| 136/136 [00:03<00:00, 40.76it/s]


test_loss  0.72156996


train: 100%|██████████| 1070/1070 [01:24<00:00, 12.65it/s]
test:   3%|▎         | 4/136 [00:00<00:03, 39.94it/s]

epoch 41 
train_loss  0.057202074084121944


test: 100%|██████████| 136/136 [00:03<00:00, 40.42it/s]


test_loss  0.7991146


train:  52%|█████▏    | 558/1070 [00:45<00:40, 12.77it/s]

In [7]:
PATH = r'C:\Users\likit\OneDrive\Desktop\Cardio-Viz\Code\BaseECG\ECGNet\11__0.9102.pt'

model.load_state_dict(torch.load(PATH))
model.eval()

ECGNet(
  (parallel_conv): ModuleList(
    (0): Conv1d(12, 16, kernel_size=(15,), stride=(1,), bias=False)
    (1): Conv1d(12, 16, kernel_size=(17,), stride=(1,), bias=False)
    (2): Conv1d(12, 16, kernel_size=(19,), stride=(1,), bias=False)
    (3): Conv1d(12, 16, kernel_size=(21,), stride=(1,), bias=False)
  )
  (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (conv1): Conv1d(16, 16, kernel_size=(17,), stride=(2,), padding=(2,), bias=False)
  (block): Sequential(
    (0): ResBlock(
      (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.1, inplace=False)
      (conv1): Conv1d(16, 16, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(17,), stride=(1,), padding=(8,), bias=False)
      (maxpool): MaxPoo

In [8]:
test_gen = DataGen(X_test_scale, y_test, batch_size=len(X_test_scale))

batch = test_gen[0][0].permute(0,2,1)
pred = model(batch).detach()

roc_auc_score(y_test, pred, average='macro')

0.9101563184785769

In [9]:
### Class wise AUC

roc_score = roc_auc_score(y_test, pred, average='macro')
print(f'roc_score : {roc_score}')

def AUC(y_true: np.ndarray, y_pred: np.ndarray, verbose=False) -> float:
    """Computes the macro-average AUC score.

    Args:
        y_true (np.ndarray): list of labels
        y_pred (np.ndarray): list of predicted probabilities

    Returns:
        float: macro-average AUC score.
    """
    aucs = []
    assert len(y_true.shape) == 2 and len(y_pred.shape) == 2, 'Predictions and labels must be 2D.'
    for col in range(y_true.shape[1]):
        try:
            aucs.append(roc_auc_score(y_true[:, col], y_pred[:, col]))
        except ValueError as e:
            if verbose:
                print(
                    f'Value error encountered for label {col}, likely due to using mixup or '
                    f'lack of full label presence. Setting AUC to accuracy. '
                    f'Original error was: {str(e)}.'
                )
            aucs.append((y_pred == y_true).sum() / len(y_pred))
    return np.array(aucs)

class_auc = AUC(y_test, pred)
print(f'class wise AUC : {class_auc}')



roc_score : 0.9101563184785769
class wise AUC : [0.91022348 0.88426456 0.9021273  0.93360477 0.92056149]


In [10]:
### Accuracy metric

def metrics(y_true, y_scores):
    y_pred = y_scores >= 0.5
    acc = np.zeros(y_pred.shape[-1])
    
    for i in range(y_pred.shape[-1]):
        acc[i] = accuracy_score(y_true[:,i], y_pred[:,i])
    return acc, np.mean(acc)

acc, mean_acc = metrics(y_test, pred)
print(f'class wise accuracy: {acc}')
print(f'accuracy: {mean_acc}')

class wise accuracy: [0.88719371 0.90661119 0.84604716 0.85899214 0.86870088]
accuracy: 0.8735090152565881


In [11]:
pred_values = pred >= 0.5

report = classification_report(y_test, pred_values, target_names = mlb.classes_)
print(report)

              precision    recall  f1-score   support

          CD       0.86      0.61      0.71       498
         HYP       0.77      0.33      0.47       263
          MI       0.80      0.53      0.64       553
        NORM       0.79      0.92      0.85       964
        STTC       0.83      0.58      0.68       523

   micro avg       0.81      0.67      0.73      2801
   macro avg       0.81      0.59      0.67      2801
weighted avg       0.81      0.67      0.72      2801
 samples avg       0.75      0.70      0.71      2801

  _warn_prf(average, modifier, msg_start, len(result))


In [12]:
def multi_threshold_precision_recall(y_true: np.ndarray, y_pred: np.ndarray, thresholds: np.ndarray) :
    
    # Expand analysis to number of thresholds
    y_pred_bin = np.repeat(y_pred[None, :, :], len(thresholds), axis=0) >= thresholds[:, None, None]
    y_true_bin = np.repeat(y_true[None, :, :], len(thresholds), axis=0)

    # Compute true positives
    TP = np.sum(np.logical_and(y_true, y_pred_bin), axis=2)

    # Compute macro-average precision handling all warnings
    with np.errstate(divide='ignore', invalid='ignore'):
        den = np.sum(y_pred_bin, axis=2)
        precision = TP / den
        precision[den == 0] = np.nan
        with warnings.catch_warnings():  # for nan slices
            warnings.simplefilter("ignore", category=RuntimeWarning)
            av_precision = np.nanmean(precision, axis=1)

    # Compute macro-average recall
    recall = TP / np.sum(y_true_bin, axis=2)
    av_recall = np.mean(recall, axis=1)

    return av_precision, av_recall


def metric_summary(y_true: np.ndarray, y_pred: np.ndarray, num_thresholds: int = 10) :
    
    thresholds = np.arange(0.00, 1.01, 1. / (num_thresholds - 1), float)
    average_precisions, average_recalls = multi_threshold_precision_recall(
        y_true, y_pred, thresholds
    )
    f_scores = 2 * (average_precisions * average_recalls) / (average_precisions + average_recalls)
    auc = np.array(AUC(y_true, y_pred, verbose=True)).mean()
    return (
        f_scores[np.nanargmax(f_scores)],
        auc,
        f_scores,
        average_precisions,
        average_recalls,
        thresholds
    )

metric_summary(y_test, pred.numpy())


(0.7712464813278368,
 0.9101563184785769,
 array([0.77124648, 0.76814814, 0.76475761, 0.76157367, 0.75672646,
        0.75207953, 0.74914798, 0.74659073, 0.7404506 , 0.73561017]),
 array([0.7926559 , 0.79706679, 0.8023798 , 0.80757946, 0.81168563,
        0.81620603, 0.82248596, 0.82982319, 0.83320099, 0.8367813 ]),
 array([0.75096317, 0.74125443, 0.73050547, 0.72052705, 0.70873786,
        0.69729542, 0.68781785, 0.6785329 , 0.6662814 , 0.65626445]),
 array([0.        , 0.11111111, 0.22222222, 0.33333333, 0.44444444,
        0.55555556, 0.66666667, 0.77777778, 0.88888889, 1.        ]))