## Train Infection Classifier

In this lesson we will train a neural network for classifying cells into infected / non infected. We will use a [ResNet](https://arxiv.org/abs/1512.03385) for this task. ResNets are the most commonly used architecture for image classification task.

Here, we will choose the following three channels as input for the ResNet:
- The `marker` channel that indicates the presence of viral RNA in a cell. This is the most important signal for determining if a cell is infected or not.
- The `nucleus` channel, so that the position of the cell nucleus is also given to the network.
- The `cell-mask` channel, which corresponds to the binary mask for the current cell.

We will go through the steps to create this input for the network and then to train it now.

In [None]:
# General imports.

import itertools
import json
import os
import sys
import time
from glob import glob

import imageio.v3 as imageio
import matplotlib.pyplot as plt
import napari
import numpy as np

sys.path.append("..")
import utils

In [None]:
# This function will download and unpack the data and do some further data preparation.
# It will only be executed if the data has not been downloaded yet.
data_dir = "../data"
if os.path.exists(data_dir):
    print("The data is downloaded already.")
else:
    utils.prepare_data(data_dir)

In [None]:
# The data has been downloaded and separated into folders for the train, validation and test split already.
# We first create a dictionary with the location of the three different split folders.
data_dirs = {
    "train": os.path.join(data_dir, "train"),
    "val": os.path.join(data_dir, "val"),
    "test": os.path.join(data_dir, "test")
}

# And check the content for one of the samples.
# After the print you should see the images (marker, nuclei, serum), labels (cells and nuclei) as well as a json file.
train_sample0 = os.path.join(data_dirs["train"], "gt_image_000")
print(os.listdir(train_sample0))

### 1. Inspect Training Data

We check the training data for one of the images. Our goal is to classify the cells into being infected / non-infected.
To this end we cut out small images containing only the image data around a given cell.
We then us three channels as input to the neural network for classification:
- The `marker` channel, which indicates the infection.
- The `nucleus` channel, to locate the signal w.r.t. the nucleus.
- The `cell-mask`, which corresponds to the binary mask for the given cell. We will create it from the cell segmentation. It is added to delineate the cell within the local image.

We will also remove the signal that is not within the cell, in order to avoid areas outside of the cell affecting the classification result.
Note that We DO NOT use the `serum` channel as input to the model.

All the information for extracting these labels is present already, we now check it for the first training sample.

In [None]:
# The file paths for the images of the first training sample.
marker_path = os.path.join(train_sample0, "gt_image_000_marker_image.tif")
nucleus_path = os.path.join(train_sample0, "gt_image_000_nucleus_image.tif")
cell_segmentation_path = os.path.join(train_sample0, "gt_image_000_cell_labels.tif")

In [None]:
# Load the image and segmentation data for the first sample.
marker = imageio.imread(marker_path)
nuclei = imageio.imread(nucleus_path)
cells = imageio.imread(cell_segmentation_path)

In [None]:
# Load the classification data from the json file.
# This file was created in the data preparation script and it already contains the information needed
# for classifying the cells: it holds a list which contains for each cell the following three attributes:
# - cell_id: the pixel value for this cell in the cell segmentation.
# - infected_label: the classification label for this cell. This label has 4 possible values:
#  - 0: the cell has not been labeled
#  - 1: the cell is infected
#  - 2: the cell is not infected
#  - 3: the infection status of the cell is unclear
# -  bbox: the coordinates for the local window in the image that contains the cell
#    (bbox is short for bounding_box)
classification_label_path = os.path.join(train_sample0, "labels.json")
with open(classification_label_path, "r") as f:
    classification_label_data = json.load(f)

In [None]:
# Let's look at the content of the classification label data:
classification_label_data
# You should see a list with values for each cell as described above.

In [None]:
# We now want to visualize the classification data together with the image data
# to verify that we understand the labels and their meaning correctly.

# We will use napari for this visualization and so we have to process the
# classification data a bit to bring it in a suitable format for napari.
# For this we extract both the classification labels and the bounding boxes.
classification_labels = []
bounding_boxes = []
for cell_data in classification_label_data["cells"]:
    bbox = cell_data["bbox"]
    # We skip cells for which we don't have bounding box information.
    if bbox is None:
        continue
    classification_labels.append(cell_data["infected_label"])
    # We need to change the bounding box format to display them correctly in napari.
    bounding_boxes.append([bbox[:2], bbox[2:]])

In [None]:
# We now visualize the image data and classification labels with napari:
viewer = napari.Viewer()
# Add the image data: the marker channel (red channel) and nucleus channel (blue channel).
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nuclei, colormap="blue", blending="additive")
# Add another layer for visualizing the classification labels:
# we make use of the napari shape layer for this, which can be used to overlay windows on top of the image.
# Here, we use the 'bounding_boxes' list we just extracted as coordinates for this windows.
# So each window will show the cutout for one of the cells.
# We set the outline color of the window depending on the classification label and set it to 
# red for infected cells, blue for non-infected cells and grey for cells with unclear infection status.
viewer.add_shapes(
    bounding_boxes, shape_type="rectangle", face_color="transparent", edge_width=2,
    properties={"label": classification_labels},
    edge_color="label", edge_color_cycle=["red", "blue", "grey"],
)

