In [1]:
from pathlib import Path

import torch
import pytorch_lightning as pl
import wandb

from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor

from src.lit_models.ptbxl_model import ECGClassifier
from src.models.resnet1d import resnet1d_wang
# from src.models.xresnet1d import xresned1d101
from pytorch_lightning.loggers import WandbLogger

from src.data.ptb_xl_multiclass_datamodule import PTB_XL_Datamodule
from torchmetrics.classification import MulticlassAccuracy

import os
from datetime import datetime

ImportError: cannot import name 'xresned1d101' from 'src.models.xresnet1d' (c:\Users\arekp\OneDrive\Desktop\ecg_benchmarking_lit\src\models\xresnet1d.py)

In [None]:
def get_model_registry():
    return {
        "resnet1d_wang": resnet1d_wang,
        "xresnet1d101": xresned1d101
    }

In [None]:
def create_directory_with_timestamp(path, prefix):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    dir_name = f"{prefix}_{timestamp}"
    full_path = os.path.join(path, dir_name)
    os.makedirs(full_path, exist_ok=True)

    return full_path

In [None]:
def get_datamodule(run, FILTER_FOR_SINGLELABEL, BATCH_SIZE):
    artifact = run.use_artifact(f"{'ptbxl_split'}:latest")

    datadir = artifact.download()

    data_module = PTB_XL_Datamodule(Path(datadir), filter_for_singlelabel=FILTER_FOR_SINGLELABEL, batch_size=BATCH_SIZE)

    data_module.prepare_data()
    data_module.setup()

    return data_module

In [None]:
def get_model(total_optimizer_steps, model_config, model_name="resnet1d_wang", loss=torch.nn.BCEWithLogitsLoss()):
    model = get_model_registry()[model_name](
    **model_config
)

    model_lit = ECGClassifier(
        model, 5, loss, 0.01, wd=0.01, total_optimizer_steps=total_optimizer_steps)
    
    return model_lit

In [None]:
def train_model(model_lit, data_module, config):
    wandb_logger = WandbLogger(log_model="all")
    wandb_logger.watch(model_lit, log="all")

    dir_model = create_directory_with_timestamp("./models", "resnet1d_wang")

    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=30, verbose=False, mode="min")
    learning_rate_monitor = LearningRateMonitor(logging_interval="step", log_momentum=True)

    # Create the Learner
    trainer = pl.Trainer(
        accumulate_grad_batches=config.ACCUMULATE_GRADIENT_STEPS,
        log_every_n_steps=1,
        max_epochs=config.EPOCHS,
        logger=wandb_logger,
        callbacks=[early_stop_callback, learning_rate_monitor],
    )

    trainer.fit(model_lit, datamodule=data_module)

    return trainer

In [None]:
def validate_model(trainer, data_module, metrics={}):
    res = trainer.predict(dataloaders=data_module.test_dataloader())

    y_hat, y = torch.concatenate([x[0] for x in res]), torch.concatenate([x[1] for x in res])

    y_hat = torch.nn.functional.sigmoid(y_hat)

    metrics  = {
        'multiclass_accuracy': MulticlassAccuracy(num_classes=y_hat.size(1), average='weighted')
    }

    target = torch.argmax(y, axis=-1)
    preds = torch.argmax(y_hat, axis=-1)



    return {
        k: v(preds, target) for k, v in metrics.items()
    }


In [None]:

def train_model_with_validation(config, project="ecg_benchmarking_lit", name="test_run", entity="phd-dk"):

    run = wandb.init(project=project, name=name, entity=entity, config=config)

    BATCH_SIZE = run.config.BATCH_SIZE
    FILTER_FOR_SINGLELABEL = run.config.FILTER_FOR_SINGLELABEL

    loss = torch.nn.BCEWithLogitsLoss() if not FILTER_FOR_SINGLELABEL else torch.nn.CrossEntropyLoss()

    data_module = get_datamodule(run, FILTER_FOR_SINGLELABEL, BATCH_SIZE)
    print(len(data_module.val_dataset))

    total_optimizer_steps = int(len(data_module.train_dataset) * run.config.EPOCHS / run.config.ACCUMULATE_GRADIENT_STEPS)

    model_lit = get_model(total_optimizer_steps, run.config.model_config, run.config.model_name, loss)

    trainer = train_model(model_lit, data_module, run.config)


    trainer.test(model=trainer.model, dataloaders=data_module.test_dataloader())



    # results = validate_model(trainer=trainer, data_module=data_module)

    # wandb_code = run.log({
    #     f"test/{metric_name}": metric_value for metric_name, metric_value in results.items()
    # })

    # print(wandb_code, {
    #     f"test/{metric_name}": metric_value for metric_name, metric_value in results.items()
    # })

    run.finish()


