# Demo : using UNET to segment chest X-ray images (JSRT segmentation02 dataset)

In this notebook, **we're going to build a simple UNET model** to segment chest X-ray images using 247 images from the "Standard Digital Image Database" of the Japan Society of Japan Radiological Technology ([JSRT](http://imgcom.jsrt.or.jp/minijsrtdb/)).

It is a 4-labels segmentation problem where pixels belong to different organs. To each image is associated a mask (ground-truth) containing 255 for pixels in the lung area, 85 for pixels in the heart area, 170 for pixels in the lung field, and 0 for pixels in vitro.

Note : There is no medical basis for the definition and determination of the lung area of the label data
because it has not been medically supervised.

We will:
* `instanciate the datasets` - during this step image and mask data will be downloaded if missing
* `train the model` - fit the model on images 
* `test and evaluate` - test model prediction and compute segmentation scores

A custom data-augmenter (AugmentedSemanticSegmentationDataset) is used to generate additional samples from the existing ones
during the training process. This augmenter is applying RANDOM:
- colorimetric transformation ONLY to image 
- spatial transformation to image AND associated mask

This simple use case can be modified to check quickly the implementation of a custom model architecture.  


In [1]:
import logging
import os
from pathlib import Path 

import torch

from src.model.unet import UNetModel
from src.dataset.jsrt_dataset import JsrtSegmentation02Dataset
from src.dataset.data_augmentation.color_augmentation import ColorimetricAugmentationParams
from src.dataset.data_augmentation.spatial_augmentation import SpatialAugmentationParams
from src.dataset.data_augmentation.augmented_semantic_segmentation import (
    AugmentedSemanticSegmentationDataset,
)
from src.benchmark.train import train, HyperParameters, EarlyStoppingParams, TrainParameters


# algorithm hyperparameters
# -------------------------
# we use a large value here in order to train the model
# until early stopping is triggered
NB_EPOCHS = 20 
BATCH_SIZE = 8
LEARNING_RATE = 0.0005

# model hyperparameters
# ---------------------
# for the simple task at hand the base number (20) used to compute the number of feature maps
# in each model layer can be low. In the original paper this value is 64
BASE_FM_NUMBER = 20
IMG_NB_CHANNELS = 1 # grayscale images
NB_LABELS = 4       # pixels can belong to one of 4 labels
INIT_WEIGHTS_PATH = None

MODEL_PATH = "./square_model_dict.pth"
DATASET_FOLDER_PATH = Path("./jsrt") 

# deactivate PIL debug ifo polutting output...
pil_logger = logging.getLogger("PIL")
pil_logger.setLevel(logging.INFO)

logging.basicConfig(filename=None, filemode="w", level=logging.DEBUG)

# select device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device : {device}")


Selected device : cuda


In [2]:
""" Train the model on a synthetic dataset composed of squares with different colors (1-label problem) """

# build model architecture
model = UNetModel(IMG_NB_CHANNELS, NB_LABELS, BASE_FM_NUMBER)

# create train and validation datasets 
# Note that the first run may take a longer time (around 2 minutes) because 
# dataset needs to be downloaded
current_dir = Path(os.getcwd())
train_dataset = JsrtSegmentation02Dataset(DATASET_FOLDER_PATH, download=True, image_set="train")
validation_dataset = None

# Add data augmentation
spatial_params = SpatialAugmentationParams()
colorimetric_params = ColorimetricAugmentationParams()
train_dataset = AugmentedSemanticSegmentationDataset(
    train_dataset, spatial_params, colorimetric_params
)

if validation_dataset:
    validation_dataset = AugmentedSemanticSegmentationDataset(
        validation_dataset, spatial_params, colorimetric_params
    )

# set training hyperparameters
hyperparameters = HyperParameters(
    NB_EPOCHS, BATCH_SIZE, LEARNING_RATE, None# EarlyStoppingParams(3, 0.001)
)

# run the training process
train(
    model,
    train_dataset,
    validation_dataset,
    TrainParameters(hyperparameters, INIT_WEIGHTS_PATH, device),
    MODEL_PATH,
)


INFO:root:Early stopping DEACTIVATED


INFO:root:Init model weights using the default initialization strategy
INFO:root:initialize loss and optimizer for a model predicting 4 labels
INFO:root:selected loss function: CrossEntropyLoss()
INFO:root:Epoch [1/20],Loss: 1.3781, Val Loss: 0.0000
INFO:root:Epoch 0, Learning Rate: 0.000500
INFO:root:Epoch [2/20],Loss: 0.6482, Val Loss: 0.0000
INFO:root:Epoch 1, Learning Rate: 0.000500
INFO:root:Epoch [3/20],Loss: 0.3954, Val Loss: 0.0000
INFO:root:Epoch 2, Learning Rate: 0.000500
INFO:root:Epoch [4/20],Loss: 0.2851, Val Loss: 0.0000
INFO:root:Epoch 3, Learning Rate: 0.000500
INFO:root:Epoch [5/20],Loss: 0.2083, Val Loss: 0.0000
INFO:root:Epoch 4, Learning Rate: 0.000500
INFO:root:Epoch [6/20],Loss: 0.1801, Val Loss: 0.0000
INFO:root:Epoch 5, Learning Rate: 0.000500
INFO:root:Epoch [7/20],Loss: 0.1531, Val Loss: 0.0000
INFO:root:Epoch 6, Learning Rate: 0.000500
INFO:root:Epoch [8/20],Loss: 0.1440, Val Loss: 0.0000
INFO:root:Epoch 7, Learning Rate: 0.000500
INFO:root:Epoch [9/20],Loss:

In [None]:
""" test the model on the jsrt dataset composed of Chest X-ray images (4-labels problem)  """
from src.benchmark.test import test, EvalMetrics, TestReport

# build model architecture and load weights generated during the training process
model = UNetModel(1, NB_LABELS, BASE_FM_NUMBER).to(device)
model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))

# get the "test" dataset
current_dir = Path(os.getcwd())
test_dataset = train_dataset = JsrtSegmentation02Dataset(
    DATASET_FOLDER_PATH, download=True, image_set="test"
)

# Configure test report : micro-averaged AND per-label metrics
test_report = TestReport(
    current_dir.joinpath("report.json"),
    [
        EvalMetrics.ACCURACY,
        EvalMetrics.RECALL_PER_LABEL,
        EvalMetrics.PRECISION_PER_LABEL,
        EvalMetrics.F1_SCORE_PER_LABEL,
        EvalMetrics.IOU_PER_LABEL,
    ],
)

# run evaluation (with segmentation image display deactivated : verbise = 0)
test(
    device,
    model,
    test_dataset,
    test_report,
    verbose=0,
)


DEBUG:matplotlib:matplotlib data path: d:\user\JB\code\TorchSemanticSegmentation\.venv\Lib\site-packages\matplotlib\mpl-data
DEBUG:matplotlib:CONFIGDIR=C:\Users\authe\.matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is win32
DEBUG:matplotlib:CACHEDIR=C:\Users\authe\.matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from C:\Users\authe\.matplotlib\fontlist-v390.json
