In [40]:
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

from dataset import ECGDataset
from resnet import resnet34, resnet50
from utils import cal_f1s, cal_aucs, split_data

In [41]:

data_dir = 'data\ptb-xl'
database = os.path.basename(data_dir)


In [42]:
model_path = 'models\resnet50_ptb-xl_all_42.pth'
device = 'cpu'


leads = 'all'
nleads = 12

label_csv = os.path.join(data_dir, 'labels.csv') #Modified. Long. 23.Mar.24, original: os.path.join(data_dir, 'labelx.csv')

train_folds, val_folds, test_folds = split_data(seed=42)
train_dataset = ECGDataset('train', data_dir, label_csv, train_folds, leads)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
#shape train_loader data
for data in train_loader:
    inputs, targets = data
    print(inputs.shape)
    print(targets.shape)
    break

val_dataset = ECGDataset('val', data_dir, label_csv, val_folds, leads)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
test_dataset = ECGDataset('test', data_dir, label_csv, test_folds, leads)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
net = resnet34(input_channels=nleads).to(device) #Modified. Long. 23.Mar.24, original: resnet50
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.1)

criterion = nn.BCEWithLogitsLoss()

#print net output shape
for data in train_loader:
    inputs, targets = data
    outputs = net(inputs.to(device))
    print(outputs.shape)
    print(targets.shape)
    break
    


torch.Size([16, 12, 15000])
torch.Size([16, 37])
torch.Size([16, 9])
torch.Size([16, 37])


In [36]:
def train(dataloader, net, args, criterion, epoch, scheduler, optimizer, device):
    print('Training epoch %d:' % epoch)
    net.train()
    running_loss = 0
    output_list, labels_list = [], []
    for _, (data, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        output = net(data)
        # print(output)
        # print(labels)
        # print(criterion(output, labels))
        print(data.shape)
        print(output.shape)
        print(labels.shape)
        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    # scheduler.step()
    print('Loss: %.4f' % running_loss)
    

def evaluate(dataloader, net, args, criterion, device):
    best_metric = 0
    print('Validating...')
    net.eval()
    running_loss = 0
    output_list, labels_list = [], []
    for _, (data, labels) in enumerate(tqdm(dataloader)):
        data, labels = data.to(device), labels.to(device)
        output = net(data)
        # print(output)
        # print(labels)
        # print(criterion(output, labels))
        print(output.shape)
        print(labels.shape)
        loss = criterion(output, labels)
        running_loss += loss.item()
        output = torch.sigmoid(output)
        output_list.append(output.data.cpu().numpy())
        labels_list.append(labels.data.cpu().numpy())
    print('Loss: %.4f' % running_loss)
    y_trues = np.vstack(labels_list)
    y_scores = np.vstack(output_list)
    f1s = cal_f1s(y_trues, y_scores)
    avg_f1 = np.mean(f1s)
    print('F1s:', f1s)
    print('Avg F1: %.4f' % avg_f1)
    if avg_f1 > best_metric:
        best_metric = avg_f1
        # torch.save(net.state_dict(), args.model_path)
    else:
        aucs = cal_aucs(y_trues, y_scores)
        avg_auc = np.mean(aucs)
        print('AUCs:', aucs)
        print('Avg AUC: %.4f' % avg_auc)

In [37]:
#train phase

# net.load_state_dict(torch.load(model_path, map_location=device))
for epoch in range(2):
    train(train_loader, net, '', criterion, epoch, scheduler, optimizer, device)
    evaluate(val_loader, net, '', criterion, device)


Training epoch 0:


  0%|          | 0/11 [00:00<?, ?it/s]

torch.Size([16, 12, 15000])
torch.Size([16, 9])
torch.Size([16, 37])


  0%|          | 0/11 [00:06<?, ?it/s]


ValueError: Target size (torch.Size([16, 37])) must be the same as input size (torch.Size([16, 9]))