## Train Infection Classifier

In the previous lessons we have built a method for cell instance segmentation and applied it to our dataset. Now we turn to classifying the cells into infected vs. non-infected cells, based on the virus marker channel, nucleus image channel and segmentation mask for each individual cell. We will use a ResNet for this task.

The goal of this lesson is to learn how to train a classification model with `torch_em`.

In [None]:
# General imports.
import torch_em

import os
from glob import glob

import h5py
import napari
import numpy as np

from skimage.measure import regionprops

In [None]:
# Define the paths to folders with the data and train/val splits.
# If you store the data somewhere else just change the 'data_folder' variable.

data_folder = "../data"
train_data_folder = os.path.join(data_folder, "train")
val_data_folder = os.path.join(data_folder, "val")

### 1. Inspect Training Data

First, we visually check all the relevant training data. We will use it to construct image patches for training the classification model as follows:
- Compute the bounding box around each cell.
- Cut out the nucleus image, virus marker and segmentation mask for the bounding box.
- Set all values outside the mask to zero.
- Derive the label (infected or not infected) for the given patch from the infetion label image.

In [None]:
# Load all necessary data for one training image.
image_path = os.path.join(train_data_folder, "gt_image_000.h5")
with h5py.File(image_path, "r") as f:
    marker = f["raw/marker/s0"][:]
    nucleus_image = f["raw/nuclei/s0"][:]
    cells = f["labels/cells/s0"][:]
    infected_labels = f["labels/infected/nuclei/s0"][:]

In [None]:
# Check it visually.
viewer = napari.Viewer()
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nucleus_image, colormap="blue", blending="additive")
viewer.add_labels(cells)
viewer.add_labels(infected_labels)

In [None]:
# Function to extract the label (infected vs. not infected) for each cell in an image.
def extract_labels_for_cells(cells, infected_labels):
    # First we get all non-background cell ids for this image.
    cell_ids = np.unique(cells)[1:]
    cell_labels = {}
    
    # We iterate over the ids.
    for cell_id in cell_ids:
        # Compute the cell mask and get the infection labels inside of it
        cell_mask = cells == cell_id
        infected_labels_cell = infected_labels[cell_mask]
        # Zero means on inferction label.
        infected_labels_cell = infected_labels_cell[infected_labels_cell != 0]
        # If we only have zeros then skip this cell.
        if infected_labels_cell.size == 0:
            cell_labels[cell_id] = None
            continue
    
        # The label values mean the following: 1 = infected, 2 = not infected.
        # If there is more than one label we need to check which of the two is more prevalent.
        label_ids, counts = np.unique(infected_labels_cell, return_counts=True)
        # We map the label id to 0, 1 (infected, not infected) because pytorch / torch_em expects zero-based indexing.
        if len(label_ids) == 1:
            assert label_ids[0] in (1, 2)
            label = label_ids[0] - 1
        else:
            assert label_ids.tolist() == [1, 2], str(label_ids)
            label = 0 if counts[0] > counts[1] else 0 
        cell_labels[cell_id] = label
    return cell_labels

In [None]:
# We apply the function to get the infection labels for the cells in our current image.
cell_infection_labels = extract_labels_for_cells(cells, infected_labels)

# And use skimage regionprops to compute other properties for all cells in the image.
props = regionprops(cells)

# Now we visualize the infected labels as points, by putting a point per cell centroid and coloring it
# according to their label using a napari points layer (see below).
points = [prop.centroid for prop in props]
infected_points = ["infected" if label == 0 else "not-infected" for label in cell_infection_labels.values()]

viewer = napari.Viewer()
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nucleus_image, colormap="blue", blending="additive")
point_layer = viewer.add_points(
    points, properties={"infected": infected_points}, face_color="infected", face_color_cycle=["orange", "cyan"],
)
point_layer.face_color_mode = "cycle"

In [None]:
# Function to extract the training patches and labels for one image.
def image_to_training_data(cells, marker, nucleus_image, infected_labels, apply_cell_mask=True):
    # Compute the infection labels with the previously defined function and the region properties.
    cell_infection_labels = extract_labels_for_cells(cells, infected_labels)
    props = regionprops(cells)
    
    # Iterate over all cells in the image and extract the training patch.
    train_image_data, train_labels = [], []
    for prop in props:
        cell_id = prop.label
        
        # Get the infection label and skip the cell if it doesn't have one.
        label = cell_infection_labels[cell_id]
        if label is None:
            continue
        
        # Get the bounding box from the properties for this cell.
        bbox = prop.bbox
        bbox = np.s_[bbox[0]:bbox[2], bbox[1]:bbox[3]]
        
        # Cut out mask, nucleus image and virus marker for this cell.
        cell_mask = cells[bbox] == cell_id
        nuc_im = nucleus_image[bbox].astype("float32")
        marker_im = marker[bbox].astype("float32")
        # And se the image values outsied of the cell to 0.
        if apply_cell_mask:
            nuc_im[~cell_mask] = 0.0
            marker_im[~cell_mask] = 0.0
        
        # Stack the 3 channels into one image and append to the training patches and labels.
        image_data = np.stack([nuc_im, marker_im, cell_mask.astype("float32")])
        train_image_data.append(image_data)
        train_labels.append(label)
        
    return train_image_data, train_labels

In [None]:
# Apply the function to our current image.
train_image_data, train_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)

