## Train Infection Classifier

In [None]:
import os
from glob import glob

import h5py
import napari
import numpy as np

from skimage.measure import regionprops

In [None]:
data_folder = "../data"
train_data_folder = "../data/train"
val_data_folder = "../data/val"

### 1. Inspect Training Data

In [None]:
image_path = os.path.join(train_data_folder, "gt_image_000.h5")
with h5py.File(image_path, "r") as f:
    marker = f["raw/marker/s0"][:]
    nucleus_image = f["raw/nuclei/s0"][:]
    cells = f["labels/cells/s0"][:]
    infected_labels = f["labels/infected/nuclei/s0"][:]

In [None]:
viewer = napari.Viewer()
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nucleus_image, colormap="blue", blending="additive")
viewer.add_labels(cells)
viewer.add_labels(infected_labels)

In [None]:
def extract_labels_for_cells(cells, infected_labels):
    cell_ids = np.unique(cells)[1:]
    cell_labels = {}
    for cell_id in cell_ids:
        cell_mask = cells == cell_id
        infected_labels_cell = infected_labels[cell_mask]
        infected_labels_cell = infected_labels_cell[infected_labels_cell != 0]
        if infected_labels_cell.size == 0:
            cell_labels[cell_id] = None
            continue
    
        # meaning of the labels: 1 = infected, 2 = not infected
        # if there is more than one label
        # check which of the two is more prevalent
        label_ids, counts = np.unique(infected_labels_cell, return_counts=True)
        # we map the label id to 0, 1 (infected, not infected) because pytorch / torch_em expects zero-based indexing
        if len(label_ids) == 1:
            assert label_ids[0] in (1, 2)
            label = label_ids[0] - 1
        else:
            assert label_ids.tolist() == [1, 2], str(label_ids)
            label = 0 if counts[0] > counts[1] else 0 
        cell_labels[cell_id] = label
    return cell_labels

In [None]:
cell_infection_labels = extract_labels_for_cells(cells, infected_labels)
props = regionprops(cells)

points = [prop.centroid for prop in props]
infected_points = ["infected" if label == 0 else "not-infected" for label in cell_infection_labels.values()]

viewer = napari.Viewer()
viewer.add_image(marker, colormap="red", blending="additive")
viewer.add_image(nucleus_image, colormap="blue", blending="additive")
point_layer = viewer.add_points(
    points, properties={"infected": infected_points}, face_color="infected", face_color_cycle=["cyan", "orange"],
)
point_layer.face_color_mode = "cycle"

In [None]:
def image_to_training_data(cells, marker, nucleus_image, infected_labels, apply_cell_mask=True):
    cell_infection_labels = extract_labels_for_cells(cells, infected_labels)
    props = regionprops(cells)
    
    train_image_data, train_labels = [], []
    for prop in props:
        cell_id = prop.label
        
        label = cell_infection_labels[cell_id]
        if label is None:
            continue
        
        bbox = prop.bbox
        bbox = np.s_[bbox[0]:bbox[2], bbox[1]:bbox[3]]
        
        cell_mask = cells[bbox] == cell_id
        nuc_im = nucleus_image[bbox].astype("float32")
        marker_im = marker[bbox].astype("float32")
        if apply_cell_mask:
            nuc_im[~cell_mask] = 0.0
            marker_im[~cell_mask] = 0.0
        
        image_data = np.stack([nuc_im, marker_im, cell_mask.astype("float32")])
        train_image_data.append(image_data)
        train_labels.append(label)
        
    return train_image_data, train_labels

In [None]:
train_image_data, train_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)

In [None]:
# check out five training images
for i in range(25, 30):
    im_data = train_image_data[i]
    label = train_labels[i]
    viewer = napari.Viewer()
    viewer.add_image(im_data[0], name="nucleus-channel", colormap="blue", blending="additive")   
    viewer.add_image(im_data[1], name="marker-channel", colormap="red", blending="additive")
    viewer.add_labels(im_data[2].astype("uint8"), name="cell-mask")
    viewer.title = f"Label: {label}"

### 2. Prepare Training Data

In [None]:
from tqdm import tqdm

def prepare_classification_data(root):
    images = glob(os.path.join(root, "*.h5"))
    images.sort()

    image_data, labels = [], []
    for path in tqdm(images, desc="Prepare classification data"):
        with h5py.File(path, "r") as f:
            marker = f["raw/marker/s0"][:]
            nucleus_image = f["raw/nuclei/s0"][:]
            cells = f["labels/cells/s0"][:]
            infected_labels = f["labels/infected/nuclei/s0"][:]
            
        this_data, this_labels = image_to_training_data(cells, marker, nucleus_image, infected_labels)
        image_data.extend(this_data)
        labels.extend(this_labels)
        
    assert len(image_data) == len(labels)
    return image_data, labels

In [None]:
train_data, train_labels = prepare_classification_data(train_data_folder)
print("We have", len(train_data), "samples for training")

val_data, val_labels = prepare_classification_data(val_data_folder)
print("We have", len(val_data), "samples for validation")

### 3. Train the Infection Classifier

In [None]:
from torch_em.classification import default_classification_loader, default_classification_trainer
from torchvision.models.resnet import resnet34

In [None]:
# find the mean shape and 
shapes = np.stack([np.array(im.shape[1:]) for im in (train_data + val_data)])
mean_shape = np.mean(shapes, axis=0)
print("Mean image shape:", mean_shape)

In [None]:
batch_size = 32
image_shape = (64, 64)
train_loader = default_classification_loader(
    train_data, train_labels, batch_size=batch_size, image_shape=image_shape
)
val_loader = default_classification_loader(
    val_data, val_labels, batch_size=batch_size, image_shape=image_shape
)

In [None]:
model = resnet34(num_classes=2)
trainer = default_classification_trainer(
    name="infection-classifier", model=model,
    train_loader=train_loader, val_loader=val_loader,
    # TODO loss and metric
)

In [None]:
trainer.fit(50000)

### 3. Export

### Exercises

**What's next?**