Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoNuSAC Dataset #158

Merged
merged 10 commits into from
Oct 23, 2023
33 changes: 33 additions & 0 deletions scripts/datasets/check_monusac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_monusac_loader


MONUSAC_ROOT = "/scratch/usr/nimanwai/data/monusac"


def check_monusac():
train_loader = get_monusac_loader(
path=MONUSAC_ROOT,
download=True,
patch_shape=(512, 512),
batch_size=2,
split="train",
organ_type=["breast", "lung"]
)
print("Length of train loader: ", len(train_loader))
check_loader(train_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monusac_train.png")

test_loader = get_monusac_loader(
path=MONUSAC_ROOT,
download=True,
patch_shape=(512, 512),
batch_size=1,
split="test",
organ_type=["breast", "prostate"]
)
print("Length of test loader: ", len(test_loader))
check_loader(test_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monusac_test.png")


if __name__ == "__main__":
check_monusac()
1 change: 1 addition & 0 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .lucchi import get_lucchi_loader, get_lucchi_dataset
from .mitoem import get_mitoem_loader, get_mitoem_dataset
from .monuseg import get_monuseg_loader, get_monuseg_dataset
from .monusac import get_monusac_loader, get_monusac_dataset
from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset
from .neurips_cell_seg import (
get_neurips_cellseg_supervised_loader, get_neurips_cellseg_supervised_dataset,
Expand Down
180 changes: 180 additions & 0 deletions torch_em/data/datasets/monusac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
import shutil
from glob import glob
from tqdm import tqdm
from pathlib import Path
from typing import Optional, List

import imageio.v3 as imageio

import torch_em
from . import util


URL = {
"train": "https://drive.google.com/uc?export=download&id=1lxMZaAPSpEHLSxGA9KKMt_r-4S8dwLhq",
"test": "https://drive.google.com/uc?export=download&id=1G54vsOdxWY1hG7dzmkeK3r0xz9s-heyQ"
}


CHECKSUM = {
"train": "5b7cbeb34817a8f880d3fddc28391e48d3329a91bf3adcbd131ea149a725cd92",
"test": "bcbc38f6bf8b149230c90c29f3428cc7b2b76f8acd7766ce9fc908fc896c2674"
}

# here's the description: https://drive.google.com/file/d/1kdOl3s6uQBRv0nToSIf1dPuceZunzL4N/view
ORGAN_SPLITS = {
"train": {
"lung": ["TCGA-55-1594", "TCGA-69-7760", "TCGA-69-A59K", "TCGA-73-4668", "TCGA-78-7220",
"TCGA-86-7713", "TCGA-86-8672", "TCGA-L4-A4E5", "TCGA-MP-A4SY", "TCGA-MP-A4T7"],
"kidney": ["TCGA-5P-A9K0", "TCGA-B9-A44B", "TCGA-B9-A8YI", "TCGA-DW-7841", "TCGA-EV-5903", "TCGA-F9-A97G",
"TCGA-G7-A8LD", "TCGA-MH-A560", "TCGA-P4-AAVK", "TCGA-SX-A7SR", "TCGA-UZ-A9PO", "TCGA-UZ-A9PU"],
"breast": ["TCGA-A2-A0CV", "TCGA-A2-A0ES", "TCGA-B6-A0WZ", "TCGA-BH-A18T", "TCGA-D8-A1X5",
"TCGA-E2-A154", "TCGA-E9-A22B", "TCGA-E9-A22G", "TCGA-EW-A6SD", "TCGA-S3-AA11"],
"prostate": ["TCGA-EJ-5495", "TCGA-EJ-5505", "TCGA-EJ-5517", "TCGA-G9-6342", "TCGA-G9-6499",
"TCGA-J4-A67Q", "TCGA-J4-A67T", "TCGA-KK-A59X", "TCGA-KK-A6E0", "TCGA-KK-A7AW",
"TCGA-V1-A8WL", "TCGA-V1-A9O9", "TCGA-X4-A8KQ", "TCGA-YL-A9WY"]
},
"test": {
"lung": ["TCGA-49-6743", "TCGA-50-6591", "TCGA-55-7570", "TCGA-55-7573",
"TCGA-73-4662", "TCGA-78-7152", "TCGA-MP-A4T7"],
"kidney": ["TCGA-2Z-A9JG", "TCGA-2Z-A9JN", "TCGA-DW-7838", "TCGA-DW-7963",
"TCGA-F9-A8NY", "TCGA-IZ-A6M9", "TCGA-MH-A55W"],
"breast": ["TCGA-A2-A04X", "TCGA-A2-A0ES", "TCGA-D8-A3Z6", "TCGA-E2-A108", "TCGA-EW-A6SB"],
"prostate": ["TCGA-G9-6356", "TCGA-G9-6367", "TCGA-VP-A87E", "TCGA-VP-A87H", "TCGA-X4-A8KS", "TCGA-YL-A9WL"]
},
}


def _download_monusac(path, download, split):
assert split in ["train", "test"], "Please choose from train/test"

# check if we have extracted the images and labels already
im_path = os.path.join(path, "images", split)
label_path = os.path.join(path, "labels", split)
if os.path.exists(im_path) and os.path.exists(label_path):
return

os.makedirs(path, exist_ok=True)
zip_path = os.path.join(path, f"monusac_{split}.zip")
util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split])

_process_monusac(path, split)

_check_channel_consistency(path, split)


def _check_channel_consistency(path, split):
"The provided tif images have RGBA channels, check and remove the alpha channel"
all_image_path = glob(os.path.join(path, "images", split, "*.tif"))
for image_path in all_image_path:
image = imageio.imread(image_path)
assert image.shape[-1] == 4, f"Image has an unexpected shape: {image.shape}"
rgb_image = image[..., :-1] # get rid of the alpha channel
imageio.imwrite(image_path, rgb_image)


def _process_monusac(path, split):
util.unzip(os.path.join(path, f"monusac_{split}.zip"), path)

# assorting the images into expected dir;
# converting the label xml files to numpy arrays (of same dimension as input images) in the expected dir
root_img_save_dir = os.path.join(path, "images", split)
root_label_save_dir = os.path.join(path, "labels", split)

os.makedirs(root_img_save_dir, exist_ok=True)
os.makedirs(root_label_save_dir, exist_ok=True)

all_patient_dir = sorted(glob(os.path.join(path, "MoNuSAC*", "*")))

for patient_dir in tqdm(all_patient_dir, desc=f"Converting {split} inputs for all patients"):
all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))
all_xml_label_dir = sorted(glob(os.path.join(patient_dir, "*.xml")))

