In [15]:
import logging
from monai.data import DataLoader
import monai
from config.constants import (ZENODO_COVID_CASES_PATH, ZENODO_INFECTION_MASKS_PATH)
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path
from utils.helpers import load_images_from_path

In [16]:
SPATIAL_SIZE = (32, 32, 32)
NUM_RAND_PATCHES = 4
LEVEL = -650
WIDTH = 1500
LOWER_BOUND_WINDOW_HRCT = LEVEL - (WIDTH // 2) 
UPPER_BOUND_WINDOW_HRCT = LEVEL + (WIDTH // 2)
LOWER_BOUND_WINDOW_CBCT = 0
UPPER_BOUND_WINDOW_CBCT = 255
SEED = 33


print(LOWER_BOUND_WINDOW_HRCT)
print(UPPER_BOUND_WINDOW_HRCT)
def get_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=False, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=True, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                                  a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0, clip=True, allow_missing_keys=True),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_cbct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.RandCropByPosNegLabeld(keys=('img', 'mask'), label_key="mask",
                                                    spatial_size=SPATIAL_SIZE, pos=1, neg=1,
                                                    num_samples=NUM_RAND_PATCHES, allow_smaller=True),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

def get_val_hrct_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=LOWER_BOUND_WINDOW_HRCT, above=True, cval=LOWER_BOUND_WINDOW_HRCT),
            monai.transforms.ThresholdIntensityd(keys=("img",), threshold=UPPER_BOUND_WINDOW_HRCT, above=False, cval=UPPER_BOUND_WINDOW_HRCT),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )


def get_val_cbct_transforms():
    print("asdfas")
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="ALI"),
            monai.transforms.ScaleIntensityd(keys='img', minv=0.0, maxv=1.0),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )

def get_raw_transforms():
    return monai.transforms.Compose(
        [
            monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=False, ensure_channel_first=True),
        ]
    )

def load_radiopaedia_from_path(path: str) -> list[str]:
    return sorted([str(f) for f in Path(path).iterdir() if f.suffix == '.gz' and f.is_file()]) # and 'radiopaedia' in f.stem])

class CovidDataset(torch.utils.data.Dataset):
    def __init__(self, volumes, hrct_transform=None, cbct_transform=None):
        self.volumes = volumes
        self.hrct_transform = hrct_transform
        self.cbct_transform = cbct_transform

    def __len__(self):
        return len(self.volumes)

    def __getitem__(self, index):
        volume = self.volumes[index]

        if "coronacases" in volume["img"]:
            return self.hrct_transform(volume)
        else:
            return self.cbct_transform(volume)
        

-1400
100


In [17]:
# Load images and masks
logging.info(f"Loading images from {ZENODO_COVID_CASES_PATH}")
images = load_radiopaedia_from_path("../" + ZENODO_COVID_CASES_PATH)
labels = load_radiopaedia_from_path("../" + ZENODO_INFECTION_MASKS_PATH)


# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

# shuffler = np.random.RandomState(SEED)
# shuffler.shuffle(data_dicts)
# data_dicts = list(data_dicts)

In [18]:
# Load images and masks
mask_range = ("0255", "0304")
images_path = "../Datasets/COVID19_1110/studies/"
labels_path = "../Datasets/COVID19_1110/masks/"
# for all subdirectories in the images_path, load the images in range mask_range
# the images are names like study_0255.nii.gz and we have to load the images in the range mask_range
images = []
for subdir in Path(images_path).iterdir():
    if subdir.is_dir():
        for f in subdir.iterdir():
            # take name of the file and split it by "_" and clean the extension .nii.gz and then check if the number is in the range
            if f.is_file() and mask_range[0] <= f.stem.split("_")[1].split(".")[0] <= mask_range[1]:
                images.append(str(f))

labels = load_images_from_path(labels_path)

# # Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts1 = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
# logging.debug(data_dicts)


