Skip to content

Commit

Permalink
Add Plethora dataset (#263)
Browse files Browse the repository at this point in the history
Add support for tcia dataloader, update transform resizing and add plethora dataset

---------

Co-authored-by: Constantin Pape <c.pape@gmx.net>
  • Loading branch information
anwai98 and constantinpape committed May 24, 2024
1 parent e973f05 commit e99d01a
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 6 deletions.
24 changes: 24 additions & 0 deletions scripts/datasets/medical/check_plethora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_plethora_loader


ROOT = "/media/anwai/ANWAI/data/plethora"


def check_plethora():
loader = get_plethora_loader(
path=ROOT,
task="thoracic",
patch_shape=(1, 512, 512),
batch_size=2,
resize_inputs=True,
download=True,
sampler=MinInstanceSampler(),
)

check_loader(loader, 8)


if __name__ == "__main__":
check_plethora()
77 changes: 77 additions & 0 deletions scripts/datasets/medical/check_tcia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import requests
from glob import glob
from natsort import natsorted

import numpy as np
import pandas as pd
import nibabel as nib
import pydicom as dicom

from tcia_utils import nbia


ROOT = "/media/anwai/ANWAI/data/tmp/"

TCIA_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/NSCLC-Radiomics-OriginalCTs.tcia"


def check_tcia(download):
trg_path = os.path.join(ROOT, os.path.split(TCIA_URL)[-1])
if download:
# output = nbia.getSeries(collection="LIDC-IDRI")
# nbia.downloadSeries(output, number=3, path=ROOT)

manifest = requests.get(TCIA_URL)
with open(trg_path, 'wb') as f:
f.write(manifest.content)

nbia.downloadSeries(
series_data=trg_path, input_type="manifest", number=3, path=ROOT, csv_filename="save"
)

df = pd.read_csv("save.csv")

all_patient_dirs = glob(os.path.join(ROOT, "*"))
for patient_dir in all_patient_dirs:
patient_id = os.path.split(patient_dir)[-1]
if not patient_id.startswith("1.3"):
continue

subject_id = pd.Series.to_string(df.loc[df["Series UID"] == patient_id]["Subject ID"])[-9:]
seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*_primary_reviewer.nii.gz"))[0]
gt = nib.load(seg_path)
gt = gt.get_fdata()
gt = gt.transpose(2, 1, 0)
gt = np.flip(gt, axis=(0, 1))

all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm")))
samples = []
for dcm_fpath in all_dicom_files:
file = dicom.dcmread(dcm_fpath)
img = file.pixel_array
samples.append(img)

samples = np.stack(samples)

import napari

v = napari.Viewer()
v.add_image(samples)
v.add_labels(gt.astype("uint64"))
napari.run()


def _test_me():
data = nbia.getSeries(collection="Soft-tissue-Sarcoma")
print(data)

nbia.downloadSeries(data, number=3)

seriesUid = "1.3.6.1.4.1.14519.5.2.1.5168.1900.104193299251798317056218297018"
nbia.viewSeries(seriesUid)


if __name__ == "__main__":
# _test_me()
check_tcia(download=True)
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .camus import get_camus_dataset, get_camus_loader
from .drive import get_drive_dataset, get_drive_loader
from .papila import get_papila_dataset, get_papila_loader
from .plethora import get_plethora_dataset, get_plethora_loader
from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader
182 changes: 182 additions & 0 deletions torch_em/data/datasets/medical/plethora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
from glob import glob
from tqdm import tqdm
from pathlib import Path
from natsort import natsorted
from typing import Union, Tuple
from urllib.parse import urljoin

import numpy as np
import pandas as pd
import nibabel as nib
import pydicom as dicom

import torch_em

from .. import util


BASE_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/"


URL = {
"image": urljoin(BASE_URL, "NSCLC-Radiomics-OriginalCTs.tcia"),
"gt": {
"thoracic": urljoin(
BASE_URL, "PleThora%20Thoracic_Cavities%20June%202020.zip?version=1&modificationDate=1593202695428&api=v2"
),
"pleural_effusion": urljoin(
BASE_URL, "PleThora%20Effusions%20June%202020.zip?version=1&modificationDate=1593202778373&api=v2"
)
}
}


CHECKSUMS = {
"image": None,
"gt": {
"thoracic": "6dfcb60e46c7b0ccf240bc5d13acb1c45c8d2f4922223f7b2fbd5e37acff2be0",
"pleural_effusion": "5dd07c327fb5723c5bbb48f2a02d7f365513d3ad136811fbe4def330ef2d7f6a"
}
}


ZIPFILES = {
"thoracic": "thoracic.zip",
"pleural_effusion": "pleural_effusion.zip"
}


def get_plethora_data(path, task, download):
os.makedirs(path, exist_ok=True)

image_dir = os.path.join(path, "data", "images")
gt_dir = os.path.join(path, "data", "gt", "Thoracic_Cavities" if task == "thoracic" else "Effusions")
csv_path = os.path.join(path, "plethora_images")
if os.path.exists(image_dir) and os.path.exists(gt_dir):
return image_dir, gt_dir, Path(csv_path).with_suffix(".csv")

# let's download dicom files from the tcia manifest
tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia")
util.download_source_tcia(path=tcia_path, url=URL["image"], dst=image_dir, csv_filename=csv_path, download=download)

# let's download the segmentations from zipfiles
zip_path = os.path.join(path, ZIPFILES[task])
util.download_source(
path=zip_path, url=URL["gt"][task], download=download, checksum=CHECKSUMS["gt"][task]
)
util.unzip(zip_path=zip_path, dst=os.path.join(path, "data", "gt"))

return image_dir, gt_dir, Path(csv_path).with_suffix(".csv")


def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path):
df = pd.read_csv(csv_path)