if len(all_img_dir) != len(all_xml_label_dir):
_convert_missing_tif_from_svs(patient_dir)
all_img_dir = sorted(glob(os.path.join(patient_dir, "*.tif")))

assert len(all_img_dir) == len(all_xml_label_dir)

for img_path, xml_label_path in zip(all_img_dir, all_xml_label_dir):
desired_label_shape = imageio.imread(img_path).shape[:-1]

img_id = os.path.split(img_path)[-1]
dst = os.path.join(root_img_save_dir, img_id)
shutil.move(src=img_path, dst=dst)

_label = util.generate_labeled_array_from_xml(shape=desired_label_shape, xml_file=xml_label_path)
_fileid = img_id.split(".")[0]
imageio.imwrite(os.path.join(root_label_save_dir, f"{_fileid}.tif"), _label)

shutil.rmtree(glob(os.path.join(path, "MoNuSAC*"))[0])


def _convert_missing_tif_from_svs(patient_dir):
"""This function activates when we see some missing tiff inputs (and converts svs to tiff)

Cause: Happens only in the test split, maybe while converting the data, some were missed
Fix: We have the original svs scans. We convert the svs scans to tiff
"""
all_svs_dir = sorted(glob(os.path.join(patient_dir, "*.svs")))
for svs_path in all_svs_dir:
save_tif_path = os.path.splitext(svs_path)[0] + ".tif"
if not os.path.exists(save_tif_path):
img_array = util.convert_svs_to_array(svs_path)
# the array from svs scans are supposed to be RGB images
assert img_array.shape[-1] == 3
imageio.imwrite(save_tif_path, img_array)


def get_patient_id(path, split_wrt="-01Z-00-"):
"""Gets us the patient id in the expected format
Input Names: "TCGA-<XX>-<XXXX>-01z-00-DX<X>-(<X>, <00X>).tif" (example: TCGA-2Z-A9JG-01Z-00-DX1_1.tif)
Expected: "TCGA-<XX>-<XXXX>" (example: TCGA-2Z-A9JG)
"""
patient_image_id = Path(path).stem
patient_id = patient_image_id.split(split_wrt)[0]
return patient_id


def get_monusac_dataset(
path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False,
offsets=None, boundaries=False, binary=False, **kwargs
):
"""Dataset from https://monusac-2020.grand-challenge.org/Data/
"""
_download_monusac(path, download, split)

