diff --git a/torch_em/data/datasets/deepbacs.py b/torch_em/data/datasets/deepbacs.py index 32bd582b..4149128e 100644 --- a/torch_em/data/datasets/deepbacs.py +++ b/torch_em/data/datasets/deepbacs.py @@ -1,4 +1,7 @@ import os +import shutil +import numpy as np +from glob import glob import torch_em from . import util @@ -25,18 +28,61 @@ def _require_deebacs_dataset(path, bac_type, download): util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) util.unzip(zip_path, os.path.join(path, bac_type)) + # let's get a val split for the expected bacteria type + _assort_val_set(path, bac_type) + + +def _assort_val_set(path, bac_type): + image_paths = glob(os.path.join(path, bac_type, "training", "source", "*")) + image_paths = [os.path.split(_path)[-1] for _path in image_paths] + + val_partition = 0.2 + # let's get a balanced set of bacterias, if bac_type is mixed + if bac_type == "mixed": + _sort_1, _sort_2, _sort_3 = [], [], [] + for _path in image_paths: + if _path.startswith("JE2"): + _sort_1.append(_path) + elif _path.startswith("pos"): + _sort_2.append(_path) + elif _path.startswith("train_"): + _sort_3.append(_path) + + val_image_paths = [ + *np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False), + *np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False), + *np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False) + ] + else: + val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False) + + val_image_dir = os.path.join(path, bac_type, "val", "source") + val_label_dir = os.path.join(path, bac_type, "val", "target") + os.makedirs(val_image_dir, exist_ok=True) + os.makedirs(val_label_dir, exist_ok=True) + + for sample_id in val_image_paths: + src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id) + dst_val_image_path = os.path.join(val_image_dir, sample_id) + shutil.move(src_val_image_path, dst_val_image_path) + + src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id) + dst_val_label_path = os.path.join(val_label_dir, sample_id) + shutil.move(src_val_label_path, dst_val_label_path) + def _get_paths(path, bac_type, split): # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet # mixed is the combination of all other types + if split == "train": + dir_choice = "training" + else: + dir_choice = split + if bac_type != "mixed": raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") - image_folder = os.path.join( - path, bac_type, "training" if split == "train" else "test", "source" - ) - label_folder = os.path.join( - path, bac_type, "training" if split == "train" else "test", "target" - ) + image_folder = os.path.join(path, bac_type, dir_choice, "source") + label_folder = os.path.join(path, bac_type, dir_choice, "target") return image_folder, label_folder @@ -48,7 +94,7 @@ def get_deepbacs_dataset( This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. Please cite it if you use this dataset for a publication. """ - assert split in ("train", "test") + assert split in ("train", "val", "test") bac_types = list(URLS.keys()) assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"