In [1]:
# später anpassen für colab und lokal

data_path = "./data"

In [2]:
from pytorch_lightning import LightningDataModule
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os

class ConfigurableDataModule(LightningDataModule):
    """Class wraper für mit austauschbaren transforms"""
    def __init__(self, data_dir: str, batch_size: int, transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transform

    def setup(self, stage=None):
        # Erstellen des Datensatzes als Instanz von ImageFolder
        full_dataset = ImageFolder(root=self.data_dir, transform=self.transform)
        # Setzen der Trainingsset/Validierungsset Größe
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        # Zufälliges aufteilen in Training- und Validierungdatensatz
        self.train_dataset, self.val_dataset = random_split(full_dataset, [train_size, val_size])

    def train_dataloader(self):
        # Setzen des Traindataloader
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())

    def val_dataloader(self):
        # Setzen des Validation Dataloader
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=os.cpu_count())

In [3]:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
import torchmetrics
from torchmetrics import MeanMetric
import torchvision
import torchvision.transforms.functional as F
import torchvision.utils as vutils
from torchvision.transforms.functional import to_pil_image, to_tensor
import random
import time
import shutil
from pathlib import Path

class BaseWasteClassifier(pl.LightningModule):
    CLASS_NAMES = ['Cardboard', 'Food Organics', 'Glass', 'Metal', 'Miscellaneous Trash', 'Paper', 'Plastic', 'Textile Trash', 'Vegetation']

    def __init__(self, num_classes: int, results_dir="results"):
        super().__init__()
        self.num_classes = num_classes
        # Get the class name of the model instance
        model_class_name = self.__class__.__name__
        # Initialize paths
        self.results_dir = Path(results_dir) / model_class_name
        self.models_dir = self.results_dir / "models"
        self.images_dir = self.results_dir / "images"
        self.plots_dir = self.results_dir / "plots"

        # Create directories if they don't exist
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.models_dir.mkdir(parents=True, exist_ok=True)
        self.images_dir.mkdir(parents=True, exist_ok=True)
        self.plots_dir.mkdir(parents=True, exist_ok=True)

        # Placeholder for the actual model
        self.model = None

        # Initialize metrics
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes, average='macro')
        self.precision = torchmetrics.Precision(task='multiclass', num_classes=num_classes, average='weighted')
        self.recall = torchmetrics.Recall(task='multiclass', num_classes=num_classes, average='weighted')
        self.f1_score = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='weighted')

        # Initialize metrics for average loss
        self.train_loss_metric = torchmetrics.MeanMetric()
        self.val_loss_metric = torchmetrics.MeanMetric()
        # Initialize metrics for averaging
        self.avg_train_losses = []
        self.avg_val_losses = []

        
    def forward(self, x):
        raise NotImplementedError("This method should be overridden by subclasses.")

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        acc = self.accuracy(torch.argmax(logits, dim=1), y)
        precision = self.precision(torch.argmax(logits, dim=1), y)
        recall = self.recall(torch.argmax(logits, dim=1), y)
        f1 = self.f1_score(torch.argmax(logits, dim=1), y)  # Calculate F1 Score
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        # Log precision, recall, and F1 Score
        self.log('train_precision', precision, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_recall', recall, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.train_loss_metric(loss)
        return loss
    
    def on_train_epoch_end(self):
        avg_train_loss = self.train_loss_metric.compute()
        self.log('epoch_avg_train_loss', avg_train_loss, on_epoch=True, prog_bar=True, logger=True)
        self.avg_train_losses.append(avg_train_loss.item())
        self.train_loss_metric.reset()

    def on_validation_start(self):
        self.clear_images_directory()
    
    def validation_step(self, batch, batch_idx):
        start_time = time.perf_counter()  # Start timing for inference speed
        
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        acc = self.accuracy(torch.argmax(logits, dim=1), y)
        precision = self.precision(torch.argmax(logits, dim=1), y)
        recall = self.recall(torch.argmax(logits, dim=1), y)
        f1 = self.f1_score(torch.argmax(logits, dim=1), y)  # Calculate F1 Score
        
        inference_time = time.perf_counter() - start_time  # Stop timing for inference speed
        self.log('val_inference_time', inference_time, prog_bar=True, logger=True)
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        # Log precision, recall, and F1 Score
        self.log('val_precision', precision, prog_bar=True)
        self.log('val_recall', recall, prog_bar=True)
        self.log('val_f1', f1, prog_bar=True)

        predictions = torch.argmax(logits, dim=1)  # Convert logits to predicted class indices

        if random.random() < 0.1:  # Log images randomly
            self.log_images_with_labels(x, y, predictions, batch_idx)
        
        self.val_loss_metric(loss)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        # Compute the average losses for the current epoch
        avg_val_loss = self.val_loss_metric.compute()

        # Log the average losses
        self.log('epoch_avg_val_loss', avg_val_loss, on_epoch=True, prog_bar=True, logger=True)

        # Append the average losses to the lists for plotting
        self.avg_val_losses.append(avg_val_loss.item())

        # Reset the metrics for the next epoch
        self.val_loss_metric.reset()

        # Plot and save the loss curves
        self.plot_and_save_loss_curves()


    def log_images_with_labels(self, images, labels, predictions, batch_idx):
        """Save a batch of images with their actual and predicted labels, organized by class name and model class."""
        annotated_images = []
        for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
            # Unnormalize the image for visualization
            image = self.unnormalize(image)  # Make sure to call with self if it's an instance method
            # Determine class name for the actual label
            actual_class_name = self.CLASS_NAMES[label.item()]

            # Ensure the class-specific directory exists within the model class directory
            image_dir = self.images_dir / actual_class_name
            image_dir.mkdir(parents=True, exist_ok=True)

            # Convert to PIL Image for easy manipulation
            pil_img = F.to_pil_image(image)

            # Annotate image with actual and predicted labels
            draw = ImageDraw.Draw(pil_img)
            annotation_text = f'Actual: {actual_class_name}, Predicted: {self.CLASS_NAMES[prediction.item()]}'
            draw.text((10, 10), annotation_text, fill="white")

            # Define the file path for saving the image within the specific class directory
            file_path = image_dir / f"epoch_{self.current_epoch}_batch_{batch_idx}_image_{i}.png"

            # Save the annotated image
            pil_img.save(file_path)

            # Convert back to tensor and add to list
            annotated_img = to_tensor(pil_img)
            annotated_images.append(annotated_img.unsqueeze(0))  # Add batch dimension

        # Stack all annotated images into a single tensor for logging
        annotated_images_tensor = torch.cat(annotated_images, dim=0)
        img_grid = torchvision.utils.make_grid(annotated_images_tensor, nrow=4)

        # Log the grid of annotated images
        self.logger.experiment.add_image(f'Validation Images, Batch {batch_idx}', img_grid, self.current_epoch)

    def clear_images_directory(self):
        if self.images_dir.exists() and self.images_dir.is_dir():
            for class_dir in self.images_dir.iterdir():
                if class_dir.is_dir():  # Ensure it's a directory
                    shutil.rmtree(class_dir)  # Delete the directory and all its contents  
    
    def plot_and_save_loss_curves(self):
        plt.figure(figsize=(10, 6))
        plt.plot(range(len(self.avg_train_losses)), self.avg_train_losses, label='Average Training Loss')
        plt.plot(range(len(self.avg_val_losses)), self.avg_val_losses, label='Average Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Average Training and Validation Loss Over Epochs')
        plt.legend()
        plt.tight_layout()

        plot_path = self.plots_dir / "average_loss_curves.png"
        plt.savefig(plot_path)
        plt.close()

    def unnormalize(self, image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        """Revert normalization of an image tensor."""
        image = image.clone()  # Clone the tensor to avoid in-place operations
        for t, m, s in zip(image, mean, std):
            t.mul_(s).add_(m)  # Multiply by std and add mean
        return image

    def configure_optimizers(self):
        # Subclasses can override this if needed
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        return optimizer


In [4]:
import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


def train_model(model, data_module, log_dir="tb_logs", max_epochs=50, logger_name="model_logs", callbacks=[]):
    # Configure the ModelCheckpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=model.models_dir,
        filename='{epoch}-{val_loss:.2f}',
        save_top_k=2,  # Save only the best checkpoint
        verbose=True,
        monitor='val_loss',  # Monitor validation loss (change to val_acc or any other metric as needed)
        mode='min',  # 'min' for loss (use 'max' for accuracy)
    )
    callbacks.append(checkpoint_callback)
    # Starten das Trainingsprozesses
    logger = TensorBoardLogger(log_dir, name=logger_name)
    trainer = Trainer(max_epochs=max_epochs, logger=logger, callbacks=callbacks)
    trainer.fit(model, datamodule=data_module)

## SimpleCNN

In [5]:
import torch.nn as nn

class SimpleCNN(BaseWasteClassifier):
    def __init__(self, num_classes=9, lr=1e-3):
        self.lr = 1e-3
        super().__init__(num_classes)
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 28 * 28, 512),  # Bestätigt, dass dies für eine Eingabegröße von 224x224 korrekt ist
            nn.ReLU(),
            nn.Dropout(0.5), # Overfitting vermeiden
            nn.Linear(512, num_classes) # lineare Schicht auf die Klassen
        )


    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