dataset = CovidDataset(volumes=data_dicts1, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
dataloader = DataLoader(dataset, batch_size=1, num_workers=2)


In [19]:
# Create a radiopaedia dataset
radiopaedia_dataset = CovidDataset(volumes=data_dicts, hrct_transform=get_raw_transforms(), cbct_transform=get_raw_transforms())
radiopaedia_loader = DataLoader(radiopaedia_dataset, batch_size=1, num_workers=2)

In [20]:
for i, data in enumerate(dataloader):
    print(data["img"].min(), data["img"].max())

metatensor(-2048.) metatensor(5351.)
metatensor(-2048.) metatensor(1754.)
metatensor(-2048.) metatensor(1831.)
metatensor(-2828.) metatensor(16608.)
metatensor(-2048.) metatensor(1815.)
metatensor(-2048.) metatensor(1916.)
metatensor(-2048.) metatensor(1704.)
metatensor(-4535.) metatensor(21549.)
metatensor(-2048.) metatensor(1734.)
metatensor(-2048.) metatensor(1623.)
metatensor(-2048.) metatensor(1732.)
metatensor(-2048.) metatensor(4740.)
metatensor(-2048.) metatensor(1722.)
metatensor(-2048.) metatensor(1703.)
metatensor(-2048.) metatensor(1766.)
metatensor(-2048.) metatensor(1788.)
metatensor(-2048.) metatensor(1692.)
metatensor(-2048.) metatensor(1901.)
metatensor(-2048.) metatensor(1863.)
metatensor(-2048.) metatensor(1619.)
metatensor(-2048.) metatensor(2545.)
metatensor(-2048.) metatensor(1798.)
metatensor(-2048.) metatensor(2056.)
metatensor(-2048.) metatensor(1782.)
metatensor(-2048.) metatensor(1685.)
metatensor(-2048.) metatensor(1770.)
metatensor(-2048.) metatensor(1592.)

In [41]:
def t():
    return monai.transforms.Compose(
        [
            # monai.transforms.LoadImaged(keys=('img', 'mask'), image_only=True, ensure_channel_first=True),
            monai.transforms.Orientationd(keys=('img', 'mask'), axcodes="PLI"),
            monai.transforms.ScaleIntensityRanged(keys=('img',), a_min=LOWER_BOUND_WINDOW_HRCT,
                                            a_max=UPPER_BOUND_WINDOW_HRCT, b_min=0.0, b_max=1.0, clip=True),
            monai.transforms.ToTensord(keys=("img", "mask")),
        ]
    )
sample_data = {
    'img': np.random.randint(-2000, 2000, (2, 128, 128, 128)),  # Example volume
    'mask': np.random.randint(0, 2, (2, 128, 128, 128))         # Example mask
}

transforms_pipeline = t()
transformed_data = transforms_pipeline(sample_data)

# Print some statistics of the normalized volume to verify
print(f"Min value (img): {transformed_data['img'].min()}")
print(f"Max value (img): {transformed_data['img'].max()}")
print(f"Mean value (img): {transformed_data['img'].mean()}")

Min value (img): 0.0
Max value (img): 1.0
Mean value (img): 0.6622612476348877


In [22]:
# import nibabel and load an image from data_paths and explore the min and max value of each image
import nibabel as nib
import os
import numpy as np

min = [100000,1000000,100000]
for i in range(len(data_dicts)):
    f = nib.load(data_dicts[i]["img"])
    if min[0] > abs(f.header["srow_x"][0]):
        min[0] = abs(f.header["srow_x"][0])
    if min[1] > abs(f.header["srow_y"][1]):
        min[1] = abs(f.header["srow_y"][1])
    if min[2] > abs(f.header["srow_z"][2]):
        min[2] = abs(f.header["srow_z"][2])
print(min) 

0.671
0.595
0.708
0.647
0.782
0.801
0.782
0.705
0.668
0.751
0.782
0.782
0.782
0.782
0.781
0.686
0.647
0.698
0.724
0.629
0.705
0.671
0.598
0.65
0.782
0.694
0.698
0.582
0.812
0.736
0.782
0.782
0.677
0.708
0.801
0.668
0.637
0.702
0.781
0.705
0.753
0.662
0.677
1.152
0.631
0.923
0.729
0.74
0.613
0.846
[0.582, 0.582, 8.0]


In [13]:
for i, data in enumerate(radiopaedia_loader):
    print(f'{radiopaedia_dataset.volumes[i]["img"].split("/")[-1]}_{data["img_meta_dict"]["pixdim"][0][1:4]}')
    print(f'{data["img"].shape}')

coronacases_001.nii.gz_tensor([0.8105, 0.8105, 1.0000])
torch.Size([1, 1, 512, 512, 301])
coronacases_002.nii.gz_tensor([0.6836, 0.6836, 1.5000])
torch.Size([1, 1, 512, 512, 200])
coronacases_003.nii.gz_tensor([0.7363, 0.7363, 1.5000])
torch.Size([1, 1, 512, 512, 200])
coronacases_004.nii.gz_tensor([0.6836, 0.6836, 1.0000])
torch.Size([1, 1, 512, 512, 270])
coronacases_005.nii.gz_tensor([0.6836, 0.6836, 1.0000])
torch.Size([1, 1, 512, 512, 290])
coronacases_006.nii.gz_tensor([0.7598, 0.7598, 1.5000])
torch.Size([1, 1, 512, 512, 213])
coronacases_007.nii.gz_tensor([0.7129, 0.7129, 1.0000])
torch.Size([1, 1, 512, 512, 249])
coronacases_008.nii.gz_tensor([0.7246, 0.7246, 1.0000])
torch.Size([1, 1, 512, 512, 301])
coronacases_009.nii.gz_tensor([0.6836, 0.6836, 1.0000])
torch.Size([1, 1, 512, 512, 256])
coronacases_010.nii.gz_tensor([0.6836, 0.6836, 1.0000])
torch.Size([1, 1, 512, 512, 301])
radiopaedia_10_85902_1.nii.gz_tensor([0.6836, 0.6836, 6.0000])
torch.Size([1, 1, 630, 630, 39])


KeyboardInterrupt: 

In [45]:
for i, data in enumerate(dataloader):
    print(data["img_meta_dict"]["pixdim"][0][1:4])

tensor([0.6710, 0.6710, 8.0000])
tensor([0.5950, 0.5950, 8.0000])
tensor([0.7080, 0.7080, 8.0000])
tensor([0.6470, 0.6470, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.8010, 0.8010, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.7050, 0.7050, 8.0000])
tensor([0.6680, 0.6680, 8.0000])
tensor([0.7510, 0.7510, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.7810, 0.7810, 8.0000])
tensor([0.6860, 0.6860, 8.0000])
tensor([0.6470, 0.6470, 8.0000])
tensor([0.6980, 0.6980, 8.0000])
tensor([0.7240, 0.7240, 8.0000])
tensor([0.6290, 0.6290, 8.0000])
tensor([0.7050, 0.7050, 8.0000])
tensor([0.6710, 0.6710, 8.0000])
tensor([0.5980, 0.5980, 8.0000])
tensor([0.6500, 0.6500, 8.0000])
tensor([0.7820, 0.7820, 8.0000])
tensor([0.6940, 0.6940, 8.0000])
tensor([0.6980, 0.6980, 8.0000])
tensor([0.5820, 0.5820, 8.0000])
tensor([0.8120, 0.8120, 8.0000])
tensor([0.7360, 0.7360, 8.0000])
tensor([0.

In [20]:
# Now we are going to take the mean values from the data["img_meta_dict"]["pixdim"][0][1:4] from each dimension 1, 2, 3 for 
# both datasets
values = []
for i, data in enumerate(dataloader):
    values.append(list(data["img_meta_dict"]["pixdim"][0][1:4]))

for i, data in enumerate(radiopaedia_loader):
    values.append(list(data["img_meta_dict"]["pixdim"][0][1:4]))


In [21]:
for i in values:
    print(i)
# print mean values for each dimension
print(np.mean(values, axis=0))

    

[tensor(0.6710), tensor(0.6710), tensor(8.)]
[tensor(0.5950), tensor(0.5950), tensor(8.)]
[tensor(0.7080), tensor(0.7080), tensor(8.)]
[tensor(0.6470), tensor(0.6470), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.8010), tensor(0.8010), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.7050), tensor(0.7050), tensor(8.)]
[tensor(0.6680), tensor(0.6680), tensor(8.)]
[tensor(0.7510), tensor(0.7510), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.7820), tensor(0.7820), tensor(8.)]
[tensor(0.7810), tensor(0.7810), tensor(8.)]
[tensor(0.6860), tensor(0.6860), tensor(8.)]
[tensor(0.6470), tensor(0.6470), tensor(8.)]
[tensor(0.6980), tensor(0.6980), tensor(8.)]
[tensor(0.7240), tensor(0.7240), tensor(8.)]
[tensor(0.6290), tensor(0.6290), tensor(8.)]
[tensor(0.7050), tensor(0.7050), tensor(8.)]
[tensor(0.6710), tensor(0.6710), tensor(8.)]
[tensor(0.

In [25]:
# print the resolution of the images
for i, data in enumerate(dataloader):
    print(data["img"].shape)
# Now we are going to take the mean values for 

torch.Size([1, 1, 512, 512, 39])
torch.Size([1, 1, 512, 512, 33])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 46])
torch.Size([1, 1, 512, 512, 39])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 42])
torch.Size([1, 1, 512, 512, 38])
torch.Size([1, 1, 512, 512, 43])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 37])
torch.Size([1, 1, 512, 512, 39])
torch.Size([1, 1, 512, 512, 45])
torch.Size([1, 1, 512, 512, 45])
torch.Size([1, 1, 512, 512, 38])
torch.Size([1, 1, 512, 512, 33])
torch.Size([1, 1, 512, 512, 48])
torch.Size([1, 1, 512, 512, 43])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 41])
torch.Size([1, 1, 512, 512, 37])
torch.Size([1, 1, 512, 512, 42])
torch.Size([1, 1, 512, 512, 43])
torch.Size([1, 1, 512, 512, 43])
torch.Size([1, 1, 512, 512, 39])
torch.Size([1, 1, 512, 512, 40])
torch.Size

In [42]:
len(data_dicts)

20