After running the code above napari will open. With some adjustment of the contrast limits for nuclei and marker channel you should then see that the cells contained in the red ouline have a high intensity and patterned expression in the marker channel, whereas the cells in the blue outline have a low signal. Similar to the screenshot below.

![image.png](attachment:dfc7691a-cb63-49ef-91fb-6973c40b1a5a.png)

### 2. Prepare the Data for Training

We now have to prepare this training data so that it can be used in PyTorch to train a neural network for classifying the cells.
For this step we first cut out the small window around each cell as an extra image. We do this for each image and store the small cell images for the training, validation and test set separately.

In [None]:
# Here we define the function that cuts out the small images centered around the cells for one image.

# This is a helper function for normalizing an image to the range [0, 1].
# Data normalization is important when training neural networks for image data
# to make sure all the inputs are in the same data range.
def normalize(image):
    image = image.astype("float32")
    image -= image.min()
    image /= (image.max() + 1e-7)
    return image


# This is the main function for extracting the small images.
# We give it the path to the folder containing the data for a sample as input.
# Remember that this contains the images, segmentations and classification data
# in individual files for each sample. We have explored some of this data above.
def extract_images_and_labels(sample_folder):
    # First we load the classififcation data from the json file.
    classification_label_path = os.path.join(sample_folder, "labels.json")
    with open(classification_label_path, "r") as f:
        classification_data = json.load(f)

    # Then we create the filepaths for the image and segmentation data we will load. 
    sample_name = os.path.basename(sample_folder)  # We can derive the sample name from the foldername.
    marker_path = os.path.join(sample_folder, f"{sample_name}_marker_image.tif")
    nucleus_path = os.path.join(sample_folder, f"{sample_name}_nucleus_image.tif")
    cell_segmentation_path = os.path.join(sample_folder, f"{sample_name}_cell_labels.tif")

    # And load the two images (marker and nuclei) and the cell segmentation.
    # We normalize the marker and nucleus image so that their data range is in [0, 1].
    marker = normalize(imageio.imread(marker_path))
    nuclei = normalize(imageio.imread(nucleus_path))
    # Note that we must not normalize the segmentation!
    cells = imageio.imread(cell_segmentation_path)

    i = 0
    # Now we iterate over the classification data and cut out the small
    # image with marker, nucleus channel and binary mask from the window containing the cell. 
    images, labels = [], []
    for cell_data in classification_data["cells"]:
        label, bbox = cell_data["infected_label"], cell_data["bbox"]
        # We only consider data which has either the classification label 1 (cell is infected)
        # or label 2 (cell is not infected). We skip cells with a different label.
        if label not in (1, 2):
            continue

        # Convert the bounding box to a format that can be used to index the image.
        bbox = np.s_[bbox[0]:bbox[2], bbox[1]: bbox[3]]

        # Extract the small image for the marker and nucleus.
        # Note: it is important to copy the data here, otherwise we later modify data
        # from the original image, which leads to errors in data preprocessing.
        marker_im = marker[bbox].copy()
        nuc_im = nuclei[bbox].copy()
        
        # Extract the binary mask for the current cell. 
        # We create the mask by setting the pixels which have the id of our cell to 1
        # (which is done by the '==' below.)
        mask = cells[bbox] == cell_data["cell_id"]

        # Now we black out all data that is not in the cell mask. We can do this
        # by indexing the image with the inverse of the mask (`~mask`) and setting all
        # the values in the inverse mask to zero.
        # This step is taken to ignore the data from the surrounding of the cell, but not
        # directly in the cell.
        marker_im[~mask] = 0
        nuc_im[~mask] = 0

        # Now we combine the marker, nucleus and binary mask into an image with three channels
        # and then append the small image and the corresponding label to the list of images and labels.
        image = np.stack([marker_im, nuc_im, mask.astype("float32")])
        images.append(image)
        labels.append(label)

    # We check that we have the same number of small images and labels and then return them.
    assert len(images) == len(labels)
    return images, labels

