In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
'''
Training script for ecg classification
'''
from __future__ import print_function

import os
import cv2
import json
import time
import torch
import random
import shutil
import argparse
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import sklearn.metrics as skm
import torch.utils.data as data
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
# import models as models
from utils import Logger, AverageMeter, mkdir_p, savefig



In [2]:
torch.cuda.is_available()

True

In [3]:

parser = argparse.ArgumentParser(description='PyTorch ECG LSTM MITBIH Training')
# Datasets
parser.add_argument('-dt', '--dataset', default='ecg', type=str)
parser.add_argument('-ft', '--transformation', default='stft', type=str)
parser.add_argument('-d', '--data', default='./generate_features_3/mitbih_rl', type=str)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
# Optimization options
parser.add_argument('--epochs', default=25, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--train-batch', default=8, type=int, metavar='N',
                    help='train batchsize')
parser.add_argument('--test-batch', default=8, type=int, metavar='N',
                    help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--drop', '--dropout', default=0, type=float,
                    metavar='Dropout', help='Dropout ratio')
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225],
                        help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
# Checkpoints
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                    help='path to save checkpoint (default: checkpoint)')

parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')

# Architecture
parser.add_argument('--depth', type=int, default=110, help='Model depth.')
parser.add_argument('--block-name', type=str, default='BasicBlock',
                    help='the building block for Resnet and Preresnet: BasicBlock, Bottleneck (default: '
                         'Basicblock for ecg)')
parser.add_argument('--cardinality', type=int, default=8, help='Model cardinality (group).')
parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...')
parser.add_argument('--growthRate', type=int, default=12, help='Growth rate for DenseNet.')
parser.add_argument('--compressionRate', type=int, default=2, help='Compression Rate (theta) for DenseNet.')
# Miscs
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', default=False,
                    help='evaluate model on validation set')
# Device options
parser.add_argument('--gpu-id', default='0', type=str,
                    help='id(s) for CUDA_VISIBLE_DEVICES')

args = parser.parse_args("")
state = {k: v for k, v in args._get_kwargs()}




# Use CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
use_cuda = torch.cuda.is_available()

# Random seed
if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)

best_acc = 0  # best test accuracy


In [4]:

class Ecg_loader(torch.utils.data.Dataset):
    def __init__(self, path, transform):
        super(Ecg_loader, self).__init__()
        self.male_vec = np.load('male_vec.npy')
        self.female_vec = np.load('female_vec.npy')

        with open(os.path.join(path, 'ecg_labels.json')) as j_file:
            json_data = json.load(j_file)
        self.idx2name = json_data['labels']
        data = json_data['data']
        self.inputs = []
        self.labels = []
        self.gender = []
        self.inputs_full = []
        self.whole_ecg = []
        self.ecg = []
        self.age = []
        for i in tqdm(data):
            subject_img = []
            subject_ecg = []
            a = np.zeros((100))
            for i_name, w_name in zip(i['images'], i['ecg']):
                img = cv2.imread(os.path.join(path, 'images', transform, i_name))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (90, 90))
                ecg = np.load(os.path.join(path, 'ecg', w_name))
                subject_img.append(np.expand_dims(img.transpose((2, 0, 1)), axis=0))
                subject_ecg.append(np.expand_dims(np.expand_dims(ecg, axis=0), axis=0))
            img_full = cv2.imread(os.path.join(path, 'images_full', transform, i['images_full']))
            img_full = cv2.cvtColor(img_full, cv2.COLOR_BGR2RGB)
            l = i['label']
            a[int(i['age']*100)] = 1
            if i['gender'] == [0, 1]:
                g = self.male_vec
            elif i['gender'] == [1, 0]:
                g = self.female_vec
            self.inputs_full.append(img_full.transpose((2, 0, 1)))
            self.inputs.append(np.concatenate(subject_img, axis=0))
            self.ecg.append(np.concatenate(subject_ecg, axis=0))
            self.whole_ecg.append(np.concatenate(subject_ecg, axis=2))
            counts = np.bincount(l)
            ind = np.argmax(counts)
            encoded_array = np.zeros(6, dtype=np.float64)
            encoded_array[ind] = 1.0 
            # print(encoded_array)
            encoded_array = encoded_array.astype('float64')
            self.labels.append(encoded_array)
            self.gender.append(g)
            self.age.append(a)
        print(len(self.whole_ecg))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.inputs[idx]).float()
        # y = torch.from_numpy(np.array(self.labels[idx])).long()
        y = torch.from_numpy(np.array(self.labels[idx]))
        a = torch.from_numpy(np.array(self.age[idx])).float()
        g = torch.from_numpy(np.array(self.gender[idx])).float()
        w = torch.from_numpy(self.ecg[idx]).float()
        return (x, a, g, w), y

