In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from sklearn.model_selection import train_test_split

import pandas as pd
from pytorch_lightning import Trainer, seed_everything

from src.dataset import MITDataModule
from src.model import CNNResidual

In [2]:
#dataset = "ptbdb"
dataset = "mitbih"

if(dataset == "ptbdb"):
    nr_classes = 1
    df_ptbdb_normal= pd.read_csv("../data/ptbdb_normal.csv", header=None)
    df_ptbdb_abnormal = pd.read_csv("../data/ptbdb_abnormal.csv", header=None)
    df = pd.concat([df_ptbdb_normal, df_ptbdb_abnormal])

    df_train, df_test = train_test_split(df, test_size=0.2, random_state=1337, stratify=df[187])
    df_train, df_test = df_train.reset_index(drop=True), df_test.reset_index(drop=True)
else:
    nr_classes = 5
    df_train = pd.read_csv("../data/mitbih_train.csv", header=None)
    df_test = pd.read_csv("../data/mitbih_test.csv", header=None)

In [3]:
seed_everything(1234)

model = CNNResidual(nr_classes=nr_classes) # 1: Binary, 5: Non-Binary task
trainer = Trainer(max_epochs = 25)

mit = MITDataModule(df_train, df_test)
trainer.fit(model, datamodule=mit)

trainer.validate(model, datamodule=mit)

trainer.test(model, datamodule=mit)

Global seed set to 1234
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

   | Name      | Type                 | Params
----------------------------------------------------
0  | c1        | Conv1d               | 48    
1  | c2        | Conv1d               | 615   
2  | c3        | Conv1d               | 2.6 K 
3  | fc1       | Linear               | 113 K 
4  | fc2       | Linear               | 6.0 K 
5  | fc3       | Linear               | 33    
6  | fc4       | Linear               | 8.1 K 
7  | maxP      | MaxPool1d            | 0     
8  | dropout   | Dropout              | 0     
9  | flatten   | Flatten              | 0     
10 | flatten0  | Flatten              | 0     
11 | test_acc  | Accuracy             | 0     
12 | train_acc | Accuracy             | 0     
13 | valid_acc | Accuracy             | 0     
14 | test_f1   | F1Score              | 0     
15 | train_f1  | F1Score              | 0     
16 | valid_f

                                                              

Global seed set to 1234


Epoch 24: 100%|██████████| 364/364 [00:06<00:00, 55.41it/s, loss=0.456, v_num=70]
Validating: 100%|██████████| 73/73 [00:00<00:00, 85.94it/s] --------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.7221983671188354,
 'val_acc_epoch': 0.7221983671188354,
 'valid_loss': 0.4595416188240051}
--------------------------------------------------------------------------------
Validating: 100%|██████████| 73/73 [00:00<00:00, 86.74it/s]
Testing: 100%|██████████| 91/91 [00:01<00:00, 98.60it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test prc': 0.9728344678878784,
 'test roc': 0.929253339767456,
 'test_acc': 0.722088634967804,
 'test_f1': 0.8366101980209351,
 'test_loss': 0.4598807096481323}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 91/91 [00:01<00:00, 83.15it/s]


[{'test_loss': 0.4598807096481323,
  'test_f1': 0.8366101980209351,
  'test_acc': 0.722088634967804,
  'test roc': 0.929253339767456,
  'test prc': 0.9728344678878784}]