In [3]:
from pathlib import Path

import torch
import pytorch_lightning as pl
import wandb

from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

from src.lit_models.ptbxl_model import ECGClassifier
from src.models.resnet1d import resnet1d_wang
from src.models.conv_transformer import conv_transformer
from pytorch_lightning.loggers import WandbLogger

from src.data.ptb_xl_multiclass_datamodule import PTB_XL_Datamodule
from torchmetrics.classification import MulticlassAccuracy

import os
from datetime import datetime

In [4]:
def get_model_registry():
    return {
        "resnet1d_wang": resnet1d_wang,
        "conv_transformer": conv_transformer
    }

In [5]:
def create_directory_with_timestamp(path, prefix):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    dir_name = f"{prefix}_{timestamp}"
    full_path = os.path.join(path, dir_name)
    os.makedirs(full_path, exist_ok=True)

    return full_path

In [6]:
def get_datamodule(run, FILTER_FOR_SINGLELABEL, BATCH_SIZE):
    artifact = run.use_artifact(f"{'ptbxl_split'}:latest")

    datadir = artifact.download()

    data_module = PTB_XL_Datamodule(Path(datadir), filter_for_singlelabel=FILTER_FOR_SINGLELABEL, batch_size=BATCH_SIZE)

    data_module.prepare_data()
    data_module.setup()

    return data_module

In [7]:
def get_model(total_optimizer_steps, model_config, model_name="resnet1d_wang", task='multilabel', loss=torch.nn.BCEWithLogitsLoss()):
    model = get_model_registry()[model_name](
    **model_config
)

    model_lit = ECGClassifier(
        model, 5, loss, 0.01, wd=0.01, total_optimizer_steps=total_optimizer_steps, task=task)
    
    return model_lit

In [8]:
def train_model(model_lit, data_module, config):
    wandb_logger = WandbLogger(log_model="all")
    wandb_logger.watch(model_lit, log="all")

    dir_model = create_directory_with_timestamp("./models", "resnet1d_wang")

    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=30, verbose=False, mode="min")
    learning_rate_monitor = LearningRateMonitor(logging_interval="step", log_momentum=True)

    # Create the Learner
    trainer = pl.Trainer(
        accumulate_grad_batches=config.ACCUMULATE_GRADIENT_STEPS,
        log_every_n_steps=1,
        max_epochs=config.EPOCHS,
        logger=wandb_logger,
        callbacks=[early_stop_callback, learning_rate_monitor],
    )

    trainer.fit(model_lit, datamodule=data_module)

    return trainer

In [9]:
def validate_model(trainer, data_module, metrics={}):
    res = trainer.predict(dataloaders=data_module.test_dataloader())

    y_hat, y = torch.concatenate([x[0] for x in res]), torch.concatenate([x[1] for x in res])

    y_hat = torch.nn.functional.sigmoid(y_hat)

    metrics  = {
        'multiclass_accuracy': MulticlassAccuracy(num_classes=y_hat.size(1), average='weighted')
    }

    target = torch.argmax(y, axis=-1)
    preds = torch.argmax(y_hat, axis=-1)



    return {
        k: v(preds, target) for k, v in metrics.items()
    }


In [10]:

def train_model_with_validation(config, project="ecg_benchmarking_lit", name="test_run", entity="phd-dk"):

    run = wandb.init(project=project, name=name, entity=entity, config=config)

    BATCH_SIZE = run.config.BATCH_SIZE
    FILTER_FOR_SINGLELABEL = run.config.FILTER_FOR_SINGLELABEL

    loss = torch.nn.BCEWithLogitsLoss() if not FILTER_FOR_SINGLELABEL else torch.nn.CrossEntropyLoss()
    task = "multilabel" if not FILTER_FOR_SINGLELABEL else "multiclass"

    data_module = get_datamodule(run, FILTER_FOR_SINGLELABEL, BATCH_SIZE)
    print(len(data_module.val_dataset))

    total_optimizer_steps = int(len(data_module.train_dataset) * run.config.EPOCHS / run.config.ACCUMULATE_GRADIENT_STEPS)

    model_lit = get_model(total_optimizer_steps, run.config.model_config, run.config.model_name, task, loss)

    trainer = train_model(model_lit, data_module, run.config)


    trainer.test(model=trainer.model, dataloaders=data_module.test_dataloader())



    # results = validate_model(trainer=trainer, data_module=data_module)

    # wandb_code = run.log({
    #     f"test/{metric_name}": metric_value for metric_name, metric_value in results.items()
    # })

    # print(wandb_code, {
    #     f"test/{metric_name}": metric_value for metric_name, metric_value in results.items()
    # })

    run.finish()

    return trainer, data_module, model_lit


