## Apply Infection Classifier

Finally, we apply the trained infection classifier to the test data, also using the cell segmentation we predicted instead of the ground-truth. We will also evaluate the accuracy of predictions.

In [None]:
# General imports.
import os
from glob import glob

import h5py
import napari
import numpy as np
from skimage.measure import regionprops

In [None]:
# Define the paths to folders with the data and predictions.
# If you store the data somewhere else just change the 'data_folder' variable.

data_folder = "../data"
output_folder = os.path.join(data_folder, "predictions")

### 1. Test Data Extraction

We first extract the input patches and labels for the test images. We copy these functions from the previous function. With the difference that we do not skip cells that could not be assigned a label here, but instead set them to -1.

In [None]:
# Function to extract the label (infected vs. not infected) for each cell in an image.
def extract_labels_for_cells(cells, infected_labels):
    # First we get all non-background cell ids for this image.
    cell_ids = np.unique(cells)[1:]
    cell_labels = {}
    
    # We iterate over the ids.
    for cell_id in cell_ids:
        # Compute the cell mask and get the infection labels inside of it
        cell_mask = cells == cell_id
        infected_labels_cell = infected_labels[cell_mask]
        # Zero means on inferction label.
        infected_labels_cell = infected_labels_cell[infected_labels_cell != 0]

        # If we only have zeros then mark this label with -1
        if infected_labels_cell.size == 0:
            cell_labels[cell_id] = -1
            continue
    
        # The label values mean the following: 1 = infected, 2 = not infected.
        # If there is more than one label we need to check which of the two is more prevalent.
        label_ids, counts = np.unique(infected_labels_cell, return_counts=True)
        # We map the label id to 0, 1 (infected, not infected) because pytorch / torch_em expects zero-based indexing.
        if len(label_ids) == 1:
            assert label_ids[0] in (1, 2)
            label = label_ids[0] - 1
        else:
            assert label_ids.tolist() == [1, 2], str(label_ids)
            label = 0 if counts[0] > counts[1] else 0 
        cell_labels[cell_id] = label

    return cell_labels

In [None]:
# Function to extract the training patches and labels for one image.
def image_to_training_data(cells, marker, nucleus_image, infected_labels, apply_cell_mask=True):
    # Compute the infection labels with the previously defined function and the region properties.
    cell_infection_labels = extract_labels_for_cells(cells, infected_labels)
    props = regionprops(cells)
    
    # Iterate over all cells in the image and extract the training patch.
    train_image_data, train_labels = [], []
    for prop in props:
        cell_id = prop.label
        
        # Get the infection label and skip the cell if it doesn't have one.
        label = cell_infection_labels[cell_id]
        
        # Get the bounding box from the properties for this cell.
        bbox = prop.bbox
        bbox = np.s_[bbox[0]:bbox[2], bbox[1]:bbox[3]]
        
        # Cut out mask, nucleus image and virus marker for this cell.
        cell_mask = cells[bbox] == cell_id
        nuc_im = nucleus_image[bbox].astype("float32")
        marker_im = marker[bbox].astype("float32")
        # And se the image values outsied of the cell to 0.
        if apply_cell_mask:
            nuc_im[~cell_mask] = 0.0
            marker_im[~cell_mask] = 0.0
        
        # Stack the 3 channels into one image and append to the training patches and labels.
        image_data = np.stack([nuc_im, marker_im, cell_mask.astype("float32")])
        train_image_data.append(image_data)
        train_labels.append(label)
        
    return train_image_data, train_labels

In [None]:
# Get the test image and test prediction paths.
test_images = glob(os.path.join(data_folder, "test", "*.h5"))
test_images.sort()
test_predictions = glob(os.path.join(output_folder, "*.h5"))
test_predictions.sort()
assert len(test_images) == len(test_predictions)