In [None]:

model_config = dict(
    num_classes=5,
    input_channels=12,
    kernel_size=5,
    ps_head=0.5,
    lin_ftrs_head=[128],
)
config = {
    "BATCH_SIZE": 128,
    "EPOCHS": 50,
    "ACCUMULATE_GRADIENT_STEPS": 1,
    "FILTER_FOR_SINGLELABEL" : False,
    "model_config": model_config,
    "model_name": "resnet1d_wang"
}

In [None]:
train_model_with_validation(config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33markadiusz-czerwinski[0m ([33mphd-dk[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact ptbxl_split:latest, 1800.85MB. 9 files... 
[34m[1mwandb[0m:   9 of 9 files downloaded.  
Done. 0:0:2.6
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\loggers\wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


14903


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

   | Name                   | Type               

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.12it/s]



Epoch 39: 100%|██████████| 117/117 [00:11<00:00,  9.93it/s, v_num=vilb, train_loss_step=0.194, val_loss_step=0.478, val_loss_epoch=0.236, train_loss_epoch=0.191]

`Trainer.fit` stopped: `max_epochs=40` reached.


Epoch 39: 100%|██████████| 117/117 [00:12<00:00,  9.48it/s, v_num=vilb, train_loss_step=0.194, val_loss_step=0.478, val_loss_epoch=0.236, train_loss_epoch=0.191]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\arekp\anaconda3\envs\phd\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0:   3%|▎         | 3/118 [00:00<00:11, 10.18it/s]



Testing DataLoader 0: 100%|██████████| 118/118 [00:04<00:00, 24.52it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Test metric               DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    test_accuracy_epoch        0.6889894604682922
     test_auroc_epoch          0.8707942962646484
test_averageprecision_epoch    0.5935240983963013
    test_f1score_epoch         0.6889894604682922
test_matthewscorrcoef_epoch    0.5619130730628967
   test_precision_epoch        0.6889894604682922
     test_recall_epoch         0.6889894604682922
  test_specificity_epoch       0.9222473502159119
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr-AdamW,▁▁▂▂▃▄▅▆▆▇███████▇▇▇▇▆▆▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁
lr-AdamW-momentum,██▇▇▆▅▄▃▃▂▁▁▁▁▁▁▁▂▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇▇█████
test_accuracy_epoch,▁
test_accuracy_step,▄▆▇▄▄▆▇▆▅▁▂▅▄▅▃▅▂▆█▃▄▆▆▂▄▄▅▂▅▆▅▅▇▇▇▅▆▇▄▄
test_auroc_epoch,▁
test_auroc_step,█▅▆█▇▆▆█▆▆▇▅█▃▆▆▇▅▆▃▆▃▅▆▄▅▃▅█▃▅▅▆▅▃█▇▅▅▁
test_averageprecision_epoch,▁
test_averageprecision_step,▅▆▆▄▆▅▆▅▇▂▃▄▅▄▂▂▄▃▇▃▃▆▅▃▃▆▆▁▆▄▅▅▇▄█▆▄▅▄█
test_f1score_epoch,▁

0,1
epoch,40.0
lr-AdamW,0.0
lr-AdamW-momentum,0.95
test_accuracy_epoch,0.68899
test_accuracy_step,0.63043
test_auroc_epoch,0.87079
test_auroc_step,0.4
test_averageprecision_epoch,0.59352
test_averageprecision_step,1.0
test_f1score_epoch,0.68899
