In [50]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import (COVID_CASES_PATH, INFECTION_MASKS_PATH)
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path
from utils.helpers import load_images_from_path

In [57]:
SPATIAL_SIZE = (32, 32, 32)
NUM_RAND_PATCHES = 4
LEVEL = -650
WIDTH = 1500
LOWER_BOUND_WINDOW_HRCT = LEVEL - (WIDTH // 2) 
UPPER_BOUND_WINDOW_HRCT = LEVEL + (WIDTH // 2)
LOWER_BOUND_WINDOW_CBCT = 0
UPPER_BOUND_WINDOW_CBCT = 255
SEED = 33
SPACING = (0.7, 0.7, 1.)
COVID_PREPROCESSED_CASES_PATH = "Datasets/preprocessed/images/"
INFECTION_PREPROCESSED_MASKS_PATH = "Datasets/preprocessed/labels/"

def get_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            # monai.transforms.Spacingd(keys=('img', 'mask'), pixdim=SPACING, mode=("bilinear", "nearest")),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=True, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=False, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method='symmetric'),

            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=0),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=1),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=2),
            monai.transforms.RandZoomd(keys=('img', 'mask'), prob=0.4, min_zoom=0.9, max_zoom=1.1),

            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            # monai.transforms.Spacingd(keys=('img', 'mask'), pixdim=SPACING, mode=("bilinear", "nearest")),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.SpatialPadd(keys=('img', 'mask'), spatial_size=SPATIAL_SIZE, method='symmetric'),

            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=0),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=1),
            monai.transforms.RandFlipd(keys=('img', 'mask'), prob=0.2, spatial_axis=2),
            monai.transforms.RandZoomd(keys=('img', 'mask'), prob=0.4, min_zoom=0.9, max_zoom=1.1),

            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            # monai.transforms.Spacingd(keys=('img', 'mask'), pixdim=SPACING, mode=("bilinear", "nearest")),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=True, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=False, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            # monai.transforms.Spacingd(keys=('img', 'mask'), pixdim=SPACING, mode=("bilinear", "nearest")),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

def load_radiopaedia_from_path(path: str) -> list[str]:
    logging.debug(f"Loading images from {path}")
    return sorted([str(f) for f in Path(path).iterdir() if f.suffix == '.gz' and f.is_file() and 'radiopaedia' in f.stem])

# Load images and masks
logging.info(f"Loading images from {COVID_PREPROCESSED_CASES_PATH}")
images = load_radiopaedia_from_path("../" + COVID_PREPROCESSED_CASES_PATH)
labels = load_radiopaedia_from_path("../" + INFECTION_PREPROCESSED_MASKS_PATH)


# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

shuffler = np.random.RandomState(SEED)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)

# Split the data into training (70%), validation (10%), and test sets (20%)
test_split = int(len(data_dicts) * 0.1)
val_split = int(len(data_dicts) * 0.2)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

In [58]:
# Create a radiopaedia dataset
radiopaedia_dataset = CovidDataset(volumes=data_dicts, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
radiopaedia_loader = DataLoader(radiopaedia_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

# SHOW DATASET

In [59]:
def show_train_post_transforms_with_mask(dataloader, dataset):
     root = "images/train_post_transforms_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            print(f"blended_img.shape: {blended_img.shape}, last dim: {blended_img.shape[-1]}")
            fig = plt.figure(figsize=(16,16))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(25, 25, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()
             
def show_validation_post_transforms_no_mask(dataloader, dataset):
     root = "images/validation_post_transforms_no_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        fig = plt.figure()
        matshow3d(
            data["img"][0][0],
            fig=fig,
            figsize=(20, 20),
            frame_dim=-1,
            cmap="gray"
        )
        plt.savefig(f"{root}/{volume}/{volume}.png")
        plt.close()
        
def show_validation_post_transforms_with_mask(dataloader, dataset):
     root = "images/validation_post_transforms_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")        
        img = data["img"]
        seg = data["mask"]
        blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
        blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
        fig = plt.figure(figsize=(22,22))
        print(f"blended_img.shape: {blended_img.shape}, last dim: {blended_img.shape[-1]}")
        for k in range(blended_img.shape[-1]):
            ax = fig.add_subplot(22, 22, k+1)
            ax.imshow(torch.moveaxis(blended_img[:, 0, :, :, k], 0, -1))
            ax.axis('off')
        plt.savefig(f"{root}/{volume}/{volume}.png")
        plt.close()
    

In [60]:
# Show validation volumes
show_validation_post_transforms_with_mask(radiopaedia_loader, radiopaedia_dataset)

Processing volume 0: radiopaedia_4_85506_1
blended_img.shape: torch.Size([3, 1, 479, 479, 254]), last dim: 254
Processing volume 1: radiopaedia_27_86410_0
blended_img.shape: torch.Size([3, 1, 479, 479, 434]), last dim: 434
Processing volume 2: radiopaedia_29_86491_1
blended_img.shape: torch.Size([3, 1, 479, 479, 274]), last dim: 274
Processing volume 3: radiopaedia_7_85703_0
blended_img.shape: torch.Size([3, 1, 479, 479, 294]), last dim: 294
Processing volume 4: radiopaedia_10_85902_3
blended_img.shape: torch.Size([3, 1, 479, 479, 464]), last dim: 464
Processing volume 5: radiopaedia_36_86526_0
blended_img.shape: torch.Size([3, 1, 479, 479, 294]), last dim: 294
Processing volume 6: radiopaedia_14_85914_0
blended_img.shape: torch.Size([3, 1, 305, 479, 243]), last dim: 243
Processing volume 7: radiopaedia_10_85902_1
blended_img.shape: torch.Size([3, 1, 479, 479, 254]), last dim: 254
Processing volume 8: radiopaedia_40_86625_0
blended_img.shape: torch.Size([3, 1, 479, 479, 410]), last dim

In [39]:
for i, data in enumerate(radiopaedia_loader):
    print(data["img"].shape)

torch.Size([1, 1, 500, 500, 256])
torch.Size([1, 1, 500, 500, 270])
torch.Size([1, 1, 556, 556, 319])
torch.Size([1, 1, 500, 500, 301])
torch.Size([1, 1, 500, 500, 300])
torch.Size([1, 1, 521, 521, 249])
torch.Size([1, 1, 539, 539, 300])
torch.Size([1, 1, 593, 593, 301])
torch.Size([1, 1, 530, 530, 301])
torch.Size([1, 1, 500, 500, 290])


In [8]:
# Now we are going to visualize the pictures raw, for that we will create a data loader with raw transformations
def get_raw_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

raw_val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_val_loader = DataLoader(raw_val_dataset, batch_size=1, num_workers=2)

raw_train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_train_loader = DataLoader(raw_train_dataset, batch_size=1, num_workers=2)

In [12]:
def show_raw_with_mask(dataloader, dataset):
     root = "images/raw_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            fig = plt.figure(figsize=(8,8))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()

In [None]:
show_raw_with_mask(raw_train_loader, raw_train_dataset)

[PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_27_86410_0.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_29_86491_1.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_14_85914_0.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_3.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_4_85506_1.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_36_86526_0.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_40_86625_0.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_1.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_7_85703_0.nii.gz'),
 PosixPath('../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_29_86490_1.nii.gz')]