diff --git a/scripts/datasets/medical/check_plethora.py b/scripts/datasets/medical/check_plethora.py new file mode 100644 index 00000000..53e4c154 --- /dev/null +++ b/scripts/datasets/medical/check_plethora.py @@ -0,0 +1,24 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_plethora_loader + + +ROOT = "/media/anwai/ANWAI/data/plethora" + + +def check_plethora(): + loader = get_plethora_loader( + path=ROOT, + task="thoracic", + patch_shape=(1, 512, 512), + batch_size=2, + resize_inputs=True, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_plethora() diff --git a/scripts/datasets/medical/check_tcia.py b/scripts/datasets/medical/check_tcia.py new file mode 100644 index 00000000..466773a1 --- /dev/null +++ b/scripts/datasets/medical/check_tcia.py @@ -0,0 +1,77 @@ +import os +import requests +from glob import glob +from natsort import natsorted + +import numpy as np +import pandas as pd +import nibabel as nib +import pydicom as dicom + +from tcia_utils import nbia + + +ROOT = "/media/anwai/ANWAI/data/tmp/" + +TCIA_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/NSCLC-Radiomics-OriginalCTs.tcia" + + +def check_tcia(download): + trg_path = os.path.join(ROOT, os.path.split(TCIA_URL)[-1]) + if download: + # output = nbia.getSeries(collection="LIDC-IDRI") + # nbia.downloadSeries(output, number=3, path=ROOT) + + manifest = requests.get(TCIA_URL) + with open(trg_path, 'wb') as f: + f.write(manifest.content) + + nbia.downloadSeries( + series_data=trg_path, input_type="manifest", number=3, path=ROOT, csv_filename="save" + ) + + df = pd.read_csv("save.csv") + + all_patient_dirs = glob(os.path.join(ROOT, "*")) + for patient_dir in all_patient_dirs: + patient_id = os.path.split(patient_dir)[-1] + if not patient_id.startswith("1.3"): + continue + + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == patient_id]["Subject ID"])[-9:] + seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*_primary_reviewer.nii.gz"))[0] + gt = nib.load(seg_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) + gt = np.flip(gt, axis=(0, 1)) + + all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm"))) + samples = [] + for dcm_fpath in all_dicom_files: + file = dicom.dcmread(dcm_fpath) + img = file.pixel_array + samples.append(img) + + samples = np.stack(samples) + + import napari + + v = napari.Viewer() + v.add_image(samples) + v.add_labels(gt.astype("uint64")) + napari.run() + + +def _test_me(): + data = nbia.getSeries(collection="Soft-tissue-Sarcoma") + print(data) + + nbia.downloadSeries(data, number=3) + + seriesUid = "1.3.6.1.4.1.14519.5.2.1.5168.1900.104193299251798317056218297018" + nbia.viewSeries(seriesUid) + + +if __name__ == "__main__": + # _test_me() + check_tcia(download=True) diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 981f8f50..88c0ffa9 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -4,4 +4,5 @@ from .camus import get_camus_dataset, get_camus_loader from .drive import get_drive_dataset, get_drive_loader from .papila import get_papila_dataset, get_papila_loader +from .plethora import get_plethora_dataset, get_plethora_loader from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py new file mode 100644 index 00000000..34308bb5 --- /dev/null +++ b/torch_em/data/datasets/medical/plethora.py @@ -0,0 +1,182 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple +from urllib.parse import urljoin + +import numpy as np +import pandas as pd +import nibabel as nib +import pydicom as dicom + +import torch_em + +from .. import util + + +BASE_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/" + + +URL = { + "image": urljoin(BASE_URL, "NSCLC-Radiomics-OriginalCTs.tcia"), + "gt": { + "thoracic": urljoin( + BASE_URL, "PleThora%20Thoracic_Cavities%20June%202020.zip?version=1&modificationDate=1593202695428&api=v2" + ), + "pleural_effusion": urljoin( + BASE_URL, "PleThora%20Effusions%20June%202020.zip?version=1&modificationDate=1593202778373&api=v2" + ) + } +} + + +CHECKSUMS = { + "image": None, + "gt": { + "thoracic": "6dfcb60e46c7b0ccf240bc5d13acb1c45c8d2f4922223f7b2fbd5e37acff2be0", + "pleural_effusion": "5dd07c327fb5723c5bbb48f2a02d7f365513d3ad136811fbe4def330ef2d7f6a" + } +} + + +ZIPFILES = { + "thoracic": "thoracic.zip", + "pleural_effusion": "pleural_effusion.zip" +} + + +def get_plethora_data(path, task, download): + os.makedirs(path, exist_ok=True) + + image_dir = os.path.join(path, "data", "images") + gt_dir = os.path.join(path, "data", "gt", "Thoracic_Cavities" if task == "thoracic" else "Effusions") + csv_path = os.path.join(path, "plethora_images") + if os.path.exists(image_dir) and os.path.exists(gt_dir): + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + + # let's download dicom files from the tcia manifest + tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") + util.download_source_tcia(path=tcia_path, url=URL["image"], dst=image_dir, csv_filename=csv_path, download=download) + + # let's download the segmentations from zipfiles + zip_path = os.path.join(path, ZIPFILES[task]) + util.download_source( + path=zip_path, url=URL["gt"][task], download=download, checksum=CHECKSUMS["gt"][task] + ) + util.unzip(zip_path=zip_path, dst=os.path.join(path, "data", "gt")) + + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + + +def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): + df = pd.read_csv(csv_path) + + task_gt_dir = os.path.join(gt_dir, ) + + os.makedirs(os.path.join(image_dir, "preprocessed"), exist_ok=True) + os.makedirs(os.path.join(task_gt_dir, "preprocessed"), exist_ok=True) + + # let's get all the series uid of the volumes downloaded and spot their allocated subject id + all_series_uid_dirs = glob(os.path.join(image_dir, "1.3*")) + image_paths, gt_paths = [], [] + for series_uid_dir in tqdm(all_series_uid_dirs): + series_uid = os.path.split(series_uid_dir)[-1] + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:] + + try: + gt_path = glob(os.path.join(task_gt_dir, subject_id, "*.nii.gz"))[0] + except IndexError: + # - some patients do not have "Thoracic_Cavities" segmentation + print(f"The ground truth is missing for subject '{subject_id}'") + continue + + assert os.path.exists(gt_path) + + vol_path = os.path.join(image_dir, "preprocessed", f"{subject_id}.nii.gz") + neu_gt_path = os.path.join(task_gt_dir, "preprocessed", os.path.split(gt_path)[-1]) + + image_paths.append(vol_path) + gt_paths.append(neu_gt_path) + if os.path.exists(vol_path) and os.path.exists(neu_gt_path): + continue + + # the individual slices for the inputs need to be merged into one volume. + if not os.path.exists(vol_path): + all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) + all_slices = [] + for dcm_path in all_dcm_slices: + dcmfile = dicom.dcmread(dcm_path) + img = dcmfile.pixel_array + all_slices.append(img) + + volume = np.stack(all_slices) + volume = volume.transpose(1, 2, 0) + nii_vol = nib.Nifti1Image(volume, np.eye(4)) + nii_vol.header.get_xyzt_units() + nii_vol.to_filename(vol_path) + + # the ground truth needs to be aligned as the inputs, let's take care of that. + gt = nib.load(gt_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs + gt = np.flip(gt, axis=(0, 1)) + + gt = gt.transpose(1, 2, 0) + gt_nii_vol = nib.Nifti1Image(gt, np.eye(4)) + gt_nii_vol.header.get_xyzt_units() + gt_nii_vol.to_filename(neu_gt_path) + + return image_paths, gt_paths + + +def _get_plethora_paths(path, task, download): + image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download) + image_paths, gt_paths = _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path) + return image_paths, gt_paths + + +def get_plethora_dataset( + path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_plethora_loader( + path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_plethora_dataset( + path=path, task=task, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 3fd416e4..e0f8887a 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -1,26 +1,34 @@ -import inspect import os import hashlib -import zipfile -import numpy as np +import inspect +import requests from tqdm import tqdm from warnings import warn -from xml.dom import minidom -from shutil import copyfileobj, which from subprocess import run from packaging import version +from shutil import copyfileobj, which +import zipfile +import numpy as np +from xml.dom import minidom from skimage.draw import polygon import torch + import torch_em -import requests +from torch_em.transform import get_raw_transform +from torch_em.transform.generic import ResizeInputs, Compose try: import gdown except ImportError: gdown = None +try: + from tcia_utils import nbia +except ModuleNotFoundError: + nbia = None + BIOIMAGEIO_IDS = { "covid_if": "ilastik/covid_if_training_data", @@ -163,6 +171,23 @@ def download_source_kaggle(path, dataset_name, download): api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) +def download_source_tcia(path, url, dst, csv_filename, download): + if not download: + raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") + + assert url.endswith(".tcia"), f"{path} is not a TCIA Manifest." + + # downloads the manifest file from the collection page + manifest = requests.get(url=url) + with open(path, "wb") as f: + f.write(manifest.content) + + # this part extracts the UIDs from the manigests and downloads them. + nbia.downloadSeries( + series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename, + ) + + def update_kwargs(kwargs, key, value, msg=None): if key in kwargs: msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg @@ -225,6 +250,33 @@ def add_instance_label_transform( return kwargs, label_dtype +def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None): + """ + Checks for raw_transform and label_transform incoming values. + If yes, it will automatically merge these two transforms to apply them together. + """ + if resize_inputs: + assert isinstance(resize_kwargs, dict) + patch_shape = None + + raw_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_rgb=resize_kwargs["is_rgb"]) + label_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_label=True) + + if "raw_transform" in kwargs: + trafo = Compose(raw_trafo, kwargs["raw_transform"]) + kwargs["raw_transform"] = trafo + else: + kwargs["raw_transform"] = Compose(raw_trafo, get_raw_transform()) + + if "label_transform" in kwargs: + trafo = Compose(label_trafo, kwargs["label_transform"]) + kwargs["label_transform"] = trafo + else: + kwargs["label_transform"] = label_trafo + + return kwargs, patch_shape + + def generate_labeled_array_from_xml(shape, xml_file): """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb