In [None]:
# Installiert das 'lightning' Paket für das Training des Modells
%pip install lightning -qq

In [None]:
# Importiert die notwendigen Bibliotheken für Datenverarbeitung, Modellierung und Training
import os
import random
import numpy as np

import torch
from torchvision import transforms
from timm import create_model
import torch.nn.functional as F
from datasets import load_dataset
import lightning as L
from torchmetrics.utilities.data import to_categorical
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

In [None]:
# Lädt das "chihuahua-muffin" Dataset von Hugging Face und teilt es in Trainings- und Testsets auf
# https://huggingface.co/datasets/sasha/chihuahua-muffin
ds = load_dataset("sasha/chihuahua-muffin")
print(ds.keys())
split_dataset = ds['train'].train_test_split(test_size=0.25)

In [None]:
# Definiert die Standard-Transformationsschritte für die Bildvorverarbeitung
# Dazu gehören das Konvertieren in einen Tensor, das Ändern der Größe und die Normalisierung
DEFAULT_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64)),
    transforms.Normalize(
      mean=[0.485, 0.456, 0.406],
      std=[0.229, 0.224, 0.225]
    )
])

In [None]:
# Wendet die definierten Transformationen auf das Dataset an
split_dataset = split_dataset.map(
  lambda example: {'image': DEFAULT_TRANSFORM(example['image'])},
  num_proc=4
)

In [None]:
# Erstellt Data-Loader für Trainings- und Validierungsdaten
split_dataset.set_format(type="torch", columns=["image", "label"])
batch_sizes = (12, 4)
train_dataloader = torch.utils.data.DataLoader(
  split_dataset['train'],
  batch_size=batch_sizes[0],
  shuffle=True,
  num_workers=4
)
val_dataloader = torch.utils.data.DataLoader(
  split_dataset['test'],
  batch_size=batch_sizes[1],
  num_workers=4
)

In [None]:
# Darstellung von 16 zufälligen Bildern aus dem Original-Dataset
fig = plt.figure(figsize=(10,10))
random.seed(123) # Seed setzen für Reproduzierbarkeit
random_integers = []
for _ in range(16):
    random_integers.append(random.randint(0, len(ds["train"]) - 1))
for i, z in zip(random_integers, range(16)):
    fig.add_subplot(4, 4, z+1)
    plt.imshow(ds['train'][i]["image"])
    if ds['train'][i]["label"] == 0:
        label = 'Muffin'
    else:
        label = 'Chihuahua'
    plt.xlabel(f'{label}')
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()

In [None]:
# Darstellung der ersten 16 Trainings-Bilder aus dem "split_dataset" (auf diese
# Bilder wurden bereits die Tranformationen ("DEFAULT_TRANSFORMATIONS") angewendet)
fig = plt.figure(figsize=(10,10))
for i in range(16):
    fig.add_subplot(4, 4, i+1)
    plt.imshow(
        np.transpose(
            split_dataset['train'][i]["image"],
            (1 , 2, 0)
        )
    )
    if split_dataset['train'][i]["label"] == 0:
        label = 'Muffin'
    else:
        label = 'Chihuahua'
    plt.xlabel(f'{label}')
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()