In [None]:
# This function applies the data extraction function we just defined to all samples
# for a split (training, validation or test) and extracts all the corresponding small images
# and classification labels.
def prepare_split(split):
    # Get all the folders for the samples of this split.
    split_folder = data_dirs[split]
    samples = glob(os.path.join(split_folder, "gt*"))
    # Iterate over all the samples and extract the images and labels from them.
    images, labels = [], []
    for sample in samples:
        sample_images, sample_labels = extract_images_and_labels(sample)
        images.extend(sample_images)
        labels.extend(sample_labels)
    return images, labels

In [None]:
# Now we apply the functions for the training, validation and test split
# and check how many samples we have for each split.

train_images, train_labels = prepare_split("train")
print("We have", len(train_images), "training samples.")

In [None]:
val_images, val_labels = prepare_split("val")
print("We have", len(val_images), "validation samples")

In [None]:
test_images, test_labels = prepare_split("test")
print("We have", len(test_images), "test samples")

In [None]:
# Finally, we visualize some of the extracted images with napari.
# For this we first randomly select a number of samples from the training data.
n_samples_for_visualization = 25
indices_for_visualization = np.random.choice(
    len(train_images), n_samples_for_visualization, replace=False
)  # This function randomly selects indices.

# Now we add those images to napari.
viewer = napari.Viewer()
for index in indices_for_visualization:
    image, label = train_images[index], train_labels[index]
    # Change the image name depending on its infection status.
    label_name = "infected" if label == 1 else "not-infected"
    
    # We modify the image a bit to visualize it correctly in napari.
    image_for_visualization = image.transpose((1, 2, 0))
    image_for_visualization[..., -1] = image_for_visualization[..., 1]
    image_for_visualization[..., 1] *= 0.0
    viewer.add_image(image_for_visualization, name=f"sample{index}-{label_name}")

# This enables a grid-view in napari, so that you will see all images side by side.
viewer.grid.enabled = True

You should now see a grid of the small images we have extracted around the cells, as in the screenshot belows.
These will be the images that our neural network gets as input. The label is written in the name for each image on the left hand side of the viewer (not shown in the image). You should see that there are infected and not infected cells, and see that the former have a spotted intensity pattern in the read channel, while the latter only have a low intensity in that channel.

![image.png](attachment:30cc7829-960f-4212-ad36-9c0ce70f4fbe.png)

In general, it is important to visualize the images you feed into the neural network, to avoid issues that may occur during data processing.

### 3. Train the Infection Classifier

We now have the data in a format we can use to train a network with classification with PyTorch.
For training with PyTorch we need to define a few more steps and functions that convert the data to 
the PyTorch data formats and then to define the training procedure.

If you have not used PyTorch before you can find a quick introduction to how it works [here](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).

In [None]:
# Import the PyTorch functionality we need.
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import resnet18

# Additional imports for evaluation and image transformations.
from sklearn import metrics
from skimage.transform import resize

In [None]:
# We first check if we have access to a GPU.
# The model training will be much faster if we can use a GPU.
if torch.cuda.is_available():
    print("GPU is available")
    device = torch.device("cuda")