In [5]:

start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

if not os.path.isdir(args.checkpoint):
    mkdir_p(args.checkpoint)

# Data
print('==> Preparing dataset %s' % args.dataset)

dataloader = Ecg_loader
train_path = args.data

traindir = os.path.join(train_path, 'train')
valdir = os.path.join(train_path, 'val')
if not args.evaluate:
    trainset = dataloader(traindir, transform=args.transformation)
testset = dataloader(valdir, transform=args.transformation)

idx2name = testset.idx2name
label_names = []
for i in range(0, len(idx2name.keys())):
    label_names.append(idx2name[str(i)])
num_classes = len(label_names)

if not args.evaluate:
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)
testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)


==> Preparing dataset ecg


100%|██████████| 1782/1782 [00:26<00:00, 66.72it/s]


1782


100%|██████████| 198/198 [00:05<00:00, 39.46it/s]

198





In [6]:
from collections import defaultdict
import csv
import json
import numpy as np
import os
import sys

from sklearn.metrics import roc_curve, auc
from tqdm import tqdm



def all_metrics(yhat, y, k=8, yhat_raw=None, calc_auc=True):
    """
        Inputs:
            yhat: binary predictions matrix 
            y: binary ground truth matrix
            k: for @k metrics
            yhat_raw: prediction scores matrix (floats)
        Outputs:
            dict holding relevant metrics
    """
    names = ["acc", "prec", "rec", "f1"]

    #macro
    macro = all_macro(yhat, y)

    #micro
    ymic = y.ravel()
    yhatmic = yhat.ravel()
    micro = all_micro(yhatmic, ymic)

    metrics = {names[i] + "_macro": macro[i] for i in range(len(macro))}
    metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

    #AUC and @k
    if yhat_raw is not None and calc_auc:
        #allow k to be passed as int or list
        if type(k) != list:
            k = [k]
        for k_i in k:
            rec_at_k = recall_at_k(yhat_raw, y, k_i)
            metrics['rec_at_%d' % k_i] = rec_at_k
            prec_at_k = precision_at_k(yhat_raw, y, k_i)
            metrics['prec_at_%d' % k_i] = prec_at_k
            metrics['f1_at_%d' % k_i] = 2*(prec_at_k*rec_at_k)/(prec_at_k+rec_at_k)

        roc_auc = auc_metrics(yhat_raw, y, ymic)
        metrics.update(roc_auc)

    return metrics

def all_macro(yhat, y):
    return macro_accuracy(yhat, y), macro_precision(yhat, y), macro_recall(yhat, y), macro_f1(yhat, y)

def all_micro(yhatmic, ymic):
    return micro_accuracy(yhatmic, ymic), micro_precision(yhatmic, ymic), micro_recall(yhatmic, ymic), micro_f1(yhatmic, ymic)

#########################################################################
#MACRO METRICS: calculate metric for each label and average across labels
#########################################################################

def macro_accuracy(yhat, y):
    num = intersect_size(yhat, y, 0) / (union_size(yhat, y, 0) + 1e-10)
    return np.mean(num)

def macro_precision(yhat, y):
    num = intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_recall(yhat, y):
    num = intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_f1(yhat, y):
    prec = macro_precision(yhat, y)
    rec = macro_recall(yhat, y)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

###################
# INSTANCE-AVERAGED
###################

def inst_precision(yhat, y):
    num = intersect_size(yhat, y, 1) / yhat.sum(axis=1)
    #correct for divide-by-zeros
    num[np.isnan(num)] = 0.
    return np.mean(num)

