# LIDC-IDRI Lung Segmentation


In [None]:
import cv2
import keras
import datasets
import numpy as np
import tensorflow as tf
import albumentations as A
import matplotlib.pyplot as plt

from PIL import Image
from dataclasses import dataclass

from models.unet import UNet

Load the dataset


In [None]:
lung_dataset = datasets.load_dataset("jmanuelc87/lidc-idri-segmentation")
lung_dataset = lung_dataset["train"].train_test_split(train_size=0.8, seed=42)  # type: ignore
lung_dataset

## Configuration


In [None]:
@dataclass
class TrainingConfig:
    EPOCHS: int = 50
    NUM_CLASSES: int = 1
    BATCH_SIZE: int = 32
    IMG_WIDTH: int = 192
    IMG_HEIGHT: int = 192

## Visualization

Exploration of some samples of the dataset and its masks


In [None]:
def num_to_rgb(mask):
    num_arr = np.array(mask)
    output = np.zeros(num_arr.shape[:2] + (3,))
    output[num_arr == 255] = (255, 0, 0)
    return output.astype(np.uint8)

In [None]:
def image_overlay(image, segmented_image):
    alpha = 1.0  # Transparency for the original image.
    beta = 0.7  # Transparency for the segmentation map.
    gamma = 0.0  # Scalar added to each sum.

    segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    image = cv2.addWeighted(image, alpha, segmented_image, beta, gamma, image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return image

In [None]:
def plot(dataset: datasets.Dataset, qty=10):
    fig, ax = plt.subplots(qty, 3, figsize=(12, 28))

    for i, item in enumerate(dataset):
        if i >= qty:
            break

        ax[i, 0].axis("off")
        ax[i, 0].imshow(item["patch"], cmap="gray")  # type: ignore

        ax[i, 1].axis("off")
        ax[i, 1].imshow(item["patch_mask"], cmap="gray")  # type: ignore

        patch = np.array(item["patch"])  # type: ignore
        patch = np.transpose(np.stack([patch, patch, patch]), axes=(1, 2, 0))

        mask = num_to_rgb(item["patch_mask"])  # type: ignore
        image = image_overlay(patch, mask)

        ax[i, 2].axis("off")
        ax[i, 2].imshow(image, cmap="gray")


plot(lung_dataset["train"], qty=7)  # type: ignore

## Data Preparation

Preparation of the dataset using albumentations library for augmenting the dataset samples using the transformations:

- RandomCrop
- CenterCrop
- SquareSymmetry
- GaussNoise
- Normalize
- ToTensor

And creation of the train, valid and test splits for datasets and dataloaders


In [None]:
train_transforms = A.Compose(
    [
        A.Resize(
            height=TrainingConfig.IMG_HEIGHT,
            width=TrainingConfig.IMG_WIDTH,
            interpolation=cv2.INTER_NEAREST,
            p=1.0,
        ),
        A.Normalize(
            mean=(0.0, 0.0, 0.0),
            std=(1.0, 1.0, 1.0),
            max_pixel_value=255.0,
            p=1.0,
        ),
    ]
)


valid_transforms = A.Compose(
    [
        A.Resize(
            height=TrainingConfig.IMG_HEIGHT,
            width=TrainingConfig.IMG_WIDTH,
            p=1.0,
        ),
        A.Normalize(
            mean=(0.0, 0.0, 0.0),
            std=(1.0, 1.0, 1.0),
            max_pixel_value=255.0,
            p=1.0,
        ),
    ]
)

In [None]:
def map_image_transforms(transformations):

    def wrapper(row):
        augmented = []
        keys = row.keys()
        for item in zip(*row.values()):
            items = {k: np.array(v) for k, v in zip(keys, item)}
            values = transformations(**items)
            augmented.append(values)

        for key in keys:
            row[key] = [np.expand_dims(item[key], axis=-1) for item in augmented]

        return row

    return wrapper

In [None]:
# Remove the columns not needed and rename the ones needed to be interpreted by albumentations library
new_lung_dataset = (
    lung_dataset.remove_columns(["image", "image_mask", "malignancy", "cancer"])
    .rename_column("patch", "image")
    .rename_column("patch_mask", "mask")
)

In [None]:
def create_datasets(lung_dataset):
    lung_train_dataset = (
        lung_dataset["train"]
        .with_format("tensorflow")
        .with_transform(map_image_transforms(train_transforms))
        .to_tf_dataset(
            columns="image",
            label_cols="mask",
            batch_size=TrainingConfig.BATCH_SIZE,
        )
    )

    tmp_dataset = lung_dataset["test"].train_test_split(train_size=0.5)

    lung_valid_dataset = (
        tmp_dataset["train"]
        .with_format("tensorflow")
        .with_transform(map_image_transforms(valid_transforms))
        .to_tf_dataset(
            columns="image",
            label_cols="mask",
            batch_size=TrainingConfig.BATCH_SIZE,
        )
    )

    lung_test_dataset = (
        tmp_dataset["test"]
        .with_format("tensorflow")
        .with_transform(map_image_transforms(valid_transforms))
        .to_tf_dataset(
            columns="image",
            label_cols="mask",
            batch_size=TrainingConfig.BATCH_SIZE,
        )
    )

    return lung_train_dataset, lung_valid_dataset, lung_test_dataset

In [None]:
(
    train_dataset,
    valid_dataset,
    test_dataset,
) = create_datasets(new_lung_dataset)

## Modeling

Creation of a UNet model in pytorch


In [None]:
model = UNet(num_classes=TrainingConfig.NUM_CLASSES)

cosine_decay = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-3, decay_steps=100
)


model.compile(
    loss=keras.losses.mean_squared_error,
    optimizer=keras.optimizers.SGD(learning_rate=cosine_decay),  # type: ignore
    metrics=["accuracy"],
)

In [None]:
history = model.fit(
    train_dataset,
    validation_data=valid_dataset,
    epochs=TrainingConfig.EPOCHS,
)

In [None]:
def inference(model, dataset):

    num_batches_to_process = 1
    tf.experimental.numpy.experimental_enable_numpy_behavior()

    for idx, data in enumerate(dataset):

        if idx == num_batches_to_process:
            break

        batch_img, batch_mask = data[0], data[1]
        pred_all = (model.predict(batch_img)).astype("float32")
        pred_all = pred_all.argmax(-1)
        batch_img = (batch_img * 255).astype("uint8")

        for i in range(0, len(batch_img)):

            fig = plt.figure(figsize=(20, 8))

            # Display the original image.
            ax1 = fig.add_subplot(1, 4, 1)
            ax1.imshow(batch_img[i], cmap="gray")
            ax1.title.set_text("Actual frame")
            plt.axis("off")

            # Display the ground truth mask.
            true_mask = batch_mask[i]
            ax2 = fig.add_subplot(1, 4, 2)
            ax2.set_title("Ground truth labels")
            ax2.imshow(true_mask, cmap="gray")
            plt.axis("off")

            # Display the predicted segmentation mask.
            pred_mask = pred_all[i][:, :, np.newaxis]

            print(pred_mask.shape)

            ax3 = fig.add_subplot(1, 4, 3)
            ax3.set_title("Predicted labels")
            ax3.imshow(pred_mask, cmap="gray")
            plt.axis("off")

            plt.show()

In [None]:
inference(model, test_dataset)