From fc2aa6dd51befd0b1f59e54f0451844b25a1a477 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 29 Dec 2023 15:30:02 +0100 Subject: [PATCH] Update CEM dataloaders --- scripts/datasets/check_cem.py | 103 +++++++++++++ torch_em/data/datasets/__init__.py | 2 +- torch_em/data/datasets/cem.py | 232 ++++++++++++++++++++++++----- torch_em/data/datasets/util.py | 46 +++++- torch_em/util/modelzoo.py | 10 +- torch_em/util/util.py | 4 + 6 files changed, 351 insertions(+), 46 deletions(-) create mode 100644 scripts/datasets/check_cem.py diff --git a/scripts/datasets/check_cem.py b/scripts/datasets/check_cem.py new file mode 100644 index 00000000..7f25a505 --- /dev/null +++ b/scripts/datasets/check_cem.py @@ -0,0 +1,103 @@ +import os +import imageio.v3 as imageio +from glob import glob + +import numpy as np +import torch_em +from torch_em.data.datasets import cem +from torch_em.util.debug import check_loader + + +def get_all_shapes(): + # Get the shape for the 3d datasets (id: 1-6) + data_root = "./data/10982/data/mito_benchmarks" + i = 1 + for root, dirs, files in os.walk(data_root): + dirs.sort() + for ff in files: + if ff.endswith("em.tif"): + shape = imageio.imread(os.path.join(root, ff)).shape + print(i, ":", ff, ":", shape) + i += 1 + + # Get the shape for the 2d dataset (id: 7) + data_root = "./data/10982/data/tem_benchmark/images" + + shapes_2d = [] + for image in glob(os.path.join(data_root, "*.tiff")): + shape = imageio.imread(image).shape + shapes_2d.append(shape) + print(i, ":", set(shapes_2d)) + + +def check_benchmark_loaders(): + for dataset_id in range(1, 8): + print("Check benchmark dataset", dataset_id) + full_shape = cem.BENCHMARK_SHAPES[dataset_id] + if dataset_id == 7: + patch_shape = full_shape + else: + patch_shape = (1,) + full_shape[1:] + loader = cem.get_benchmark_loader( + "./data", dataset_id=dataset_id, batch_size=1, patch_shape=patch_shape, ndim=2 + ) + check_loader(loader, 4, instance_labels=True) + + +def check_mitolab_loader(): + val_fraction = 0.1 + train_loader = cem.get_mitolab_loader( + "./data", split="train", batch_size=1, shuffle=True, + sampler=torch_em.data.sampler.MinInstanceSampler(), + val_fraction=val_fraction, + ) + print("Checking train loader ...") + check_loader(train_loader, 8, instance_labels=True) + print("... done") + + val_loader = cem.get_mitolab_loader( + "./data", split="val", batch_size=1, shuffle=True, + sampler=torch_em.data.sampler.MinInstanceSampler(), + val_fraction=val_fraction, + ) + print("Checking val loader ...") + check_loader(val_loader, 8, instance_labels=True) + print("... done") + + +def analyse_mitolab(): + data_root = "data/11037/cem_mitolab" + folders = glob(os.path.join(data_root, "*")) + + n_datasets = len(folders) + + n_images = 0 + n_images_with_labels = 0 + + for folder in folders: + assert os.path.isdir(folder) + images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) + labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) + + n_images += len(images) + n_labels = [len(np.unique(imageio.imread(lab))) for lab in labels] + n_images_with_labels += sum([n_lab > 1 for n_lab in n_labels]) + + # print(folder) + # this_shapes = [imageio.imread(im).shape for im in images] + # print(set(this_shapes)) + + print(n_datasets) + print(n_images) + print(n_images_with_labels) + + +def main(): + # get_all_shapes() + # check_benchmark_loaders() + check_mitolab_loader() + # analyse_mitolab() + + +if __name__ == "__main__": + main() diff --git a/torch_em/data/datasets/__init__.py b/torch_em/data/datasets/__init__.py index c698a701..166e19fe 100644 --- a/torch_em/data/datasets/__init__.py +++ b/torch_em/data/datasets/__init__.py @@ -1,6 +1,6 @@ from .axondeepseg import get_axondeepseg_loader, get_axondeepseg_dataset from .bcss import get_bcss_loader, get_bcss_dataset -from .cem import get_cem_mitolab_loader +from .cem import get_mitolab_loader from .covid_if import get_covid_if_loader, get_covid_if_dataset from .cremi import get_cremi_loader, get_cremi_dataset from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset diff --git a/torch_em/data/datasets/cem.py b/torch_em/data/datasets/cem.py index c226fcc6..7de3ed49 100644 --- a/torch_em/data/datasets/cem.py +++ b/torch_em/data/datasets/cem.py @@ -1,44 +1,128 @@ -# Data loaders for the CEM datasets: -# - CEM-MitoLab: annotated data for training mitochondria segmentation models -# - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ -# - CEM-1.5M: unlabeled EM images for pretraining: -# - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ -# - CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation -# - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ - +"""Contains datasets and dataloader for the CEM data: +- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models + - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ +- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation + - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ +- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented) + - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ + +The data itself can be downloaded from EMPIAR via aspera. +- You can install aspera via mamba. I recommend to do this in a separate environment + to avoid dependency issues: + - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli` +- After this you can run `$ mamba activate aspera` to have an environment with aspera installed. +- You can then download the data for one of the three datasets like this: + - ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/ + - Where is the path to the mamba environment, the id of one of the three datasets + and where you want to download the data. +- After this you can use the functions in this file if you use as location for the data. + +Note that I have implemented automatic download, but this leads to issues with +mamba for me so I recommend to download the data manually and then run the loaders +with the correct path. +""" + +import json import os from glob import glob +import imageio.v3 as imageio +import numpy as np import torch_em from sklearn.model_selection import train_test_split +from . import util -# TODO -def _download_cem_mitolab(path): - # os.makedirs(path, exist_ok=True) - raise NotImplementedError("Data download is not implemented yet for CEM data.") +BENCHMARK_DATASETS = { + 1: "mito_benchmarks/c_elegans", + 2: "mito_benchmarks/fly_brain", + 3: "mito_benchmarks/glycolytic_muscle", + 4: "mito_benchmarks/hela_cell", + 5: "mito_benchmarks/lucchi_pp", + 6: "mito_benchmarks/salivary_gland", + 7: "tem_benchmark", +} +BENCHMARK_SHAPES = { + 1: (256, 256, 256), + 2: (256, 255, 255), + 3: (302, 383, 765), + 4: (256, 256, 256), + 5: (165, 768, 1024), + 6: (1260, 1081, 1200), + 7: (224, 224), # NOTE: this is the minimal square shape that fits +} -def _get_cem_mitolab_paths(path, split, val_fraction, download): - folders = glob(os.path.join(path, "*")) - assert all(os.path.isdir(folder) for folder in folders) +def _get_mitolab_data(path, download): + access_id = "11037" + data_path = util.download_source_empiar(path, access_id, download) - if len(folders) == 0 and download: - _download_cem_mitolab(path) - elif len(folders) == 0: - raise RuntimeError(f"The CEM Mitolab data is not available at {path}, but download was set to False.") + zip_path = os.path.join(data_path, "data/cem_mitolab.zip") + if os.path.exists(zip_path): + util.unzip(zip_path, data_path, remove=True) - raw_paths, label_paths = [], [] + data_root = os.path.join(data_path, "cem_mitolab") + assert os.path.exists(data_root) + return data_root + + +def _get_all_images(path): + raw_paths, label_paths = [], [] + folders = glob(os.path.join(path, "*")) + assert all(os.path.isdir(folder) for folder in folders) for folder in folders: - images = glob(os.path.join(folder, "images", "*.tiff")) - images.sort() + images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) assert len(images) > 0 - labels = glob(os.path.join(folder, "masks", "*.tiff")) - labels.sort() + labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) assert len(images) == len(labels) raw_paths.extend(images) label_paths.extend(labels) + return raw_paths, label_paths + + +def _get_non_empty_images(path): + save_path = os.path.join(path, "non_empty_images.json") + + if os.path.exists(save_path): + with open(save_path, "r") as f: + saved_images = json.load(f) + raw_paths, label_paths = saved_images["images"], saved_images["labels"] + raw_paths = [os.path.join(path, rp) for rp in raw_paths] + label_paths = [os.path.join(path, lp) for lp in label_paths] + return raw_paths, label_paths + + folders = glob(os.path.join(path, "*")) + assert all(os.path.isdir(folder) for folder in folders) + + raw_paths, label_paths = [], [] + for folder in folders: + labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) + images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) + assert len(images) > 0 + assert len(images) == len(labels) + + for im, lab in zip(images, labels): + n_labels = len(np.unique(imageio.imread(lab))) + if n_labels > 1: + raw_paths.append(im) + label_paths.append(lab) + + raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths] + label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths] + + with open(save_path, "w") as f: + json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f) + + return raw_paths, label_paths + + +def _get_mitolab_paths(path, split, val_fraction, download, discard_empty_images): + data_path = _get_mitolab_data(path, download) + if discard_empty_images: + raw_paths, label_paths = _get_non_empty_images(data_path) + else: + raw_paths, label_paths = _get_all_images(data_path) if split is not None: raw_train, raw_val, labels_train, labels_val = train_test_split( @@ -54,24 +138,100 @@ def _get_cem_mitolab_paths(path, split, val_fraction, download): return raw_paths, label_paths -def get_cem_mitolab_loader( - path, split, batch_size, patch_shape=(224, 224), val_fraction=0.05, download=False, **kwargs +def _get_benchmark_data(path, dataset_id, download): + access_id = "10982" + data_path = util.download_source_empiar(path, access_id, download) + dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) + + # these are the 3d datasets + if dataset_id in range(1, 7): + dataset_name = os.path.basename(dataset_path) + raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") + label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") + raw_key, label_key = None, None + is_seg_dataset = True + + # this is the 2d dataset + else: + raw_paths = os.path.join(dataset_path, "images") + label_paths = os.path.join(dataset_path, "masks") + raw_key, label_key = "*.tiff", "*.tiff" + is_seg_dataset = False + + return raw_paths, label_paths, raw_key, label_key, is_seg_dataset + + +# +# data sets +# + + +def get_mitolab_dataset( + path, split, patch_shape=(224, 224), val_fraction=0.05, download=False, + discard_empty_images=True, **kwargs ): assert split in ("train", "val", None) assert os.path.exists(path) - raw_paths, label_paths = _get_cem_mitolab_paths(path, split, val_fraction, download) - return torch_em.default_segmentation_loader( - raw_paths=raw_paths, raw_key=None, label_paths=label_paths, label_key=None, - batch_size=batch_size, patch_shape=patch_shape, - is_seg_dataset=False, ndim=2, **kwargs + raw_paths, label_paths = _get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, raw_key=None, + label_paths=label_paths, label_key=None, + patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs ) -# TODO +def get_cem15m_dataset(path): + raise NotImplementedError + + +def get_benchmark_dataset( + path, dataset_id, patch_shape, download=False, **kwargs, +): + """ + ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/10982 + """ + if dataset_id not in range(1, 8): + raise ValueError + raw_paths, label_paths, raw_key, label_key, is_seg_dataset = _get_benchmark_data(path, dataset_id, download) + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, raw_key=raw_key, + label_paths=label_paths, label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=is_seg_dataset, **kwargs, + ) + + +# +# data loaders +# + + +def get_mitolab_loader( + path, split, batch_size, patch_shape=(224, 224), + discard_empty_images=True, + val_fraction=0.05, download=False, **kwargs +): + ds_kwargs, loader_kwargs = util.split_kwargs( + torch_em.default_segmentation_dataset, **kwargs + ) + dataset = get_mitolab_dataset( + path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) + return loader + + def get_cem15m_loader(path): - pass + raise NotImplementedError -# TODO -def get_cem_mito_benchmark_loader(path): - pass +def get_benchmark_loader(path, dataset_id, batch_size, patch_shape, download=False, **kwargs): + ds_kwargs, loader_kwargs = util.split_kwargs( + torch_em.default_segmentation_dataset, **kwargs + ) + dataset = get_benchmark_dataset( + path, dataset_id, + patch_shape=patch_shape, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 8ae38b71..27d260d5 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -6,7 +6,8 @@ from tqdm import tqdm from warnings import warn from xml.dom import minidom -from shutil import copyfileobj +from shutil import copyfileobj, which +from subprocess import run from skimage.draw import polygon @@ -87,17 +88,17 @@ def download_source(path, url, download, checksum=None, verify=True): def download_source_gdrive(path, url, download, checksum=None, download_type="zip"): - if gdown is None: - raise RuntimeError( - "Need gdown library to download data from google drive." - "Please isntall gdown and then rerun." - ) - if os.path.exists(path): return if not download: raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") + if gdown is None: + raise RuntimeError( + "Need gdown library to download data from google drive." + "Please install gdown and then rerun." + ) + if download_type == "zip": gdown.download(url, path, quiet=False) elif download_type == "folder": @@ -110,6 +111,37 @@ def download_source_gdrive(path, url, download, checksum=None, download_type="zi _check_checksum(path, checksum) +def download_source_empiar(path, access_id, download): + download_path = os.path.join(path, access_id) + + if os.path.exists(download_path): + return download_path + if not download: + raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") + + if which("ascp") is None: + raise RuntimeError( + "Need aspera-cli to download data from empiar." + "You can install it via 'mamba install -c hcc aspera-cli'." + ) + + key_file = os.path.expanduser("~/.aspera/cli/etc/asperaweb_id_dsa.openssh") + if not os.path.exists(key_file): + conda_root = os.environ["CONDA_PREFIX"] + key_file = os.path.join(conda_root, "etc/asperaweb_id_dsa.openssh") + + if not os.path.exists(key_file): + raise RuntimeError("Could not find the aspera ssh keyfile") + + cmd = [ + "ascp", "-QT", "-l", "200M", "-P33001", + "-i", key_file, f"emp_ext2@fasp.ebi.ac.uk:/{access_id}", path + ] + run(cmd) + + return download_path + + def update_kwargs(kwargs, key, value, msg=None): if key in kwargs: msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg diff --git a/torch_em/util/modelzoo.py b/torch_em/util/modelzoo.py index e3777753..7affa122 100644 --- a/torch_em/util/modelzoo.py +++ b/torch_em/util/modelzoo.py @@ -107,7 +107,10 @@ def _get_model(trainer, postprocessing): def _pad(input_data, trainer): try: - ndim = trainer.train_loader.dataset.ndim + if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): + ndim = trainer.train_loader.dataset.dataset.ndim + else: + ndim = trainer.train_loader.dataset.ndim except AttributeError: ndim = trainer.train_loader.dataset.datasets[0].ndim target_dims = ndim + 2 @@ -305,7 +308,10 @@ def _write_weights(model, export_folder): def _get_preprocessing(trainer): try: - ndim = trainer.train_loader.dataset.ndim + if isinstance(trainer.train_loader.dataset, torch.utils.data.dataset.Subset): + ndim = trainer.train_loader.dataset.dataset.ndim + else: + ndim = trainer.train_loader.dataset.ndim except AttributeError: ndim = trainer.train_loader.dataset.datasets[0].ndim normalizer = get_normalizer(trainer) diff --git a/torch_em/util/util.py b/torch_em/util/util.py index a74a141a..4ecbf31c 100644 --- a/torch_em/util/util.py +++ b/torch_em/util/util.py @@ -196,6 +196,10 @@ def get_normalizer(trainer): isinstance(dataset, torch.utils.data.dataset.ConcatDataset) ): dataset = dataset.datasets[0] + + if isinstance(dataset, torch.utils.data.dataset.Subset): + dataset = dataset.dataset + preprocessor = dataset.raw_transform if hasattr(preprocessor, "normalizer"):