def inst_recall(yhat, y):
    num = intersect_size(yhat, y, 1) / y.sum(axis=1)
    #correct for divide-by-zeros
    num[np.isnan(num)] = 0.
    return np.mean(num)

def inst_f1(yhat, y):
    prec = inst_precision(yhat, y)
    rec = inst_recall(yhat, y)
    f1 = 2*(prec*rec)/(prec+rec)
    return f1

##############
# AT-K
##############

def recall_at_k(yhat_raw, y, k):
    #num true labels in top k predictions / num true labels
    sortd = np.argsort(yhat_raw)[:,::-1]
    topk = sortd[:,:k]

    #get recall at k for each example
    vals = []
    for i, tk in enumerate(topk):
        num_true_in_top_k = y[i,tk].sum()
        denom = y[i,:].sum()
        vals.append(num_true_in_top_k / float(denom))

    vals = np.array(vals)
    vals[np.isnan(vals)] = 0.

    return np.mean(vals)

def precision_at_k(yhat_raw, y, k):
    #num true labels in top k predictions / k
    sortd = np.argsort(yhat_raw)[:,::-1]
    topk = sortd[:,:k]

    #get precision at k for each example
    vals = []
    for i, tk in enumerate(topk):
        if len(tk) > 0:
            num_true_in_top_k = y[i,tk].sum()
            denom = len(tk)
            vals.append(num_true_in_top_k / float(denom))

    return np.mean(vals)

##########################################################################
#MICRO METRICS: treat every prediction as an individual binary prediction
##########################################################################