In [9]:

model_config = dict(
    k = 12,
    headers = 10,
    depth = 5,
    seq_length= 128
)

config = {
    "BATCH_SIZE": 10,
    "EPOCHS": 50,
    "ACCUMULATE_GRADIENT_STEPS": 1,
    "FILTER_FOR_SINGLELABEL" : False,
    "model_config": model_config,
    "model_name": "conv_transformer"
}

In [10]:
trainer, data_module, model = train_model_with_validation(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33markadiusz-czerwinski[0m ([33mphd-dk[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact ptbxl_split:latest, 1800.85MB. 9 files... 
[34m[1mwandb[0m:   9 of 9 files downloaded.  
Done. 0:0:3.5
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\loggers\wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


14903


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

   | Name                   | Type               

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:04<00:00,  0.50it/s]



Epoch 34:   1%|          | 1/117 [00:00<00:08, 13.00it/s, v_num=xyvr, train_loss_step=0.214, val_loss_step=0.466, val_loss_epoch=0.241, train_loss_epoch=0.226]  

c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0:   9%|▉         | 11/118 [00:00<00:06, 15.62it/s]



Testing DataLoader 0: 100%|██████████| 118/118 [00:05<00:00, 21.40it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Test metric               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    test_accuracy_epoch        0.8827852606773376
     test_auroc_epoch          0.9295147061347961
test_averageprecision_epoch    0.8131654858589172
    test_f1score_epoch         0.7669419646263123
test_matthewscorrcoef_epoch    0.6894055604934692
   test_precision_epoch        0.7927109599113464
     test_recall_epoch         0.7427955865859985
  test_specificity_epoch       0.9318802952766418
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('phd-dk/ecg_benchmarking_lit/model-ied6sv54:v4', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33markadiusz-czerwinski[0m ([33mavatar2pjm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [12]:
import torchmetrics
torchmetrics.__version__

'1.2.0'

In [None]:
artifact_dir

'.\\artifacts\\model-ied6sv54-v4'

In [None]:
trainer

<pytorch_lightning.trainer.trainer.Trainer at 0x24f62db07f0>

In [None]:
data_module

<src.data.ptb_xl_multiclass_datamodule.PTB_XL_Datamodule at 0x24f60287850>

In [None]:
from torchmetrics import F1Score
from torchmetrics.classification import MulticlassF1Score, MultilabelF1Score
from sklearn.metrics import f1_score

In [None]:
def validate_model(trainer, loader, metrics={}):
    res = trainer.predict(dataloaders=loader)

    y_hat, y = torch.concatenate([x[0] for x in res]), torch.concatenate([x[1] for x in res])

    y_hat = torch.nn.functional.sigmoid(y_hat)

    metrics  = {
        'multiclass_accuracy': MulticlassAccuracy(num_classes=y_hat.size(1), average='weighted')
    }

    # metric = F1Score(num_labels=5, task='multilabel')

    metric = f1_score

    result = metric(y_hat, y, average='micro')



    return result


In [None]:
validate_model(trainer, data_module.test_dataloader())

Restoring states from the checkpoint path at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 118/118 [00:01<00:00, 59.83it/s]


ValueError: Classification metrics can't handle a mix of continuous-multioutput and multilabel-indicator targets

In [None]:
validate_model(trainer, data_module.train_dataloader())

Restoring states from the checkpoint path at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:492: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 117/117 [00:01<00:00, 59.36it/s]


tensor(0.8790)

In [None]:
validate_model(trainer, data_module.val_dataloader())

Restoring states from the checkpoint path at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at .\lightning_logs\phun1fnc\checkpoints\epoch=49-step=5850.ckpt
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 117/117 [00:01<00:00, 62.26it/s]


tensor(0.8039)