In [None]:
# Load the inputs and labels for the test images.
classification_inputs, classification_labels = [], []
for test_image, test_prediction in zip(test_images, test_predictions):
    with h5py.File(test_image, "r") as f:
        marker = f["raw/marker/s0"][:]
        nucleus_image = f["raw/nuclei/s0"][:]
        infected_labels = f["labels/infected/nuclei/s0"][:]
    with h5py.File(test_prediction, "r") as f:
        cells = f["segmentations/cells/watershed_based"][:]
    inputs, labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)
    classification_inputs.append(inputs)
    classification_labels.append(labels)

### 2. Prediction and Visualization for a Test Image

We run prediction for one of the test images and visualize the results in napari.

In [None]:
# torch and model imports
import torch
from torch_em.classification import default_classification_loader
from torchvision.models.resnet import resnet34

In [None]:
# Use GPU if available, otherwise the CPU.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Load the model from the best checkpoint.
model_path = "checkpoints/infection-classifier/best.pt"
model = resnet34(num_classes=2)
model_state = torch.load(model_path)["model_state"]
model.load_state_dict(model_state)
model.eval()
model = model.to(device)

In [None]:
# Function to run prediction and to return the corresponding labels in a format that can be 
# evaluated by sklearn.metrics (see below).
def predict_infection(model, inputs, labels, batch_size=128):
    loader = default_classification_loader(
        inputs, labels, batch_size=batch_size, image_shape=(64, 64),
    )
    y_pred, y_true = [], []
    
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            pred = model(x).cpu().numpy()
            class_pred = np.argmax(pred, axis=1)
            y_pred.append(class_pred)
            y_true.append(y.numpy().squeeze())
            
    y_pred = np.concatenate(y_pred)
    y_true = np.concatenate(y_true)
    return y_pred, y_true

In [None]:
# Get the infection predictions for the first input.
infection_predictions, _ = predict_infection(model, classification_inputs[0], classification_labels[0])

In [None]:
# Load the images and segmentation for the first test image again.
with h5py.File(test_images[0], "r") as f:
    marker = f["raw/marker/s0"][:]
    nucleus_image = f["raw/nuclei/s0"][:]
    infected_labels = f["labels/infected/nuclei/s0"][:]
    
with h5py.File(test_predictions[0], "r") as f:
    cells = f["segmentations/cells/watershed_based"][:]

In [None]:
# Visualize the predictions in napari.
props = regionprops(cells)

points = [prop.centroid for prop in props]
infected_points = ["infected" if pred == 0 else "not-infected" for pred in infection_predictions]

viewer = napari.Viewer()
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nucleus_image, colormap="blue", blending="additive")
point_layer = viewer.add_points(
    points, properties={"infected": infected_points}, face_color="infected", face_color_cycle=["orange", "cyan"],
)
point_layer.face_color_mode = "cycle"

### 3. Prediction and Evaluation for the Test Set

Run prediction for all test images and evaluate the accuracy of the result.

In [None]:
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [None]:
# Get the prediction and labels for all images.
y_pred, y_true = [], []
for inputs, labels in tqdm(zip(classification_inputs, classification_labels), total=len(classification_inputs)):
    pred, true = predict_infection(model, inputs, labels)
    y_pred.append(pred)
    y_true.append(true)
y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)

In [None]:
# Exclude the labels and predictions for which labels are -1 (could not be mapped to either of the two labels).
valid_labels = y_true != -1
y_pred, y_true = y_pred[valid_labels], y_true[valid_labels]

In [None]:
# Compute the accuracy.
accuracy = accuracy_score(y_true, y_pred)
print("The overall accuracy is:", accuracy)

### Exercises

- If you have trained any other models in the previous notebook then evaluate them as well and compare the performance between the different models.
- Use other metrics form [sklearn.metrics](https://scikit-learn.org/stable/modules/model_evaluation.html) to evaluate other aspects of the results. In particular check if there are differences in the precision vs. recall and think about what this implies experimentally.
- Check if there are any systematic differences in the scores between the different test images. If yes, check the corresponding image data and see if you can find a reason for this visually.