def micro_accuracy(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / union_size(yhatmic, ymic, 0)

def micro_precision(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)

def micro_recall(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / ymic.sum(axis=0)

def micro_f1(yhatmic, ymic):
    prec = micro_precision(yhatmic, ymic)
    rec = micro_recall(yhatmic, ymic)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

def auc_metrics(yhat_raw, y, ymic):
    if yhat_raw.shape[0] <= 1:
        return
    fpr = {}
    tpr = {}
    roc_auc = {}
    #get AUC for each label individually
    relevant_labels = []
    auc_labels = {}
    for i in range(y.shape[1]):
        #only if there are true positives for this label
        if y[:,i].sum() > 0:
            fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
            if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                auc_score = auc(fpr[i], tpr[i])
                if not np.isnan(auc_score): 
                    auc_labels["auc_%d" % i] = auc_score
                    relevant_labels.append(i)

    #macro-AUC: just average the auc scores
    aucs = []
    for i in relevant_labels:
        aucs.append(auc_labels['auc_%d' % i])
    roc_auc['auc_macro'] = np.mean(aucs)

    #micro-AUC: just look at each individual prediction
    yhatmic = yhat_raw.ravel()
    fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic) 
    roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

########################
# METRICS BY CODE TYPE
########################

def results_by_type(Y, mdir, version='mimic3'):
    d2ind = {}
    p2ind = {}

    #get predictions for diagnoses and procedures
    diag_preds = defaultdict(lambda: set([]))
    proc_preds = defaultdict(lambda: set([]))
    preds = defaultdict(lambda: set())
    with open('%s/preds_test.psv' % mdir, 'r') as f:
        r = csv.reader(f, delimiter='|')
        for row in r:
            if len(row) > 1:
                for code in row[1:]:
                    preds[row[0]].add(code)
                    if code != '':
                        try:
                            pos = code.index('.')
                            if pos == 3 or (code[0] == 'E' and pos == 4):
                                if code not in d2ind:
                                    d2ind[code] = len(d2ind)
                                diag_preds[row[0]].add(code)
                            elif pos == 2:
                                if code not in p2ind:
                                    p2ind[code] = len(p2ind)
                                proc_preds[row[0]].add(code)
                        except:
                            if len(code) == 3 or (code[0] == 'E' and len(code) == 4):
                                if code not in d2ind:
                                    d2ind[code] = len(d2ind)
                                diag_preds[row[0]].add(code)
    #get ground truth for diagnoses and procedures
    diag_golds = defaultdict(lambda: set([]))
    proc_golds = defaultdict(lambda: set([]))
    golds = defaultdict(lambda: set())
    test_file = '%s/test_%s.csv' % (MIMIC_3_DIR, str(Y)) if version == 'mimic3' else '%s/test.csv' % MIMIC_2_DIR
    with open(test_file, 'r') as f:
        r = csv.reader(f)
        #header
        next(r)
        for row in r:
            codes = set([c for c in row[3].split(';')])
            for code in codes:
                golds[row[1]].add(code)
                try:
                    pos = code.index('.')
                    if pos == 3:
                        if code not in d2ind:
                            d2ind[code] = len(d2ind)
                        diag_golds[row[1]].add(code)
                    elif pos == 2:
                        if code not in p2ind:
                            p2ind[code] = len(p2ind)
                        proc_golds[row[1]].add(code)
                except:
                    if len(code) == 3 or (code[0] == 'E' and len(code) == 4):
                        if code not in d2ind:
                            d2ind[code] = len(d2ind)
                        diag_golds[row[1]].add(code)

    hadm_ids = sorted(set(diag_golds.keys()).intersection(set(diag_preds.keys())))

    ind2d = {i:d for d,i in d2ind.items()}
    ind2p = {i:p for p,i in p2ind.items()}
    type_dicts = (ind2d, ind2p)
    return diag_preds, diag_golds, proc_preds, proc_golds, golds, preds, hadm_ids, type_dicts


def diag_f1(diag_preds, diag_golds, ind2d, hadm_ids):
    num_labels = len(ind2d)
    yhat_diag = np.zeros((len(hadm_ids), num_labels))
    y_diag = np.zeros((len(hadm_ids), num_labels))
    for i,hadm_id in tqdm(enumerate(hadm_ids)):
        yhat_diag_inds = [1 if ind2d[j] in diag_preds[hadm_id] else 0 for j in range(num_labels)]
        gold_diag_inds = [1 if ind2d[j] in diag_golds[hadm_id] else 0 for j in range(num_labels)]
        yhat_diag[i] = yhat_diag_inds
        y_diag[i] = gold_diag_inds
    return micro_f1(yhat_diag.ravel(), y_diag.ravel())

def proc_f1(proc_preds, proc_golds, ind2p, hadm_ids):
    num_labels = len(ind2p)
    yhat_proc = np.zeros((len(hadm_ids), num_labels))
    y_proc = np.zeros((len(hadm_ids), num_labels))
    for i,hadm_id in tqdm(enumerate(hadm_ids)):
        yhat_proc_inds = [1 if ind2p[j] in proc_preds[hadm_id] else 0 for j in range(num_labels)]
        gold_proc_inds = [1 if ind2p[j] in proc_golds[hadm_id] else 0 for j in range(num_labels)]
        yhat_proc[i] = yhat_proc_inds
        y_proc[i] = gold_proc_inds
    return micro_f1(yhat_proc.ravel(), y_proc.ravel())

def metrics_from_dicts(preds, golds, mdir, ind2c):
    with open('%s/pred_100_scores_test.json' % mdir, 'r') as f:
        scors = json.load(f)

    hadm_ids = sorted(set(golds.keys()).intersection(set(preds.keys())))
    num_labels = len(ind2c)
    yhat = np.zeros((len(hadm_ids), num_labels))
    yhat_raw = np.zeros((len(hadm_ids), num_labels))
    y = np.zeros((len(hadm_ids), num_labels))
    for i,hadm_id in tqdm(enumerate(hadm_ids)):
        yhat_inds = [1 if ind2c[j] in preds[hadm_id] else 0 for j in range(num_labels)]
        yhat_raw_inds = [scors[hadm_id][ind2c[j]] if ind2c[j] in scors[hadm_id] else 0 for j in range(num_labels)]
        gold_inds = [1 if ind2c[j] in golds[hadm_id] else 0 for j in range(num_labels)]
        yhat[i] = yhat_inds
        yhat_raw[i] = yhat_raw_inds
        y[i] = gold_inds
    return yhat, yhat_raw, y, all_metrics(yhat, y, yhat_raw=yhat_raw, calc_auc=False)


def union_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_or(yhat, y).sum(axis=axis).astype(float)

def intersect_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_and(yhat, y).sum(axis=axis).astype(float)

def print_metrics(metrics):
    print()
    if "auc_macro" in metrics.keys():
        print("[MACRO] accuracy, precision, recall, f-measure, AUC")
        print("%.4f, %.4f, %.4f, %.4f, %.4f" % (metrics["acc_macro"], metrics["prec_macro"], metrics["rec_macro"], metrics["f1_macro"], metrics["auc_macro"]))
    else:
        print("[MACRO] accuracy, precision, recall, f-measure")
        print("%.4f, %.4f, %.4f, %.4f" % (metrics["acc_macro"], metrics["prec_macro"], metrics["rec_macro"], metrics["f1_macro"]))

    if "auc_micro" in metrics.keys():
        print("[MICRO] accuracy, precision, recall, f-measure, AUC")
        print("%.4f, %.4f, %.4f, %.4f, %.4f" % (metrics["acc_micro"], metrics["prec_micro"], metrics["rec_micro"], metrics["f1_micro"], metrics["auc_micro"]))
    else:
        print("[MICRO] accuracy, precision, recall, f-measure")
        print("%.4f, %.4f, %.4f, %.4f" % (metrics["acc_micro"], metrics["prec_micro"], metrics["rec_micro"], metrics["f1_micro"]))
    for metric, val in metrics.items():
        if metric.find("rec_at") != -1:
            print("%s: %.4f" % (metric, val))
    print()

In [27]:


# def evaluate(outputs, labels, label_names=None):
#     gt = torch.cat(labels, dim=0)
#     pred = torch.cat(outputs, dim=0)
#     probs = pred
#     pred = torch.argmax(pred, dim=1)
#     acc = torch.div(100*torch.sum((gt == pred).float()), gt.shape[0])
#     name_dict = {0: 'Normal beat (N)', 1: 'Left bundle branch block beat (L)', 2: 'Right bundle branch block beat (R)', 3:
#         'Premature ventricular contraction (V)', 4: 'Atrial premature beat (A)', 5: 'Non classified (~)'}

#     print('accuracy :', acc)

#     gt = gt.cpu().tolist()
#     pred = pred.cpu().tolist()

#     report = skm.classification_report(
#         gt, pred,
#         target_names=[name_dict[i] for i in np.unique(gt)],
#         digits=3)
#     scores = skm.precision_recall_fscore_support(
#         gt,
#         pred,
#         average=None)
#     print(report)
#     print("F1 Average {:3f}".format(np.mean(scores[2][:3])))

#     fpr = dict()
#     tpr = dict()
#     roc_auc = dict()
#     n_classes = np.unique(gt).shape[0]
#     oh_gt = np.zeros((len(gt), n_classes))
#     plt.figure()
#     colors = ['b', 'g', 'r', 'c', 'm', 'y']

#     for i in range(n_classes):
#         oh_gt[:, gt == i] = 1
#         fpr[i], tpr[i], _ = roc_curve(gt, probs[:, i].cpu(), pos_label=i)
#         roc_auc[i] = auc(fpr[i], tpr[i])
#         lw = 2
#         plt.plot(fpr[i], tpr[i], color=colors[i],
#                  lw=lw, label=name_dict[i] +' : %0.4f' % roc_auc[i])
#     plt.xlim([0.0, 1.0])
#     plt.ylim([0.0, 1.05])
#     plt.xlabel('False Positive Rate')
#     plt.ylabel('True Positive Rate')
#     plt.title('Class-Wise AUC and ROC curve')
#     plt.legend(loc="lower right")
#     plt.savefig(os.path.join(args.checkpoint, 'roc.png'))
#     return 0



def train(trainloader, model, criterion, optimizer, epoch, use_cuda):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()


    for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader)):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = (inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda(),
                               inputs[3].cuda()), targets.cuda()
        inputs, targets = (torch.autograd.Variable(inputs[0]), torch.autograd.Variable(inputs[1]),
                           torch.autograd.Variable(inputs[2]),
                           torch.autograd.Variable(inputs[3])), torch.autograd.Variable(targets)
        optimizer.zero_grad()
        outputs = model(inputs)
        # print(outputs)
        # print(inputs)
        # print("*******************")
        # print(outputs.size() , targets.size())
        loss = criterion(outputs, targets)

        # print(outputs.size(), targets.size())
        # print(outputs, targets)
        # prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 6))
        output_np = outputs.data.cpu().numpy()
        output_np = np.nan_to_num(output_np)
        yhat =  (output_np > 0.5)
        met = all_metrics(yhat, targets.data.cpu().numpy(), k=3, yhat_raw=output_np, calc_auc=True)
        if float(torch.__version__[:3]) < 0.5:
            losses.update(loss.data[0], inputs[0].size(0))
            # top1.update(prec1[0], inputs[0].size(0))
            # top5.update(prec5[0], inputs[0].size(0))
        else:
            losses.update(loss.data, inputs[0].size(0))
            # top1.update(prec1, inputs[0].size(0))
            # top5.update(prec5, inputs[0].size(0))

        # compute gradient and do Adam step
        
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    # evaluate(pred, gt)
    return losses.avg,  met


