## Imports and Configuration

In [1]:
import os
import datetime
import wandb

import torch

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary

from config.load_configuration import load_configuration
from data.datamodule import ECG_DataModule
from model.model import UNET_1D

#### Loading configuration

This notebook loads configuration settings using the `load_configuration` function from the `config.load_configuration` module. The configuration is stored in the `config` variable.

In [2]:
config = load_configuration()

PC Name: DESKTOP-LUKAS
Loaded configuration from config/config_lukas.yaml


#### Logging in to Weights & Biases (wandb)

Before starting any experiment tracking, ensure you are logged in to your Weights & Biases (wandb) account. This enables automatic logging of metrics, model checkpoints, and experiment configurations. The following code logs you in to wandb:

```python
wandb.login()
```
If you are running this for the first time, you may be prompted to enter your API key.

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlukas-pelz[0m ([33mHKA-EKG-Signalverarbeitung[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

#### Setting Seeds for Reproducibility

To ensure comparable and reproducible results, we set the random seed using the `seed_everything` function from PyTorch Lightning. This helps in achieving consistent behavior across multiple runs of the notebook.

In [4]:
pl.seed_everything(config['seed'])
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"   # disable oneDNN optimizations for reproducibility
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

Seed set to 42


#### Checking for GPU Devices

In this step, we check for the availability of GPU devices and print the device currently being used by PyTorch. This ensures that the computations are performed on the most efficient hardware available.

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("="*40)
print(f"Torch Version      : {torch.__version__}")
print(f"Selected Device    : {device}")
if device.type == 'cuda':
    print(f"CUDA Version       : {torch.version.cuda}")
    print(f"Device Name        : {torch.cuda.get_device_name(0)}")
    allocated = torch.cuda.memory_allocated(0) / 1024**3
    reserved = torch.cuda.memory_reserved(0) / 1024**3
    print(f"Memory Usage       : Allocated: {allocated:.2f} GB | Reserved: {reserved:.2f} GB")
    torch.set_float32_matmul_precision('high')
else:
    print("CUDA not available, using CPU.")
print("="*40)

Torch Version      : 2.7.0+cu128
Selected Device    : cuda
CUDA Version       : 12.8
Device Name        : NVIDIA GeForce RTX 5060 Ti
Memory Usage       : Allocated: 0.00 GB | Reserved: 0.00 GB


#### Initializing the Data Module

The `ECG_DataModule` is initialized using the data path, batch size, and feature list from the configuration. This prepares the data for training and validation.

In [6]:
dm = ECG_DataModule(data_dir=config['path_to_data'], batch_size=config['batch_size'], feature_list=config['feature_list'])

#### Creating the Model

In this step, we will define the model architecture and print its summary using the `ModelSummary` utility from PyTorch Lightning. This provides an overview of the model's layers, parameters, and structure.

In [7]:
model = UNET_1D(in_channels=1, layer_n=512, out_channels=6, kernel_size=5)
print(ModelSummary(model, max_depth=-1))  
print(type(model).__name__)

    | Name                                            | Type                  | Params | Mode  | In sizes      | Out sizes    
------------------------------------------------------------------------------------------------------------------------------------
0   | criterion                                       | BCEWithLogitsLoss     | 0      | train | ?             | ?            
1   | train_jaccard                                   | BinaryJaccardIndex    | 0      | train | ?             | ?            
2   | val_jaccard                                     | BinaryJaccardIndex    | 0      | train | ?             | ?            
3   | test_jaccard                                    | BinaryJaccardIndex    | 0      | train | ?             | ?            
4   | multi_tolerance_metrics                         | MultiToleranceWrapper | 0      | train | ?             | ?            
5   | multi_tolerance_metrics.metrics                 | ModuleDict            | 0      | train | ?       

In [None]:
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
print(f"Current Time       : {current_time}")

# Initialize wandb logger (https://wandb.ai/HKA-EKG-Signalverarbeitung)
wandb_logger = WandbLogger(
    project=config['wandb_project_name'],
    name=f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}",
    config={
        'model': type(model).__name__,
        'dataset': type(dm).__name__,
        'batch_size': config['batch_size'],
        'max_epochs': config['max_epochs'],
        'learning_rate': config['learning_rate']
    }
)

# Initialize Trainer with wandb logger, using early stopping callback (https://lightning.ai/docs/pytorch/stable/common/early_stopping.html)
trainer = Trainer(
    max_epochs=config['max_epochs'], 
    default_root_dir='model/checkpoint/', #data_directory, 
    accelerator="auto", 
    devices="auto", 
    strategy="auto",
    callbacks=[EarlyStopping(monitor='val_loss', patience=5, mode='min')], 
    logger=wandb_logger)

trainer.fit(model=model, datamodule=dm)

# Finish wandb
wandb.finish()

# Create a filename with date identifier
model_filename = f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}.ckpt"

# Save the model's state_dict to the path specified in config
save_path = os.path.join(config['path_to_models'], model_filename)
trainer.save_checkpoint(save_path)
print(f"Model checkpoint saved as {save_path}")

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Current Time       : 2025-09-02_14-45


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                    | Type                  | Params | Mode  | In sizes      | Out sizes    
-----------------------------------------------------------------------------------------------------------
0  | criterion               | BCEWithLogitsLoss     | 0      | train | ?             | ?            
1  | train_jaccard           | BinaryJaccardIndex    | 0      | train | ?             | ?            
2  | val_jaccard             | BinaryJaccardIndex    | 0      | train | ?             | ?            
3  | test_jaccard            | BinaryJaccardIndex    | 0      | train | ?             | ?            
4  | multi_tolerance_metrics | MultiToleranceWrapper | 0      | train | ?             | ?            
5  | AvgPool1D1              | AvgPool1d             | 0      | train | [1, 64, 512]  | [1, 64, 256] 
6  | AvgPool1D2              | AvgPool1d             | 0      | train | [1, 128, 256] | [1, 128, 128]
7  | AvgPool1D3              | Av

Epoch 1:  76%|███████▌  | 2556/3359 [01:13<00:23, 34.77it/s, v_num=ihpm, train_loss_step=0.0936, val_loss_step=0.104, val_loss_epoch=0.122, train_loss_epoch=0.228]