In [1]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import (COVID_CASES_PATH, INFECTION_MASKS_PATH, SEED)
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path
from utils.helpers import load_images_from_path

2024-06-25 12:16:18.826718: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-25 12:16:19.158634: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
SPATIAL_SIZE = (32, 32, 32)
NUM_RAND_PATCHES = 4
LEVEL = -650
WIDTH = 1500
LOWER_BOUND_WINDOW_HRCT = LEVEL - (WIDTH // 2) 
UPPER_BOUND_WINDOW_HRCT = LEVEL + (WIDTH // 2)
LOWER_BOUND_WINDOW_CBCT = 0
UPPER_BOUND_WINDOW_CBCT = 255


def get_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                                  a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0, clip=True),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

def get_val_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                                  a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0,
                                                  clip=True),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


# Load images and masks
logging.info(f"Loading images from {COVID_CASES_PATH}")
images = load_images_from_path("../" + COVID_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

shuffler = np.random.RandomState(1)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)

# Split the data into training (70%), validation (10%), and test sets (20%)
test_split = int(len(data_dicts) * 0.2)
val_split = int(len(data_dicts) * 0.1)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

In [3]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

# SHOW DATASET

In [9]:
def show_train_post_transforms_with_mask(dataloader, dataset):
     root = "images/train_post_transforms_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            fig = plt.figure(figsize=(8,8))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()
             
def show_validation_post_transforms_no_mask(dataloader, dataset):
     root = "images/validation_post_transforms_no_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        fig = plt.figure()
        matshow3d(
            data["img"][0][0],
            fig=fig,
            figsize=(20, 20),
            frame_dim=-1,
            cmap="gray"
        )
        plt.savefig(f"{root}/{volume}/{volume}.png")
        plt.close()

In [5]:
# Show validation volumes
show_validation_post_transforms_no_mask(val_loader, val_dataset)

os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.


Processing volume 0: coronacases_003
Processing volume 1: radiopaedia_29_86490_1


os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.


In [11]:
show_train_post_transforms_with_mask(train_loader, train_dataset)

Processing volume 1: coronacases_005
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 2: radiopaedia_40_86625_0
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 3: coronacases_008
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 4: coronacases_002
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 5: radiopaedia_27_86410_0
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 6: coronacases_001
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 7: radiopaedia_7_85703_0
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4


invalid value encountered in cast


Processing volume 8: radiopaedia_4_85506_1
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 9: coronacases_010
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 10: radiopaedia_29_86491_1
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 11: coronacases_009
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 12: radiopaedia_14_85914_0
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 13: radiopaedia_10_85902_3
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4
Processing volume 14: coronacases_006
Processing subvolume 1
Processing subvolume 2
Processing subvolume 3
Processing subvolume 4


In [8]:
# Now we are going to visualize the pictures raw, for that we will create a data loader with raw transformations
def get_raw_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

raw_val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_val_loader = DataLoader(raw_val_dataset, batch_size=1, num_workers=2)

raw_train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
raw_train_loader = DataLoader(raw_train_dataset, batch_size=1, num_workers=2)

In [12]:
def show_raw_with_mask(dataloader, dataset):
     root = "images/raw_with_mask"
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"{root}/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i+1}: {volume}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            fig = plt.figure(figsize=(8,8))
            plt.title(f"subvolume{j+1}")
            for k in range(blended_img.shape[-1]):
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"{root}/{volume}/{j+1}.png")
            plt.close()

In [None]:
show_raw_with_mask(raw_train_loader, raw_train_dataset)