def test(testloader, model, criterion, epoch, use_cuda, label_names=None):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()

    all_preds = []
    all_preds_raw = []
    all_labels = []
    for batch_idx, (inputs, targets) in tqdm(enumerate(testloader)):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = (inputs[0].cuda(), inputs[1].cuda(), inputs[2].cuda(),
                               inputs[3].cuda()), targets.cuda()
        inputs, targets = (torch.autograd.Variable(inputs[0]), torch.autograd.Variable(inputs[1]),
                           torch.autograd.Variable(inputs[2]),
                           torch.autograd.Variable(inputs[3])), torch.autograd.Variable(targets)

        # compute output
        st = time.time()
        # print(inputs.shape())
        with torch.no_grad():
          outputs = model(inputs)
        # print(time.time()-st)
        loss = criterion(outputs, targets)
        targets = targets.data.cpu().numpy()
        outputs = outputs.data.cpu().numpy()
        output_np = np.nan_to_num(outputs)
        yhat =  (output_np > 0.5)
        
        # measure accuracy and record loss
        
        # gt.append(targets.tolist())
        # pred.append(output_np.tolist())
        all_preds_raw.extend(list(output_np))
        all_preds.extend(list(yhat))
        all_labels.extend(list(targets))
        # print(len(all_preds_raw))
        # print(pred.shape)
        # print(y_hat.shape)

        # prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 6))
        if float(torch.__version__[:3]) < 0.5:
            losses.update(loss.data[0], inputs[0].size(0))
            # top1.update(prec1[0], inputs[0].size(0))
            # top5.update(prec5[0], inputs[0].size(0))
        else:
            losses.update(loss.data, inputs[0].size(0))
            # top1.update(prec1, inputs[0].size(0))
            # top5.update(prec5, inputs[0].size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    # print(pred)
    # gt_array = np.array(gt)
    # pred_array = np.array(pred)
    # print(pred_array.shape)
    
    # yhat =  (pred_array > 0.5)
    all_preds_raw = np.stack(all_preds_raw)
    all_preds = np.stack(all_preds)
    all_labels = np.stack(all_labels)
    met = all_metrics(yhat=all_preds, y=all_labels, k = 3,  yhat_raw=all_preds_raw)
    # met = all_metrics(y_hat, gt, k=3, yhat_raw=pred, calc_auc=True)
    # evaluate(pred, gt, label_names=label_names)
    return losses.avg, met


