# setup

In [1]:
import os

if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [2]:
import torch
import torch.nn as nn

from models.baseline import ResnetBaseline
from runners.train import Runner
from utils import load_backbone

In [3]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('model_label', type = str, choices = ['fn_code', 'fn_cpsc2018', 'fn_ptbxl', 'fn_ningbo'])
args = parser.parse_args(args = ['fn_cpsc2018'])

In [4]:
if args.model_label == 'fn_code':
    from dataloaders.code import CODE as DS
    from dataloaders.code import CODEsplit as DSsplit

    epochs = 5
    n_classes = 6

if args.model_label == 'fn_cpsc2018':
    from dataloaders.cpsc2018 import CPSC2018 as DS
    from dataloaders.cpsc2018 import CPSC2018split as DSsplit

    epochs = 225
    n_classes = 8

if args.model_label == 'fn_ptbxl':
    from dataloaders.ptbxl import PTBXL as DS
    from dataloaders.ptbxl import PTBXLsplit as DSsplit

    epochs = 80
    n_classes = 5

if args.model_label == 'fn_ningbo':
    from dataloaders.ningbo import NINGBO as DS
    from dataloaders.ningbo import NINGBOsplit as DSsplit

    epochs = 50
    n_classes = 9

# init

In [7]:
database = DS()
model = ResnetBaseline(n_classes = n_classes)
model = load_backbone(model, 'output/backbone/backbone.pt')['model']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

0it [00:00, ?it/s]

787it [00:00, 7865.94it/s]

checking exam_id consistency in idx dict


5845it [00:00, 8521.15it/s]
687it [00:00, 8700.09it/s]
345it [00:00, 8624.03it/s]


checking exam_id consistency in idx dict
checking exam_id consistency in idx dict


In [9]:
runner = Runner(device = device, model = model, database = database, split = DSsplit, model_label = args.model_label)

# run

In [10]:
runner.train(epochs)

  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 0


