Skip to content

Commit

Permalink
Add Seg-A dataset (#276)
Browse files Browse the repository at this point in the history
Add Seg-A dataset
  • Loading branch information
anwai98 committed Jun 5, 2024
1 parent 871f7db commit 8f40ebe
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 0 deletions.
25 changes: 25 additions & 0 deletions scripts/datasets/medical/check_sega.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader
from torch_em.data.datasets.medical import get_sega_loader


ROOT = "/media/anwai/ANWAI/data/sega"


def check_sega():
loader = get_sega_loader(
path=ROOT,
patch_shape=(1, 512, 512),
batch_size=2,
ndim=2,
data_choice="KiTS",
resize_inputs=True,
download=True,
sampler=MinInstanceSampler(),
)

check_loader(loader, 8)


if __name__ == "__main__":
check_sega()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
from .papila import get_papila_dataset, get_papila_loader
from .plethora import get_plethora_dataset, get_plethora_loader
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
from .sega import get_sega_dataset, get_sega_loader
from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader
from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader
126 changes: 126 additions & 0 deletions torch_em/data/datasets/medical/sega.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
from glob import glob
from pathlib import Path
from natsort import natsorted
from typing import Union, Tuple, Optional, Literal

import torch_em

from .. import util


URL = {
"kits": "https://figshare.com/ndownloader/files/30950821",
"rider": "https://figshare.com/ndownloader/files/30950914",
"dongyang": "https://figshare.com/ndownloader/files/30950971"
}

CHECKSUMS = {
"kits": "6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e",
"rider": "7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa",
"dongyang": "0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541"
}

ZIPFILES = {
"kits": "KiTS.zip",
"rider": "Rider.zip",
"dongyang": "Dongyang.zip"
}


def get_sega_data(path, data_choice, download):
os.makedirs(path, exist_ok=True)

data_choice = data_choice.lower()

zip_fid = ZIPFILES[data_choice]

data_dir = os.path.join(path, Path(zip_fid).stem)
if os.path.exists(data_dir):
return data_dir

zip_path = os.path.join(path, zip_fid)
util.download_source(
path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice],
)
util.unzip(zip_path=zip_path, dst=path)

return data_dir


def _get_sega_paths(path, data_choice, download):
if data_choice is None:
data_choices = URL.keys()
else:
if isinstance(data_choice, str):
data_choices = [data_choice]

data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices]

image_paths, gt_paths = [], []
for data_dir in data_dirs:
all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd"))
for volume_path in all_volumes_paths:
if volume_path.endswith(".seg.nrrd"):
gt_paths.append(volume_path)
else:
image_paths.append(volume_path)

return natsorted(image_paths), natsorted(gt_paths)


def get_sega_dataset(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, ...],
data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataset for segmentation of aorta in computed tomography angiography (CTA) scans.
This dataset is from Pepe et al. - https://doi.org/10.1007/978-3-031-53241-2
Please cite it if you use this dataset for a publication.
"""
image_paths, gt_paths = _get_sega_paths(path=path, data_choice=data_choice, download=download)

if resize_inputs:
resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False}
kwargs, patch_shape = util.update_kwargs_for_resize_trafo(
kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs,
)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key=None,
label_paths=gt_paths,
label_key=None,
patch_shape=patch_shape,
**kwargs
)

return dataset


def get_sega_loader(
path: Union[os.PathLike, str],
patch_shape: Tuple[int, ...],
batch_size: int,
data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None,
resize_inputs: bool = False,
download: bool = False,
**kwargs
):
"""Dataloader for segmentation of aorta in CTA scans. See `get_sega_dataset` for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_sega_dataset(
path=path,
patch_shape=patch_shape,
data_choice=data_choice,
resize_inputs=resize_inputs,
download=download,
**ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
5 changes: 5 additions & 0 deletions torch_em/util/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from elf.io import open_file

try:
import imageio.v3 as imageio
except ImportError:
Expand All @@ -14,6 +15,7 @@
except ImportError:
tifffile = None


TIF_EXTS = (".tif", ".tiff")


Expand All @@ -35,6 +37,9 @@ def load_image(image_path, memmap=True):
return tifffile.memmap(image_path, mode="r")
elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"):
return tifffile.imread(image_path)
elif os.path.splitext(image_path)[1].lower() == ".nrrd":
import nrrd
return nrrd.read(image_path)[0]
elif os.path.splitext(image_path)[1].lower() == ".mha":
import SimpleITK as sitk
image = sitk.ReadImage(image_path)
Expand Down

0 comments on commit 8f40ebe

Please sign in to comment.