image_paths = sorted(glob(os.path.join(path, "images", split, "*")))
label_paths = sorted(glob(os.path.join(path, "labels", split, "*")))

if organ_type is not None:
# get all patients for multiple organ selection
all_organ_splits = sum([ORGAN_SPLITS[split][o] for o in organ_type], [])

image_paths = [_path for _path in image_paths if get_patient_id(_path) in all_organ_splits]
label_paths = [_path for _path in label_paths if get_patient_id(_path) in all_organ_splits]

assert len(image_paths) == len(label_paths)

kwargs, _ = util.add_instance_label_transform(
kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
)
return torch_em.default_segmentation_dataset(
image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs
)


def get_monusac_loader(
path, patch_shape, split, batch_size, organ_type=None, download=False,
offsets=None, boundaries=False, binary=False, **kwargs
):
ds_kwargs, loader_kwargs = util.split_kwargs(
torch_em.default_segmentation_dataset, **kwargs
)
dataset = get_monusac_dataset(
path, patch_shape, split, organ_type=organ_type, download=download,
offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader
29 changes: 29 additions & 0 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,32 @@ def generate_labeled_array_from_xml(shape, xml_file):
r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
mask[r, c] = i
return mask


def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None):
"""Converts .svs files to numpy array format

Argument:
- path: [str] - Path to the svs file
(below mentioned arguments are used for multi-resolution images)
- location: tuple[int, int] - pixel location (x, y) in level 0 of the image (default: (0, 0))
- level: [int] - target level used to read the image (default: 0)
- img_size: tuple[int, int] - expected size of the image (default: None -> obtains the original shape at the expected level)

Returns:
the image as numpy array

TODO: it can be extended to convert WSIs (or modalities with multiple resolutions)
"""
assert path.endswith(".svs"), f"The provided file ({path}) isn't in svs format"

from tiffslide import TiffSlide

_slide = TiffSlide(path)

if img_size is None:
img_size = _slide.level_dimensions[0]

img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True)

return img_arr
15 changes: 11 additions & 4 deletions torch_em/data/image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,18 @@ def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channel
if have_raw_channels and channel_first:
shape = shape[1:]
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
if have_raw_channels or have_label_channels:
raise NotImplementedError("Padding is not implemented for data with channels")
assert len(shape) == len(self.patch_shape)
pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
raw, labels = np.pad(raw, pw), np.pad(labels, pw)

if have_raw_channels and channel_first:
pw_raw = [(0, 0), *pw]
elif have_raw_channels and not channel_first:
pw_raw = [*pw, (0, 0)]
else:
pw_raw = pw

# TODO: ensure padding for labels with channels, when supported (see `_get_sample` below)

raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw)
return raw, labels

def _get_sample(self, index):
Expand Down
27 changes: 21 additions & 6 deletions torch_em/data/raw_image_collection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,21 @@ def _sample_bounding_box(self, shape):
]
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))

def _ensure_patch_shape(self, raw, have_raw_channels):
def _ensure_patch_shape(self, raw, have_raw_channels, channel_first):
shape = raw.shape
if have_raw_channels and channel_first:
shape = shape[1:]
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
if have_raw_channels:
raise NotImplementedError("Padding is not implemented for data with channels")
assert len(shape) == len(self.patch_shape)
pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)]
raw = np.pad(raw, pw)

if have_raw_channels and channel_first:
pw_raw = [(0, 0), *pw]
elif have_raw_channels and not channel_first:
pw_raw = [*pw, (0, 0)]
else:
pw_raw = pw

raw = np.pad(raw, pw_raw)
return raw

def _get_sample(self, index):
Expand All @@ -94,7 +101,15 @@ def _get_sample(self, index):
raw = load_image(self.raw_images[index])
have_raw_channels = raw.ndim == 3

raw = self._ensure_patch_shape(raw, have_raw_channels)
# We determine if the image has channels as the first or last axis based on the array shape.
# This will work only for images with less than 16 channels!
# If the last axis has a length smaller than 16 we assume that it is the channel axis,
# otherwise we assume it is a spatial axis and that the first axis is the channel axis.
channel_first = None
if have_raw_channels:
channel_first = raw.shape[-1] > 16

raw = self._ensure_patch_shape(raw, have_raw_channels, channel_first)

shape = raw.shape
# we assume images are loaded with channel last!
Expand Down