## CAS Deep Learning - Computer Vision mit Deep Learning (Part 1)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Image Classification - Project

## Learning Goals

- Learn how to model an image classification task
- Learn how to systematically implement and check each step from data prep to model selection and evaluation
- Learn how to incorporate libraries which provide boilerplate code such as [torchvision](https://pytorch.org/vision/0.9/index.html) and  [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)

## Setup

We setup our environment and data save / load paths.

In [None]:
import os
from pathlib import Path

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import torch
from tqdm.notebook import tqdm

Mount your google drive to store data and results.

In [None]:
try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

print(f"In colab: {IN_COLAB}")

In [None]:
if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

Modify the following paths if necessary.

In [None]:
if IN_COLAB:
    DATA_PATH = Path("/content/drive/MyDrive/cas-dl-module-compvis-part1")
else:
    DATA_PATH = Path("../data")

Install packages not in base Colab environment and required by you.

In [None]:
if IN_COLAB:
    os.system("pip install torchshow")

## Project Selection

Choose one of the following projects to work on. 

Choose Cats vs Dogs if you want to profit from example code and want to tag along.

### Cats vs Dogs

**Goal**: Develop a model to classify images of cats and dogs. The dataset is designed to facilitate the identification of these animals from images.

**Approach**: Create a Convolutional Neural Network (CNN) to classify the images into two categories: cats and dogs. Experiment with various CNN architectures and techniques to determine the most effective method. Use data augmentation techniques to handle variations in pose, lighting, and background. Ensure the model generalizes well by using cross-validation and monitoring for overfitting.

**Dataset**: The dataset contains 25,000 images, with approximately 12,500 images per class (cats and dogs). Each image varies in size and resolution. The data is provided by Microsoft as part of their Kaggle competition.

[Source](https://www.microsoft.com/en-us/download/details.aspx?id=54765)

![Dog](dog.jpg)
![Cat](cat.jpg)


### Concrete Crack Detection

**Goal**: Develop a model to classify concrete images as having cracks or not. The dataset is designed to facilitate the identification of structural issues in concrete buildings.

**Approach**: Create a Convolutional Neural Network (CNN) to classify the images into negative (no crack) and positive (crack) categories. Experiment with various CNN architectures and techniques to determine the most effective method. Use image processing techniques to handle variations in surface finish and illumination. Ensure the model generalizes well by using cross-validation and monitoring for overfitting.

**Dataset**: The dataset contains 40,000 images, with 20,000 images per class (negative and positive). Each image is 227 x 227 pixels with RGB channels. The data is collected from 458 high-resolution images (4032 x 3024 pixels) from various METU Campus Buildings. No data augmentation such as random rotation or flipping is applied.

[Source](https://data.mendeley.com/datasets/5y9wdsg2zt/2)

![Crack](crack_example.jpg)
![No Crack](crack_negative.jpg)


### Scene Classification

**Goal**: Develop a model to classify natural scene images into one of six categories. The dataset aims to facilitate the recognition of various natural scenes from around the world.

**Approach**: Design a Convolutional Neural Network (CNN) to classify images into six categories: buildings, forest, glacier, mountain, sea, and street. Test different CNN architectures to find the best performing model. Apply data augmentation techniques to improve generalization. Separate the data into training, testing, and prediction sets to evaluate model performance effectively.

**Dataset**: The dataset contains around 25,000 images of size 150 x 150 pixels, distributed across six categories. The data is separated into training (14,000 images), testing (3,000 images), and prediction (7,000 images) sets.

[Source](https://www.kaggle.com/datasets/puneet6060/intel-image-classification?resource=download)


![Builings](natural_scenes_buildings.jpg)
![Forest](natural_scenes_forest.jpg)
![Glacier](natural_scenes_glacier.jpg)



### Satellite Land Cover Classification

**Goal**: Develop a model to classify satellite images into different land cover types. The dataset contains images of 10 different classes and aims to support land use and land cover classification tasks.

**Approach**: Develop Convolutional Neural Networks (CNNs) to model the satellite image data. Experiment with different CNN architectures to identify the best performing model. Compare pre-trained models with those trained from scratch. Use data augmentation techniques to enhance model generalization. Given the relatively small dataset, pay attention to overfitting and compare models robustly.

**Dataset**: The dataset consists of 27,000 RGB images categorized into 10 classes. The dataset is available in two formats: one in RGB and another with 13 spectral bands. Use the RGB dataset for this project.

[Source](https://github.com/phelber/eurosat)

![Crop](sat_crop.jpg)
![Forest](sat_forest.jpg)
![Highway](sat_highway.jpg)


### Choose your own dataset!

Feel free to choose your own dataset.


# Overall Approach

Inspired by [A Recipe for Training Neural Networks by Andrej Karpathy](https://karpathy.github.io/2019/04/25/recipe/)

For your chosen dataset. Do the following:


## 1) Data

- Download the data
- Inspect the data formats
- Build a `torch.utils.data.Dataset`
    - define training, validation and test sets
- Implement a `torch.utils.data.DataLoader'
- Inspect the data:
    - Look at samples
    - Inspect the label distribution

## 2) Baselines

- Implement a small CNN
- Learn input-independent baseline (provide only labels but random noise as input)
- Overfitt CNN on one batch
- Inspect pre-processing
  
## 3) (Over)fit
- Build a large(er) architecture (pre-trained or self-implemented)
- Train a high-performing model with respect to training set
  
## 4) Regularize
- Is it beneficial to collect more data?
- Data Augmentation
- Early Stopping on Validation Set
- Weight Decay

## 5) Hyper-Parameter Tuning
- Define HPs and parameterise architecture
- do grid- or random search over HP grids

## 6) Squeeze out the juice
-  Ensembling
-  Longer training
-  Special techniques: AdamW optimizer, fancy data augmentation, label smoothing, stochastic depth

# 1) Data

- Download the data
- Inspect the data formats and file organization
- Remove corrupt data
- Build a `torch.utils.data.Dataset`
    - define training, validation and test sets
- Implement a `torch.utils.data.DataLoader'
- Inspect the data:
    - Look at samples
    - Inspect the label distribution
    

## Download the Data

The following code snippets help you get the data quickly. It may take a while to download though...

In [None]:
def download_and_extract_zip(url: str, save_path: Path, extract_path: Path):
    import os
    import requests
    import zipfile
    
    # Make sure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    if not save_path.exists():
        # Download the file
        response = requests.get(url, stream=True)
        with open(save_path, 'wb') as file:
            for chunk in response.iter_content(chunk_size=8192):
                _ = file.write(chunk)
    
        print(f"File downloaded and saved to {save_path}")
    
    if not extract_path.exists():
        # Unzip the file
        with zipfile.ZipFile(save_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)
        
        print(f"File extracted to {extract_path}")


def download_from_gdrive_and_extract_zip(file_id: str, save_path: Path, extract_path: Path):
    import os
    import requests
    import gdown

    url = f"https://drive.google.com/uc?id={file_id}"
    if not save_path.exists():
        gdown.download(url, str(save_path), quiet=False)
        print(f"File downloaded and saved to {save_path}")
    
    if not extract_path.exists():
        # Unzip the file
        with zipfile.ZipFile(save_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)
        
        print(f"File extracted to {extract_path}")


def delete_bad_file(file_path: Path):
    # Check if file exists before trying to delete it
    if os.path.exists(file_path):
        os.remove(file_path)
        print(f"{file_path} has been deleted")
    else:
        print(f"{file_path} does not exist")
    

Cats vs Dogs

In [None]:
download_and_extract_zip(
    url = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip",
    save_path = DATA_PATH.joinpath("cats_vs_dogs.zip"),
    extract_path = DATA_PATH.joinpath("cats_vs_dogs/")
)

bad_files = [
    DATA_PATH.joinpath("cats_vs_dogs") / "PetImages" / "Cat" / "666.jpg",
    DATA_PATH.joinpath("cats_vs_dogs") / "PetImages" / "Dog" / "11702.jpg"
]


for bad_file in bad_files:
    delete_bad_file(bad_file)

Case you chose EuroSat Data:

In [None]:
download_and_extract_zip(
    url = "https://zenodo.org/records/7711810/files/EuroSAT_RGB.zip?download=1",
    save_path = DATA_PATH.joinpath("EuroSAT_RGB.zip"),
    extract_path = DATA_PATH.joinpath("EuroSAT_RGB/"))

Concrete Data:

In [None]:
download_and_extract_zip(
    url = "https://prod-dcd-datasets-cache-zipfiles.s3.eu-west-1.amazonaws.com/5y9wdsg2zt-2.zip",
    save_path = DATA_PATH.joinpath("concrete.zip"),
    extract_path = DATA_PATH.joinpath("concrete/")
)

Scene classification

In [None]:
download_from_gdrive_and_extract_zip(
    file_id = "1Bx3R56VBONS-x91wCDU6KX3xqPoJoH9P",
    save_path = DATA_PATH / "scene_classification.zip",
    extract_path = DATA_PATH.joinpath("scene_classification/")
)

## Inspect Data Format & Organization, Build Dataset and Loader

We need to figure out how the data is organized. Particularly, how the data is labelled, to correctly define it with a `torch.utils.data.Dataset`.

First you should look at the data / folder structure of the downloaded data.

Once you have figured out how the data is organized we can build a `Dataset`. A `Dataset` allows for iterating over a dataset while returning tuples of images and labels.

**We already create a training, validation and a test dataset.**

Adapt the following code if necessary:

In [None]:
from typing import Callable, Tuple

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset


def load_image_paths_and_labels(image_dir: str) -> Tuple[List[str], List[str]]:
    """
    Load image paths and corresponding labels.

    Args:
        image_dir: Directory with all the images.

    Returns:
        A tuple of (image paths, labels).
    """
    image_paths = []
    labels = []
    classes = os.listdir(image_dir)
    image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif"}

    for label in classes:
        class_dir = os.path.join(image_dir, label)
        for img_name in os.listdir(class_dir):
            img_path = os.path.join(class_dir, img_name)
            if any(img_path.lower().endswith(ext) for ext in image_extensions):
                image_paths.append(img_path)
                labels.append(label)

    return image_paths, labels


def create_train_test_split(
        image_paths: List[str],
        labels: List[str], 
        test_size: float=0.2,
        random_state: int=None) -> Tuple[List[str], List[str], List[str], List[str]]:
    """
    Create stratified train and test splits.

    Args:
        image_paths: List of image paths.
        labels: List of labels.
        test_size: The proportion of the dataset to include in the test split.
        random_state: Controls the shuffling applied to the data before applying the split.

    Returns:
        train_image_paths, test_image_paths, train_labels, test_labels
    """
    train_image_paths, test_image_paths, train_labels, test_labels = train_test_split(
        image_paths, labels, stratify=labels, test_size=test_size, random_state=random_state)
    
    return train_image_paths, test_image_paths, train_labels, test_labels



class ImageDataset(Dataset):
    
    def __init__(self, image_paths: List[str], labels: List[str], transform: Callable | None=None, classes: List[str] = None):
        """
        Args:
            image_paths: List of image paths.
            labels: List of labels.
            transform: Optional transform to be applied on a sample.
            classes: List of class names.
        """
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.classes = classes if classes is not None else sorted(set(labels))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        """
        Args:
            idx: Index
        
        Returns:
            tuple: (image, label) where label is the image classification.
        """
        try:
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert("RGB")
            label = self.labels[idx]
            label_num = self.classes.index(label)

            if self.transform:
                image = self.transform(image)
            return image, label_num
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            return None


image_root_path = DATA_PATH.joinpath("cats_vs_dogs/PetImages")
image_paths, labels = load_image_paths_and_labels(image_root_path)

# Create Train, Validation and Test Splits
train_image_paths, test_image_paths, train_labels, test_labels = create_train_test_split(image_paths, labels, test_size=0.2, random_state=123)
train_image_paths, validation_image_paths, train_labels, validation_labels = create_train_test_split(train_image_paths, train_labels, test_size=0.1, random_state=123)

# Specify transformations
train_transform = None  
test_transform = None 
validation_transform = None

ds_train = ImageDataset(train_image_paths, train_labels, transform=train_transform)
ds_validation = ImageDataset(validation_image_paths, validation_labels, transform=validation_transform)
ds_test = ImageDataset(test_image_paths, test_labels, transform=test_transform)

**Question**: What is the role of: `label_num = self.classes.index(label)`?

**Question**: Why do we (often)  need a training, validation and a testset?

YOUR ANSWER HERE

Now we take a look at an example from the `Dataset` to test it.

In [None]:
import torchshow as ts
image, label = ds_train[0]
ts.show(image)

For model training we need to batch examples. Thats why we need to define a `torch.utils.data.DataLoader`. 

We also need to convert the images to Tensors. We can use the `transform` parameter to specify transformations using `torchvision.transforms`.

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms

torch.manual_seed(123)  

# Define a simple transformation
train_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),

])

# Create the dataset and dataloader
ds_train = ImageDataset(train_image_paths, train_labels, transform=train_transform)
dataloader_train = DataLoader(ds_train, batch_size=16, shuffle=True)

images, labels = next(iter(dataloader_train))

ts.show(images)
labels


**Question**: What does `shuffle=True` achieve? Why is it recommended?

**Question**: Why do we use `torch.manual_seed(123)`?

YOUR ANSWER HERE

## Inspect Samples

Now you can use the `Dataset` or `DataLoader` objects to insepct the dataset. Look for the following:

- what is the class distribution? See [numpy.unique](https://numpy.org/doc/stable/reference/generated/numpy.unique.html).
- how difficult is the problem?
- are there any obvious issues with the data?
- are the labels accurate?

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

# 2) Baselines

- Implement a small CNN
- fix random seed
- Learn input-independent baseline (provide only labels but random noise as input)
- Overfitt CNN on one batch
- Inspect pre-processing

## Implement a small CNN

For example you could implement the following architecture.

- Input Shape: (3, 32, 32)
- Convolution: 16 Filters, Kernel-Size 5x5
- Pooling: Stride 2, Kernel-Size 2
- Convolution: 32 Filter, Kernel-Size 5x5
- Global Average Pooling
- FC: 2 neurons (number of classes)

Use `ReLU` activation after each convolution.

Define a class which inherits from `torch.nn.Module`.


In [None]:
import torchinfo
import torch.nn as nn
import torch.nn.functional as F



torch.manual_seed(123)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, (5, 5))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(32, 2)

    def forward(self, x):
        # YOUR CODE HERE
        raise NotImplementedError()

net = Net()

print(net)
print(torchinfo.summary(net, input_size=(1, 3, 32, 32)))
    


## Define a training Loop

We use Pytorch-Lightning which greatly simplifys implementing boilerplate code such as  training loops.

Tutorial here: https://lightning.ai/pages/community/tutorial/step-by-step-walk-through-of-pytorch-lightning/

We also include additional metrics from [torchmetrics](https://lightning.ai/docs/torchmetrics/stable/) to easily log and calculate accuracy.  Adapt `task`if necessary!

In [None]:
import pytorch_lightning as pl
import torchmetrics

class Classifier(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="binary")

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        # Update accuracy metric
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


Change the following parameters accoring to your hardware. As you can see, this simplifies hardware switches greatly!

We want to perform a functional check only. Train the model only for 10 steps.

In [None]:
trainer = pl.Trainer(
	devices=1,
	accelerator="cpu",
	precision="32",
    max_steps=10,
    enable_checkpointing=False,
    logger=False,
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
	)

net = Net()
model = Classifier(net)
trainer.fit(model, train_dataloaders=dataloader_train)

Now we train the model for longer to get a sense of the performance. Adjust the following code accordingly:

In [None]:
trainer = pl.Trainer(
	devices=1,
	accelerator="cpu",
	precision="32",
    max_steps=1,
    enable_checkpointing=False,
    logger=False,
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
	)

net = Net()
model = Classifier(net)
trainer.fit(model, train_dataloaders=dataloader_train)

In [None]:
print(f"Metrics:  {trainer.logged_metrics}")

## Learn Input independent Model


Modify the `Dataset` class such that random images, e.g. white noise, is returned. The label remains unchanged. Then train a model.

**Question**: What kind of loss do you expect if the model works?

YOUR ANSWER HERE

In [None]:
class ImageDatasetRandom(ImageDataset):
    
    def __getitem__(self, idx: int):
        """
        Args:
            idx: Index
        
        Returns:
            tuple: (image, label) where label is the image classification.
        """
        try:
            image_path = self.image_paths[idx]
            original_image = np.array(Image.open(image_path).convert("RGB"))
            image_shape = original_image.shape
            random_image = np.random.randint(0, 256, image_shape, dtype=np.uint8)
            
            label = self.labels[idx]
            label_num = self.classes.index(label)

            if self.transform:
                random_image = self.transform(random_image)
            return random_image, label_num
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            return None

# Create the dataset and dataloader
ds_train_random = ImageDatasetRandom(train_image_paths, train_labels, transform=train_transform)
dataloader_random = DataLoader(ds_train_random, batch_size=64, shuffle=True)

Verify your work!

In [None]:
image_random, label = ds_random[0]
ts.show(image_random)

Now train your model.

In [None]:
trainer = pl.Trainer(
	devices=1,
	accelerator="cpu",
	precision="32",
    max_steps=100,
    enable_checkpointing=False,
    logger=False,
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
	)

net = Net()
model = Classifier(net)
trainer.fit(model, train_dataloaders=dataloader_random)

In [None]:
print(f"Metrics:  {trainer.logged_metrics}")

### Overfit on one Batch of Data

**Question**: What do you expect?

YOUR ANSWER HERE

In [None]:
trainer = pl.Trainer(
	devices=1,
	accelerator="cpu",
	precision="32",
    max_steps=100,
    enable_checkpointing=False,
    logger=False,
    default_root_dir=DATA_PATH.joinpath("lightning_logs"),
    limit_train_batches=1.0
	)

net = Net()
model = Classifier(net)

ds_train = ImageDataset(train_image_paths, train_labels, transform=train_transform)
dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=False)
trainer.fit(model, train_dataloaders=dataloader_train)

In [None]:
print(f"Metrics:  {trainer.logged_metrics}")

# 3) (Over)Fit

In this step we try to drive the trainings-loss as low as possible.

You can do the following:
- implement your own model
- use a pre-defined model
- use a pre-trained model

## Pre-Trained Model

In the following we will use a pre-trained model and adapt it to our dataset (transfer-learning).

### Load Model

Here we use a pre-trained model.  Read the doc here: [https://pytorch.org/vision/0.8/models.html](https://pytorch.org/vision/0.8/models.html).)

**It is important to read how the data is pre-processed for a given pre-trained model. This should be consistent with how you pre-process the data.**


In [None]:
# YOUR CODE HERE
raise NotImplementedError()

Now we adapt the output layer to match our dataset.

In [None]:
net.fc = nn.Sequential(nn.Linear(512, 2))

We can now train the model.

In [None]:
trainer = pl.Trainer(
	devices=1,
	accelerator="cpu",
	precision="32",
    max_steps=100,
    enable_checkpointing=False,
    logger=False,
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
	)

model = Classifier(net)

ds_train = ImageDataset(train_image_paths, train_labels, transform=transforms['train'])
dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=False)
trainer.fit(model, train_dataloaders=dataloader_train)

In [None]:
print(f"Metrics:  {trainer.logged_metrics}")

# 4) Regularization

Regularization is a process to deliberately limit a model's capacity in order to reduce overfitting and to improve generalization.

There are different techniques:

- Weight Decay
- Data Augmentation
- Early Stopping on Validation Set


## Weight Decay

Weight decay is a technique to reduce model complexity by adding a penalty to the magnitude of the weights. It can be implemented by decaying the weights towards 0 after each gradient descent step. 

Read the following documentation and add Weight Decay to your model: [torch.optim.Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam)

It is implemented in the optimizer.

Make it configurable.

In [None]:
import pytorch_lightning as pl
import torchmetrics

class Classifier(pl.LightningModule):
    def __init__(self, model, weight_decay=0):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="binary")
        self.weight_decay = weight_decay

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        # Update accuracy metric
        acc = self.accuracy(preds, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss
    

    def configure_optimizers(self):
        # YOUR CODE HERE
        raise NotImplementedError()


## Data Augmentation

Data augmentation is the process of applying random transformations to the input data before it is processed by the model. This increases the robustness of the model and improves its generalization capabilities.

In [None]:
import torchvision.transforms as transforms

transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(128),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(128),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Example usage with ImageDataset class
ds_train = ImageDataset(train_image_paths, train_labels, transform=transforms['train'])
ds_val = ImageDataset(val_image_paths, val_labels, transform=transforms['val'])

dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=True)
dataloader_val = DataLoader(ds_val, batch_size=64, shuffle=False)

# Create model instance
model = Classifier(net)

# Create trainer
trainer = pl.Trainer(
    devices=1,
    accelerator="cpu",
    precision="32",
    max_steps=100,
    enable_checkpointing=False,
    logger=False,
    callbacks=[early_stopping],  # Add the early stopping callback here
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
)

# Train the model
trainer.fit(model, train_dataloaders=dataloader_train, val_dataloaders=dataloader_val)


## Early Stopping

Early stopping monitors the training process on a separate validation set to determine the optimal point regarding when to stop training (when validation loss / metric is at the best level).

Pytorch-lightning provides such functionality out-of-the-box: [pytorch_lightning.callbacks.early_stopping.EarlyStopping](https://lightning.ai/docs/pytorch/stable/common/early_stopping.html)

**Make sure to let the model run enough steps such that early stopping is actually stopping the training!**

Implement a metric which early stopping should monitor. It should be one calculated on the validation set.


Inspect the `Trainer` class and set more appropriate values  (e.g. `val_check_interval` and `max_steps`)

In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


class Classifier(pl.LightningModule):
    def __init__(self, model, weight_decay=0):
        super().__init__()
        self.model = model
        self.loss_fn = nn.CrossEntropyLoss()
        self.train_accuracy = torchmetrics.Accuracy(task="binary")
        self.validation_accuracy = torchmetrics.Accuracy(task="binary")
        self.test_accuracy = torchmetrics.Accuracy(task="binary")
        self.weight_decay = weight_decay

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        
        # Update accuracy metric
        acc = self.train_accuracy(preds, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        # YOUR CODE HERE
        raise NotImplementedError()

    def test_step(self, batch, batch_idx):
        # YOUR CODE HERE
        raise NotImplementedError()
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=self.weight_decay)

# Define early stopping callback
early_stopping = EarlyStopping(
    monitor='validation_acc', 
    min_delta=0.00,
    patience=3, 
    mode='max',
    verbose=True
)

# Create model instance
model = Classifier(net, weight_decay=1e-4)

# Prepare data loaders
ds_train = ImageDataset(train_image_paths, train_labels, transform=transforms['train'])
dataloader_train = DataLoader(ds_train, batch_size=64, shuffle=True)

# Create trainer
trainer = pl.Trainer(
    devices=1,
    accelerator="cpu",
    precision="32",
    max_steps=10,
    val_check_interval=5, 
    enable_checkpointing=False,
    logger=False,
    callbacks=[early_stopping],  # Add the early stopping callback here
    default_root_dir=DATA_PATH.joinpath("lightning_logs")
)

# Train the model
trainer.fit(model, train_dataloaders=dataloader_train)

# 5) Hyper-Parameter Optimization

To optimize hyper parameters we need to consider the following:
- paramaterize training process (architecture and pre-processing)
- experiment tracking software
- evaluation procedures (such as cross-validation for smaller datasets)


**Hyper-Parameter Tuning can be time consuming!**

Ideally one uses special libraries such as [RayTune](https://docs.ray.io/en/latest/tune/index.html).

You can implement a hyper-opt loop if you like. You could test different `weight_decay` values or different model architectures.

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

# 6) Squeeze out the Juice!

You can try the following techniques to get even further:

- advanced data augmentation. For example: https://pytorch.org/vision/main/auto_examples/transforms/plot_cutmix_mixup.html#sphx-glr-auto-examples-transforms-plot-cutmix-mixup-py
- model ensembling. Train multiple models and combine their predictions.
- advanced techniques: AdamW Optimizer, Stochastic Depth Regularization

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

# Evaluate your model

We may want to evaluate our model in more detail. In particular we want to know where the model works well and where it fails. This might give us additional insight in the data and the difficulties.

In [None]:
# Prepare data loaders
ds_test= ImageDataset(test_image_paths, test_labels, transform=transforms['val'])
dataloader_test = DataLoader(ds_test, batch_size=64, shuffle=False)

trainer.test(model, dataloaders=dataloader_test)


### Confusion-Matrix

Plotten Sie eine _confusion matrix_. Benutzen Sie 

- [confusion_matrix](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html)
- [ConfusionMatrixDisplay](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ConfusionMatrixDisplay.html#sklearn.metrics.ConfusionMatrixDisplay)

In [None]:
# YOUR CODE HERE
raise NotImplementedError()

**Question:** Which classes are confused how?