# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import torch
import torch.nn as nn

from models.baseline import ResnetBaseline
from runners.train import Runner

In [3]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('model_label', type = str, choices = ['code', 'cpsc2018', 'ptbxl', 'ningbo'])
args = parser.parse_args(args = ['ningbo'])

In [4]:
if args.model_label == 'code':
    from dataloaders.code import CODE as DS
    from dataloaders.code import CODEsplit as DSsplit

    epochs = 5
    n_classes = 6

if args.model_label == 'cpsc2018':
    from dataloaders.cpsc2018 import CPSC2018 as DS
    from dataloaders.cpsc2018 import CPSC2018split as DSsplit

    epochs = 30
    n_classes = 8

if args.model_label == 'ptbxl':
    from dataloaders.ptbxl import PTBXL as DS
    from dataloaders.ptbxl import PTBXLsplit as DSsplit

    epochs = 80
    n_classes = 5

if args.model_label == 'ningbo':
    from dataloaders.ningbo import NINGBO as DS
    from dataloaders.ningbo import NINGBOsplit as DSsplit

    epochs = 50
    n_classes = 9

# init

In [7]:
database = DS()
model = ResnetBaseline(n_classes = n_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

828it [00:00, 8279.63it/s]

checking exam_id consistency in idx dict


38379it [00:04, 8305.64it/s]
838it [00:00, 8374.08it/s]

checking exam_id consistency in idx dict


4515it [00:00, 8366.39it/s]
1675it [00:00, 8375.90it/s]

checking exam_id consistency in idx dict


2258it [00:00, 8371.07it/s]


In [8]:
runner = Runner(device = device, model = model, database = database, split = DSsplit, model_label = args.model_label)

# run

In [None]:
runner.train(epochs)

In [None]:
runner.eval()

# draft

In [36]:
from tqdm import tqdm

from utils import get_inputs

In [16]:
model = model.cuda()
trn_ds = DSsplit(database, database.trn_idx_dict)
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size = 128, shuffle = True, num_workers = 1)
criterion = nn.BCEWithLogitsLoss()

In [17]:
with torch.no_grad():
    for batch in tqdm(trn_dl):
        raw = batch['X']
        label = batch['y']
        ecg = get_inputs(raw, device = device)
        label = label.to(device).float()
        
        logits = model.forward(ecg)
        loss = criterion(logits, label)
        print(loss)

  0%|          | 1/300 [00:02<13:05,  2.63s/it]

tensor(0.7047, device='cuda:0')


  1%|          | 2/300 [00:04<12:06,  2.44s/it]

tensor(nan, device='cuda:0')


  1%|          | 2/300 [00:05<14:34,  2.93s/it]


KeyboardInterrupt: 

In [19]:
logits = model.forward(ecg)

In [28]:
ecg[0:1, :, :].shape

torch.Size([1, 12, 4096])

In [42]:
for i in range(128):
    logits = model.forward(ecg[i:i+1, :, :])
    assert not torch.isnan(logits).any()

AssertionError: 

In [43]:
i, logits

(42,
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0',
        grad_fn=<AddmmBackward>))

In [44]:
ecg[42, :, :]

tensor([[-0.2304, -0.2595, -0.2554,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1070,  0.1207,  0.1188,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3374,  0.3802,  0.3742,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-2.3355, -2.6839, -2.5479,  ...,  0.0000,  0.0000,  0.0000],
        [    nan,     nan,     nan,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1938,  0.2194,  0.2144,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

In [20]:
criterion(logits, label.float())

tensor(nan, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)