Skip to content

Commit

Permalink
Add OSIC Pulmonary Fibrosis dataset (#281)
Browse files Browse the repository at this point in the history
Add OSIC Pulmonary Fibrosis dataset
---------

Co-authored-by: Constantin Pape <constantin.pape@embl.de>
  • Loading branch information
anwai98 and constantinpape committed Jun 4, 2024
1 parent 2a2f9ec commit 901fe7d
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 2 deletions.
38 changes: 38 additions & 0 deletions scripts/datasets/medical/check_osic_pulmofib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_osic_pulmofib_loader


ROOT = "/media/anwai/ANWAI/data/osic_pulmofib"


def check_osic_pulmofib():
loader = get_osic_pulmofib_loader(
path=ROOT,
patch_shape=(1, 512, 512),
batch_size=2,
resize_inputs=False,
download=False,
)

check_loader(loader, 8)


def visualize_data():
import os
from glob import glob

import nrrd
import napari

all_volume_paths = sorted(glob(os.path.join(ROOT, "nrrd_heart", "*", "*")))
for vol_path in all_volume_paths:
vol, header = nrrd.read(vol_path)

v = napari.Viewer()
v.add_image(vol.transpose(2, 0, 1))
napari.run()


if __name__ == "__main__":
# visualize_data()
check_osic_pulmofib()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .camus import get_camus_dataset, get_camus_loader
from .drive import get_drive_dataset, get_drive_loader
from .msd import get_msd_dataset, get_msd_loader
from .osic_pulmofib import get_osic_pulmofib_dataset, get_osic_pulmofib_loader
from .papila import get_papila_dataset, get_papila_loader
from .plethora import get_plethora_dataset, get_plethora_loader
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
Expand Down
163 changes: 163 additions & 0 deletions torch_em/data/datasets/medical/osic_pulmofib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import os
from glob import glob
from tqdm import tqdm
from pathlib import Path
from natsort import natsorted
from typing import Union, Tuple

import json
import nrrd
import numpy as np
import nibabel as nib
import pydicom as dicom

import torch_em

from .. import util


ORGAN_IDS = {"heart": 1, "lung": 2, "trachea": 3}


def get_osic_pulmofib_data(path, download):
os.makedirs(path, exist_ok=True)

data_dir = os.path.join(path, "data")
if os.path.exists(data_dir):
return data_dir

# download the data first
zip_path = os.path.join(path, "osic-pulmonary-fibrosis-progression.zip")
util.download_source_kaggle(
path=path, dataset_name="osic-pulmonary-fibrosis-progression", download=download, competition=True
)
util.unzip(zip_path=zip_path, dst=data_dir, remove=False)

# download the ground truth next
zip_path = os.path.join(path, "ct-lung-heart-trachea-segmentation.zip")
util.download_source_kaggle(
path=path, dataset_name="sandorkonya/ct-lung-heart-trachea-segmentation", download=download
)
util.unzip(zip_path=zip_path, dst=data_dir)

return data_dir


def _get_osic_pulmofib_paths(path, download):
data_dir = get_osic_pulmofib_data(path=path, download=download)

image_dir = os.path.join(data_dir, "preprocessed", "images")
gt_dir = os.path.join(data_dir, "preprocessed", "ground_truth")

os.makedirs(image_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

cpath = os.path.join(data_dir, "preprocessed", "confirmer.json")
_completed_preproc = os.path.exists(cpath)

image_paths, gt_paths = [], []
uid_paths = natsorted(glob(os.path.join(data_dir, "train", "*")))
for uid_path in tqdm(uid_paths):
uid = uid_path.split("/")[-1]

image_path = os.path.join(image_dir, f"{uid}.nii.gz")
gt_path = os.path.join(gt_dir, f"{uid}.nii.gz")

if _completed_preproc:
if os.path.exists(image_path) and os.path.exists(gt_path):
image_paths.append(image_path)
gt_paths.append(gt_path)

continue

# creating the volume out of individual dicom slices
all_slices = []
for slice_path in natsorted(glob(os.path.join(uid_path, "*.dcm"))):
per_slice = dicom.dcmread(slice_path)
per_slice = per_slice.pixel_array
all_slices.append(per_slice)
all_slices = np.stack(all_slices).transpose(1, 2, 0)

# next, combining the semantic organ annotations into one ground-truth volume with specific semantic labels
all_gt = np.zeros(all_slices.shape, dtype="uint8")
for ann_path in glob(os.path.join(data_dir, "*", "*", f"{uid}_*.nrrd")):
ann_organ = Path(ann_path).stem.split("_")[-1]
if ann_organ == "noisy":
continue

per_gt, _ = nrrd.read(ann_path)
per_gt = per_gt.transpose(1, 0, 2)

# some organ anns have weird dimension mismatch, we don't consider them for simplicity
if per_gt.shape == all_slices.shape:
all_gt[per_gt > 0] = ORGAN_IDS[ann_organ]

# only if the volume has any labels (some volumes do not have segmentations), we save those raw and gt volumes
if len(np.unique(all_gt)) > 1:
all_gt = np.flip(all_gt, axis=2)

image_nifti = nib.Nifti2Image(all_slices, np.eye(4))
gt_nifti = nib.Nifti2Image(all_gt, np.eye(4))

nib.save(image_nifti, image_path)
nib.save(gt_nifti, gt_path)

image_paths.append(image_path)
gt_paths.append(gt_path)

if not _completed_preproc:
# since we do not have segmentation for all volumes, we store a file which reflects aggrement of created dataset
confirm_msg = "The dataset has been preprocessed. "
confirm_msg += f"It has {len(image_paths)} volume and {len(gt_paths)} respective ground-truth."
print(confirm_msg)

with open(cpath, "w") as f:
json.dump(confirm_msg, f)

return image_paths, gt_paths


def get_osic_pulmofib_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, ...],
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataset for segmentation of lung, heart and trachea in CT scans.
This dataset is from OSIC Pulmonary Fibrosis Progression Challenge:
- https://www.kaggle.com/c/osic-pulmonary-fibrosis-progression/data (dataset source)
- https://www.kaggle.com/datasets/sandorkonya/ct-lung-heart-trachea-segmentation (segmentation source)
Please cite it if you use this dataset for a publication.
"""
image_paths, gt_paths = _get_osic_pulmofib_paths(path=path, download=download)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=gt_paths,
label_key="data",
patch_shape=patch_shape,
**kwargs
)

return dataset


def get_osic_pulmofib_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, ...],
batch_size: int,
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataloader for segmentation of lung, heart and trachea in CT scans. See `get_osic_pulmofib_dataset` for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_osic_pulmofib_dataset(
path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
8 changes: 6 additions & 2 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def download_source_empiar(path, access_id, download):
return download_path


def download_source_kaggle(path, dataset_name, download):
def download_source_kaggle(path, dataset_name, download, competition=False):
if not download:
raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")

Expand All @@ -168,7 +168,11 @@ def download_source_kaggle(path, dataset_name, download):

api = KaggleApi()
api.authenticate()
api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)

if competition:
api.competition_download_files(competition=dataset_name, path=path, quiet=False)
else:
api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)


def download_source_tcia(path, url, dst, csv_filename, download):
Expand Down

0 comments on commit 901fe7d

Please sign in to comment.