def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    global state
    if epoch in args.schedule:
        state['lr'] *= args.gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = state['lr']



In [None]:
import warnings
warnings.filterwarnings('ignore')
# Model
print("==> creating model ResNet{}".format(args.depth))
import model
num_classes = 6

model = model.__dict__['resnet_lstm_mitbih'](
            num_classes=num_classes,
            depth=args.depth,
            block_name=args.block_name,
        )

model = model.cuda()
cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

# Resume
title = 'ecg-lstm-resnet' + str(args.depth)
if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
    args.checkpoint = os.path.dirname(args.resume)
    checkpoint = torch.load(args.resume)
    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
else:
    logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
    logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

if args.evaluate:
    print('\nEvaluation only')
    test_loss, test_acc= test(testloader, model, criterion, start_epoch, use_cuda, label_names=label_names)
    print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))


# Train and val
for epoch in range(start_epoch, args.epochs):
    adjust_learning_rate(optimizer, epoch)
    print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))
    train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
    test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda, label_names=label_names)

    print("Train accuracy ", train_acc)
    print("test accuracy ", test_acc)

    # append logger file
    logger.append([state['lr'], train_loss, test_loss, train_acc['acc_micro'], test_acc['acc_micro']])

    # save model
    is_best = test_acc['acc_micro'] > best_acc
    best_acc = max(test_acc['acc_micro'], best_acc)
    save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'acc': test_acc['acc_micro'],
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint=args.checkpoint)