else:
    print("GPU is NOT available. The training will be very slow!")
    device = torch.device("cpu")

In [None]:
# In order to use the image data in PyTorch we need to bring all small images to a common size.
# This will enable combining multiple images in one batch and is essential for efficient training (more details below).
# So as a first step we determine the average size of all images in our training and validation set.
shapes = np.stack([np.array(image.shape) for image in (train_images + val_images)])
mean_shape = np.mean(shapes, axis=0)
print("Mean image shape:", mean_shape)

You should see that the average image shape is roughly 52 x 52 pixels. We choose the closest multiple of 16 as common shape, which is 64 x 64.

In [None]:
# To provide the data for training we need to create a PyTorch Dataset.
# Datasets provide a single example (= small image + label) for training.
# They can also be used to process the data further. Here, we resize the images
# to the common shape within the dataset.

# You can find more information on datasets here:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

# To create a dataset for our task we create a class that inherits
# from the PyTorch dataset.
class CustomDataset(Dataset):
    # Here we define the data for creating the dataset:
    # The small images for this dataset, the labels and the size for reshaping the images.
    def __init__(self, images, labels, target_size):
        self.images = images
        self.labels = labels
        self.target_size = target_size

    # A dataset needs a __len__ method that returns how many samples are in the dataset.
    # Here, the number of samples corresponds to the number of small images.
    def __len__(self):
        return len(self.images)

    # The __getitem__ method returns the image data and labels for a given sample index.
    def __getitem__(self, index):
        # Load the image data and label for this sample index.
        image = self.images[index]

        # We have to subtract 1 from the labels, because PyTorch expects the classification labels
        # to start from 0. Hence we change the labels as follows:
        # 1 -> 0 = cell is infected
        # 2 -> 1 = cell is not infected.
        label = self.labels[index] - 1
        
        # Resize the image to the common size.
        # Note: there a different possible strategies for resizing the image.
        # Here, we chose to resize it (corresponding to "zooming in" if the image is too large
        # and "zooming out" otherwise). Alternatively one could for example pad the image.
        resized_image = resize(image, self.target_size, preserve_range=True)
        return resized_image, label

In [None]:
# Now we create the datasets for the training and validation set.

# We also need to define the common shape for the images here.
# Remember that we have determined a good image shape as 64 x 64.
# The 3 is for the number of channels, which is the same for each image.
image_shape = (3, 64, 64)

train_dataset = CustomDataset(train_images, train_labels, target_size=image_shape)
val_dataset = CustomDataset(val_images, val_labels, target_size=image_shape)

In [None]:
# We now need to create the dataloaders that use the datasets we have just defined in order
# to provide the training data for PyTorch. Dataloaders do the following:
# - They choose indices for a set number of samples ("batch_size") that will be used for each training iteration.
# - Then they fetch the corresponding samples from the dataset.
# - They convert the image and label data to the correct datatype for PyTorch.
# - Finally, they combine the image data into a single input, by stacking all the small images.
#   And do the same for the labels.
#    - This is also the reason why we need to bring all images to the same shape, otherwise they could
#      not be combined into a single input.

# This is the number of samples that will be used for one iteration.
# This means the data loader will return 32 images and labels, combined to a single input / label vector.
batch_size = 32

# Number of workers for parallel loading (This is a detail you can ignore for now.)
num_workers = 4 if torch.cuda.is_available() else 1

# Create the dataloaders for training and validation.
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
# Let's check one sample from the training loader and see that this matches what we just discussed.
# x is the input data (images for one batch) and y the label data (labels for one batch).
x, y = next(iter(train_loader))

# Let's print the shapes.
print("Input shape:", x.shape)
print("Label shape:", y.shape)
print("Label values:", torch.unique(y))

# You should see that these shapes match what we discussed above:
# - x contains 32 small images, all resized to the common image shape
# - y contains 32 labels (each is either 0 or 1)