### Train SimpleCNN

In [6]:
%reload_ext tensorboard
%tensorboard --logdir=tb_logs/

In [7]:
from torchvision import transforms
# Definiere die Transformationspipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Skaliere alle Bilder auf 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [8]:
from pytorch_lightning.callbacks import EarlyStopping

data_module = ConfigurableDataModule(data_dir=data_path, batch_size=32, transform=transform)

callbacks = [
    # Stoppt des Training wenn sich val Metrik nicht verbessert
    EarlyStopping(
        monitor="val_acc",  # Die Metrik beobachtet wird
        mode="max",  # Maximiert die Genauigkeit
        patience=1,  # "Wartet" 1 Epoche ohne Verbesserung
    )
]
model = SimpleCNN(num_classes=9, lr=1e-3)
train_model(model, data_module, log_dir="tb_logs", max_epochs=50, logger_name="simple_CNN", callbacks=callbacks)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2024-02-24 01:30:13.029517: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-24 01:30:13.029563: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-24 01:30:13.030208: E external/local_xla/xla/stream_executor/c

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 119: 'val_loss' reached 1.33447 (best 1.33447), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=0-val_loss=1.33.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 238: 'val_loss' reached 1.12750 (best 1.12750), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=1-val_loss=1.13.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 357: 'val_loss' reached 1.04771 (best 1.04771), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=2-val_loss=1.05.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 476: 'val_loss' reached 0.95307 (best 0.95307), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=3-val_loss=0.95.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 595: 'val_loss' reached 0.90840 (best 0.90840), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=4-val_loss=0.91.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 714: 'val_loss' reached 0.93166 (best 0.90840), saving model to '/home/jovyan/work/Sonstiges/Module/Machine_Learning/RealWaste/results/SimpleCNN/models/epoch=5-val_loss=0.93.ckpt' as top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 833: 'val_loss' was not in top 2


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 952: 'val_loss' was not in top 2
