Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIIM ACR Pneumothorax dataset #256

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions scripts/datasets/check_siim_acr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_siim_acr_loader


ROOT = "/media/anwai/ANWAI/data/siim_acr"


def check_siim_acr():
loader = get_siim_acr_loader(
path=ROOT,
split="train",
patch_shape=(512, 512),
batch_size=2,
download=True,
resize_inputs=True,
sampler=MinInstanceSampler()
)
check_loader(loader, 8)


if __name__ == "__main__":
check_siim_acr()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .autopet import get_autopet_loader
from .btcv import get_btcv_dataset, get_btcv_loader
from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader
99 changes: 99 additions & 0 deletions torch_em/data/datasets/medical/siim_acr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
from glob import glob
from typing import Union, Tuple

import torch_em
from torch_em.transform.generic import ResizeInputs

from .. import util
from ... import ImageCollectionDataset


KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks"
CHECKSUM = "1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538"


def get_siim_acr_data(path, download):
os.makedirs(path, exist_ok=True)

data_dir = os.path.join(path, "siim-acr-pneumothorax")
if os.path.exists(data_dir):
return data_dir

util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download)

zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip")
util._check_checksum(path=zip_path, checksum=CHECKSUM)
util.unzip(zip_path=zip_path, dst=path)

return data_dir


def _get_siim_acr_paths(path, split, download):
data_dir = get_siim_acr_data(path=path, download=download)

assert split in ["train", "test"], f"'{split}' is not a valid split."

image_paths = sorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png")))
gt_paths = sorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png")))

return image_paths, gt_paths


def get_siim_acr_dataset(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
download: bool = False,
resize_inputs: bool = False,
**kwargs
):
"""Dataset for pneumothorax segmentation in CXR.

The database is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data

This dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition:
https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation

Please cite it if you use this dataset for a publication.
"""
image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download)

if resize_inputs:
raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False)
label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True)
patch_shape = None
else:
patch_shape = patch_shape
raw_trafo, label_trafo = None, None

dataset = ImageCollectionDataset(
raw_image_paths=image_paths,
label_image_paths=gt_paths,
patch_shape=patch_shape,
raw_transform=raw_trafo,
label_transform=label_trafo,
**kwargs
)
dataset.max_sampling_attempts = 5000

return dataset


def get_siim_acr_loader(
path: Union[os.PathLike, str],
split: str,
patch_shape: Tuple[int, int],
batch_size: int,
download: bool = False,
resize_inputs: bool = False,
**kwargs
):
"""Dataloader for pneumothorax segmentation in CXR. See `get_siim_acr_dataset` for details.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_siim_acr_dataset(
path=path, split=split, patch_shape=patch_shape, download=download, resize_inputs=resize_inputs, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader
17 changes: 17 additions & 0 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ def download_source_empiar(path, access_id, download):
return download_path


def download_source_kaggle(path, dataset_name, download):
if not download:
raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.")

try:
from kaggle.api.kaggle_api_extended import KaggleApi
except ModuleNotFoundError:
msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. "
msg += "After you have installed kaggle, you would need an API token. "
msg += "Follow the instructions at https://www.kaggle.com/docs/api."
raise ModuleNotFoundError(msg)

api = KaggleApi()
api.authenticate()
api.dataset_download_files(dataset=dataset_name, path=path, quiet=False)


def update_kwargs(kwargs, key, value, msg=None):
if key in kwargs:
msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg
Expand Down
15 changes: 11 additions & 4 deletions torch_em/data/image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(
self.label_images = label_image_paths
self._ndim = 2

assert len(patch_shape) == self._ndim
if patch_shape is not None:
assert len(patch_shape) == self._ndim
self.patch_shape = patch_shape

self.raw_transform = raw_transform
Expand All @@ -95,11 +96,16 @@ def ndim(self):
return self._ndim

def _sample_bounding_box(self, shape):
if self.patch_shape is None:
patch_shape_for_bb = shape
else:
patch_shape_for_bb = self.patch_shape
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

bb_start = [
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
for sh, psh in zip(shape, self.patch_shape)
for sh, psh in zip(shape, patch_shape_for_bb)
]
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb))

def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first):
shape = raw.shape
Expand Down Expand Up @@ -137,7 +143,8 @@ def _load_data(self, raw_path, label_path):
if have_raw_channels:
channel_first = raw.shape[-1] > 16

raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first)
if self.patch_shape is not None:
raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first)
shape = raw.shape

prefix_box = tuple()
Expand Down
24 changes: 23 additions & 1 deletion torch_em/transform/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from skimage.transform import rescale
from skimage.transform import rescale, resize


class Tile(torch.nn.Module):
Expand Down Expand Up @@ -72,6 +72,28 @@ def __call__(self, *inputs):
return outputs


class ResizeInputs:
def __init__(self, target_shape, is_label=False):
self.target_shape = target_shape
self.is_label = is_label

def __call__(self, inputs):
if self.is_label:
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
anti_aliasing = True
else:
anti_aliasing = False

inputs = resize(
image=inputs,
output_shape=self.target_shape,
order=3,
anti_aliasing=anti_aliasing,
preserve_range=True,
)

return inputs


class PadIfNecessary:
def __init__(self, shape):
self.shape = tuple(shape)
Expand Down