In [None]:
# This function trains the model for one "epoch".
# An epoch corresponds to iterating over all available training data once.
def train_epoch(model, loader, loss_function, optimizer):

    # Set the model to training mode.
    model.train()

    # Create a list to store the loss values throughout training.
    # We do this to later plot the evolution of the loss over time.
    loss_values = []

    # Iterate over the data loader. It will return the input data 'x'
    # and the target / label 'y'. The data is already grouped into batches.
    for x, y in loader:
        # Move data to the device (the GPU).
        x, y = x.to(device), y.to(device)

        # Zero out the gradients. If you don't do this the optimizer will
        # accumulate the weight updates from past iterations, and this will
        # lead to wrong results.
        optimizer.zero_grad()

        # Forward pass: run prediction with the model.
        prediction = model(x)

        # Compute the loss.
        loss = loss_function(prediction, y)

        # Backward pass: compute the gradients.
        loss.backward()

        # Update model parameters with the optimizer.
        optimizer.step()

        # Store the loss for the current batch.
        loss_values.append(loss.item())

    # Return the list of training loss values for the epoch.
    return loss_values

In [None]:
# This function evaluates the model.
# During evaluation we do not update the model weights, but only compute metrics,
# which are the evaluation measures for our classification problem.
def evaluate_model(model, loader, loss):
    # Set the model to evaluation mode.
    model.eval()

    # Lists to store predictions and actual labels.
    # We need this to later compute the metrics (f1-score etc.)
    predictions, labels, loss_values = [], [], []

    # Disable the gradient computation, because we don't want update the model parameters.
    with torch.no_grad():
        for x, y in loader:
            # Move data to the device (the GPU).
            x, y = x.to(device), y.to(device)

            # Run prediction.
            prediction = model(x)
            # Calculate loss
            loss_value = loss(prediction, y).item()
            loss_values.append(loss_value)
            # And compute the most likely class, which has the highest value in the prediction.
            # The code below will find the position of this highest score and return the
            # corresponding class label.
            # For example, if the prediction was
            # [[0.1, 0.9],
            #  [0.7, 0.3]]
            # We would get
            # [0, 1]
            # afterwards, corresponding to the predicted class labels.
            _, prediction = torch.max(prediction, 1)

            # Extend the lists with predicted and actual labels.
            predictions.extend(prediction.cpu().numpy())
            labels.extend(y.cpu().numpy())

    # Compute evaluation the metrics:
    f1 = metrics.f1_score(labels, predictions)
    precision = metrics.precision_score(labels, predictions)
    recall = metrics.recall_score(labels, predictions)
    confusion_matrix = metrics.confusion_matrix(labels, predictions)

    # Return evaluation metrics.
    return loss_values, f1, precision, recall, confusion_matrix

In [None]:
# Define model, optimizer and loss function.
# Here, we use a ResNet18, which is a small neural network for image classification.
model = resnet18(num_classes=2)
model.to(device)

# Define the learning rate and chose the optimizer.
# We use ADAM as optimizer, which implements an advanced version of stochastic gradient descent,
# which is more robust than normal stochastic gradient descent.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function. We use the cross entropy, which is the standard loss function
# for classification tasks.
loss_function = torch.nn.CrossEntropyLoss()

In [None]:
# The lists to keep track of the values we monitor during training and evalation:
# the loss and different metrics
train_loss_history = []
val_loss_history = []
avg_val_losses = []
f1_history = []
precision_history = []
recall_history = []
confusion_matrix_history = []

In [None]:
# Now we run the training!

# We train our model for 15 epochs.
num_epochs = 15

# The code below calls all the functions we defined beforehand to run
# training and validation for each epoch and to save the loss and metric values.
for epoch in range(num_epochs):
    t_start = time.time()
    train_losses = train_epoch(model, train_loader, loss_function, optimizer)
    val_losses, f1, precision, recall, confusion_matrix = evaluate_model(model, val_loader, loss_function)

    train_loss_history.extend(train_losses)
    val_loss_history.extend(val_losses)
    f1_history.append(f1)
    precision_history.append(precision)
    recall_history.append(recall)
    confusion_matrix_history.append(confusion_matrix)

    avg_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    avg_val_losses.append(avg_val_loss)
    t_epoch = time.time() - t_start
    print("Epoch", epoch, "training loss:", avg_loss, "val_loss", avg_val_loss, "validation f1-score:", f1, "ran for", t_epoch, "s")

