Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add splits for BCSS dataset #164

Merged
merged 3 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/datasets/check_bcss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def check_bcss():
patch_shape=(512, 512),
batch_size=2,
download=False,
label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2])
label_transform=BCSSLabelTrafo(label_choices=[0, 1, 2]),
split="train"
)
check_loader(chosen_label_loader, 8, instance_labels=True, rgb=True, plt=True, save_path="./bcss.png")

Expand Down
98 changes: 81 additions & 17 deletions torch_em/data/datasets/bcss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import shutil
from glob import glob
from pathlib import Path

import torch
from sklearn.model_selection import train_test_split

import torch
import torch_em
from torch_em.data.datasets import util
from torch_em.data import ImageCollectionDataset
Expand All @@ -15,6 +18,18 @@
CHECKSUM = None


TEST_LIST = [
'TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500', 'TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500',
'TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500', 'TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500',
'TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500', 'TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500',
'TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500', 'TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500',
'TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500', 'TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500',
'TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500', 'TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500',
'TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500', 'TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500',
'TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500'
]


def _download_bcss_dataset(path, download):
"""Current recommendation:
- download the folder from URL manually
Expand All @@ -28,7 +43,52 @@ def _download_bcss_dataset(path, download):
util.download_source_gdrive(path=path, url=URL, download=download, checksum=CHECKSUM, download_type="folder")


def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64, **kwargs):
def _get_image_and_label_paths(path):
# when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized`
# when getting the files from the git repo's command line feature, the input images are stored under `images`
if os.path.exists(os.path.join(path, "images")):
image_paths = sorted(glob(os.path.join(path, "images", "*")))
label_paths = sorted(glob(os.path.join(path, "masks", "*")))
elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")):
image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*")))
label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*")))
else:
raise ValueError("Please check the image directory. If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\"")

return image_paths, label_paths


def _assort_bcss_data(path, download):
if download:
_download_bcss_dataset(path, download)

if os.path.exists(os.path.join(path, "train")) and os.path.exists(os.path.join(path, "test")):
return

all_image_paths, all_label_paths = _get_image_and_label_paths(path)

train_img_dir, train_lab_dir = os.path.join(path, "train", "images"), os.path.join(path, "train", "masks")
test_img_dir, test_lab_dir = os.path.join(path, "test", "images"), os.path.join(path, "test", "masks")
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(train_lab_dir, exist_ok=True)
os.makedirs(test_img_dir, exist_ok=True)
os.makedirs(test_lab_dir, exist_ok=True)

for image_path, label_path in zip(all_image_paths, all_label_paths):
img_idx, label_idx = os.path.split(image_path)[-1], os.path.split(label_path)[-1]
if Path(image_path).stem in TEST_LIST:
# move image and label to test
dst_img_path, dst_lab_path = os.path.join(test_img_dir, img_idx), os.path.join(test_lab_dir, label_idx)
shutil.copy(src=image_path, dst=dst_img_path)
shutil.copy(src=label_path, dst=dst_lab_path)
else:
# move image and label to train
dst_img_path, dst_lab_path = os.path.join(train_img_dir, img_idx), os.path.join(train_lab_dir, label_idx)
shutil.copy(src=image_path, dst=dst_img_path)
shutil.copy(src=label_path, dst=dst_lab_path)


def get_bcss_dataset(path, patch_shape, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs):
"""Dataset for breast cancer tissue segmentation in histopathology.

This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/.
Expand Down Expand Up @@ -58,21 +118,25 @@ def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64,
- 20: dcis
- 21: other
"""
if download:
_download_bcss_dataset(path, download)
assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits"

# when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized`
# when getting the files from the git repo's command line feature, the input images are stored under `images`
if os.path.exists(os.path.join(path, "images")):
image_paths = sorted(glob(os.path.join(path, "images", "*")))
label_paths = sorted(glob(os.path.join(path, "masks", "*")))
elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")):
image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*")))
label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*")))
_assort_bcss_data(path, download)

if split == "test":
image_paths = sorted(glob(os.path.join(path, "test", "images", "*")))
label_paths = sorted(glob(os.path.join(path, "test", "masks", "*")))
else:
raise ValueError(
"Please check the image directory. If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\""
)
image_paths = sorted(glob(os.path.join(path, "train", "images", "*")))
label_paths = sorted(glob(os.path.join(path, "train", "masks", "*")))

(train_image_paths, val_image_paths,
train_label_paths, val_label_paths) = train_test_split(
image_paths, label_paths, test_size=val_fraction, random_state=42
)

image_paths = train_image_paths if split == "train" else val_image_paths
label_paths = train_label_paths if split == "train" else val_label_paths

assert len(image_paths) == len(label_paths)

dataset = ImageCollectionDataset(
Expand All @@ -82,12 +146,12 @@ def get_bcss_dataset(path, patch_shape, download=False, label_dtype=torch.int64,


def get_bcss_loader(
path, patch_shape, batch_size, download=False, label_dtype=torch.int64, **kwargs
path, patch_shape, batch_size, split, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
):
"""Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details."""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_bcss_dataset(
path, patch_shape, download=download, label_dtype=label_dtype, **ds_kwargs
path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs
)
loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader