# Malaria cells complete example

**Note :** to use this notebook in Google Colab, create a new cell with
the following lines and run it.

```shell
!git clone https://gitlab.in2p3.fr/jbarnier/ateliers_deep_learning.git
%cd ateliers_deep_learning
!pip install .
```


In [None]:
import random
import time
from datetime import timedelta
from pathlib import Path

import matplotlib.pyplot as plt
import plotnine as pn
import polars as pl
import torch
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchmetrics import MetricCollection
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassConfusionMatrix,
    MulticlassPrecision,
    MulticlassRecall,
)
from torchvision.transforms import v2


In this notebook we will implement a complete training process using the
[Malaria cells
image](https://www.kaggle.com/datasets/iarunava/cell-images-for-detecting-malaria)
dataset, which has been built from the [Malaria screener
project](https://lhncbc.nlm.nih.gov/LHC-research/LHC-projects/image-processing/malaria-project.html).
It will be quite close to the previous sign language notebook, the main
difference being that input data is provided as image files instead of a
precomputed data frame.

The dataset contains about 27 000 cell images, half of them being
infected by malaria. Images have been reduced to 28x28 pixels and
converted to grayscale. The image files are available in the
`data/malaria_cells/parasitized` and `data/malaria_cells/uninfected`
directories.

We can display randomly chosen parasitized and uninfected image cells.


In [None]:
uninfected_imgs = list(Path("data/malaria_cells/uninfected").glob("*.jpg"))
parasitized_imgs = list(Path("data/malaria_cells/parasitized").glob("*.jpg"))

u_img = random.choice(uninfected_imgs)
p_img = random.choice(parasitized_imgs)
fig, ax = plt.subplots(ncols=2)
ax[0].imshow(Image.open(u_img), cmap="gray")
ax[0].set_title(f"{u_img.stem} - Uninfected")
ax[1].imshow(Image.open(p_img), cmap="gray")
ax[1].set_title(f"{p_img.stem} - Parasitized")
plt.show()


## Datasets and dataloaders

First thing we have to do is to manage data loading and mini-batches
creation. Here our input data are image files, and their labels is
implicitly given by their directory (`parasitized` or `uninfected`).

To manage our input data we create a `CellsDataset` class. The class
constructor will load all image files, convert them to tensors and scale
their pixel values between 0 and 1 using `torchvision`. Images tensors
are stored in an `images` dictionary with image filenames as keys, and
the labels are stored in a corresponding `labels` dictionary.


In [None]:
def process_image(img_file):
    """
    Load and process an image file.

    Parameters
    ----------
    img_file : str
        Path to the image file to be processed.

    Returns
    -------
    torch.Tensor
        Processed image as a tensor.
    """
    # Define transformation pipeline: convert to torchvision image,
    # then to a scaled tensor
    transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
    # Read image and apply pipeline
    img = Image.open(img_file)
    img = transform(img)
    return img


class CellsDataset(Dataset):
    def __init__(self, uninfected_path, parasitized_path) -> None:
        u_imgs = Path(uninfected_path).glob("*.jpg")
        p_imgs = Path(parasitized_path).glob("*.jpg")

        # Initialize images and labels dictionaries
        self.images = {}
        self.labels = {}

        # Process uninfected images (label 0)
        for img_file in u_imgs:
            filename = img_file.stem
            img = process_image(img_file)
            self.images[filename] = img
            self.labels[filename] = torch.tensor(0)

        # Process parasitized images (label 1)
        for img_file in p_imgs:
            filename = img_file.stem
            img = process_image(img_file)
            self.images[filename] = img
            self.labels[filename] = torch.tensor(1, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        key = list(self.labels.keys())[index]
        # Returns the (image, label) pair of the given index
        return self.images[key], self.labels[key]

We use our class to create a new dataset object.


In [None]:
dataset = CellsDataset(
    uninfected_path="data/malaria_cells/uninfected",
    parasitized_path="data/malaria_cells/parasitized",
)

We then split this dataset into a training, validation, and test sets.


In [None]:
train_data, valid_data, test_data = random_split(dataset, [0.7, 0.2, 0.1])

We can confirm that our data is in good shape by plotting an image and
its label at a given index.


In [None]:
def show_image(data, index):
    img, label = data[index]
    plt.imshow(img.reshape(28, 28), cmap="gray")
    plt.title(f"Label: {label.item()}")
    plt.show()


# Display a random image
index = random.randint(0, len(train_data))
show_image(train_data, index)

Finally, we define a small function which will create three data loaders
for our training, validation and test datasets for a given batch size.
The training set is shuffled at each epoch.


In [None]:
def generate_loaders(batch_size):
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    return train_loader, valid_loader, test_loader

## Training code

We now detect what device the training process will run on. It will be
either `"cuda"` if a GPU is available, or `"cpu"` otherwise.


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {DEVICE}")

There are many ways to organize the training process and its code in
pytorch, the one presented here is just one way to do it among many
others. Feel free to check one of the many tutorials available online to
see which type of code organization suits you best.

Here we organize the different functions and attributes in a custom
`Training` class:

-   The class constructor gets some model hyperparameters as arguments
    and store them as object attributes. It also defines and stores a
    loss function, the optimizer and the metrics computed to evaluate
    the results.
-   The `train_loop` method runs the training for one epoch using a
    training dataloader.
-   The `eval_loop` method computes loss and metrics on a given loader
    in evaluation mode (without backpropagation). It is used to evaluate
    the model on the validation dataset during training (or on the test
    dataset post-training).
-   The main `train` method runs the training by running the
    `train_loop` on the train loader for a given number of epochs. At
    the end of each epoch it also calls `eval_loop` on the validation
    dataset.
-   We use `torchmetrics` to compute the metrics more easily across
    batches.
-   During training, the best model (the one with the lowest validation
    loss) is saved to disk with the `save_model` method. At the end of
    the training, the best model is loaded with `load_best_model` as our
    `model` attribute so that it will be easily used in the rest of the
    notebook.


In [None]:
class Training:
    def __init__(
        self, model: nn.Module, device: str, learning_rate: float, model_path: str
    ) -> None:
        """
        Training process class.

        Args:
            model (nn.Module): model to be trained.
            device (str): device to be used ("cpu" or "cuda").
            learning_rate (float): learning rate.
            model_path (str): path where the best model will be saved during training.
        """
        self.device = device
        # Save the model as attribute and send it to device
        self.model = model.to(self.device)
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Optimizer
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)  # type: ignore
        self.n_classes = 2

        # Path to save best model
        best_model_path = Path(model_path)
        if not best_model_path.exists():
            best_model_path.mkdir(exist_ok=True)
        # File to save best model
        self.best_model_path = Path(best_model_path) / "model_best.pt"

        # Attribute to store loss values during training process
        self.train_history = None

        # Metrics
        self.metrics = MetricCollection(
            {
                # Global accuracy
                "accuracy": MulticlassAccuracy(
                    num_classes=self.n_classes, average="micro"
                ),
                # Confusion matrix
                "confusion_matrix": MulticlassConfusionMatrix(
                    num_classes=self.n_classes
                ),
                # Precision for each label
                "precision": MulticlassPrecision(
                    num_classes=self.n_classes, average=None
                ),
                # Recall for each label
                "recall": MulticlassRecall(num_classes=self.n_classes, average=None),
            }
        )

    def train_loop(self, loader: DataLoader) -> float:
        """
        Training loop for one epoch.

        Args:
            loader (DataLoader): loader to use for the training loop.

        Returns:
            float: average loss value.
        """
        # Set the model to training mode
        self.model.train()
        loss = 0
        # Iterate through all batches
        for input, target in loader:
            input = input.to(self.device)
            target = target.to(self.device)
            # Reset gradients
            self.optimizer.zero_grad()
            # Forward pass: compute predicted values and loss
            pred = self.model(input)
            batch_loss = self.loss_fn(pred, target)
            # Backpropagation
            batch_loss.backward()
            self.optimizer.step()
            # Store loss
            loss_value = batch_loss.item()
            loss += loss_value
        # Compute average loss
        loss /= len(loader)  # type: ignore
        return loss

    @torch.no_grad
    def eval_loop(self, loader: DataLoader) -> tuple:
        """
        Compute loss, accuracy and confusion matrix for the current model and the
        given loader.

        Args:
            loader (DataLoader): loader to use to evaluate the model.

        Returns:
            tuple: tuple of loss value, metrics dictionary.
        """
        eval_loss = 0
        # Set model to evaluation mode
        self.model.eval()
        # Iterate through batches
        for input, target in loader:
            input = input.to(self.device)
            target = target.to(self.device)
            # Compute batch loss
            pred = self.model(input)
            batch_loss = self.loss_fn(pred, target).item()
            eval_loss += batch_loss
            # Update metrics
            self.metrics.update(pred, target)
        # Compute average loss
        eval_loss /= len(loader)
        # Compute and reset metrics
        metrics = self.metrics.compute()
        self.metrics.reset()

        return eval_loss, metrics  # type: ignore

    def train(
        self, train_loader: DataLoader, valid_loader: DataLoader, epochs: int
    ) -> None:
        """
        Main training method. Run the training loop for a given number of epochs.

        Args:
            train_loader (DataLoader): train data loader.
            valid_loader (DataLoader): validation data loader.
            epochs (int): number of epochs to run.
        """
        # Reset train attributes
        self.train_history = []
        best_loss = float("inf")
        best_epoch = None

        # Record start time
        start_time = time.time()

        print("Epoch   Train loss   Valid loss    Valid acc    Best")
        print("----------------------------------------------------")

        # For each epoch
        for epoch in range(epochs):
            # Run the training loop on train data
            loss = self.train_loop(train_loader)
            # Evaluate metrics on validation data
            valid_loss, metrics = self.eval_loop(valid_loader)
            # Check if the current model is the best one
            best = valid_loss < best_loss
            if best:
                best_epoch = epoch + 1
                best_loss = valid_loss
                self.save_model()

            # Display results
            valid_acc = metrics["accuracy"].item()
            print(
                f"{epoch + 1:5}   {loss:10.3f}   {valid_loss:10.3f}   {valid_acc:10.3f}"
                f"    {'*' if best else '':>4}"
            )
            # Store losses in train history
            self.train_history.append({"epoch": epoch, "type": "train", "loss": loss})
            self.train_history.append(
                {"epoch": epoch, "type": "valid", "loss": valid_loss}
            )

        # Display training time
        training_time = time.time() - start_time
        print(f"\nTraining time: {timedelta(seconds=round(training_time))}")

        # Load and evaluate best model and display results
        self.load_best_model()
        valid_loss, valid_metrics = self.eval_loop(valid_loader)
        valid_acc = valid_metrics["accuracy"].item()
        print(
            f"Best model at epoch {best_epoch}, valid_loss: {valid_loss:5.3f},"
            f" valid_accuracy: {valid_acc:5.3f}"
        )

    def save_model(self) -> None:
        """
        Save the current model state to the best model path.
        """
        torch.save(self.model.state_dict(), self.best_model_path)

    def load_best_model(self) -> None:
        """
        Restore the current model from the best model path.
        """
        state_dict = torch.load(self.best_model_path, weights_only=False)
        self.model.load_state_dict(state_dict)
        self.model.eval()


## Dense network training

For this example we will use a simple dense network. It first flattens
our image data to a single 784 values tensor, and then passes it through
three layers to bring it down to 2 values, _ie_ the number of classes we
want to predict.


In [None]:
class DenseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
        )

    def forward(self, x):
        return self.network(x)


We define our training hyperparameters such as batch size and learning
rate, and we create a new `Training object`:


In [None]:
# Training hyperparameters
batch_size = 128
learning_rate = 0.001
epochs = 20
best_model_path = "out_malaria_cells_dense"

# Generate DataLoaders with the given batch size
train_loader, valid_loader, test_loader = generate_loaders(batch_size=batch_size)
# Create a SignTraining object
training = Training(
    model=DenseNetwork(),
    device=DEVICE,
    learning_rate=learning_rate,
    model_path=best_model_path,
)

We can then launch our network training process:


In [None]:
# Run training
training.train(train_loader, valid_loader, epochs=epochs)

### Results

When the training has ended we can try to visualize and evaluate the
model results.

First we can plot the loss values at each epoch:


In [None]:
(
    pn.ggplot(
        pl.DataFrame(training.train_history), pn.aes(x="epoch", y="loss", color="type")
    )
    + pn.geom_line()
    + pn.scale_y_continuous(limits=(0, None))  # type: ignore
    + pn.labs(color="")
)

We may also want to evaluate our model on the test set:


In [None]:
loss, metrics = training.eval_loop(test_loader)
print(f"Test loss: {loss:.3f}")
print(f"Test accuracy: {metrics['accuracy']:.3f}")

We can plot the confusion matrix for our results:


In [None]:
cm = metrics["confusion_matrix"].numpy()
ConfusionMatrixDisplay(cm).plot()

We can display precision and recall values for each label:


In [None]:
d = pl.DataFrame(
    {
        "label": [0, 1],
        "precision": (metrics["precision"] * 100).numpy(),
        "recall": (metrics["recall"] * 100).numpy(),
    }
)
d.head()

## Exercise : CNN network training

In this notebook we used a dense feed-forward neural network to classify
images, but there are other network architectures that are much more
suitable. In particular, Convolutional Neural Networks (CNN) are very
good at detecting patterns in images which can then be used for tasks
such as classification.

Here is an example of CNN network architecture that could be used for
our problem:

```py
self.network = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(576, 128),
    nn.ReLU(),
    nn.Linear(128, 2),
)
```

Create a new `CNNCellsNetwork` class that implements this architecture,
and train a new model from this class for 20 epochs. After that,
evaluate the trained model on the test dataset.