Now we plot the loss and metric values over the course of the training. We have implemented the functions in the cell below. You don't need to look at the details of this since it's just plotting curves, which takes quite a lot of code.

In [None]:
# Functions to plot the metrics
def simple_exponential_smoothing(series, alpha):
    forecast = [series[0]]  # Initial forecast is the first observation in the series

    for t in range(1, len(series)):
        forecast_t = alpha * series[t] + (1 - alpha) * forecast[t - 1]
        forecast.append(forecast_t)

    return forecast
    
def plot_confusion_matrix(conf_matrix):
    classes = np.arange(conf_matrix.shape[0])
    plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    fmt = 'd'
    thresh = conf_matrix.max() / 2.

    for i, j in itertools.product(range(conf_matrix.shape[0]), range(conf_matrix.shape[1])):
        plt.text(j, i, format(conf_matrix[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if conf_matrix[i, j] > thresh else "black")

def plot_training_curves():
    iterations = range(1, len(train_loss_history) + 1)
    val_iterations = range(1, len(val_loss_history) + 1)
    val_epochs = range(1, len(avg_val_losses))
    epochs = range(1, len(f1_history) + 1)

    # Smooth the training curve
    alpha = 0.01
    smooth_train_loss_history = simple_exponential_smoothing(train_loss_history, alpha)
    # Plotting train loss
    plt.figure(figsize=(16, 9))
    plt.subplot(2, 3, 1)
    plt.plot(iterations, smooth_train_loss_history, label='Training Loss')
    plt.title('Training Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()

    # Plotting validation loss
    plt.subplot(2, 3, 2)
    plt.plot(epochs, avg_val_losses, label='Validation Loss')
    plt.title('Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plotting F1-Score
    plt.subplot(2, 3, 3)
    plt.plot(epochs, f1_history, label='F1-Score')
    plt.title('Validation F1-Score')
    plt.xlabel('Epochs')
    plt.ylabel('F1-Score')
    plt.legend()

    # Plotting precision
    plt.subplot(2, 3, 4)
    plt.plot(epochs, precision_history, label='Precision')
    plt.title('Validation Precision')
    plt.xlabel('Epochs')
    plt.ylabel('Precision')
    plt.legend()

    # Plotting recall
    plt.subplot(2, 3, 5)
    plt.plot(epochs, recall_history, label='Recall')
    plt.title('Validation Recall')
    plt.xlabel('Epochs')
    plt.ylabel('Recall')
    plt.legend()

    # Plotting confusion matrix
    plt.subplot(2, 3, 6)
    conf_matrix = confusion_matrix_history[-1]
    plot_confusion_matrix(conf_matrix)

    plt.tight_layout()
    plt.show()

In [None]:
plot_training_curves()

#### Interpret the plots!

Can you interpret what you see in the plots?
- How does the loss behave over the course of the training? Does it decrease the whole time?
- How do the metrics behave? Do they decrease the whole time?
- We monitor different metrics to see what kinds of errors our model does:
    - A high precision means that most cells predicted with label 1 (not infected) are correctly identified.
    - A high recall means that most cells that actually have label 1 (infected) are also found (i.e. they are predicted as infected).
    - A high F1-Score means that both precision and recall are good.
- Given this knowledge, can you identify any problems in the prediction?  

Hints:
<details>
  <summary>Click to expand!</summary
    
You will most likely see over-fitting: the train loss keeps decreasing, but the validation loss increasses.       
Aloo, precision, recall and f1-score may show problems due to the class imbalance (more infected than non infected cells)
</details>.

### 4. Apply the classifier to the test set

After we have trained the model, we now apply it to the test set, in order to estimate its performance for unseen data.

In [None]:
test_dataset = CustomDataset(test_images, test_labels, target_size=image_shape)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
# We can re-use the evaluate model function from earlier.
_, f1_score, precision, recall, confusion_matrix = evaluate_model(model, test_loader, loss_function)

# Print the scores and plot the confusion matrix.
print("F1-Score:", f1_score)
print("Precision:", precision)
print("Recall:", recall)

plt.figure(figsize=(12, 7))
plot_confusion_matrix(confusion_matrix)
plt.show()

## Exercises

To further understand some of the most important aspects of training neural networks, please work on the following exercises:

- Explore the influence of the learning rate on training. We have used the learning rate 0.0001 (1e-4) for now. Repeat the training with the learning rates 0.01 (1e-2), 0.001 (1e-3) and 0.00001 (1e-5). Plot the training and metric curves for each learning rate and compare the results among them. Describe what you observe and decide on the best learning rate among the ones tested.

- Train a [ResNet 34](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet34.html) instead of the ResNet 18. Does it perform better (on the test set) compared to it? What do you think would be required to train even larger models?

- Balance the classes in your loss function. You may have observed earlier that the imbalanced classes lead to problems during training and that in some cases only the majority class is predicted. You can reweight the classes, higher weight corresponds to more influence on the loss, using the `weight` parameter of the `CrossEntropyLoss`, see also [the documentation](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). What would be a good weight value? Do you see the desired effect after training with it?

**Tip:** Make a copy of the notebook for each of these exercises and then solve the exercise in that copy.

## SOLUTION: Learning Rates - Metrics
### LR 1e-2
![infection_train_LR_1e-2.png](attachment:7486fb18-1ffa-4635-88c7-eac8e56191c6.png)
### LR 1e-3
![infection_train_LR_1e-3.png](attachment:41688932-2183-467f-b864-c809cbac20f9.png)
### LR 1e-5
![infection_train_LR_1e-5.png](attachment:e7684831-1d68-43d5-bef7-d99b3b83f8fd.png)

## SOLUTION: Learning Rates - Confusion Matrices
### LR 1e-2
![infection_train_LR_1e-2_confusion.png](attachment:425e5d06-fb52-4d93-b5cb-0b2292e97e26.png)
### LR 1e-3
![infection_train_LR_1e-3_confusion.png](attachment:702132ea-f21f-4ba7-b94b-6a98a4ac61c0.png)
### LR 1e-5
![infection_train_LR_1e-5_confusion.png](attachment:ddc06ea6-f8a9-405d-90e7-c373e40d0812.png)

In [None]:
## SOLUTION: ResNet 34
from torchvision.models.resnet import resnet34
model = resnet34(num_classes=2)
model.to(device)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

loss_function = torch.nn.CrossEntropyLoss()

In [None]:
# The lists to keep track of the values we monitor during training and evalation:
# the loss and different metrics
train_loss_history = []
val_loss_history = []
avg_val_losses = []
f1_history = []
precision_history = []
recall_history = []
confusion_matrix_history = []

In [None]:
num_epochs = 15

# The code below calls all the functions we defined beforehand to run
# training and validation for each epoch and to save the loss and metric values.
for epoch in range(num_epochs):
    t_start = time.time()
    train_losses = train_epoch(model, train_loader, loss_function, optimizer)
    val_losses, f1, precision, recall, confusion_matrix = evaluate_model(model, val_loader, loss_function)

    train_loss_history.extend(train_losses)
    val_loss_history.extend(val_losses)
    f1_history.append(f1)
    precision_history.append(precision)
    recall_history.append(recall)
    confusion_matrix_history.append(confusion_matrix)

    avg_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    avg_val_losses.append(avg_val_loss)
    t_epoch = time.time() - t_start
    print("Epoch", epoch, "training loss:", avg_loss, "val_loss", avg_val_loss, "validation f1-score:", f1, "ran for", t_epoch, "s")

In [None]:
plot_training_curves()

In [None]:
# We can re-use the evaluate model function from earlier.
_, f1_score, precision, recall, confusion_matrix = evaluate_model(model, test_loader, loss_function)

# Print the scores and plot the confusion matrix.
print("F1-Score:", f1_score)
print("Precision:", precision)
print("Recall:", recall)

plt.figure(figsize=(12, 7))
plot_confusion_matrix(confusion_matrix)
plt.show()

In [None]:
## SOLUTION: Balance classes

model = resnet18(num_classes=2)
model.to(device)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# now we determine the distribution of the data
infected_labels = []
for entry in classification_label_data["cells"]:
    infected_labels.append(entry["infected_label"])
unique_labels, counts = np.unique(infected_labels, return_counts=True)
total_samples = counts[1] + counts[2]

# to clalculate the balancing weights use one of these:
# - 1. proportional weight calculation: 
#   - weight_for_class_i = total_samples / (num_samples_in_class_i * num_classes)
#   - from https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
#   - does not stay in [0,1]!

# - 2. Inverse class frequency:
#   - weight_for_class_i = 1 / num_samples_in_class_i
#   - might provide very small weights that could lead to Numerical instability or 
#     Potential gradient vanishing

# - 3. Inverse class frequency with normalization:
#   - normed_weight_for_class_i = weight_for_class_i / sum_of_all_weights
# here is a code example:
# weight_infected = 1 / counts[1]
# weight_not_infected = 1 / counts[2]
# total_weight = weight_infected + weight_not_infected
# normalized_weight_infected = weight_infected / total_weight
# normalized_weight_not_infected = weight_not_infected / total_weight

# - 4. Use reference class:
#    - set one class to 1.0 the others are claculated as follows
#    - weight_for_class_i = num_samples_in_reference_class / num_samples_in_class_i 
weight_infected = counts[2] / counts[1]
weight_not_infected = 1


weights = [weight_infected, weight_not_infected]
print(weights)
class_weights = torch.FloatTensor(weights).to(device)
loss_function = torch.nn.CrossEntropyLoss(weight=class_weights)

In [None]:
# The lists to keep track of the values we monitor during training and evalation:
# the loss and different metrics
train_loss_history = []
val_loss_history = []
avg_val_losses = []
f1_history = []
precision_history = []
recall_history = []
confusion_matrix_history = []

In [None]:
num_epochs = 15

# The code below calls all the functions we defined beforehand to run
# training and validation for each epoch and to save the loss and metric values.
for epoch in range(num_epochs):
    t_start = time.time()
    train_losses = train_epoch(model, train_loader, loss_function, optimizer)
    val_losses, f1, precision, recall, confusion_matrix = evaluate_model(model, val_loader, loss_function)

    train_loss_history.extend(train_losses)
    val_loss_history.extend(val_losses)
    f1_history.append(f1)
    precision_history.append(precision)
    recall_history.append(recall)
    confusion_matrix_history.append(confusion_matrix)

    avg_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    avg_val_losses.append(avg_val_loss)
    t_epoch = time.time() - t_start
    print("Epoch", epoch, "training loss:", avg_loss, "val_loss", avg_val_loss, "validation f1-score:", f1, "ran for", t_epoch, "s")

In [None]:
plot_training_curves()

In [None]:
# We can re-use the evaluate model function from earlier.
_, f1_score, precision, recall, confusion_matrix = evaluate_model(model, test_loader, loss_function)

# Print the scores and plot the confusion matrix.
print("F1-Score:", f1_score)
print("Precision:", precision)
print("Recall:", recall)

plt.figure(figsize=(12, 7))
plot_confusion_matrix(confusion_matrix)
plt.show()

## Advanced Exercises

If you want to dive deeper into building your own networks with PyTorch or adapting the training mechanism you can work on these advanced exercises. The changes here will require more coding than before.

- Build your own classification network. Use the functionality from `torch.nn` to build a small convolutional neural network for the classification task. Train it and compare the performance to the ResNets you have trained before. Check out the [pytorch tutorial](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) to get started on building your own network.

- Implement data augmentation. Data augmentation is a common technique to effectively increase the training data by applying transformations that change the inputs but not their meaning. For example, for images flipping usually does not change the meaning of an image, and this operation can be used to increase the pool of available training data. You can integrate the early stopping directly into the `Dataset` and you can use [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html) to implement them. Does training with your augmentations improve the network?

- Implement early stopping: the idea behind early stopping is to stop training the network as soon as it is not improving any longer on the validation set. I.e. as soon as the loss on the validation set starts to plateau or is increasing. This is an indication that the network starts to overfit, so we do not want to train it any longer. You need to update the training functions to implement this, and can search for `Early Stopping` for examples on how to implement it.