100%|██████████| 46/46 [00:44<00:00,  1.03it/s]
100%|██████████| 6/6 [00:03<00:00,  1.81it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.5333446562290192
exporting partial model at epoch 0
-- epoch 1


100%|██████████| 46/46 [00:07<00:00,  6.42it/s]
100%|██████████| 6/6 [00:00<00:00,  7.18it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.3007611284653346
exporting partial model at epoch 1
-- epoch 2


100%|██████████| 46/46 [00:04<00:00,  9.76it/s]
100%|██████████| 6/6 [00:00<00:00,  9.23it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.27785633007685345
exporting partial model at epoch 2
-- epoch 3


100%|██████████| 46/46 [00:04<00:00,  9.88it/s]
100%|██████████| 6/6 [00:00<00:00,  8.89it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.26162130385637283
exporting partial model at epoch 3
-- epoch 4


100%|██████████| 46/46 [00:04<00:00,  9.81it/s]
100%|██████████| 6/6 [00:00<00:00,  8.86it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.2471661145488421
exporting partial model at epoch 4
-- epoch 5


100%|██████████| 46/46 [00:04<00:00,  9.85it/s]
100%|██████████| 6/6 [00:00<00:00,  9.05it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.23025084286928177
exporting partial model at epoch 5
-- epoch 6


100%|██████████| 46/46 [00:04<00:00,  9.85it/s]
100%|██████████| 6/6 [00:00<00:00,  9.37it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.21737365672985712
exporting partial model at epoch 6
-- epoch 7


100%|██████████| 46/46 [00:04<00:00,  9.84it/s]
100%|██████████| 6/6 [00:00<00:00,  9.26it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.20773082971572876
exporting partial model at epoch 7
-- epoch 8


100%|██████████| 46/46 [00:04<00:00,  9.75it/s]
100%|██████████| 6/6 [00:00<00:00,  9.30it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.19871420909961066
exporting partial model at epoch 8
-- epoch 9


100%|██████████| 46/46 [00:04<00:00,  9.58it/s]
100%|██████████| 6/6 [00:00<00:00,  8.95it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.19061376651128134
exporting partial model at epoch 9
-- epoch 10


100%|██████████| 46/46 [00:04<00:00,  9.87it/s]
100%|██████████| 6/6 [00:00<00:00,  6.92it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.17787447075049082
exporting partial model at epoch 10
-- epoch 11


100%|██████████| 46/46 [00:04<00:00,  9.82it/s]
100%|██████████| 6/6 [00:00<00:00,  9.18it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.16725360602140427
exporting partial model at epoch 11
-- epoch 12


100%|██████████| 46/46 [00:04<00:00,  9.76it/s]
100%|██████████| 6/6 [00:00<00:00,  9.01it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.16241253912448883
exporting partial model at epoch 12
-- epoch 13


100%|██████████| 46/46 [00:04<00:00,  9.92it/s]
100%|██████████| 6/6 [00:00<00:00,  9.08it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.15849869946638742
exporting partial model at epoch 13
-- epoch 14


100%|██████████| 46/46 [00:04<00:00,  9.76it/s]
100%|██████████| 6/6 [00:00<00:00,  9.23it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.15342521419127783
exporting partial model at epoch 14
-- epoch 15


100%|██████████| 46/46 [00:04<00:00,  9.83it/s]
100%|██████████| 6/6 [00:00<00:00,  9.07it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.15074478586514792
exporting partial model at epoch 15
-- epoch 16


100%|██████████| 46/46 [00:04<00:00,  9.89it/s]
100%|██████████| 6/6 [00:00<00:00,  9.26it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 17


100%|██████████| 46/46 [00:04<00:00,  9.82it/s]
100%|██████████| 6/6 [00:00<00:00,  9.42it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 18


100%|██████████| 46/46 [00:04<00:00,  9.83it/s]
100%|██████████| 6/6 [00:00<00:00,  9.12it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 19


100%|██████████| 46/46 [00:04<00:00,  9.85it/s]
100%|██████████| 6/6 [00:00<00:00,  9.67it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 20


100%|██████████| 46/46 [00:04<00:00,  9.42it/s]
100%|██████████| 6/6 [00:00<00:00,  6.86it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 21


100%|██████████| 46/46 [00:04<00:00,  9.80it/s]
100%|██████████| 6/6 [00:00<00:00,  9.47it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.14240149408578873
exporting partial model at epoch 21
-- epoch 22


100%|██████████| 46/46 [00:04<00:00,  9.79it/s]
100%|██████████| 6/6 [00:00<00:00,  9.16it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.14177675048510233
exporting partial model at epoch 22
-- epoch 23


100%|██████████| 46/46 [00:04<00:00,  9.74it/s]
100%|██████████| 6/6 [00:00<00:00,  9.24it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 24


100%|██████████| 46/46 [00:04<00:00,  9.82it/s]
100%|██████████| 6/6 [00:00<00:00,  8.97it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

new checkpoint with val loss: 0.14020604391892752
exporting partial model at epoch 24
-- epoch 25


100%|██████████| 46/46 [00:04<00:00,  9.77it/s]
100%|██████████| 6/6 [00:00<00:00,  9.12it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 26


100%|██████████| 46/46 [00:04<00:00,  9.78it/s]
100%|██████████| 6/6 [00:00<00:00,  9.13it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 27


100%|██████████| 46/46 [00:04<00:00,  9.85it/s]
100%|██████████| 6/6 [00:00<00:00,  9.16it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 28


100%|██████████| 46/46 [00:04<00:00,  9.92it/s]
100%|██████████| 6/6 [00:00<00:00,  9.19it/s]
  0%|          | 0/46 [00:00<?, ?it/s]

-- epoch 29


100%|██████████| 46/46 [00:04<00:00,  9.79it/s]
100%|██████████| 6/6 [00:00<00:00,  8.98it/s]


exporting last model


In [11]:
runner.eval()

100%|██████████| 6/6 [00:01<00:00,  5.42it/s]
100%|██████████| 3/3 [00:02<00:00,  1.48it/s]
