In [9]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import (ZENODO_COVID_CASES_PATH, ZENODO_INFECTION_MASKS_PATH, LOWER_BOUND_WINDOW_HRCT,
                              UPPER_BOUND_WINDOW_HRCT,
                              SPATIAL_SIZE, NUM_RAND_PATCHES, LOWER_BOUND_WINDOW_CBCT, UPPER_BOUND_WINDOW_CBCT,
                              COVID_PREPROCESSED_CASES_PATH, INFECTION_PREPROCESSED_MASKS_PATH, SEED
                              )
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path
from utils.helpers import load_images_from_path

In [10]:
def get_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=True, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=False, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            # monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method="symmetric"),

            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=0),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=1),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=2),
            monai.transforms.RandZoomd(keys=('img', 'mask'), prob=0.4, min_zoom=0.9, max_zoom=1.1),

            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            # monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method="symmetric"),

            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=0),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=1),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=2),
            monai.transforms.RandZoomd(keys=('img', 'mask'), prob=0.4, min_zoom=0.9, max_zoom=1.1),

            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=True, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=False, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


# Load images and masks
images = load_images_from_path("../" + COVID_PREPROCESSED_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_PREPROCESSED_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
train_images = [image for image in images if "radiopaedia" not in image and "coronacases" not in image]
train_labels = [label for label in labels if "radiopaedia" not in label and "coronacases" not in label]

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_train_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(train_images, train_labels)])

# Shuffle the data
shuffler = np.random.RandomState(SEED)
shuffler.shuffle(data_train_dicts)
data_train_dicts = list(data_train_dicts)

# Split the training data into training and validation
val_split = int(len(data_train_dicts) * 0.2)

train_paths = data_train_dicts[val_split:]
val_paths = data_train_dicts[:val_split]

# Take coronacases and radiopeadia images for testing
test_images = [image for image in images if "radiopaedia" in image or "coronacases" in image]
test_labels = [label for label in labels if "radiopaedia" in label or "coronacases" in label]
test_paths = np.array([{"img": img, "mask": mask} for img, mask in zip(test_images, test_labels)])

In [11]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

# SHOW DATASET

In [7]:
def show_train_post_transforms_with_mask(dataloader, dataset):
     root = "images/train_post_transforms_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            fig = plt.figure(figsize=(8,8))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()

def show_validation_post_transforms_with_mask(dataloader, dataset):
    root = "images/validation_post_transforms_with_mask"
    for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        img = data["img"][0]
        seg = data["mask"][0]
        blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
        blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
        fig = plt.figure(figsize=(8,8))
        for k in range(blended_img.shape[-1]):
            ax = fig.add_subplot(25, 25, k+1)
            ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
            ax.axis('off')
        plt.savefig(f"{root}/{volume}.png")
        plt.close()

def show_validation_post_transforms_with_mask_one_by_one(dataloader, dataset):
    root = "images/show_validation_post_transforms_with_mask_one_by_one"
    for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        img = data["img"][0]
        seg = data["mask"][0]
        blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
        blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
        for k in range(blended_img.shape[-1]):
            fig = plt.figure()
            plt.title(f"{volume}_slice_{k+1}")
            plt.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
            plt.axis('off')
            plt.savefig(f"{root}/{volume}/{k+1}.png")
            plt.close()
      
def show_validation_post_transforms_no_mask(dataloader, dataset):
     root = "images/validation_post_transforms_no_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        fig = plt.figure()
        matshow3d(
            data["img"][0][0],
            fig=fig,
            figsize=(20, 20),
            frame_dim=-1,
            cmap="gray"
        )
        plt.savefig(f"{root}/{volume}/{volume}.png")
        plt.close()

In [8]:
# Show validation volumes
show_validation_post_transforms_with_mask_one_by_one(val_loader, val_dataset)

Processing volume 0: study_0299
Processing volume 1: study_0258
Processing volume 2: study_0274
Processing volume 3: study_0282
Processing volume 4: study_0295
Processing volume 5: study_0256
Processing volume 6: study_0302
Processing volume 7: study_0283
Processing volume 8: study_0289
Processing volume 9: study_0262


In [7]:
show_train_post_transforms_with_mask(train_loader, train_dataset)

Processing volume 1: study_0272
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing subvolume 5
Processing subvolume 6
Processing subvolume 7
Processing subvolume 8
Processing subvolume 9
Processing subvolume 10
Processing subvolume 11
Processing subvolume 12
Processing subvolume 13
Processing subvolume 14
Processing subvolume 15
Processing subvolume 16
Processing volume 2: study_0271
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing subvolume 5
Processing subvolume 6
Processing subvolume 7
Processing subvolume 8
Processing subvolume 9
Processing subvolume 10
Processing subvolume 11
Processing subvolume 12
Processing subvolume 13
Processing subvolume 14
Processing subvolume 15
Processing subvolume 16
Processing volume 3: study_0291
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing subvolume 5
Processing subvolume 6
Processing subvo

In [8]:
# Now we are going to visualize the pictures raw, for that we will create a data loader with raw transformations
def get_raw_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

raw_val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_val_loader = DataLoader(raw_val_dataset, batch_size=1, num_workers=2)

raw_train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_train_loader = DataLoader(raw_train_dataset, batch_size=1, num_workers=2)

In [12]:
def show_raw_with_mask(dataloader, dataset):
     root = "images/raw_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            fig = plt.figure(figsize=(8,8))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()

In [None]:
show_raw_with_mask(raw_train_loader, raw_train_dataset)