In [1]:
import logging
from preprocessing.covid_dataset import CovidDataset
from monai.data import DataLoader
from preprocessing.transforms import get_hrct_transforms, get_cbct_transforms, \
    get_val_hrct_transforms, get_val_cbct_transforms
from utils.helpers import load_images_from_path
from config.constants import (COVID_CASES_PATH, INFECTION_MASKS_PATH, SEED)
import torch
import numpy as np
import matplotlib.pyplot as plt
from monai.visualize import plot_2d_or_3d_image, matshow3d, blend_images
from pathlib import Path

2024-06-06 12:31:03.740689: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-06 12:31:04.083821: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:

# Load images and masks
logging.info(f"Loading images from {COVID_CASES_PATH}")
images = load_images_from_path("../" + COVID_CASES_PATH)
labels = load_images_from_path("../" + INFECTION_MASKS_PATH)

# Convert images and masks to a list of dictionaries with keys "img" and "mask"
data_dicts = np.array([{"img": img, "mask": mask} for img, mask in zip(images, labels)])
logging.debug(data_dicts)

print(data_dicts)
shuffler = np.random.RandomState(420)
shuffler.shuffle(data_dicts)
data_dicts = list(data_dicts)
print()
print(type(data_dicts))
print()
print(data_dicts)
print(SEED)

# Split the data into training (70%), validation (10%), and test sets (20%)
test_split = int(len(data_dicts) * 0.2)
val_split = int(len(data_dicts) * 0.1)

train_paths = data_dicts[test_split + val_split:]
val_paths = data_dicts[test_split:test_split + val_split]
test_paths = data_dicts[:test_split]

[{'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_001.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_001.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_002.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_002.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_003.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_003.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_004.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_004.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_005.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_005.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_006.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/coronacases_006.nii.gz'}
 {'img': '../Datasets/Zenodo/COVID-19-CT-Seg_20cases/coronacases_007.nii.gz', 'mask': '../Datasets/Zenodo/Infection_Mask/cor

In [3]:
# Create a dataset and a dataloader for the training set
train_dataset = CovidDataset(volumes=train_paths, hrct_transform=get_hrct_transforms(), cbct_transform=get_cbct_transforms())
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the validation set
val_dataset = CovidDataset(volumes=val_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=2)

# Create a dataset and a dataloader for the test set
test_dataset = CovidDataset(volumes=test_paths, hrct_transform=get_val_hrct_transforms(), cbct_transform=get_val_cbct_transforms())
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=2)

In [4]:
# create a funcion that takes a dataset an proves if it has only zeros in the mask. In that case throw an error. Also check for radiopaedia and coronacases in the path of the dataset. If there is one that it isn't throw an error
def check_dataset(dataset):
    radiopaedia = False
    coronacases = False
    for i, data in enumerate(dataset):
        if np.unique(data["mask"]).size == 1 and np.unique(data["mask"])[0] == 0:
            raise ValueError("The mask has only zeros")

        if "radiopaedia" in dataset.volumes[i]["img"]:
            radiopaedia = True
        if "coronacases" in dataset.volumes[i]["img"]:
            coronacases = True

    if not radiopaedia or not coronacases:
        raise ValueError("There are no radiopaedia or coronacases in the dataset")

    return

In [5]:
# test validation test
check_dataset(val_dataset)

In [9]:
print(val_dataset[0]["img"].shape)
print(val_dataset[1]["img"].shape)

torch.Size([1, 630, 630, 418])
torch.Size([1, 512, 512, 213])


# SHOW DATASET

In [None]:
def draw_volumes(dataloader, dataset):
     for i, data in enumerate(dataloader):
        volume = dataset.volumes[i]['img'].split('/')[-1].split('.')[0]
        Path(f"images/{volume}").mkdir(parents=True, exist_ok=True)
        print(f"Processing volume {i}: {volume}")
        print(f"img shape: {data['img'].shape}")
        for j in range(data["img"].shape[0]):
            print(f"Processing subvolume {j+1}")
            
            img = data["img"][j]
            seg = data["mask"][j]
            print(f"img shape: {img.shape}")
            blended_img = blend_images(img, seg, alpha=0.5, cmap="hsv", rescale_arrays=True)
            blended_img = (blended_img - blended_img.min()) / (blended_img.max() - blended_img.min())
            print(f"blended_img shape: {blended_img.shape}")
            fig = plt.figure(figsize=(30, 30))
            plt.title(f"subvolume{j}")
            print(blended_img.shape[-1])
            for k in range(blended_img.shape[-1]):
                print(f"Processing slice {k+1}")
                ax = fig.add_subplot(16, 16, k+1)
                ax.imshow(torch.moveaxis(blended_img[:, :, :, k], 0, -1))
                ax.axis('off')
            plt.savefig(f"images/{volume}/{j+1}.png")
            plt.close()

In [11]:
# Show validation volumes
draw_volumes(val_loader, val_dataset)

Processing volume 0: coronacases_003
img shape: torch.Size([1, 1, 512, 512, 200])
Processing subvolume 1
img shape: torch.Size([1, 512, 512, 200])
blended_img shape: torch.Size([3, 512, 512, 200])
200
Processing slice 1
Processing slice 2
Processing slice 3
Processing slice 4
Processing slice 5
Processing slice 6
Processing slice 7
Processing slice 8
Processing slice 9
Processing slice 10
Processing slice 11
Processing slice 12
Processing slice 13
Processing slice 14
Processing slice 15
Processing slice 16
Processing slice 17
Processing slice 18
Processing slice 19
Processing slice 20
Processing slice 21
Processing slice 22
Processing slice 23
Processing slice 24
Processing slice 25
Processing slice 26
Processing slice 27
Processing slice 28
Processing slice 29
Processing slice 30
Processing slice 31
Processing slice 32
Processing slice 33
Processing slice 34
Processing slice 35
Processing slice 36
Processing slice 37
Processing slice 38
Processing slice 39
Processing slice 40
Processi