In [1]:
import os
import math
import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.utilities.model_summary import ModelSummary
from torchvision.transforms import v2

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
from models.model_cnn import KaninchenModel, KaninchenModelResidual, KaninchenModel_v3, KaninchenModel_v11, KaninchenModel_v12, KaninchenModel_v13, KaninchenModel_v14, KaninchenModel_v15, KaninchenModel_v16, KaninchenModel_v17, KaninchenModel_v18, KaninchenModel_v19, KaninchenModel_v20, KaninchenModel_v21
from data.datamodule import BinaryImageDataModule, ReducedSizeBinaryImageDataModule

import optuna
from training.hyperparameter_tuning import CnnOptunaTrainer

  from .autonotebook import tqdm as notebook_tqdm


### Loading Configuration

In the following steps, we will load the configuration settings using the `load_configuration` function. The configuration is stored in the `config` variable which will be used throughout the script.

In [2]:
from config.load_configuration import load_configuration
config = load_configuration()

PC Name: DESKTOP-LUKAS
Loaded configuration from config/config_lukas.yaml


### Logging in to Weights & Biases (wandb)

Before starting any experiment tracking, ensure you are logged in to your Weights & Biases (wandb) account. This enables automatic logging of metrics, model checkpoints, and experiment configurations. The following code logs you in to wandb:

```python
wandb.login()
```
If you are running this for the first time, you may be prompted to enter your API key.