In [None]:
# Visualize 5 of the training patches.
for i in range(25, 30):
    im_data = train_image_data[i]
    label = train_labels[i]
    viewer = napari.Viewer()
    viewer.add_image(im_data[0], name="nucleus-channel", colormap="blue", blending="additive")   
    viewer.add_image(im_data[1], name="marker-channel", colormap="red", blending="additive")
    viewer.add_labels(im_data[2].astype("uint8"), name="cell-mask")
    viewer.title = f"Label: {label}"

### 2. Prepare Training Data

Now we apply the function we just defined to all training and validation data to build the training and validation sets for our classification model.

In [None]:
from tqdm import tqdm

# Function that extracts the patches and labels for all images in a folder.
def prepare_classification_data(root):
    images = glob(os.path.join(root, "*.h5"))
    images.sort()

    image_data, labels = [], []
    for path in tqdm(images, desc="Prepare classification data"):
        with h5py.File(path, "r") as f:
            marker = f["raw/marker/s0"][:]
            nucleus_image = f["raw/nuclei/s0"][:]
            cells = f["labels/cells/s0"][:]
            infected_labels = f["labels/infected/nuclei/s0"][:]
            
        this_data, this_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)
        image_data.extend(this_data)
        labels.extend(this_labels)
        
    assert len(image_data) == len(labels)
    return image_data, labels

In [None]:
# Build the training and validation set.
train_data, train_labels = prepare_classification_data(train_data_folder)
print("We have", len(train_data), "samples for training")

val_data, val_labels = prepare_classification_data(val_data_folder)
print("We have", len(val_data), "samples for validation")

### 3. Train the Infection Classifier

And use the training and validation set to train a ResNet34 for infection classification, using the classification functionality from `torch_em`.

In [None]:
# Import classification functionality.
import torch
import torch.nn as nn
from torch_em.classification import default_classification_loader, default_classification_trainer
from torchvision.models.resnet import resnet34
from sklearn.metrics import accuracy_score

In [None]:
# Find the mean shape of all training and validation patches.
shapes = np.stack([np.array(im.shape[1:]) for im in (train_data + val_data)])
mean_shape = np.mean(shapes, axis=0)
print("Mean image shape:", mean_shape)

You should see that the mean image shape is roughly 52 x 52 pixels. We determine this shape to choose a suitable shape that all patches will be resized to for training the model. This is necessary to stack the patches across the batch dimensions and train the model with a batch size that is larger than 1.
We choose the closest multiple of 16 as common patch shape, which is 64 x 64.

In [None]:
# Build the training and validation loader.
batch_size = 32  # The batch size used for training.
image_shape = (64, 64)  # The common shape all patches will be resized to before stacking them in a batch.
num_workers = 4 if torch.cuda.is_available() else 1
# Build the training and validation loader.
train_loader = default_classification_loader(
    train_data, train_labels, batch_size=batch_size, image_shape=image_shape, num_workers=num_workers,
)
val_loader = default_classification_loader(
    val_data, val_labels, batch_size=batch_size, image_shape=image_shape, num_workers=num_workers,
)

In [None]:
# Define the model (a resnet 34 with two output channels).
model = resnet34(num_classes=2)
# And build the trainer class. Here, we use the cross entropy as loss function and the accuracy error as metric.
trainer = default_classification_trainer(
    name="infection-classifier", model=model,
    train_loader=train_loader, val_loader=val_loader,
    loss=nn.CrossEntropyLoss(),
    metric=lambda a, b: 1.0 - accuracy_score(a, b),
    compile_model=False,
)

In [None]:
# Train the model for 10.000 iterations.
trainer.fit(10000)

As before you can open the tensorboard to monitor the progress while training via
```
tensorboard --logdir=logs
```
See `2_cell_segmentation/torchem-train-cell-membrane-segmentation` for details.

### Exercises

Train different architectures for this task, for example a `resnet18` and a `resnet50`. Also export these models to the bioimage.io format, make sure to choose different file paths for the export so that you do not overwrite the previous exported models.
You can also compare to training this network using only PyTorch in the `pytorch_train-infection-classifier` notebook (work in progress).

### What's next?

Now we can apply the trained classification model to the test images in `apply_infection_classifier`.





**This is not working yet!**

**Skip the cells below!**

#### Export the model to bioimage.io

Now we also export the model to the bioimage.io format to import it in other tools that support this format.
See the notebook `2_cell_segmentation/torchem-train-cell-membrane-segmentation` for details.

In [None]:
import h5py
from torch_em.util.modelzoo import export_bioimageio_model

In [None]:
model_root = os.path.join(data_folder, "trained_models")
model_folder = os.path.join(model_root, "infection-classification")
os.makedirs(model_folder, exist_ok=True)

In [None]:
input_, _ = next(iter(val_loader))
input_ = input_[0:1].detach().cpu().numpy()

In [None]:
doc = """#ResNet for Covid Cell Infection Classification

A model for classifying cells into infected vs. non-infected.
"""

citations = [{"text": "Pape et al.", "doi": "https://doi.org/10.1002/bies.202000257"}]

In [None]:
export_bioimageio_model(
    checkpoint="checkpoints/infection-classifier",
    export_folder=model_folder,
    input_data=input_,
    name="infection_classification_model",
    authors=[{"name": "Your Name", "affiliation": "Your Affiliation"}],
    tags=["uner", "cells", "2d", "immunofluorescence", "classification"],
    license="CC-BY-4.0",
    documentation=doc,
    description="Classify cell membranes in IF images",
    cite=citations,
    input_optional_parameters=False,
    maintainers=[{"github_user": "Your Github Handle"}]  # alternatively you can also give your mail address
)