From d8d1726045832befdbca9a40d3ca4ab7f13a6216 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 16 May 2024 16:56:09 +0200 Subject: [PATCH 01/11] Test tcia archives --- scripts/datasets/check_tcia.py | 64 ++++++++++++++++++++++++++++++++++ torch_em/util/image.py | 1 + 2 files changed, 65 insertions(+) create mode 100644 scripts/datasets/check_tcia.py diff --git a/scripts/datasets/check_tcia.py b/scripts/datasets/check_tcia.py new file mode 100644 index 00000000..aeb3a743 --- /dev/null +++ b/scripts/datasets/check_tcia.py @@ -0,0 +1,64 @@ +import os +import requests +from glob import glob +from natsort import natsorted + +import numpy as np +import pydicom as dicom + +from tcia_utils import nbia + + +ROOT = "/media/anwai/ANWAI/data/tmp/" + +TCIA_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/NSCLC-Radiomics-OriginalCTs.tcia" + + +def check_tcia(download): + trg_path = os.path.join(ROOT, os.path.split(TCIA_URL)[-1]) + if download: + # output = nbia.getSeries(collection="LIDC-IDRI") + # nbia.downloadSeries(output, number=3, path=ROOT) + + manifest = requests.get(TCIA_URL) + with open(trg_path, 'wb') as f: + f.write(manifest.content) + + df = nbia.downloadSeries(trg_path, input_type="manifest", number=3, format="df", path=ROOT) + + breakpoint() + + all_patient_dirs = glob(os.path.join(ROOT, "*")) + for patient_dir in all_patient_dirs: + if not os.path.split(patient_dir)[-1].startswith("1.3"): + continue + + all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm"))) + samples = [] + for dcm_fpath in all_dicom_files: + file = dicom.dcmread(dcm_fpath) + img = file.pixel_array + samples.append(img) + + samples = np.stack(samples) + + import napari + + v = napari.Viewer() + v.add_image(samples) + napari.run() + + +def _test_me(): + data = nbia.getSeries(collection="Soft-tissue-Sarcoma") + print(data) + + nbia.downloadSeries(data, number=3) + + seriesUid = "1.3.6.1.4.1.14519.5.2.1.5168.1900.104193299251798317056218297018" + nbia.viewSeries(seriesUid) + + +if __name__ == "__main__": + # _test_me() + check_tcia(download=False) diff --git a/torch_em/util/image.py b/torch_em/util/image.py index 17e34cfc..d69a218d 100644 --- a/torch_em/util/image.py +++ b/torch_em/util/image.py @@ -31,6 +31,7 @@ def supports_memmap(image_path): def load_image(image_path, memmap=True): + # if image_path.endswith(".dcm"): ... if supports_memmap(image_path) and memmap: return tifffile.memmap(image_path, mode="r") elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): From 08539ca152b73cc5898e18209813a509b2c8813e Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 16 May 2024 17:39:32 +0200 Subject: [PATCH 02/11] Confirm TCIA workflow --- scripts/datasets/check_tcia.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/scripts/datasets/check_tcia.py b/scripts/datasets/check_tcia.py index aeb3a743..b90157b6 100644 --- a/scripts/datasets/check_tcia.py +++ b/scripts/datasets/check_tcia.py @@ -4,6 +4,8 @@ from natsort import natsorted import numpy as np +import pandas as pd +import nibabel as nib import pydicom as dicom from tcia_utils import nbia @@ -24,15 +26,27 @@ def check_tcia(download): with open(trg_path, 'wb') as f: f.write(manifest.content) - df = nbia.downloadSeries(trg_path, input_type="manifest", number=3, format="df", path=ROOT) + nbia.downloadSeries( + series_data=trg_path, input_type="manifest", number=3, path=ROOT, csv_filename="save" + ) - breakpoint() + df = pd.read_csv("save.csv") all_patient_dirs = glob(os.path.join(ROOT, "*")) for patient_dir in all_patient_dirs: - if not os.path.split(patient_dir)[-1].startswith("1.3"): + patient_id = os.path.split(patient_dir)[-1] + if not patient_id.startswith("1.3"): continue + breakpoint() + + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == patient_id]["Subject ID"])[-9:] + seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*.nii.gz"))[0] + gt = nib.load(seg_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) + gt = np.flip(gt, axis=1) + all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm"))) samples = [] for dcm_fpath in all_dicom_files: @@ -40,12 +54,13 @@ def check_tcia(download): img = file.pixel_array samples.append(img) - samples = np.stack(samples) + samples = np.stack(samples[::-1]) import napari v = napari.Viewer() v.add_image(samples) + v.add_labels(gt.astype("uint64")) napari.run() From b52c87382eed33e56dbc5181769bf9470bb1a483 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 16 May 2024 21:05:59 +0200 Subject: [PATCH 03/11] Add plethora dataset --- scripts/datasets/check_plethora.py | 13 ++++ torch_em/data/datasets/medical/__init__.py | 1 + torch_em/data/datasets/medical/plethora.py | 90 ++++++++++++++++++++++ torch_em/util/image.py | 1 - 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 scripts/datasets/check_plethora.py create mode 100644 torch_em/data/datasets/medical/plethora.py diff --git a/scripts/datasets/check_plethora.py b/scripts/datasets/check_plethora.py new file mode 100644 index 00000000..76830566 --- /dev/null +++ b/scripts/datasets/check_plethora.py @@ -0,0 +1,13 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_plethora_loader + + +ROOT = "/media/anwai/ANWAI/data/plethora" + + +def check_plethora(): + ... + + +if __name__ == "__main__": + check_plethora() diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 86f40713..98c64625 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -3,4 +3,5 @@ from .camus import get_camus_dataset, get_camus_loader from .drive import get_drive_dataset, get_drive_loader from .papila import get_papila_dataset, get_papila_loader +from .plethora import get_plethora_dataset, get_plethora_loader from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py new file mode 100644 index 00000000..cc33f43c --- /dev/null +++ b/torch_em/data/datasets/medical/plethora.py @@ -0,0 +1,90 @@ +import os +import requests +from urllib.parse import urljoin +from typing import Union, Tuple + +import pandas as pd + +from .. import util + +from tcia_utils import nbia + + +BASE_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/" + + +URL = { + "image": urljoin(BASE_URL, "NSCLC-Radiomics-OriginalCTs.tcia"), + "gt": { + "thoracic": urljoin( + BASE_URL, "PleThora%20Thoracic_Cavities%20June%202020.zip?version=1&modificationDate=1593202695428&api=v2" + ), + "pleural_effusion": urljoin( + BASE_URL, "PleThora%20Effusions%20June%202020.zip?version=1&modificationDate=1593202778373&api=v2" + ) + } +} + + +CHECKSUMS = { + "image": None, + "gt": { + "thoracic": None, + "pleural_effusion": None + } +} + + +def download_source_tcia(path, url, dst, csv_filename): + assert url.endswith(".tcia") + + manifest = requests.get(url=url) + with open(path, "wb") as f: + f.write(manifest.content) + + if os.path.exists(csv_filename): + prev_df = pd.read_csv(csv_filename) + + df = nbia.downloadSeries( + series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename + ) + + neu_df = pd.concat(prev_df, df) + neu_df.to_csv(csv_filename) + + +def get_plethora_data(path, download): + os.makedirs(path, exist_ok=True) + + image_dir = os.path.join(path, "data", "images") + gt_dir = os.path.join(path, "data", "gt") + if os.path.exists(image_dir) and os.path.exists(gt_dir): + return image_dir, gt_dir + + tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") + + download_source_tcia( + path=tcia_path, url=URL, dst=image_dir, csv_filename=os.path.join(path, "plethora_images") + ) + + +def _get_plethora_paths(path, download): + data_dir = get_plethora_data() + + breakpoint() + + +def get_plethora_dataset( + path: Union[os.PathLike, str], + download: bool = False, + **kwargs +): + + + +def get_plethora_loader( + path: Union[os.PathLike, str], + download: bool = False, + **kwargs +): + ... diff --git a/torch_em/util/image.py b/torch_em/util/image.py index d69a218d..17e34cfc 100644 --- a/torch_em/util/image.py +++ b/torch_em/util/image.py @@ -31,7 +31,6 @@ def supports_memmap(image_path): def load_image(image_path, memmap=True): - # if image_path.endswith(".dcm"): ... if supports_memmap(image_path) and memmap: return tifffile.memmap(image_path, mode="r") elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): From fd8c9d80c9b6bea6ee93c2a24c9d9aa09541b52b Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Fri, 17 May 2024 11:24:44 +0200 Subject: [PATCH 04/11] Update plethora dataset - create tcia download fn --- scripts/datasets/check_tcia.py | 8 +- torch_em/data/datasets/medical/plethora.py | 100 ++++++++++++++++++--- 2 files changed, 93 insertions(+), 15 deletions(-) diff --git a/scripts/datasets/check_tcia.py b/scripts/datasets/check_tcia.py index b90157b6..7866ef78 100644 --- a/scripts/datasets/check_tcia.py +++ b/scripts/datasets/check_tcia.py @@ -38,14 +38,12 @@ def check_tcia(download): if not patient_id.startswith("1.3"): continue - breakpoint() - subject_id = pd.Series.to_string(df.loc[df["Series UID"] == patient_id]["Subject ID"])[-9:] - seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*.nii.gz"))[0] + seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*_primary_reviewer.nii.gz"))[0] gt = nib.load(seg_path) gt = gt.get_fdata() gt = gt.transpose(2, 1, 0) - gt = np.flip(gt, axis=1) + gt = np.flip(gt, axis=(0, 1)) all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm"))) samples = [] @@ -54,7 +52,7 @@ def check_tcia(download): img = file.pixel_array samples.append(img) - samples = np.stack(samples[::-1]) + samples = np.stack(samples) import napari diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index cc33f43c..7603ddf2 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -1,9 +1,17 @@ import os import requests -from urllib.parse import urljoin +from glob import glob +from pathlib import Path +from natsort import natsorted from typing import Union, Tuple +from urllib.parse import urljoin +import numpy as np import pandas as pd +import nibabel as nib +import pydicom as dicom + +import torch_em from .. import util @@ -35,6 +43,12 @@ } +ZIPFILES = { + "thoracic": "thoracic.zip", + "pleural_effusion": "pleural_effusion.zip" +} + + def download_source_tcia(path, url, dst, csv_filename): assert url.endswith(".tcia") @@ -53,38 +67,104 @@ def download_source_tcia(path, url, dst, csv_filename): neu_df.to_csv(csv_filename) -def get_plethora_data(path, download): +def get_plethora_data(path, task, download): os.makedirs(path, exist_ok=True) image_dir = os.path.join(path, "data", "images") gt_dir = os.path.join(path, "data", "gt") + csv_path = os.path.join(path, "plethora_images") if os.path.exists(image_dir) and os.path.exists(gt_dir): - return image_dir, gt_dir + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + # let's download dicom files from the tcia manifest tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") + download_source_tcia(path=tcia_path, url=URL, dst=image_dir, csv_filename=csv_path) - download_source_tcia( - path=tcia_path, url=URL, dst=image_dir, csv_filename=os.path.join(path, "plethora_images") + # let's download the segmentations from zipfiles + zip_path = os.path.join(path, ZIPFILES[task]) + util.download_source( + path=zip_path, url=URL["gt"][task], download=download, checksum=CHECKSUMS["gt"][task] ) + util.unzip(zip_path=zip_path, dst=gt_dir) + + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + + +def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): + df = pd.read_csv(csv_path) + + task_gt_dir = os.path.join(gt_dir, "Thoracic_Cavity" if task == "thoracic" else "Pleural_Effusion") + + # let's get all the series uid of the volumes downloaded and spot their allocated subject id + all_series_uid_dirs = glob(os.path.join(image_dir, "1.3*")) + for series_uid_dir in all_series_uid_dirs: + series_uid = os.path.split(series_uid_dir)[-1] + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:] + vol_path = os.path.join(image_dir, f"{subject_id}.nii.gz") + if os.path.exists(vol_path): + continue -def _get_plethora_paths(path, download): - data_dir = get_plethora_data() + # TODO: there are multiple raters, check it out if there is can be some consistency + gt_path = glob(os.path.join(task_gt_dir, subject_id, "*_primary_reviewer.nii.gz"))[0] - breakpoint() + # the ground truth needs to be aligned as the inputs, let's take care of that. + gt = nib.load(gt_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs + gt = np.flip(gt, axis=(0, 1)) + + all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) + all_slices = [] + for dcm_path in all_dcm_slices: + dcmfile = dicom.dcmread(dcm_path) + img = dcmfile.pixel_array + all_slices.append(img) + + volume = np.stack(all_slices) + nii_vol = nib.Nifti1Image(volume, np.eye(4)) + nii_vol.header.get_xyzt_units() + nii_vol.to_filename(vol_path) + + +def _get_plethora_paths(path, task, download): + image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download) + + _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path) + + image_paths = ... + gt_paths = ... + + return image_paths, gt_paths def get_plethora_dataset( path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, download: bool = False, **kwargs ): - + image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download) + + dataset = ... + + return dataset def get_plethora_loader( path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, download: bool = False, **kwargs ): - ... + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_plethora_dataset( + path=path, task=task, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader From ef16e9683a08b43a7a96e8b04a82948a873cc3ce Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Fri, 17 May 2024 15:08:41 +0200 Subject: [PATCH 05/11] Refactor tcia download scripts --- scripts/datasets/check_plethora.py | 11 ++++++- torch_em/data/datasets/medical/plethora.py | 23 +------------- torch_em/data/datasets/util.py | 37 +++++++++++++++++++--- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/scripts/datasets/check_plethora.py b/scripts/datasets/check_plethora.py index 76830566..fe5f8e76 100644 --- a/scripts/datasets/check_plethora.py +++ b/scripts/datasets/check_plethora.py @@ -6,7 +6,16 @@ def check_plethora(): - ... + loader = get_plethora_loader( + path=ROOT, + task="thoracic", + patch_shape=(1, 512, 512), + batch_size=2, + resize_inputs=False, + download=True, + ) + + check_loader(loader, 8) if __name__ == "__main__": diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index 7603ddf2..84dc1f2c 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -1,5 +1,4 @@ import os -import requests from glob import glob from pathlib import Path from natsort import natsorted @@ -15,8 +14,6 @@ from .. import util -from tcia_utils import nbia - BASE_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/" @@ -49,24 +46,6 @@ } -def download_source_tcia(path, url, dst, csv_filename): - assert url.endswith(".tcia") - - manifest = requests.get(url=url) - with open(path, "wb") as f: - f.write(manifest.content) - - if os.path.exists(csv_filename): - prev_df = pd.read_csv(csv_filename) - - df = nbia.downloadSeries( - series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename - ) - - neu_df = pd.concat(prev_df, df) - neu_df.to_csv(csv_filename) - - def get_plethora_data(path, task, download): os.makedirs(path, exist_ok=True) @@ -78,7 +57,7 @@ def get_plethora_data(path, task, download): # let's download dicom files from the tcia manifest tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") - download_source_tcia(path=tcia_path, url=URL, dst=image_dir, csv_filename=csv_path) + util.download_source_tcia(path=tcia_path, url=URL, dst=image_dir, csv_filename=csv_path) # let's download the segmentations from zipfiles zip_path = os.path.join(path, ZIPFILES[task]) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 3fd416e4..fb4ca94b 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -1,26 +1,32 @@ -import inspect import os import hashlib -import zipfile +import inspect +import requests import numpy as np from tqdm import tqdm from warnings import warn -from xml.dom import minidom -from shutil import copyfileobj, which from subprocess import run from packaging import version +from shutil import copyfileobj, which +import zipfile +import pandas as pd +from xml.dom import minidom from skimage.draw import polygon import torch import torch_em -import requests try: import gdown except ImportError: gdown = None +try: + from tcia_utils import nbia +except ModuleNotFoundError: + nbia = None + BIOIMAGEIO_IDS = { "covid_if": "ilastik/covid_if_training_data", @@ -163,6 +169,27 @@ def download_source_kaggle(path, dataset_name, download): api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) +def download_source_tcia(path, url, dst, csv_filename, download): + if not download: + raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") + + assert url.endswith(".tcia"), f"{path} is not a TCIA Manifest." + + # downloads the manifest file from the collection page + manifest = requests.get(url=url) + with open(path, "wb") as f: + f.write(manifest.content) + + if os.path.exists(csv_filename): + prev_df = pd.read_csv(csv_filename) + + # this part extracts the UIDs from the manigests and downloads them. + df = nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename) + + neu_df = pd.concat(prev_df, df) + neu_df.to_csv(csv_filename) + + def update_kwargs(kwargs, key, value, msg=None): if key in kwargs: msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg From 54c959cf049bfb32f742b37a7700d0d47578e246 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Fri, 17 May 2024 23:06:10 +0200 Subject: [PATCH 06/11] Simplify storing metadata --- scripts/datasets/check_tcia.py | 2 +- torch_em/data/datasets/medical/plethora.py | 4 +++- torch_em/data/datasets/util.py | 11 +++-------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/scripts/datasets/check_tcia.py b/scripts/datasets/check_tcia.py index 7866ef78..466773a1 100644 --- a/scripts/datasets/check_tcia.py +++ b/scripts/datasets/check_tcia.py @@ -74,4 +74,4 @@ def _test_me(): if __name__ == "__main__": # _test_me() - check_tcia(download=False) + check_tcia(download=True) diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index 84dc1f2c..1512974f 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -57,7 +57,9 @@ def get_plethora_data(path, task, download): # let's download dicom files from the tcia manifest tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") - util.download_source_tcia(path=tcia_path, url=URL, dst=image_dir, csv_filename=csv_path) + util.download_source_tcia(path=tcia_path, url=URL["image"], dst=image_dir, csv_filename=csv_path, download=download) + + breakpoint() # let's download the segmentations from zipfiles zip_path = os.path.join(path, ZIPFILES[task]) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index fb4ca94b..b08713c3 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -10,7 +10,6 @@ from shutil import copyfileobj, which import zipfile -import pandas as pd from xml.dom import minidom from skimage.draw import polygon @@ -180,14 +179,10 @@ def download_source_tcia(path, url, dst, csv_filename, download): with open(path, "wb") as f: f.write(manifest.content) - if os.path.exists(csv_filename): - prev_df = pd.read_csv(csv_filename) - # this part extracts the UIDs from the manigests and downloads them. - df = nbia.downloadSeries(series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename) - - neu_df = pd.concat(prev_df, df) - neu_df.to_csv(csv_filename) + nbia.downloadSeries( + series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename, + ) def update_kwargs(kwargs, key, value, msg=None): From b91e7104431377d5130ebbb29b4bc892f5255342 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Sat, 18 May 2024 23:24:54 +0200 Subject: [PATCH 07/11] Minor fix to saving nifti images --- torch_em/data/datasets/medical/plethora.py | 71 +++++++++++++++------- torch_em/data/datasets/util.py | 2 +- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index 1512974f..a9c0f1c2 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -1,5 +1,6 @@ import os from glob import glob +from tqdm import tqdm from pathlib import Path from natsort import natsorted from typing import Union, Tuple @@ -11,6 +12,7 @@ import pydicom as dicom import torch_em +from torch_em.transform.generic import ResizeInputs from .. import util @@ -34,8 +36,8 @@ CHECKSUMS = { "image": None, "gt": { - "thoracic": None, - "pleural_effusion": None + "thoracic": "6dfcb60e46c7b0ccf240bc5d13acb1c45c8d2f4922223f7b2fbd5e37acff2be0", + "pleural_effusion": "5dd07c327fb5723c5bbb48f2a02d7f365513d3ad136811fbe4def330ef2d7f6a" } } @@ -59,8 +61,6 @@ def get_plethora_data(path, task, download): tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") util.download_source_tcia(path=tcia_path, url=URL["image"], dst=image_dir, csv_filename=csv_path, download=download) - breakpoint() - # let's download the segmentations from zipfiles zip_path = os.path.join(path, ZIPFILES[task]) util.download_source( @@ -74,27 +74,30 @@ def get_plethora_data(path, task, download): def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): df = pd.read_csv(csv_path) - task_gt_dir = os.path.join(gt_dir, "Thoracic_Cavity" if task == "thoracic" else "Pleural_Effusion") + task_gt_dir = os.path.join(gt_dir, "Thoracic_Cavities" if task == "thoracic" else "Pleural_Effusion") + + os.makedirs(os.path.join(image_dir, "preprocessed"), exist_ok=True) + os.makedirs(os.path.join(task_gt_dir, "preprocessed"), exist_ok=True) # let's get all the series uid of the volumes downloaded and spot their allocated subject id all_series_uid_dirs = glob(os.path.join(image_dir, "1.3*")) - for series_uid_dir in all_series_uid_dirs: + image_paths, gt_paths = [], [] + for series_uid_dir in tqdm(all_series_uid_dirs): series_uid = os.path.split(series_uid_dir)[-1] subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:] - vol_path = os.path.join(image_dir, f"{subject_id}.nii.gz") - if os.path.exists(vol_path): - continue - - # TODO: there are multiple raters, check it out if there is can be some consistency gt_path = glob(os.path.join(task_gt_dir, subject_id, "*_primary_reviewer.nii.gz"))[0] + assert os.path.exists(gt_path) - # the ground truth needs to be aligned as the inputs, let's take care of that. - gt = nib.load(gt_path) - gt = gt.get_fdata() - gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs - gt = np.flip(gt, axis=(0, 1)) + vol_path = os.path.join(image_dir, "preprocessed", f"{subject_id}.nii.gz") + neu_gt_path = os.path.join(task_gt_dir, "preprocessed", os.path.split(gt_path)[-1]) + + image_paths.append(vol_path) + gt_paths.append(neu_gt_path) + if os.path.exists(vol_path) and os.path.exists(neu_gt_path): + continue + # the individual slices for the inputs need to be merged into one volume. all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) all_slices = [] for dcm_path in all_dcm_slices: @@ -107,15 +110,22 @@ def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): nii_vol.header.get_xyzt_units() nii_vol.to_filename(vol_path) + # the ground truth needs to be aligned as the inputs, let's take care of that. + gt = nib.load(gt_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs + gt = np.flip(gt, axis=(0, 1)) -def _get_plethora_paths(path, task, download): - image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download) + gt_nii_vol = nib.Nifti1Image(gt, np.eye(4)) + gt_nii_vol.header.get_xyzt_units() + gt_nii_vol.to_filename(neu_gt_path) - _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path) + return image_paths, gt_paths - image_paths = ... - gt_paths = ... +def _get_plethora_paths(path, task, download): + image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download) + image_paths, gt_paths = _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path) return image_paths, gt_paths @@ -129,7 +139,24 @@ def get_plethora_dataset( ): image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download) - dataset = ... + if resize_inputs: + raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + raw_transform=raw_trafo, + label_transform=label_trafo, + **kwargs + ) return dataset diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index b08713c3..d4ccc601 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -2,7 +2,6 @@ import hashlib import inspect import requests -import numpy as np from tqdm import tqdm from warnings import warn from subprocess import run @@ -10,6 +9,7 @@ from shutil import copyfileobj, which import zipfile +import numpy as np from xml.dom import minidom from skimage.draw import polygon From f0d807aa5f440945e0ebf1e673c6b9edcb39fcbb Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Sun, 19 May 2024 08:14:02 +0200 Subject: [PATCH 08/11] Fix pleural effusion paths --- scripts/datasets/check_plethora.py | 2 +- torch_em/data/datasets/medical/plethora.py | 39 +++++++++++++--------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/scripts/datasets/check_plethora.py b/scripts/datasets/check_plethora.py index fe5f8e76..11fa65d5 100644 --- a/scripts/datasets/check_plethora.py +++ b/scripts/datasets/check_plethora.py @@ -8,7 +8,7 @@ def check_plethora(): loader = get_plethora_loader( path=ROOT, - task="thoracic", + task="pleural_effusion", patch_shape=(1, 512, 512), batch_size=2, resize_inputs=False, diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index a9c0f1c2..7ea4b30f 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -52,7 +52,7 @@ def get_plethora_data(path, task, download): os.makedirs(path, exist_ok=True) image_dir = os.path.join(path, "data", "images") - gt_dir = os.path.join(path, "data", "gt") + gt_dir = os.path.join(path, "data", "gt", "Thoracic_Cavities" if task == "thoracic" else "Effusions") csv_path = os.path.join(path, "plethora_images") if os.path.exists(image_dir) and os.path.exists(gt_dir): return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") @@ -66,7 +66,7 @@ def get_plethora_data(path, task, download): util.download_source( path=zip_path, url=URL["gt"][task], download=download, checksum=CHECKSUMS["gt"][task] ) - util.unzip(zip_path=zip_path, dst=gt_dir) + util.unzip(zip_path=zip_path, dst=os.path.join(path, "data", "gt")) return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") @@ -74,7 +74,7 @@ def get_plethora_data(path, task, download): def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): df = pd.read_csv(csv_path) - task_gt_dir = os.path.join(gt_dir, "Thoracic_Cavities" if task == "thoracic" else "Pleural_Effusion") + task_gt_dir = os.path.join(gt_dir, ) os.makedirs(os.path.join(image_dir, "preprocessed"), exist_ok=True) os.makedirs(os.path.join(task_gt_dir, "preprocessed"), exist_ok=True) @@ -86,7 +86,13 @@ def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): series_uid = os.path.split(series_uid_dir)[-1] subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:] - gt_path = glob(os.path.join(task_gt_dir, subject_id, "*_primary_reviewer.nii.gz"))[0] + try: + gt_path = glob(os.path.join(task_gt_dir, subject_id, "*.nii.gz"))[0] + except IndexError: + # - some patients do not have "Thoracic_Cavities" segmentation + print(f"The ground truth is missing for subject '{subject_id}'") + continue + assert os.path.exists(gt_path) vol_path = os.path.join(image_dir, "preprocessed", f"{subject_id}.nii.gz") @@ -98,17 +104,19 @@ def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): continue # the individual slices for the inputs need to be merged into one volume. - all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) - all_slices = [] - for dcm_path in all_dcm_slices: - dcmfile = dicom.dcmread(dcm_path) - img = dcmfile.pixel_array - all_slices.append(img) - - volume = np.stack(all_slices) - nii_vol = nib.Nifti1Image(volume, np.eye(4)) - nii_vol.header.get_xyzt_units() - nii_vol.to_filename(vol_path) + if not os.path.exists(vol_path): + all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) + all_slices = [] + for dcm_path in all_dcm_slices: + dcmfile = dicom.dcmread(dcm_path) + img = dcmfile.pixel_array + all_slices.append(img) + + volume = np.stack(all_slices) + volume = volume.transpose(1, 2, 0) + nii_vol = nib.Nifti1Image(volume, np.eye(4)) + nii_vol.header.get_xyzt_units() + nii_vol.to_filename(vol_path) # the ground truth needs to be aligned as the inputs, let's take care of that. gt = nib.load(gt_path) @@ -116,6 +124,7 @@ def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs gt = np.flip(gt, axis=(0, 1)) + gt = gt.transpose(1, 2, 0) gt_nii_vol = nib.Nifti1Image(gt, np.eye(4)) gt_nii_vol.header.get_xyzt_units() gt_nii_vol.to_filename(neu_gt_path) From 39954db0940c92eec01704c6f84b5384ee063959 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 23 May 2024 10:33:08 +0200 Subject: [PATCH 09/11] Add feature to apply incoming transforms as well --- .../datasets/{ => medical}/check_plethora.py | 6 ++-- scripts/datasets/{ => medical}/check_tcia.py | 0 torch_em/data/datasets/medical/plethora.py | 13 +++----- torch_em/data/datasets/util.py | 30 +++++++++++++++++++ torch_em/transform/raw.py | 16 ++++++++++ 5 files changed, 54 insertions(+), 11 deletions(-) rename scripts/datasets/{ => medical}/check_plethora.py (74%) rename scripts/datasets/{ => medical}/check_tcia.py (100%) diff --git a/scripts/datasets/check_plethora.py b/scripts/datasets/medical/check_plethora.py similarity index 74% rename from scripts/datasets/check_plethora.py rename to scripts/datasets/medical/check_plethora.py index 11fa65d5..53e4c154 100644 --- a/scripts/datasets/check_plethora.py +++ b/scripts/datasets/medical/check_plethora.py @@ -1,3 +1,4 @@ +from torch_em.data import MinInstanceSampler from torch_em.util.debug import check_loader from torch_em.data.datasets.medical import get_plethora_loader @@ -8,11 +9,12 @@ def check_plethora(): loader = get_plethora_loader( path=ROOT, - task="pleural_effusion", + task="thoracic", patch_shape=(1, 512, 512), batch_size=2, - resize_inputs=False, + resize_inputs=True, download=True, + sampler=MinInstanceSampler(), ) check_loader(loader, 8) diff --git a/scripts/datasets/check_tcia.py b/scripts/datasets/medical/check_tcia.py similarity index 100% rename from scripts/datasets/check_tcia.py rename to scripts/datasets/medical/check_tcia.py diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py index 7ea4b30f..34308bb5 100644 --- a/torch_em/data/datasets/medical/plethora.py +++ b/torch_em/data/datasets/medical/plethora.py @@ -12,7 +12,6 @@ import pydicom as dicom import torch_em -from torch_em.transform.generic import ResizeInputs from .. import util @@ -149,12 +148,10 @@ def get_plethora_dataset( image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download) if resize_inputs: - raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) - label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) - patch_shape = None - else: - patch_shape = patch_shape - raw_trafo, label_trafo = None, None + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) dataset = torch_em.default_segmentation_dataset( raw_paths=image_paths, @@ -162,8 +159,6 @@ def get_plethora_dataset( label_paths=gt_paths, label_key="data", patch_shape=patch_shape, - raw_transform=raw_trafo, - label_transform=label_trafo, **kwargs ) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index d4ccc601..b70de65a 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -14,7 +14,10 @@ from skimage.draw import polygon import torch + import torch_em +from torch_em.transform.raw import ConcatTransforms +from torch_em.transform.generic import ResizeInputs try: import gdown @@ -247,6 +250,33 @@ def add_instance_label_transform( return kwargs, label_dtype +def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None): + """ + Checks for raw_transform and label_transform incoming values. + If yes, it will automatically merge these two transforms to apply them together. + """ + if resize_inputs: + assert isinstance(resize_kwargs, dict) + patch_shape = None + + raw_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_rgb=resize_kwargs["is_rgb"]) + label_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_label=True) + + if "raw_transform" in kwargs: + trafo = ConcatTransforms(transform1=kwargs["raw_transform"], transform2=raw_trafo) + kwargs["raw_transform"] = trafo + else: + kwargs["raw_transform"] = raw_trafo + + if "label_transform" in kwargs: + trafo = ConcatTransforms(transform1=kwargs["raw_transform"], transform2=label_trafo) + kwargs["label_transform"] = trafo + else: + kwargs["label_transform"] = label_trafo + + return kwargs, patch_shape + + def generate_labeled_array_from_xml(shape, xml_file): """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index cd43c379..fb548c80 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -227,3 +227,19 @@ def get_default_mean_teacher_augmentations( augmentation1=aug1, augmentation2=aug2 ) + + +# The functionality below merges incoming and requested transforms to work together. +class ConcatTransforms: + def __init__(self, transform1=None, transform2=None): + self.transform1 = transform1 + self.transform2 = transform2 + + def __call__(self, inputs): + if self.transform1 is not None: + inputs = self.transform1(inputs) + + if self.transform2 is not None: + inputs = self.transform2(inputs) + + return inputs From d91eeb27698673e865a0794892f5f855f665a18a Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 23 May 2024 17:23:24 +0200 Subject: [PATCH 10/11] Replace with existing functionality to use multiple trafo --- torch_em/data/datasets/util.py | 10 +++++----- torch_em/transform/raw.py | 16 ---------------- 2 files changed, 5 insertions(+), 21 deletions(-) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index b70de65a..4ee52bc2 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -16,8 +16,8 @@ import torch import torch_em -from torch_em.transform.raw import ConcatTransforms -from torch_em.transform.generic import ResizeInputs +from torch_em.transform import get_raw_transform +from torch_em.transform.generic import ResizeInputs, Compose try: import gdown @@ -263,13 +263,13 @@ def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kw label_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_label=True) if "raw_transform" in kwargs: - trafo = ConcatTransforms(transform1=kwargs["raw_transform"], transform2=raw_trafo) + trafo = Compose([kwargs["raw_transform"], raw_trafo]) kwargs["raw_transform"] = trafo else: - kwargs["raw_transform"] = raw_trafo + kwargs["raw_transform"] = Compose([get_raw_transform(), raw_trafo]) if "label_transform" in kwargs: - trafo = ConcatTransforms(transform1=kwargs["raw_transform"], transform2=label_trafo) + trafo = Compose(transform1=kwargs["raw_transform"], transform2=label_trafo) kwargs["label_transform"] = trafo else: kwargs["label_transform"] = label_trafo diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index fb548c80..cd43c379 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -227,19 +227,3 @@ def get_default_mean_teacher_augmentations( augmentation1=aug1, augmentation2=aug2 ) - - -# The functionality below merges incoming and requested transforms to work together. -class ConcatTransforms: - def __init__(self, transform1=None, transform2=None): - self.transform1 = transform1 - self.transform2 = transform2 - - def __call__(self, inputs): - if self.transform1 is not None: - inputs = self.transform1(inputs) - - if self.transform2 is not None: - inputs = self.transform2(inputs) - - return inputs From e45dfbd0020e46699891d6847b4ef0a95de6dad6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 24 May 2024 20:38:10 +0200 Subject: [PATCH 11/11] Apply suggestions from code review --- torch_em/data/datasets/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 4ee52bc2..e0f8887a 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -263,13 +263,13 @@ def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kw label_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_label=True) if "raw_transform" in kwargs: - trafo = Compose([kwargs["raw_transform"], raw_trafo]) + trafo = Compose(raw_trafo, kwargs["raw_transform"]) kwargs["raw_transform"] = trafo else: - kwargs["raw_transform"] = Compose([get_raw_transform(), raw_trafo]) + kwargs["raw_transform"] = Compose(raw_trafo, get_raw_transform()) if "label_transform" in kwargs: - trafo = Compose(transform1=kwargs["raw_transform"], transform2=label_trafo) + trafo = Compose(label_trafo, kwargs["label_transform"]) kwargs["label_transform"] = trafo else: kwargs["label_transform"] = label_trafo