task_gt_dir = os.path.join(gt_dir, )

os.makedirs(os.path.join(image_dir, "preprocessed"), exist_ok=True)
os.makedirs(os.path.join(task_gt_dir, "preprocessed"), exist_ok=True)

# let's get all the series uid of the volumes downloaded and spot their allocated subject id
all_series_uid_dirs = glob(os.path.join(image_dir, "1.3*"))
image_paths, gt_paths = [], []
for series_uid_dir in tqdm(all_series_uid_dirs):
series_uid = os.path.split(series_uid_dir)[-1]
subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:]

try:
gt_path = glob(os.path.join(task_gt_dir, subject_id, "*.nii.gz"))[0]
except IndexError:
# - some patients do not have "Thoracic_Cavities" segmentation
print(f"The ground truth is missing for subject '{subject_id}'")
continue

assert os.path.exists(gt_path)

vol_path = os.path.join(image_dir, "preprocessed", f"{subject_id}.nii.gz")
neu_gt_path = os.path.join(task_gt_dir, "preprocessed", os.path.split(gt_path)[-1])

image_paths.append(vol_path)
gt_paths.append(neu_gt_path)
if os.path.exists(vol_path) and os.path.exists(neu_gt_path):
continue

# the individual slices for the inputs need to be merged into one volume.
if not os.path.exists(vol_path):
all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm")))
all_slices = []
for dcm_path in all_dcm_slices:
dcmfile = dicom.dcmread(dcm_path)
img = dcmfile.pixel_array
all_slices.append(img)

volume = np.stack(all_slices)
volume = volume.transpose(1, 2, 0)
nii_vol = nib.Nifti1Image(volume, np.eye(4))
nii_vol.header.get_xyzt_units()
nii_vol.to_filename(vol_path)

# the ground truth needs to be aligned as the inputs, let's take care of that.
gt = nib.load(gt_path)
gt = gt.get_fdata()
gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs
gt = np.flip(gt, axis=(0, 1))

gt = gt.transpose(1, 2, 0)
gt_nii_vol = nib.Nifti1Image(gt, np.eye(4))
gt_nii_vol.header.get_xyzt_units()
gt_nii_vol.to_filename(neu_gt_path)

return image_paths, gt_paths


def _get_plethora_paths(path, task, download):
image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download)
image_paths, gt_paths = _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path)
return image_paths, gt_paths


def get_plethora_dataset(
path: Union[os.PathLike, str],
task: str,
patch_shape: Tuple[int, ...],
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download)

if resize_inputs:
resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs
)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=gt_paths,
label_key="data",
patch_shape=patch_shape,
**kwargs
)

return dataset


def get_plethora_loader(
path: Union[os.PathLike, str],
task: str,
patch_shape: Tuple[int, ...],
batch_size: int,
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_plethora_dataset(
path=path, task=task, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
64 changes: 58 additions & 6 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import inspect
import os
import hashlib
import zipfile
import numpy as np
import inspect
import requests
from tqdm import tqdm
from warnings import warn
from xml.dom import minidom
from shutil import copyfileobj, which
from subprocess import run
from packaging import version
from shutil import copyfileobj, which

import zipfile
import numpy as np
from xml.dom import minidom
from skimage.draw import polygon

import torch

import torch_em
import requests
from torch_em.transform import get_raw_transform
from torch_em.transform.generic import ResizeInputs, Compose

try:
import gdown
except ImportError:
gdown = None

try:
from tcia_utils import nbia
except ModuleNotFoundError:
nbia = None


BIOIMAGEIO_IDS = {
"covid_if": "ilastik/covid_if_training_data",
Expand Down Expand Up @@ -163,6 +171,23 @@ def download_source_kaggle(path, dataset_name, download):
api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)


def download_source_tcia(path, url, dst, csv_filename, download):
if not download:
raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")

assert url.endswith(".tcia"), f"{path} is not a TCIA Manifest."

# downloads the manifest file from the collection page
manifest = requests.get(url=url)
with open(path, "wb") as f:
f.write(manifest.content)

# this part extracts the UIDs from the manigests and downloads them.
nbia.downloadSeries(
series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename,
)


def update_kwargs(kwargs, key, value, msg=None):
if key in kwargs:
msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
Expand Down Expand Up @@ -225,6 +250,33 @@ def add_instance_label_transform(
return kwargs, label_dtype


def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None):
"""
Checks for raw_transform and label_transform incoming values.
If yes, it will automatically merge these two transforms to apply them together.
"""
if resize_inputs:
assert isinstance(resize_kwargs, dict)
patch_shape = None

raw_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_rgb=resize_kwargs["is_rgb"])
label_trafo = ResizeInputs(target_shape=resize_kwargs["patch_shape"], is_label=True)

if "raw_transform" in kwargs:
trafo = Compose(raw_trafo, kwargs["raw_transform"])
kwargs["raw_transform"] = trafo
else:
kwargs["raw_transform"] = Compose(raw_trafo, get_raw_transform())

if "label_transform" in kwargs:
trafo = Compose(label_trafo, kwargs["label_transform"])
kwargs["label_transform"] = trafo
else:
kwargs["label_transform"] = label_trafo

return kwargs, patch_shape


def generate_labeled_array_from_xml(shape, xml_file):
"""Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb
Expand Down

0 comments on commit e99d01a

Please sign in to comment.