logger.close()
logger.plot()
savefig(os.path.join(args.checkpoint, 'log.eps'))

print('Best acc:')
print(best_acc)


==> creating model ResNet110
    Total params: 4.10M

Epoch: [1 | 25] LR: 0.001000


223it [02:34,  1.45it/s]
25it [00:04,  5.66it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.16666666666666666, 'prec_at_3': 0.05555555555555555, 'f1_at_3': 0.08333333333333333, 'auc_macro': 0.5, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [2 | 25] LR: 0.001000


223it [02:36,  1.42it/s]
25it [00:04,  5.58it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.16666666666666666, 'prec_at_3': 0.05555555555555555, 'f1_at_3': 0.08333333333333333, 'auc_macro': 0.5, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [3 | 25] LR: 0.001000


223it [02:37,  1.42it/s]
25it [00:04,  5.42it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.5, 'prec_at_3': 0.16666666666666666, 'f1_at_3': 0.25, 'auc_macro': 0.5, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [4 | 25] LR: 0.001000


223it [02:37,  1.42it/s]
25it [00:04,  5.43it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.0, 'prec_at_3': 0.0, 'f1_at_3': nan, 'auc_macro': 0.5, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [5 | 25] LR: 0.001000


223it [02:37,  1.41it/s]
25it [00:04,  5.57it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.16666666666666666, 'prec_at_3': 0.05555555555555555, 'f1_at_3': 0.08333333333333333, 'auc_macro': 0.5, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [6 | 25] LR: 0.001000


223it [02:37,  1.42it/s]
25it [00:04,  5.64it/s]


Train accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.0, 'prec_at_3': 0.0, 'f1_at_3': nan, 'auc_macro': nan, 'auc_micro': 0.5}
test accuracy  {'acc_macro': 0.0, 'prec_macro': 0.0, 'rec_macro': 0.0, 'f1_macro': 0.0, 'acc_micro': 0.0, 'prec_micro': nan, 'rec_micro': 0.0, 'f1_micro': nan, 'rec_at_3': 0.2727272727272727, 'prec_at_3': 0.09090909090909091, 'f1_at_3': 0.13636363636363635, 'auc_macro': 0.5, 'auc_micro': 0.5}

Epoch: [7 | 25] LR: 0.001000


104it [01:13,  1.39it/s]

In [163]:
import gc
torch.cuda.empty_cache()
gc.collect()

6591

In [10]:
yhat =  (x > 0.5)
yhat

array([[False, False, False, False, False, False],
       [False, False, False, False, False,  True]])

In [8]:
x = np.array([[ 0.4306,  0.1597, -0.0818, -0.1469,  0.2693, -0.0962],[ 0.4306,  0.1597, -0.0818, -0.1469,  0.4693, 0.5534]])
        
y = np.array([[1., 0., 0., 0., 0., 0.],[0., 0., 0., 0., 0., 1.]])
yhat =  (x > 0.5)
all_metrics(yhat, y, k=5, yhat_raw=x, calc_auc=True)

{'acc_macro': 0.16666666665,
 'prec_macro': 0.16666666665,
 'rec_macro': 0.16666666665,
 'f1_macro': 0.16666666665,
 'acc_micro': 0.5,
 'prec_micro': 1.0,
 'rec_micro': 0.5,
 'f1_micro': 0.6666666666666666,
 'rec_at_5': 1.0,
 'prec_at_5': 0.2,
 'f1_at_5': 0.33333333333333337,
 'auc_macro': 0.75,
 'auc_micro': 0.925}

In [None]:
|