In [3]:
# Initialize the Wandb logger
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlukas-pelz[0m ([33mHKA-EKG-Signalverarbeitung[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Setting Seeds for Reproducibility

To ensure comparable and reproducible results, we set the random seed using the `seed_everything` function from PyTorch Lightning. This helps in achieving consistent behavior across multiple runs of the notebook.

In [4]:
pl.seed_everything(config['seed'])
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"   # disable oneDNN optimizations for reproducibility

Seed set to 42


### Checking for GPU Devices

In this step, we check for the availability of GPU devices and print the device currently being used by PyTorch. This ensures that the computations are performed on the most efficient hardware available.

In [5]:
# Check if CUDA is available and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('Torch Version: ', torch.__version__)
print('Using device: ', device)
if device.type == 'cuda':
    print('Cuda Version: ', torch.version.cuda)
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
    torch.set_float32_matmul_precision('high')

Torch Version:  2.7.0+cu128
Using device:  cuda
Cuda Version:  12.8
NVIDIA GeForce RTX 5060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


### Defining Transformations and Instantiating DataModule

In this step, we will define the necessary data transformations and initialize the `Animal_DataModule` with the provided configuration.

In [6]:
# TODO: Define transformations here
import os

size = config['image_size']

import data.custom_transforms as custom_transforms
transform = v2.Compose([
    custom_transforms.CenterCropSquare(),
    v2.Resize((size, size)),
    v2.ToTensor(),
    # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# dm = ReducedSizeBinaryImageDataModule(data_dir=config['path_to_split_aug_pics'], transform=transform, batch_size=config['batch_size'], num_workers=6) #, persistent_workers=True)
dm = BinaryImageDataModule(data_dir=config['path_to_split_aug_pics'], transform=transform, batch_size=config['batch_size'], num_workers=6, persistent_workers=True)
dm.setup()

# train_loader = dm.train_dataloader()
# val_loader = dm.val_dataloader()
# test_loader = dm.test_dataloader()

# # Show a few images from the training set
# from torchvision.utils import make_grid
# def show_images(loader):
#     images, labels = next(iter(loader))
#     images = images[:16]  # Show only the first 16 images
#     labels = labels[:16]
#     grid = make_grid(images, nrow=4, padding=2)
#     plt.figure(figsize=(10, 10))
#     plt.imshow(grid.permute(1, 2, 0).numpy())
#     plt.title('Sample Images')
#     plt.axis('off')
#     plt.show()
# show_images(train_loader)

# print('Train dataset size:', len(dm.train_dataset))
# print('Validation dataset size:', len(dm.val_dataset))
# print('Test dataset size:', len(dm.test_dataset))




### Creating the Model

In this step, we will define the model architecture and print its summary using the `ModelSummary` utility from PyTorch Lightning. This provides an overview of the model's layers, parameters, and structure.

In [7]:
#model = CatsDogsModel()
# model = KaninchenModel()
model = KaninchenModel_v3()
print(ModelSummary(model, max_depth=-1))  
print(type(model).__name__)

   | Name                | Type              | Params | Mode 
-------------------------------------------------------------------
0  | criterion           | BCEWithLogitsLoss | 0      | train
1  | sigmoid             | Sigmoid           | 0      | train
2  | model               | Sequential        | 1.1 M  | train
3  | model.0             | Conv2d            | 896    | train
4  | model.1             | BatchNorm2d       | 64     | train
5  | model.2             | ReLU              | 0      | train
6  | model.3             | MaxPool2d         | 0      | train
7  | model.4             | Conv2d            | 18.5 K | train
8  | model.5             | BatchNorm2d       | 128    | train
9  | model.6             | ReLU              | 0      | train
10 | model.7             | MaxPool2d         | 0      | train
11 | model.8             | Conv2d            | 73.9 K | train
12 | model.9             | BatchNorm2d       | 256    | train
13 | model.10            | ReLU              | 0      | train
14

### Training the Model and Logging with Weights & Biases

In this step, we initialize the Wandb logger and configure the experiment name to include a timestamp for better tracking. The `Trainer` from PyTorch Lightning is set up with the Wandb logger and an early stopping callback to monitor validation loss and prevent overfitting. After training, the Wandb run is finished, and the trained model checkpoint is saved with a unique filename containing the current date and time.

In [None]:
# Initialize the Wandb logger
# add time to the name of the experiment
import datetime
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d_%H-%M-%S")

# Initialize wandb logger
wandb_logger = WandbLogger(
    project=config['wandb_project_name'],
    name=f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}",
    config={
        'model': type(model).__name__,
        'dataset': 'DwarfRabbits-binary',
        'batch_size': config['batch_size'],
        'max_epochs': config['max_epochs'],
        'learning_rate': config['learning_rate']
        'image_size': config['image_size'],
    }
)

# Initialize Trainer with wandb logger, using early stopping callback (https://lightning.ai/docs/pytorch/stable/common/early_stopping.html)
trainer = Trainer(
    max_epochs=config['max_epochs'], 
    default_root_dir='model/checkpoint/', #data_directory, 
    accelerator="auto", 
    devices="auto", 
    strategy="auto",
    callbacks=[EarlyStopping(monitor='val_loss', patience=5, mode='min')], 
    logger=wandb_logger)

# Training of the model
trainer.fit(model=model, datamodule=dm)

# Finish wandb
wandb.finish()

# Create a filename with date identifier
model_filename = f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}.ckpt"

# Save the model's state_dict to the path specified in config
save_path = os.path.join(os.path.dirname(config['path_to_models']), model_filename)
trainer.save_checkpoint(save_path)
print(f"Model checkpoint saved as {save_path}")
config['path_to_model'] = save_path

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



   | Name           | Type              | Params | Mode 
--------------------------------------------------------------
0  | criterion      | BCEWithLogitsLoss | 0      | train
1  | sigmoid        | Sigmoid           | 0      | train
2  | model          | Sequential        | 1.1 M  | train
3  | train_accuracy | BinaryAccuracy    | 0      | train
4  | val_accuracy   | BinaryAccuracy    | 0      | train
5  | val_precision  | BinaryPrecision   | 0      | train
6  | val_recall     | BinaryRecall      | 0      | train
7  | test_accuracy  | BinaryAccuracy    | 0      | train
8  | init_conv      | Sequential        | 9.6 K  | train
9  | layer1         | Sequential        | 1.2 M  | train
10 | layer2         | Sequential        | 1.6 M  | train
11 | layer3         | Sequential        | 4.8 M  | train
12 | layer4         | ResidualBlock     | 1.2 M  | train
13 | pool           | AdaptiveAvgPool2d | 0      | train
14 | flatten        | Flatten           | 0      | train
15 | fc1            | Li

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/VDKI-Projekt/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Processing file: kaninchen_0474_ok_v4
Processing file: katze_0148_nok_v3
Processing file: hase_0400_nok_v5
Processing file: hund_0123_nok_v3
Processing file: hamster_0068_nok_v0
Processing file: hund_0123_nok_v4
Processing file: kaninchen_0533_ok_v2
Processing file: ratte_0095_nok_v5
Processing file: hase_0356_nok_v5
Processing file: hund_0080_nok_v1
Processing file: hamster_0063_nok_v4
Processing file: kaninchen_0573_ok_v3
Processing file: igel_0132_nok_v3
Processing file: hase_0158_nok_v5
Processing file: kaninchen_1037_ok_v1
Processing file: kaninchen_0841_ok_v2
Processing file: schlange_0090_nok_v3
Processing file: hase_0182_nok_v0
Processing file: hase_0215_nok_v0
Processing file: kaninchen_0894_ok_v0
Processing file: hamster_0075_nok_v3
Processing file: kaninchen_0396_ok_v2
Processing file: hund_0135_nok_v1
Processing file: hase_0018_nok_v0
Processing file: hase_0259_nok_v1
Processing file: waschbaer_0036_nok_v3
Processing file: igel_0089_nok_v3
Processing file: schlange_0049_nok

/opt/anaconda3/envs/VDKI-Projekt/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   0%|          | 0/211 [00:00<?, ?it/s] Processing file: kaninchen_0238_ok_v5
Processing file: siebenschlaefer_0004_nok_v0
Processing file: kaninchen_0812_ok_v4
Processing file: hase_0062_nok_v2
Processing file: kaninchen_0947_ok_v0
Processing file: katze_0055_nok_v1
Processing file: kaninchen_1015_ok_v1
Processing file: igel_0027_nok_v2
Processing file: siebenschlaefer_0044_nok_v2
Processing file: kaninchen_0016_ok_v0
Processing file: kaninchen_0209_ok_v4
Processing file: kaninchen_0251_ok_v3
Processing file: kaninchen_0598_ok_v5
Processing file: schlange_0034_nok_v0
Processing file: kaninchen_0726_ok_v3
Processing file: katze_0036_nok_v0
Processing file: hase_0269_nok_v3
Processing file: waschbaer_0003_nok_v0
Processing file: igel_0092_nok_v0
Processing file: schlange_0115_nok_v5
Processing file: hase_0268_nok_v0
Processing file: hase_0169_nok_v0
Processing file: hase_0145_nok_v1
Processing file: kaninchen_0335_ok_v0
Processing file: kaninchen_0877_ok_v4
Processing file: kan


Detected KeyboardInterrupt, attempting graceful shutdown ...


In [None]:
# List of all model classes to train
model_classes = [KaninchenModel_v3, KaninchenModel_v19]

import datetime
for model_class in model_classes:
    # Create model instance
    model = model_class()
    print(ModelSummary(model, max_depth=-1))  
    print(type(model).__name__)

    # Initialize the Wandb logger
    # add time to the name of the experiment
    now = datetime.datetime.now()
    current_time = now.strftime("%Y-%m-%d_%H-%M-%S")

    # Initialize wandb logger
    wandb_logger = WandbLogger(
        project=config['wandb_project_name'],
        name=f"{config['wandb_experiment_name']}_{type(model).__name__}_img{config['image_size']}_{current_time}",
        config={
            'model': type(model).__name__,
            'dataset': 'DwarfRabbits-binary',
            'batch_size': config['batch_size'],
            'max_epochs': config['max_epochs'],
            'learning_rate': config['learning_rate'],
            'image_size': config['image_size']
        }
    )

    # Initialize Trainer with wandb logger, using early stopping callback
    trainer = Trainer(
        max_epochs=config['max_epochs'], 
        default_root_dir='model/checkpoint/',
        accelerator="auto", 
        devices="auto", 
        strategy="auto",
        callbacks=[EarlyStopping(monitor='val_loss', patience=5, mode='min')], 
        logger=wandb_logger)

    # Training of the model
    trainer.fit(model=model, datamodule=dm)

    # Finish wandb
    wandb.finish()

    # Create a filename with date identifier
    model_filename = f"{config['wandb_experiment_name']}_{type(model).__name__}_{current_time}.ckpt"

    # Save the model's state_dict to the path specified in config
    save_path = os.path.join(os.path.dirname(config['path_to_models']), model_filename)
    trainer.save_checkpoint(save_path)
    print(f"Model checkpoint saved as {save_path}")
    
    # Update config with the last trained model path
    config['path_to_model'] = save_path
    
    print(f"Completed training for {model_class.__name__}")

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


   | Name                | Type              | Params | Mode 
-------------------------------------------------------------------
0  | criterion           | BCEWithLogitsLoss | 0      | train
1  | sigmoid             | Sigmoid           | 0      | train
2  | model               | Sequential        | 1.1 M  | train
3  | model.0             | Conv2d            | 896    | train
4  | model.1             | BatchNorm2d       | 64     | train
5  | model.2             | ReLU              | 0      | train
6  | model.3             | MaxPool2d         | 0      | train
7  | model.4             | Conv2d            | 18.5 K | train
8  | model.5             | BatchNorm2d       | 128    | train
9  | model.6             | ReLU              | 0      | train
10 | model.7             | MaxPool2d         | 0      | train
11 | model.8             | Conv2d            | 73.9 K | train
12 | model.9             | BatchNorm2d       | 256    | train
13 | model.10            | ReLU              | 0      | train
14

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type              | Params | Mode 
--------------------------------------------------------------
0  | criterion      | BCEWithLogitsLoss | 0      | train
1  | sigmoid        | Sigmoid           | 0      | train
2  | model          | Sequential        | 1.1 M  | train
3  | train_accuracy | BinaryAccuracy    | 0      | train
4  | val_accuracy   | BinaryAccuracy    | 0      | train
5  | val_precision  | BinaryPrecision   | 0      | train
6  | val_recall     | BinaryRecall      | 0      | train
7  | test_accuracy  | BinaryAccuracy    | 0      | train
8  | init_conv      | Sequential        | 1.9 K  | train
9  | layer1         | Sequential        | 221 K  | train
10 | layer2         | Sequential        | 295 K  | train
11 | layer3         | Sequential        | 886 K  | train
12 | layer4         | ResidualBlock     | 1.2 M  | train
13 | flatten        | Flatten           | 0      | train
14 | fc1            | Linear           

Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  3.71it/s]



                                                                           



Epoch 10: 100%|██████████| 301/301 [01:21<00:00,  3.71it/s, v_num=iu9x, val_loss=0.521]


0,1
Validation Data ROC AUC,▁▂▃▄▃▇▄▇▆▇█
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██
train_acc,▁▃▃▄▄▅▅▆▆▇█
train_loss,█▇▆██▆▆▆▇▇▄▅▅▆▆▄▄▅█▆▆▄▄▄▃▃▅▄▆▄▅▂▁▃▂▄▃▂▂▁
trainer/global_step,▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
val_acc,▃▃▃▁▃▆▄█▆██
val_loss,▆▅▅█▇▁▆▂▅▄▆
val_precision,▆▅▃▁▆▇▆█▆▇▆
val_recall,▁▂▅█▁▃▃▃▄▄▅

0,1
Validation Data ROC AUC,0.88861
epoch,10.0
train_acc,0.93738
train_loss,0.10829
trainer/global_step,3310.0
val_acc,0.79995
val_loss,0.52128
val_precision,0.78916
val_recall,0.82359


Model checkpoint saved as C:\Users\lukas\SynologyDrive_IMS/SS25_MSYS_KAER-AI-PoseAct/21_Test_Data/Models/CNN\binaryClassification_CNN_KaninchenModel_v21_2025-06-10_19-05-17.ckpt
Completed training for KaninchenModel_v21


### Hyperparameter Optimization with Optuna

This section performs automated hyperparameter tuning using Optuna to find the optimal model configuration. The optimization process searches for hyperparameters that minimize validation loss across multiple trials, helping to improve model performance beyond manual tuning.

In [None]:
# Hyperparameter optimization
import datetime
config['sweep_id'] = datetime.datetime.now().strftime("%Y%m%d_%H%M")

def objective(trial):
    model = KaninchenModel                  # or another model's class, depending on your choice
    trainer = CnnOptunaTrainer(
        model=model,                        # Function to create the model
        config=config,
        normalize_mean=None, #[0.485, 0.456, 0.406], 
        normalize_std=None, #[0.229, 0.224, 0.225],
        dataset_name="DwarfRabbits-binary"
    )
    return trainer.run_training(trial)

# Create an Optuna study
study = optuna.create_study(direction="minimize")  # because we minimize val_loss

# Set verbosity to WARNING to reduce output clutter
optuna.logging.set_verbosity(optuna.logging.WARNING)

# Start the hyperparameter optimization
study.optimize(objective, n_trials=config['number_of_trials'])
# study.optimize(objective, n_trials=3)

# Best result
print("Best trial:")
print(study.best_trial.params)
print("Best value (val_loss):", study.best_value)

In [None]:
optuna.visualization.plot_optimization_history(study)

In [None]:
optuna.visualization.plot_param_importances(study)

# Predict with the Model


In [None]:
# from PIL import Image
# import torch
# # Load the saved model weights from the path specified in config

# def predict_image(path, model):
#     transform = transforms.Compose([
#         transforms.Resize((150, 150)),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5]*3, [0.5]*3)
#     ])

#     img = Image.open(path).convert('RGB')
#     img = transform(img).unsqueeze(0)  # Add batch dimension

#     model.eval()
#     with torch.no_grad():
#         pred = model(img)
#         result = "Dog" if pred.item() > 0.5 else "Cat"
#     print(f"Prediction: {result}")


### Loading and Evaluating the Trained Model

The trained model is loaded from the checkpoint specified in the configuration. If the checkpoint exists, the model weights are restored and the model is set to evaluation mode. PyTorch Lightning's `Trainer` is then used to evaluate the model on the test dataset, providing a streamlined way to assess model performance after training.

In [None]:
model_path = config['path_to_model']
if model_path and os.path.exists(model_path):
    #model = CatsDogsModel.load_from_checkpoint(model_path, map_location=device)
    model = KaninchenModel.load_from_checkpoint(model_path, map_location=device)
    print(f"Loaded model weights from {model_path}")
else:
    print("Model path not found or not specified in config.")

# Ensure model is in eval mode
model.eval()

# Pytorch Lightning's Trainer can be used to test the model
trainer = Trainer()
trainer.test(model=model, dataloaders=test_loader)