In [None]:
# Neural Network Module
class MuffinChihuahuaClassificator(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = create_model('resnet18', pretrained=True, num_classes=1)
        self.train_step_outputs = []
        self.valid_step_outputs = []

    def forward(self, x):
        return self.model(x)

    @staticmethod
    def loss(preds, targets):
        loss = F.binary_cross_entropy_with_logits(
          input=torch.squeeze(preds),
          target=targets.type_as(preds)
        )
        return loss

    def training_step(self, batch):
        return self._shared_step(batch, prefix="train")

    def on_train_epoch_end(self):
        # https://github.com/Lightning-AI/pytorch-lightning/releases/tag/2.0.0#bc-changes-pytorch
        self._shared_epoch_end(prefix="train")

    def validation_step(self, batch):
        return self._shared_step(batch, prefix="valid")

    def on_validation_epoch_end(self):
        self._shared_epoch_end(prefix="valid")

    def _shared_step(self, batch, prefix):
        x = batch["image"]
        y = batch["label"]

        # apply model to data
        y_hat = self(x)

        # batch size
        bs = torch.tensor(len(x), dtype=torch.int16).type_as(x)
        # loss
        loss = self.loss(
            preds=y_hat,
            targets=y.type(torch.FloatTensor)
        )
        self.log(
            name=f"{prefix}_loss",
            value=loss,
            prog_bar=True,
            logger=True,
            on_step=True,
            on_epoch=False
        )
        eval(f"self.{prefix}_step_outputs").append({"batch_size": bs, "loss": loss})
        return loss

    def _shared_epoch_end(self, prefix):
        # concat batch sizes
        batch_sizes = torch.stack(
            [x["batch_size"] for x in eval(f"self.{prefix}_step_outputs")]
        ).type_as(eval(f"self.{prefix}_step_outputs")[0]["loss"])

        # concat losses
        losses = torch.stack(
            [x["loss"] for x in eval(f"self.{prefix}_step_outputs")]
        ).type_as(eval(f"self.{prefix}_step_outputs")[0]["loss"])

        # clear outputs
        eval(f"self.{prefix}_step_outputs").clear()

        # calculating weighted mean loss
        avg_loss = torch.sum(losses * batch_sizes) / torch.sum(batch_sizes)

        self.log(
            name=f"loss/{prefix}",
            value=avg_loss,
            prog_bar=True,
            logger=True,
            on_step=False,
            on_epoch=True
        )

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=0.005)

In [None]:
# Dataset Modul
class ClassificationData(L.LightningDataModule):

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
      return val_dataloader

In [None]:
# Training des Modells
n_epochs = 5 # Testen für 5 Epochen; je nach Bedarf / Verfügbarkeit von Ressourcen anpassen


if __name__ == "__main__":
    model = MuffinChihuahuaClassificator()
    data = ClassificationData()
    trainer = L.Trainer(max_epochs=n_epochs)
    trainer.fit(model, data)

In [None]:
# Definiert das Verzeichnis, in dem die TensorBoard-Logs gespeichert werden
# Standardmäßig speichert Lightning die Logs in 'lightning_logs'
log_dir = trainer.logger.log_dir

# Überprüft, ob das Log-Verzeichnis existiert
if not os.path.exists(log_dir):
    print(f"Log-Verzeichnis nicht gefunden unter {log_dir}")
else:
    # Initialisiert EventAccumulator
    event_acc = EventAccumulator(log_dir)
    event_acc.Reload()

    # Holt die Tags für Trainings- und Validierungsverlust
    train_loss_tag = "loss/train"
    val_loss_tag = "loss/valid"

    # Überprüft, ob die Tags in den Logs existieren
    if train_loss_tag in event_acc.Tags()['scalars'] and val_loss_tag in event_acc.Tags()['scalars']:
        # Holt die Verlustwerte für jede Epoche
        train_loss_events = event_acc.Scalars(train_loss_tag)
        val_loss_events = event_acc.Scalars(val_loss_tag)

        # Extrahiert Epoche und Wert für jedes Event
        train_epochs = [event.step for event in train_loss_events]
        train_values = [event.value for event in train_loss_events]

        val_epochs = [event.step for event in val_loss_events]
        val_values = [event.value for event in val_loss_events]

        # Plottet die Verlustkurven
        plt.figure(figsize=(10, 6))
        plt.plot(train_epochs, train_values, label="Training-loss")
        plt.plot(val_epochs, val_values, label="Validation-loss")
        plt.xlabel("Epoche")
        plt.ylabel("Verlust")
        plt.title("Trainings- und Validierungs-Loss-Kurven")
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        print(f"Loss-Tags '{train_loss_tag}' oder '{val_loss_tag}' nicht in den Logs gefunden.")