Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to MoNuSeg #157

Merged
merged 9 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions scripts/datasets/check_monuseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_monuseg_loader


MONUSEG_ROOT = "/scratch/usr/nimanwai/data/monuseg"


def check_monuseg():
train_loader = get_monuseg_loader(
path=MONUSEG_ROOT,
patch_shape=(512, 512),
batch_size=2,
split="train",
download=True,
organ_type=["colon", "breast"]
)
check_loader(train_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monuseg_train.png")

test_loader = get_monuseg_loader(
path=MONUSEG_ROOT,
patch_shape=(512, 512),
batch_size=1,
split="test",
download=True
)
check_loader(test_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./monuseg_test.png")


if __name__ == "__main__":
check_monuseg()
3 changes: 1 addition & 2 deletions torch_em/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from .lizard import get_lizard_loader, get_lizard_dataset
from .lucchi import get_lucchi_loader, get_lucchi_dataset
from .mitoem import get_mitoem_loader, get_mitoem_dataset
# monuseg is only partially implemented
# from .monuseg import get_monuseg_loader, get_monuseg_dataset
from .monuseg import get_monuseg_loader, get_monuseg_dataset
from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset
from .neurips_cell_seg import (
get_neurips_cellseg_supervised_loader, get_neurips_cellseg_supervised_dataset,
Expand Down
122 changes: 100 additions & 22 deletions torch_em/data/datasets/monuseg.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,135 @@
import os
import shutil
from tqdm import tqdm
from glob import glob
from pathlib import Path
from typing import List, Optional

import imageio.v3 as imageio

import torch_em
from torch_em.data.datasets import util


URL = {
"train": "https://drive.google.com/uc?export=download&id=1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA",
"test": "https://drive.google.com/uc?export=download&id=1NKkSQ5T0ZNQ8aUhh0a8Dt2YKYCQXIViw"
}

CHECKSUM = {
"train": "25d3d3185bb2970b397cafa72eb664c9b4d24294aee382e7e3df9885affce742",
"test": "13e522387ae8b1bcc0530e13ff9c7b4d91ec74959ef6f6e57747368d7ee6f88a"
}

from . import util
# here's the description: https://drive.google.com/file/d/1xYyQ31CHFRnvTCTuuHdconlJCMk2SK7Z/view?usp=sharing
ORGAN_SPLITS = {
"breast": ["TCGA-A7-A13E-01Z-00-DX1", "TCGA-A7-A13F-01Z-00-DX1", "TCGA-AR-A1AK-01Z-00-DX1",
"TCGA-AR-A1AS-01Z-00-DX1", "TCGA-E2-A1B5-01Z-00-DX1", "TCGA-E2-A14V-01Z-00-DX1"],
"kidney": ["TCGA-B0-5711-01Z-00-DX1", "TCGA-HE-7128-01Z-00-DX1", "TCGA-HE-7129-01Z-00-DX1",
"TCGA-HE-7130-01Z-00-DX1", "TCGA-B0-5710-01Z-00-DX1", "TCGA-B0-5698-01Z-00-DX1"],
"liver": ["TCGA-18-5592-01Z-00-DX1", "TCGA-38-6178-01Z-00-DX1", "TCGA-49-4488-01Z-00-DX1",
"TCGA-50-5931-01Z-00-DX1", "TCGA-21-5784-01Z-00-DX1", "TCGA-21-5786-01Z-00-DX1"],
"prostate": ["TCGA-G9-6336-01Z-00-DX1", "TCGA-G9-6348-01Z-00-DX1", "TCGA-G9-6356-01Z-00-DX1",
"TCGA-G9-6363-01Z-00-DX1", "TCGA-CH-5767-01Z-00-DX1", "TCGA-G9-6362-01Z-00-DX1"],
"bladder": ["TCGA-DK-A2I6-01A-01-TS1", "TCGA-G2-A2EK-01A-02-TSB"],
"colon": ["TCGA-AY-A8YK-01A-01-TS1", "TCGA-NH-A8F7-01A-01-TS1"],
"stomach": ["TCGA-KB-A93J-01A-01-TS1", "TCGA-RD-A8N9-01A-01-TS1"]
}

URL = "https://drive.google.com/uc?export=download&id=1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA"
CHECKSUM = ""

def _download_monuseg(path, download, split):
assert split in ["train", "test"], "The split choices in MoNuSeg datset are train/test, please choose from them"

# TODO separate via organ
def _download_monuseg(path, download):
# check if we have extracted the images and labels already
im_path = os.path.join(path, "images")
label_path = os.path.join(path, "labels")
im_path = os.path.join(path, "images", split)
label_path = os.path.join(path, "labels", split)
if os.path.exists(im_path) and os.path.exists(label_path):
return

raise NotImplementedError("Download and post-processing for the monuseg data is not yet implemented.")

os.makedirs(path, exist_ok=True)
zip_path = os.path.join(path, "monuseg.zip")
util.download_source_gdrive(zip_path, URL, download=download, checksum=CHECKSUM)
zip_path = os.path.join(path, f"monuseg_{split}.zip")
util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split])

_process_monuseg(path, split)


def _process_monuseg(path, split):
util.unzip(os.path.join(path, f"monuseg_{split}.zip"), path)

# assorting the images into expected dir;
# converting the label xml files to numpy arrays (of same dimension as input images) in the expected dir
root_img_save_dir = os.path.join(path, "images", split)
root_label_save_dir = os.path.join(path, "labels", split)

# TODO
def _process_monuseg():
pass
os.makedirs(root_img_save_dir, exist_ok=True)
os.makedirs(root_label_save_dir, exist_ok=True)

if split == "train":
all_img_dir = sorted(glob(os.path.join(path, "*", "Tissue*", "*")))
all_xml_label_dir = sorted(glob(os.path.join(path, "*", "Annotations", "*")))
else:
all_img_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.tif")))
all_xml_label_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.xml")))

assert len(all_img_dir) == len(all_xml_label_dir)

for img_path, xml_label_path in tqdm(zip(all_img_dir, all_xml_label_dir),
desc=f"Converting {split} split to the expected format",
total=len(all_img_dir)):
desired_label_shape = imageio.imread(img_path).shape[:-1]

img_id = os.path.split(img_path)[-1]
dst = os.path.join(root_img_save_dir, img_id)
shutil.move(src=img_path, dst=dst)

_label = util.generate_labeled_array_from_xml(shape=desired_label_shape, xml_file=xml_label_path)
_fileid = img_id.split(".")[0]
imageio.imwrite(os.path.join(root_label_save_dir, f"{_fileid}.tif"), _label)

shutil.rmtree(glob(os.path.join(path, "MoNuSeg*"))[0])
if split == "train":
shutil.rmtree(glob(os.path.join(path, "__MACOSX"))[0])


def get_monuseg_dataset(
path, patch_shape, download=False, offsets=None, boundaries=False, binary=False, **kwargs
path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False,
offsets=None, boundaries=False, binary=False, **kwargs
):
_download_monuseg(path, download)
"""Dataset from https://monuseg.grand-challenge.org/Data/
"""
_download_monuseg(path, download, split)

image_paths = sorted(glob(os.path.join(path, "images", split, "*")))
label_paths = sorted(glob(os.path.join(path, "labels", split, "*")))

if split == "train" and organ_type is not None:
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
# get all patients for multiple organ selection
all_organ_splits = sum([ORGAN_SPLITS[_o] for _o in organ_type], [])

image_paths = [_path for _path in image_paths if Path(_path).stem in all_organ_splits]
label_paths = [_path for _path in label_paths if Path(_path).stem in all_organ_splits]

image_path = os.path.join(path, "images")
label_path = os.path.join(path, "labels")
elif split == "test" and organ_type is not None:
# we don't have organ splits in the test dataset
raise ValueError("The test split does not have any organ informations, please pass `organ_type=None`")

kwargs, _ = util.add_instance_label_transform(
kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
)
return torch_em.default_segmentation_dataset(
image_path, "*.tif", label_path, "*.tif", patch_shape, is_seg_dataset=False, **kwargs
image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs
)


# TODO implement selecting organ
def get_monuseg_loader(
path, patch_shape, batch_size, download=False, offsets=None, boundaries=False, binary=False, **kwargs
path, patch_shape, batch_size, split, organ_type=None, download=False, offsets=None, boundaries=False, binary=False,
**kwargs
):
ds_kwargs, loader_kwargs = util.split_kwargs(
torch_em.default_segmentation_dataset, **kwargs
)
dataset = get_monuseg_dataset(
path, patch_shape, download=download,
path, patch_shape, split, organ_type=organ_type, download=download,
offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
Expand Down
50 changes: 48 additions & 2 deletions torch_em/data/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
import os
import hashlib
import zipfile
from shutil import copyfileobj
import numpy as np
from tqdm import tqdm
from warnings import warn
from xml.dom import minidom
from shutil import copyfileobj

from skimage.draw import polygon

import torch
import torch_em
Expand All @@ -14,7 +19,6 @@
except ImportError:
gdown = None

from tqdm import tqdm

BIOIMAGEIO_IDS = {
"covid_if": "ilastik/covid_if_training_data",
Expand Down Expand Up @@ -158,3 +162,45 @@ def add_instance_label_transform(
kwargs = update_kwargs(kwargs, "label_transform", label_transform, msg=msg)
label_dtype = torch.float32
return kwargs, label_dtype


def generate_labeled_array_from_xml(shape, xml_file):
"""Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb

Given image shape and path to annotations (xml file), generatebit mask with the region inside a contour being white
shape: The image shape on which bit mask will be made
xml_file: path relative to the current working directory where the xml file is present

Returns:
An image of given shape with region inside contour being white..
"""
# DOM object created by the minidom parser
xDoc = minidom.parse(xml_file)

# List of all Region tags
regions = xDoc.getElementsByTagName('Region')

# List which will store the vertices for each region
xy = []
for region in regions:
# Loading all the vertices in the region
vertices = region.getElementsByTagName('Vertex')

# The vertices of a region will be stored in a array
vw = np.zeros((len(vertices), 2))

for index, vertex in enumerate(vertices):
# Storing the values of x and y coordinate after conversion
vw[index][0] = float(vertex.getAttribute('X'))
vw[index][1] = float(vertex.getAttribute('Y'))

# Append the vertices of a region
xy.append(np.int32(vw))

# Creating a completely black image
mask = np.zeros(shape, np.float32)

for i, contour in enumerate(xy):
r, c = polygon(np.array(contour)[:, 1], np.array(contour)[:, 0], shape=shape)
mask[r, c] = i
return mask
Loading