In [1]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import (COVID_CASES_PATH, INFECTION_MASKS_PATH, SEED)
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path

2024-06-18 13:39:58.165493: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-18 13:39:58.209186: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def load_images_from_path(path: str) -> list[str]:
    logging.debug(f"Loading images from {path}")
    return sorted([str(f) for f in Path(path).iterdir() if f.is_file() and f.suffix == '.gz'])

images = load_images_from_path("../" + COVID_CASES_PATH)
print(images)

['../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_008.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_009.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_010.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_1.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_3.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_14_85914_0.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_27_86410_0.nii.gz', '../D

In [3]:
SPATIAL_SIZE = (32, 32, 32)
NUM_RAND_PATCHES = 4
LEVEL = -650
WIDTH = 1500
LOWER_BOUND_WINDOW_HRCT = LEVEL - (WIDTH // 2) 
UPPER_BOUND_WINDOW_HRCT = LEVEL + (WIDTH // 2)
LOWER_BOUND_WINDOW_CBCT = 0
UPPER_BOUND_WINDOW_CBCT = 255


def get_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                                  a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0, clip=True),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

def get_val_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                                  a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0,
                                                  clip=True),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.ScaleIntensityd(keys='img', minv=LOWER_BOUND_WINDOW_CBCT, maxv=UPPER_BOUND_WINDOW_CBCT),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


# Load images and masks
logging.info(f"Loading images from {COVID_CASES_PATH}")
images = load_images_from_path("../" + COVID_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

print(data_dicts)
shuffler = np.random.RandomState(1)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)
print()
print(type(data_dicts))
print()
print(data_dicts)
print(SEED)

# Split the data into training (70%), validation (10%), and test sets (20%)
test_split = int(len(data_dicts) * 0.2)
val_split = int(len(data_dicts) * 0.1)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

[{'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_001.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_002.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_004.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_005.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_006.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/cor

In [4]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

In [None]:
# create a funcion that takes a dataset an proves if it has only zeros in the mask. In that case throw an error. Also check for radiopaedia and coronacases in the path of the dataset. If there is one that it isn't throw an error
def check_dataset(dataset):
    radiopaedia = False
    coronacases = False
    for i, data in enumerate(dataset):
        if np.unique(data["mask"]).size == 1 and np.unique(data["mask"])[0] == 0:
            raise ValueError("The mask has only zeros")

        if "radiopaedia" in dataset.volumes[i]["img"]:
            radiopaedia = True
        if "coronacases" in dataset.volumes[i]["img"]:
            coronacases = True

    if not radiopaedia or not coronacases:
        raise ValueError("There are no radiopaedia or coronacases in the dataset")

    return

In [None]:
# test validation test
check_dataset(val_dataset)

In [None]:
print(val_dataset[0]["img"].shape)
print(val_dataset[1]["img"].shape)

In [None]:
val_dataset.volumes

# SHOW DATASET

In [6]:
def draw_volumes(dataloader, dataset):
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"images/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        print(f"img shape: {data['img'].shape}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            print(f"img shape: {img.shape}")
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            print(f"blended_img shape: {blended_img.shape}")
            fig = plt.figure(figsize=(30, 30))
            plt.title(f"subvolume{j}")
            print(blended_img.shape[-1])
            for k in range(blended_img.shape[-1]):
                print(f"Processing slice {k+1}")
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"images/{volume}/{j+1}.png")
            plt.close()
             
def draw_volumes2(dataloader, dataset):
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"images/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        print(f"img shape: {data['img'].shape}")
        fig = plt.figure()
        print(data["img"].shape)
        matshow3d(
            data["img"][0][0],
            fig=fig,
            figsize=(20, 20),
            frame_dim=-1,
            cmap="gray"
        )
        plt.savefig(f"images/{volume}/{i}.png")
        plt.close()

In [7]:
# Show validation volumes
draw_volumes2(val_loader, val_dataset)

suuuusuuuu2

volume: {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_29_86490_1.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/radiopaedia_29_86490_1.nii.gz'}volume: {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}

Processing volume 0: coronacases_003
img shape: torch.Size([1, 1, 512, 512, 200])
torch.Size([1, 1, 512, 512, 200])
Processing volume 1: radiopaedia_29_86490_1
img shape: torch.Size([1, 1, 630, 630, 42])
torch.Size([1, 1, 630, 630, 42])


In [None]:
import matplotlib
matplotlib.__version__

In [5]:
for i, data in enumerate(val_loader):
    print(data["img"].shape)
    print(data["mask"].shape)
    break

os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.


suuuu
volume: {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}
suuuu2
volume: {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_29_86490_1.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/radiopaedia_29_86490_1.nii.gz'}


os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.


torch.Size([1, 1, 512, 512, 200])
torch.Size([1, 1, 512, 512, 200])
