In [1]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import (COVID_CASES_PATH, INFECTION_MASKS_PATH, SEED)
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from preprocessing.transforms import get_cbct_transforms, get_hrct_transforms, get_val_cbct_transforms, get_val_hrct_transforms

2024-07-04 19:03:48.881428: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def load_images_from_path(path: str) -> list[str]:
    logging.debug(f"Loading images from {path}")
    return sorted([str(f) for f in Path(path).iterdir() if f.is_file() and f.suffix == '.gz'])

images = load_images_from_path("../" + COVID_CASES_PATH)
print(images)

['../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_008.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_009.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_010.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_1.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_10_85902_3.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_14_85914_0.nii.gz', '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/radiopaedia_27_86410_0.nii.gz', '../D

In [29]:
# Load images and masks
logging.info(f"Loading images from {COVID_CASES_PATH}")
images = load_images_from_path("../" + COVID_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

print(data_dicts)
shuffler = np.random.RandomState(33)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)
print(type(data_dicts))
print(data_dicts)
print(SEED)

# Split the data into training (70%), validation (20%), and test sets (10%)
test_split = int(len(data_dicts) * 0.1)
val_split = int(len(data_dicts) * 0.2)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

[{'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_001.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_002.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_004.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_005.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_006.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/cor

In [30]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

In [37]:
# create a funcion that takes a dataset an proves if it has only zeros in the mask. In that case throw an error. Also check for radiopaedia and coronacases in the path of the dataset. If there is one that it isn't throw an error
def check_dataset(dataset):
    radiopaedia = False
    coronacases = False
    for i, data in enumerate(dataset):
        if np.unique(data["mask"]).size == 1 and np.unique(data["mask"])[0] == 0:
            raise ValueError("The mask has only zeros")

        if "radiopaedia" in dataset.volumes[i]["img"]:
            radiopaedia = True
        if "coronacases" in dataset.volumes[i]["img"]:
            coronacases = True

    if not radiopaedia or not coronacases:
        raise ValueError("There are no radiopaedia or coronacases in the dataset")

    return

In [38]:
# test validation test
check_dataset(val_dataset)

In [18]:
print(val_dataset[0]["img"].shape)
print(val_dataset[1]["img"].shape)
print(val_dataset[2]["img"].shape)
print(val_dataset[3]["img"].shape)

torch.Size([1, 401, 630, 110])
torch.Size([1, 512, 512, 200])
torch.Size([1, 630, 630, 418])
torch.Size([1, 512, 512, 213])


In [43]:
print(val_dataset)
for i, data in enumerate(val_dataset):
    print(data["mask"].shape)
    print(np.unique(data["mask"]).size == 1)
    print(np.unique(data["mask"])[0] == 0)


<preprocessing.covid_dataset.CovidDataset object at 0x72e12b361410>
torch.Size([1, 630, 630, 42])
False
True
torch.Size([1, 630, 630, 39])
False
True
torch.Size([1, 512, 512, 213])
False
True
torch.Size([1, 512, 512, 301])
False
True
