In [1]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
import monai
from config.constants import SEED, ZENODO_COVID_CASES_PATH, ZENODO_LUNG_MASKS_PATH, ZENODO_INFECTION_MASKS_PATH, COVID_PREPROCESSED_CASES_PATH, INFECTION_PREPROCESSED_MASKS_PATH
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from preprocessing.transforms import get_cbct_transforms, get_hrct_transforms, get_val_cbct_transforms, get_val_hrct_transforms

2024-07-16 12:20:24.426217: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def load_images_from_path(path: str) -> list[str]:
    logging.debug(f"Loading images from {path}")
    return sorted([str(f) for f in Path(path).iterdir() if f.is_file() and f.suffix == '.gz'])


In [12]:
# Load images and masks
logging.info(f"Loading images from {ZENODO_COVID_CASES_PATH}")
images = load_images_from_path("../" + ZENODO_COVID_CASES_PATH)
labels = load_images_from_path("../" + ZENODO_INFECTION_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

print(data_dicts)
shuffler = np.random.RandomState(33)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)
print(type(data_dicts))
print(data_dicts)
print(SEED)

# Split the data into training (70%), validation (20%), and test sets (10%)
test_split = int(len(data_dicts) * 0.1)
val_split = int(len(data_dicts) * 0.2)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

[{'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_001.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_002.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_004.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_005.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_006.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/cor

In [3]:
# Load images and masks
images = load_images_from_path("../" + COVID_PREPROCESSED_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_PREPROCESSED_MASKS_PATH)

# Take only the images that are from Mosmed
train_images = [image for image in images if "radiopaedia" not in image and "coronacases" not in image]
train_labels = [label for label in labels if "radiopaedia" not in label and "coronacases" not in label]

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_train_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(train_images, train_labels)])
logging.debug(data_train_dicts)

# Shuffle the data
shuffler = np.random.RandomState(SEED)
shuffler.shuffle(data_train_dicts)
data_train_dicts = list(data_train_dicts)

# Split the training data into training and validation
val_split = int(len(data_train_dicts) * 0.2)

train_paths = data_train_dicts[val_split:]
val_paths = data_train_dicts[:val_split]

# Take coronacases and radiopeadia images for testing
test_paths = [image for image in images if "radiopaedia" in image or "coronacases" in image]

for i in range(len(test_paths)):
    print(test_paths[i])


../Datasets/preprocessed/images/coronacases_001.nii.gz
../Datasets/preprocessed/images/coronacases_002.nii.gz
../Datasets/preprocessed/images/coronacases_003.nii.gz
../Datasets/preprocessed/images/coronacases_004.nii.gz
../Datasets/preprocessed/images/coronacases_005.nii.gz
../Datasets/preprocessed/images/coronacases_006.nii.gz
../Datasets/preprocessed/images/coronacases_007.nii.gz
../Datasets/preprocessed/images/coronacases_008.nii.gz
../Datasets/preprocessed/images/coronacases_009.nii.gz
../Datasets/preprocessed/images/coronacases_010.nii.gz
../Datasets/preprocessed/images/radiopaedia_10_85902_1.nii.gz
../Datasets/preprocessed/images/radiopaedia_10_85902_3.nii.gz
../Datasets/preprocessed/images/radiopaedia_14_85914_0.nii.gz
../Datasets/preprocessed/images/radiopaedia_27_86410_0.nii.gz
../Datasets/preprocessed/images/radiopaedia_29_86490_1.nii.gz
../Datasets/preprocessed/images/radiopaedia_29_86491_1.nii.gz
../Datasets/preprocessed/images/radiopaedia_36_86526_0.nii.gz
../Datasets/prep

In [4]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

In [6]:
# create a funcion that takes a dataset an proves if it has only zeros in the mask. In that case throw an error. Also check for radiopaedia and coronacases in the path of the dataset. If there is one that it isn't throw an error
def check_dataset(dataset):
    radiopaedia = False
    coronacases = False
    for i, data in enumerate(dataset):
        if np.unique(data["mask"]).size == 1 and np.unique(data["mask"])[0] == 0:
            raise ValueError("The mask has only zeros")

        if "radiopaedia" in dataset.volumes[i]["img"]:
            radiopaedia = True
        if "coronacases" in dataset.volumes[i]["img"]:
            coronacases = True

    if not radiopaedia or not coronacases:
        raise ValueError("There are no radiopaedia or coronacases in the dataset")

    return

def check_dataset2(dataset):
    for i, data in enumerate(dataset):
        if np.unique(data["mask"]).size == 1 and np.unique(data["mask"])[0] == 0:
            raise ValueError("The mask has only zeros")

    return

In [7]:
# test validation test
check_dataset2(val_dataset)

In [None]:
print(val_dataset[0]["img"].shape)
print(val_dataset[1]["img"].shape)
print(val_dataset[2]["img"].shape)
print(val_dataset[3]["img"].shape)

In [None]:
print(val_dataset)
for i, data in enumerate(val_dataset):
    print(data["mask"].shape)
    print(np.unique(data["mask"]).size == 1)
    print(np.unique(data["mask"])[0] == 0)


In [7]:
for i, data in enumerate(val_dataset):
    # check unique values
    print(np.unique(data["mask"]))
    break

HRCT
[0. 1.]


{'epoch': 1, 'train_loss': 0.1, 'val_loss': 0.3}