From d85b18a10c5340f2a9dc49638430b7eeea61fc56 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 15 Jul 2022 22:00:30 +0200 Subject: [PATCH 01/31] Enable anisotropic shallow2deep training --- ...in_mitoem_direct.py => train_direct_2d.py} | 0 .../{train_mitoem_2d.py => train_mito_2d.py} | 0 ...mitoem_3d.py => train_mito_anisotropic.py} | 0 .../em-mitochondria/train_vnc_2d.py | 98 ------------------- .../em-mitochondria/train_vnc_direct.py | 38 ------- torch_em/shallow2deep/shallow2deep_dataset.py | 51 ++++++++-- 6 files changed, 41 insertions(+), 146 deletions(-) rename experiments/shallow2deep/em-mitochondria/{train_mitoem_direct.py => train_direct_2d.py} (100%) rename experiments/shallow2deep/em-mitochondria/{train_mitoem_2d.py => train_mito_2d.py} (100%) rename experiments/shallow2deep/em-mitochondria/{train_mitoem_3d.py => train_mito_anisotropic.py} (100%) delete mode 100644 experiments/shallow2deep/em-mitochondria/train_vnc_2d.py delete mode 100644 experiments/shallow2deep/em-mitochondria/train_vnc_direct.py diff --git a/experiments/shallow2deep/em-mitochondria/train_mitoem_direct.py b/experiments/shallow2deep/em-mitochondria/train_direct_2d.py similarity index 100% rename from experiments/shallow2deep/em-mitochondria/train_mitoem_direct.py rename to experiments/shallow2deep/em-mitochondria/train_direct_2d.py diff --git a/experiments/shallow2deep/em-mitochondria/train_mitoem_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py similarity index 100% rename from experiments/shallow2deep/em-mitochondria/train_mitoem_2d.py rename to experiments/shallow2deep/em-mitochondria/train_mito_2d.py diff --git a/experiments/shallow2deep/em-mitochondria/train_mitoem_3d.py b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py similarity index 100% rename from experiments/shallow2deep/em-mitochondria/train_mitoem_3d.py rename to experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py diff --git a/experiments/shallow2deep/em-mitochondria/train_vnc_2d.py b/experiments/shallow2deep/em-mitochondria/train_vnc_2d.py deleted file mode 100644 index 02cc05d8..00000000 --- a/experiments/shallow2deep/em-mitochondria/train_vnc_2d.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -from glob import glob - -import numpy as np -import torch_em -import torch_em.shallow2deep as shallow2deep -from torch_em.model import UNet2d -from torch_em.data.datasets.vnc import _get_vnc_data - - -def prepare_shallow2deep(args, out_folder): - patch_shape_min = [1, 256, 256] - patch_shape_max = [1, 512, 512] - - raw_transform = torch_em.transform.raw.normalize - label_transform = shallow2deep.ForegroundTransform(ndim=2) - - path = os.path.join(args.input, "vnc_train.h5") - raw_key = "raw" - label_key = "labels/mitochondria" - - if args.train_advanced: - shallow2deep.prepare_shallow2deep_advanced( - raw_paths=path, raw_key=raw_key, label_paths=path, label_key=label_key, - patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, - n_forests=args.n_rfs, n_threads=args.n_threads, - forests_per_stage=25, sample_fraction_per_stage=0.05, - output_folder=out_folder, ndim=2, - raw_transform=raw_transform, label_transform=label_transform, - is_seg_dataset=True, - ) - else: - shallow2deep.prepare_shallow2deep( - raw_paths=path, raw_key=raw_key, label_paths=path, label_key=label_key, - patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, - n_forests=args.n_rfs, n_threads=args.n_threads, - output_folder=out_folder, ndim=2, - raw_transform=raw_transform, label_transform=label_transform, - is_seg_dataset=True, - ) - - -def get_loader(args, split, rf_folder): - rf_paths = glob(os.path.join(rf_folder, "*.pkl")) - rf_paths.sort() - patch_shape = (1, 512, 512) - - path = os.path.join(args.input, "vnc_train.h5") - roi = np.s_[:18, :, :] if split == "train" else np.s_[18:, :, :] - n_samples = 500 if split == "train" else 25 - - raw_transform = torch_em.transform.raw.normalize - label_transform = torch_em.transform.BoundaryTransform(ndim=2, add_binary_target=True) - loader = shallow2deep.get_shallow2deep_loader( - raw_paths=path, raw_key="raw", - label_paths=path, label_key="labels/mitochondria", - rf_paths=rf_paths, - batch_size=args.batch_size, patch_shape=patch_shape, - raw_transform=raw_transform, label_transform=label_transform, - n_samples=n_samples, ndim=2, is_seg_dataset=True, shuffle=True, - num_workers=12, rois=roi - ) - return loader - - -def train_shallow2deep(args): - name = "shallow2deep-em-mitochondria" - if args.train_advanced: - name += "-advanced" - _get_vnc_data(args.input, download=True) - - # check if we need to train the rfs for preparation - rf_folder = os.path.join("checkpoints", name, "rfs") - have_rfs = len(glob(os.path.join(rf_folder, "*.pkl"))) == args.n_rfs - if not have_rfs: - prepare_shallow2deep(args, rf_folder) - assert os.path.exists(rf_folder) - - model = UNet2d(in_channels=1, out_channels=2, final_activation="Sigmoid", - depth=4, initial_features=64) - - train_loader = get_loader(args, "train", rf_folder) - val_loader = get_loader(args, "val", rf_folder) - - trainer = torch_em.default_segmentation_trainer( - name, model, train_loader, val_loader, learning_rate=1.0e-4, - device=args.device, log_image_interval=50 - ) - trainer.fit(args.n_iterations) - - -if __name__ == "__main__": - parser = torch_em.util.parser_helper() - parser.add_argument("--train_advanced", "-a", type=int, default=0) - parser.add_argument("--n_rfs", type=int, default=500) - parser.add_argument("--n_threads", type=int, default=32) - args = parser.parse_args() - train_shallow2deep(args) diff --git a/experiments/shallow2deep/em-mitochondria/train_vnc_direct.py b/experiments/shallow2deep/em-mitochondria/train_vnc_direct.py deleted file mode 100644 index 351b82d8..00000000 --- a/experiments/shallow2deep/em-mitochondria/train_vnc_direct.py +++ /dev/null @@ -1,38 +0,0 @@ -import numpy as np -import torch_em -from torch_em.model import UNet2d -from torch_em.data.datasets import get_vnc_mito_loader - - -def get_loader(args, split): - patch_shape = (1, 512, 512) - - roi = np.s_[:18, :, :] if split == "train" else np.s_[18:, :, :] - n_samples = 500 if split == "train" else 25 - - loader = get_vnc_mito_loader( - args.input, boundaries=True, - batch_size=args.batch_size, patch_shape=patch_shape, - n_samples=n_samples, ndim=2, shuffle=True, - num_workers=12, rois=roi - ) - return loader - - -def train_direct(args): - name = "em-mitochondria" - model = UNet2d(in_channels=1, out_channels=2, final_activation="Sigmoid", depth=4, initial_features=64) - - train_loader = get_loader(args, "train") - val_loader = get_loader(args, "val") - - trainer = torch_em.default_segmentation_trainer( - name, model, train_loader, val_loader, learning_rate=1.0e-4, device=args.device, log_image_interval=50 - ) - trainer.fit(args.n_iterations) - - -if __name__ == "__main__": - parser = torch_em.util.parser_helper() - args = parser.parse_args() - train_direct(args) diff --git a/torch_em/shallow2deep/shallow2deep_dataset.py b/torch_em/shallow2deep/shallow2deep_dataset.py index 8801d75b..3a56747e 100644 --- a/torch_em/shallow2deep/shallow2deep_dataset.py +++ b/torch_em/shallow2deep/shallow2deep_dataset.py @@ -1,4 +1,5 @@ import pickle +import warnings import numpy as np import torch @@ -42,12 +43,7 @@ def rf_channels(self, value): assert isinstance(value, tuple) self._rf_channels = value - def _predict_rf(self, raw): - n_rfs = len(self._rf_paths) - rf_path = self._rf_paths[np.random.randint(0, n_rfs)] - with open(rf_path, "rb") as f: - rf = pickle.load(f) - filters_and_sigmas = _get_filters(self.ndim, self._filter_config) + def _predict(self, raw, rf, filters_and_sigmas): features = _apply_filters(raw, filters_and_sigmas) assert rf.n_features_in_ == features.shape[1], f"{rf.n_features_in_}, {features.shape[1]}" @@ -56,7 +52,7 @@ def _predict_rf(self, raw): assert pred_.shape[1] > max(self.rf_channels), f"{pred_.shape}, {self.rf_channels}" pred_ = pred_[:, self.rf_channels] except IndexError: - print("Prediction failed:", features.shape) + warnings.warn(f"Random forest prediction failed for input features of shape: {features.shape}") pred_shape = (len(features), len(self.rf_channels)) pred_ = np.zeros(pred_shape, dtype="float32") @@ -68,6 +64,29 @@ def _predict_rf(self, raw): return prediction + def _predict_rf(self, raw): + n_rfs = len(self._rf_paths) + rf_path = self._rf_paths[np.random.randint(0, n_rfs)] + with open(rf_path, "rb") as f: + rf = pickle.load(f) + filters_and_sigmas = _get_filters(self.ndim, self._filter_config) + return self._predict(raw, rf, filters_and_sigmas) + + def _predict_rf_anisotropic(self, raw): + n_rfs = len(self._rf_paths) + rf_path = self._rf_paths[np.random.randint(0, n_rfs)] + with open(rf_path, "rb") as f: + rf = pickle.load(f) + filters_and_sigmas = _get_filters(2, self._filter_config) + + n_channels = len(self.rf_channels) + prediction = np.zeros((n_channels,) + raw.shape, dtype="float32") + for z in range(raw.shape[0]): + pred = self._predict(raw[z], rf, filters_and_sigmas) + prediction[:, z] = pred + + return prediction + def __getitem__(self, index): assert self._rf_paths is not None raw, labels = self._get_sample(index) @@ -97,15 +116,24 @@ def __getitem__(self, index): ) # NOTE we assume single channel raw data here; this needs to be changed for multi-channel - prediction = self._predict_rf(raw[0].numpy()) + if getattr(self, "is_anisotropic", False): + prediction = self._predict_rf_anisotropic(raw[0].numpy()) + else: + prediction = self._predict_rf(raw[0].numpy()) prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) return prediction, labels -def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, **kwargs): +def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs): rois = kwargs.pop("rois", None) filter_config = kwargs.pop("filter_config", None) + if ndim == "anisotropic": + ndim = 3 + is_anisotropic = True + else: + is_anisotropic = False + if isinstance(raw_paths, str): if rois is not None: assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) @@ -113,6 +141,7 @@ def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_pa ds.rf_paths = rf_paths ds.filter_config = filter_config ds.rf_channels = rf_channels + ds.is_anisotropic = is_anisotropic else: assert len(raw_paths) > 0 if rois is not None: @@ -132,6 +161,7 @@ def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_pa dset.rf_paths = rf_paths dset.filter_config = filter_config dset.rf_channels = rf_channels + dset.is_anisotropic = is_anisotropic ds.append(dset) ds = ConcatDataset(*ds) return ds @@ -169,7 +199,8 @@ def get_shallow2deep_dataset( # we always use augmentations in the convenience function if transform is None: transform = _get_default_transform( - raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim + raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, + 3 if ndim == "anisotropic" else ndim ) if is_seg_dataset: From a96b87f80050ada9917a7bff59a3b2c18beb139e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 15 Jul 2022 22:32:16 +0200 Subject: [PATCH 02/31] Update shallow2deep mito training --- .../em-mitochondria/train_mito_2d.py | 144 +++++++++++------ .../em-mitochondria/train_mito_anisotropic.py | 145 +++++++++++------- torch_em/shallow2deep/shallow2deep_dataset.py | 4 +- 3 files changed, 193 insertions(+), 100 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py index 51be4129..13addd82 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py @@ -1,88 +1,139 @@ import os from glob import glob +import torch import torch_em import torch_em.shallow2deep as shallow2deep +from torch_em.data.datasets.mitoem import _require_mitoem_sample +from torch_em.data.datasets.vnc import _get_vnc_data from torch_em.model import UNet2d -def prepare_shallow2deep(args, out_folder): - patch_shape_min = [1, 256, 256] - patch_shape_max = [1, 512, 512] +DATA_ROOT = "/scratch/pape/s2d-mitochondria" +DATASETS = ["mitoem", "vnc"] + + +def normalize_datasets(datasets): + wrong_ds = list(set(datasets) - set(DATASETS)) + if wrong_ds: + raise ValueError(f"Unkown datasets: {wrong_ds}. Only {DATASETS} are supported") + datasets = list(sorted(datasets)) + return datasets + + +def require_ds(dataset): + os.makedirs(DATA_ROOT, exist_ok=True) + data_path = os.path.join(DATA_ROOT, dataset) + if dataset == "mitoem": + if not os.path.exists(data_path): + _require_mitoem_sample(data_path, sample="human", download=True) + _require_mitoem_sample(data_path, sample="rat", download=True) + paths = [ + os.path.join(data_path, "human_train.n5"), + os.path.join(data_path, "rat_train.n5"), + ] + assert all(os.path.exists(pp) for pp in paths) + raw_key, label_key = "raw", "labels" + elif dataset == "vnc": + _get_vnc_data(data_path, True) + paths = [os.path.join(data_path, "vnc_train.h5")] + raw_key, label_key = "raw", "labels/mitochondria" + return paths, raw_key, label_key + + +def require_rfs_ds(dataset, n_rfs, sampling_strategy): + if sampling_strategy is None: + out_folder = os.path.join(DATA_ROOT, "rfs2d", dataset) + else: + out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) + os.makedirs(out_folder, exist_ok=True) + if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: + return + + patch_shape_min = [1, 128, 128] + patch_shape_max = [1, 256, 256] raw_transform = torch_em.transform.raw.normalize label_transform = shallow2deep.ForegroundTransform(ndim=2) - paths = [ - os.path.join(args.input, "human_train.n5"), - os.path.join(args.input, "rat_train.n5") - ] - raw_key = "raw" - label_key = "labels" + paths, raw_key, label_key = require_ds(dataset) - if args.train_advanced: - shallow2deep.prepare_shallow2deep_advanced( + if sampling_strategy == "vanilla": + shallow2deep.prepare_shallow2deep( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, n_forests=args.n_rfs, n_threads=args.n_threads, - forests_per_stage=25, sample_fraction_per_stage=0.05, output_folder=out_folder, ndim=2, raw_transform=raw_transform, label_transform=label_transform, is_seg_dataset=True, ) else: - shallow2deep.prepare_shallow2deep( + shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, n_forests=args.n_rfs, n_threads=args.n_threads, + forests_per_stage=25, sample_fraction_per_stage=0.10, output_folder=out_folder, ndim=2, raw_transform=raw_transform, label_transform=label_transform, - is_seg_dataset=True, + is_seg_dataset=True, sampling_strategy=sampling_strategy, ) -def get_loader(args, split, rf_folder): - rf_paths = glob(os.path.join(rf_folder, "*.pkl")) - rf_paths.sort() - patch_shape = (1, 512, 512) +def require_rfs(datasets, n_rfs, sampling_strategy): + for ds in datasets: + require_rfs_ds(ds, n_rfs, sampling_strategy) - paths = [ - os.path.join(args.input, f"human_{split}.n5"), - os.path.join(args.input, f"rat_{split}.n5") - ] - n_samples = 500 if split == "train" else 25 - raw_transform = torch_em.transform.raw.normalize +def get_ds(file_pattern, rf_pattern, n_samples, label_key): label_transform = torch_em.transform.BoundaryTransform(ndim=2, add_binary_target=True) - loader = shallow2deep.get_shallow2deep_loader( - raw_paths=paths, raw_key="raw", - label_paths=paths, label_key="labels", - rf_paths=rf_paths, - batch_size=args.batch_size, patch_shape=patch_shape, - raw_transform=raw_transform, label_transform=label_transform, - n_samples=n_samples, ndim=2, is_seg_dataset=True, shuffle=True, - num_workers=12 + patch_shape = (1, 512, 512) + paths = glob(file_pattern) + paths.sort() + assert len(paths) > 0 + rf_paths = glob(rf_pattern) + rf_paths.sort() + assert len(rf_paths) > 0 + raw_key = "raw" + return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( + paths, raw_key, paths, label_key, rf_paths, + patch_shape=patch_shape, label_transform=label_transform, + n_samples=n_samples, ndim=2, ) + + +def get_loader(args, split, dataset_names): + datasets = [] + n_samples = 500 if split == "train" else 25 + if "mitoem" in dataset_names: + ds_name = "mitoem" + file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split}.n5") + rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels")) + if "vnc" in dataset_names and split == "train": + ds_name = "vnc" + file_pattern = os.path.join(DATA_ROOT, ds_name, f"vnc_{split}.h5") + rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels/mitochondria")) + ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) if len(datasets) > 1 else datasets[0] + loader = torch.utils.data.DataLoader( + ds, shuffle=True, batch_size=args.batch_size, num_workers=12 + ) + loader.shuffle = True return loader def train_shallow2deep(args): - name = "shallow2deep-mitoem2d" - if args.train_advanced: - name += "-advanced" - - # check if we need to train the rfs for preparation - rf_folder = os.path.join("checkpoints", name, "rfs") - have_rfs = len(glob(os.path.join(rf_folder, "*.pkl"))) == args.n_rfs - if not have_rfs: - prepare_shallow2deep(args, rf_folder) - assert os.path.exists(rf_folder) + datasets = normalize_datasets(args.datasets) + name = f"s2d-em-mitos-{'_'.join(datasets)}-2d" + if args.sampling_strategy is not None: + name += f"-{args.sampling_strategy}" + require_rfs(datasets, args.n_rfs, args.sampling_strategy) model = UNet2d(in_channels=1, out_channels=2, final_activation="Sigmoid", depth=4, initial_features=64) - train_loader = get_loader(args, "train", rf_folder) - val_loader = get_loader(args, "val", rf_folder) + train_loader = get_loader(args, "train", datasets) + val_loader = get_loader(args, "val", datasets) trainer = torch_em.default_segmentation_trainer( name, model, train_loader, val_loader, learning_rate=1.0e-4, @@ -92,9 +143,10 @@ def train_shallow2deep(args): if __name__ == "__main__": - parser = torch_em.util.parser_helper() - parser.add_argument("--train_advanced", "-a", type=int, default=0) + parser = torch_em.util.parser_helper(require_input=False, default_batch_size=4) + parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) + parser.add_argument("--sampling_strategy", "-s", default=None) args = parser.parse_args() train_shallow2deep(args) diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py index 7a39aef4..b80747a5 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py @@ -1,89 +1,129 @@ import os from glob import glob +import torch import torch_em import torch_em.shallow2deep as shallow2deep from torch_em.model import AnisotropicUNet +from torch_em.data.datasets.mitoem import _require_mitoem_sample -def prepare_shallow2deep(args, out_folder): - patch_shape_min = [24, 128, 128] - patch_shape_max = [32, 256, 256] +DATA_ROOT = "/scratch/pape/s2d-mitochondria" +DATASETS = ["mitoem"] - raw_transform = torch_em.transform.raw.normalize - label_transform = shallow2deep.ForegroundTransform(ndim=3) - paths = [ - os.path.join(args.input, "human_train.n5"), - os.path.join(args.input, "rat_train.n5") - ] - raw_key = "raw" - label_key = "labels" +def normalize_datasets(datasets): + wrong_ds = list(set(datasets) - set(DATASETS)) + if wrong_ds: + raise ValueError(f"Unkown datasets: {wrong_ds}. Only {DATASETS} are supported") + datasets = list(sorted(datasets)) + return datasets - if args.train_advanced: - shallow2deep.prepare_shallow2deep_advanced( + +def require_ds(dataset): + os.makedirs(DATA_ROOT, exist_ok=True) + data_path = os.path.join(DATA_ROOT, dataset) + if dataset == "mitoem": + if not os.path.exists(data_path): + _require_mitoem_sample(data_path, sample="human", download=True) + _require_mitoem_sample(data_path, sample="rat", download=True) + paths = [ + os.path.join(data_path, "human_train.n5"), + os.path.join(data_path, "rat_train.n5"), + ] + assert all(os.path.exists(pp) for pp in paths) + raw_key, label_key = "raw", "labels" + return paths, raw_key, label_key + + +def require_rfs_ds(dataset, n_rfs, sampling_strategy): + if sampling_strategy is None: + out_folder = os.path.join(DATA_ROOT, "rfs2d", dataset) + else: + out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) + os.makedirs(out_folder, exist_ok=True) + if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: + return + + patch_shape_min = [1, 128, 128] + patch_shape_max = [1, 256, 256] + + raw_transform = torch_em.transform.raw.normalize + label_transform = shallow2deep.ForegroundTransform(ndim=2) + + paths, raw_key, label_key = require_ds(dataset) + if sampling_strategy == "vanilla": + shallow2deep.prepare_shallow2deep( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, n_forests=args.n_rfs, n_threads=args.n_threads, - forests_per_stage=25, sample_fraction_per_stage=0.025, - output_folder=out_folder, ndim=3, + output_folder=out_folder, ndim=2, raw_transform=raw_transform, label_transform=label_transform, is_seg_dataset=True, ) else: - shallow2deep.prepare_shallow2deep( + shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, n_forests=args.n_rfs, n_threads=args.n_threads, - output_folder=out_folder, ndim=3, + forests_per_stage=25, sample_fraction_per_stage=0.10, + output_folder=out_folder, ndim=2, raw_transform=raw_transform, label_transform=label_transform, - is_seg_dataset=True, + is_seg_dataset=True, sampling_strategy=sampling_strategy, ) -def get_loader(args, split, rf_folder): - rf_paths = glob(os.path.join(rf_folder, "*.pkl")) - rf_paths.sort() - patch_shape = (32, 256, 256) +def require_rfs(datasets, n_rfs, sampling_strategy): + for ds in datasets: + require_rfs_ds(ds, n_rfs, sampling_strategy) - paths = [ - os.path.join(args.input, f"human_{split}.n5"), - os.path.join(args.input, f"rat_{split}.n5") - ] - n_samples = 500 if split == "train" else 25 - raw_transform = torch_em.transform.raw.normalize +def get_ds(file_pattern, rf_pattern, n_samples, label_key): label_transform = torch_em.transform.BoundaryTransform(ndim=3, add_binary_target=True) - loader = shallow2deep.get_shallow2deep_loader( - raw_paths=paths, raw_key="raw", - label_paths=paths, label_key="labels", - rf_paths=rf_paths, - batch_size=args.batch_size, patch_shape=patch_shape, - raw_transform=raw_transform, label_transform=label_transform, - n_samples=n_samples, ndim=3, is_seg_dataset=True, shuffle=True, - num_workers=12 + patch_shape = (32, 256, 256) + paths = glob(file_pattern) + paths.sort() + assert len(paths) > 0 + rf_paths = glob(rf_pattern) + rf_paths.sort() + assert len(rf_paths) > 0 + raw_key = "raw" + return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( + paths, raw_key, paths, label_key, rf_paths, + patch_shape=patch_shape, label_transform=label_transform, + n_samples=n_samples, ndim="anisotropic", + ) + + +def get_loader(args, split, dataset_names): + datasets = [] + n_samples = 500 if split == "train" else 25 + if "mitoem" in dataset_names: + ds_name = "mitoem" + file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split}.n5") + rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels")) + ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) if len(datasets) > 1 else datasets[0] + loader = torch.utils.data.DataLoader( + ds, shuffle=True, batch_size=args.batch_size, num_workers=12 ) + loader.shuffle = True return loader def train_shallow2deep(args): - name = "shallow2deep-mitoem3d" - if args.train_advanced: - name += "-advanced" - - # check if we need to train the rfs for preparation - rf_folder = os.path.join("checkpoints", name, "rfs") - have_rfs = len(glob(os.path.join(rf_folder, "*.pkl"))) == args.n_rfs - if not have_rfs: - prepare_shallow2deep(args, rf_folder) - assert os.path.exists(rf_folder) - - scale_factors = [[1, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + datasets = normalize_datasets(args.datasets) + name = f"s2d-em-mitos-{'_'.join(datasets)}-anisotropic" + if args.sampling_strategy is not None: + name += f"-{args.sampling_strategy}" + require_rfs(datasets, args.n_rfs, args.sampling_strategy) + + scale_factors = [[1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2]] model = AnisotropicUNet(in_channels=1, out_channels=2, final_activation="Sigmoid", scale_factors=scale_factors, initial_features=32) - train_loader = get_loader(args, "train", rf_folder) - val_loader = get_loader(args, "val", rf_folder) + train_loader = get_loader(args, "train", datasets) + val_loader = get_loader(args, "val", datasets) trainer = torch_em.default_segmentation_trainer( name, model, train_loader, val_loader, learning_rate=1.0e-4, @@ -93,9 +133,10 @@ def train_shallow2deep(args): if __name__ == "__main__": - parser = torch_em.util.parser_helper() - parser.add_argument("--train_advanced", "-a", type=int, default=0) + parser = torch_em.util.parser_helper(require_input=False) + parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) + parser.add_argument("--sampling_strategy", "-s", default=None) args = parser.parse_args() train_shallow2deep(args) diff --git a/torch_em/shallow2deep/shallow2deep_dataset.py b/torch_em/shallow2deep/shallow2deep_dataset.py index 3a56747e..d7a3e6e3 100644 --- a/torch_em/shallow2deep/shallow2deep_dataset.py +++ b/torch_em/shallow2deep/shallow2deep_dataset.py @@ -137,7 +137,7 @@ def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_pa if isinstance(raw_paths, str): if rois is not None: assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois) - ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs) + ds = Shallow2DeepDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, ndim=ndim, **kwargs) ds.rf_paths = rf_paths ds.filter_config = filter_config ds.rf_channels = rf_channels @@ -156,7 +156,7 @@ def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_pa for i, (raw_path, label_path) in enumerate(zip(raw_paths, label_paths)): roi = None if rois is None else rois[i] dset = Shallow2DeepDataset( - raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], **kwargs + raw_path, raw_key, label_path, label_key, roi=roi, n_samples=samples_per_ds[i], ndim=ndim, **kwargs ) dset.rf_paths = rf_paths dset.filter_config = filter_config From adbb4774b0f88895c6883aa904f458ca12c4cdd9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 15 Jul 2022 23:41:30 +0200 Subject: [PATCH 03/31] Implement more s2d sampling strategies --- .../em-mitochondria/train_mito_2d.py | 1 + .../em-mitochondria/train_mito_anisotropic.py | 1 + torch_em/shallow2deep/prepare_shallow2deep.py | 79 ++++++++++++++++--- 3 files changed, 71 insertions(+), 10 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py index 13addd82..37e3443b 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py @@ -68,6 +68,7 @@ def require_rfs_ds(dataset, n_rfs, sampling_strategy): is_seg_dataset=True, ) else: + sampling_strategy = "worst_points" if sampling_strategy is None else sampling_strategy shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py index b80747a5..9619f767 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py @@ -62,6 +62,7 @@ def require_rfs_ds(dataset, n_rfs, sampling_strategy): is_seg_dataset=True, ) else: + sampling_strategy = "worst_points" if sampling_strategy is None else sampling_strategy shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index d85d6b85..b51e20e8 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -375,23 +375,20 @@ def _train_rf(rf_id): list(tqdm(tp.map(_train_rf, range(n_forests)), desc="Train RFs", total=n_forests)) -def worst_points( +def _score_based_points( + score_function, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, - accumulate_samples=True, + accumulate_samples, ): # get the corresponding random forest from the last stage # and predict with it last_forest = forests[rf_id - forests_per_stage] pred = last_forest.predict_proba(features) - # labels to one-hot encoding - unique, inverse = np.unique(labels, return_inverse=True) - onehot = np.eye(unique.shape[0])[inverse] - # compute the difference between labels and prediction - diff = np.abs(onehot - pred).sum(axis=1) - assert len(diff) == len(features) + score = score_function(pred, labels) + assert len(score) == len(features) # get training samples based on the label-prediction diff samples = [] @@ -400,7 +397,7 @@ def worst_points( n_samples = int(sample_fraction_per_stage * len(features)) n_samples_class = n_samples // nc for class_id in range(nc): - this_samples = np.argsort(diff[labels == class_id])[::-1][:n_samples_class] + this_samples = np.argsort(score[labels == class_id])[::-1][:n_samples_class] samples.append(this_samples) samples = np.concatenate(samples) @@ -413,6 +410,66 @@ def worst_points( return features, labels +def worst_points( + features, labels, rf_id, + forests, forests_per_stage, + sample_fraction_per_stage, + accumulate_samples=True, +): + def score(pred, labels): + # labels to one-hot encoding + unique, inverse = np.unique(labels, return_inverse=True) + onehot = np.eye(unique.shape[0])[inverse] + # compute the difference between labels and prediction + return np.abs(onehot - pred).sum(axis=1) + + return _score_based_points( + score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples + ) + + +def uncertain_points( + features, labels, rf_id, + forests, forests_per_stage, + sample_fraction_per_stage, + accumulate_samples=True, +): + def score(pred, labels): + assert pred.ndim == 2 + channel_sorted = np.sort(pred, axis=1) + uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] + return uncertainty + + return _score_based_points( + score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples + ) + + +def uncertain_worst_points( + features, labels, rf_id, + forests, forests_per_stage, + sample_fraction_per_stage, + accumulate_samples=True, + alpha=0.5, +): + def score(pred, labels): + assert pred.ndim == 2 + + # labels to one-hot encoding + unique, inverse = np.unique(labels, return_inverse=True) + onehot = np.eye(unique.shape[0])[inverse] + # compute the difference between labels and prediction + diff = np.abs(onehot - pred).sum(axis=1) + + channel_sorted = np.sort(pred, axis=1) + uncertainty = channel_sorted[:, -1] - channel_sorted[:, -2] + return alpha * diff + (1.0 - alpha) * uncertainty + + return _score_based_points( + score, features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples + ) + + def random_points( features, labels, rf_id, forests, forests_per_stage, @@ -434,8 +491,10 @@ def random_points( SAMPLING_STRATEGIES = { - "worst_points": worst_points, "random_points": random_points, + "uncertain_points": uncertain_points, + "uncertain_worst_points": uncertain_worst_points, + "worst_points": worst_points, } From b99b89401df243da81e203c39f59ec91aca35e94 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 18 Jul 2022 21:52:59 +0200 Subject: [PATCH 04/31] Add modelzoo config functionality WIP --- torch_em/util/__init__.py | 2 ++ torch_em/util/modelzoo_configs.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 torch_em/util/modelzoo_configs.py diff --git a/torch_em/util/__init__.py b/torch_em/util/__init__.py index 962ada4d..81ab5fb3 100644 --- a/torch_em/util/__init__.py +++ b/torch_em/util/__init__.py @@ -6,6 +6,8 @@ export_parser_helper, get_default_citations, import_bioimageio_model) +from .modelzoo_configs import (get_mws_config, + get_shallow2deep_config) from .reporting import get_training_summary from .training import parser_helper from .util import (ensure_array, ensure_spatial_array, diff --git a/torch_em/util/modelzoo_configs.py b/torch_em/util/modelzoo_configs.py new file mode 100644 index 00000000..6b34e0cd --- /dev/null +++ b/torch_em/util/modelzoo_configs.py @@ -0,0 +1,21 @@ + + +def get_mws_config(offsets, config=None): + mws_config = {"offsets": offsets} + if config is None: + config = {"mws": mws_config} + else: + assert isinstance(config, dict) + config["mws"] = mws_config + return config + + +def get_shallow2deep_config(config=None): + # TODO + shallow2deep_config = {} + if config is None: + config = {"shallow2deep": shallow2deep_config} + else: + assert isinstance(config, dict) + config["shallow2deep"] = shallow2deep_config + return config From 76f8ca9311594a1b2e5c7ffd7683f243419a44c1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 18 Jul 2022 23:33:46 +0200 Subject: [PATCH 05/31] Implement shallow2deep modelzoo config --- .../em-mitochondria/export_enhancer.py | 98 +++++++++++++++---- torch_em/shallow2deep/prepare_shallow2deep.py | 18 ++++ torch_em/util/modelzoo_configs.py | 16 ++- 3 files changed, 110 insertions(+), 22 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/export_enhancer.py b/experiments/shallow2deep/em-mitochondria/export_enhancer.py index 76c3a601..244d68ad 100644 --- a/experiments/shallow2deep/em-mitochondria/export_enhancer.py +++ b/experiments/shallow2deep/em-mitochondria/export_enhancer.py @@ -2,17 +2,20 @@ import os from glob import glob +import numpy as np from elf.io import open_file from torch_em.data.datasets import get_bioimageio_dataset_id from torch_em.util import (add_weight_formats, export_bioimageio_model, get_default_citations, + get_shallow2deep_config, get_training_summary) from torch_em.shallow2deep.shallow2deep_model import RFWithFilters, _get_filters -def _get_name_and_description(is3d): - name = "EnhancerMitochondriaEM3D" if is3d else "EnhancerMitochondriaEM2D" +def _get_name_and_description(is3d, name): + if name is None: + name = "EnhancerMitochondriaEM3D" if is3d else "EnhancerMitochondriaEM2D" description = "Prediction enhancer for segmenting mitochondria in EM images." return name, description @@ -56,10 +59,9 @@ def _get_doc(ckpt, name, is3d): return doc -def create_input_2d(input_path, checkpoint): +def create_input_2d(input_path, rf_path): with open_file(input_path, "r") as f: data = f["raw"][-1, :512, :512] - rf_path = glob(os.path.join(checkpoint, "rfs/*.pkl"))[-1] assert os.path.exists(rf_path), rf_path filter_config = _get_filters(2, None) rf = RFWithFilters(rf_path, ndim=2, filter_config=filter_config, output_channel=1) @@ -67,27 +69,50 @@ def create_input_2d(input_path, checkpoint): return pred[None] -def create_input_3d(input_path, checkpoint): +def create_input_anisotropic(input_path, rf_path): with open_file(input_path, "r") as f: data = f["raw"][:32, :256, :256] - rf_path = glob(os.path.join(checkpoint, "rfs/*.pkl"))[-1] assert os.path.exists(rf_path), rf_path filter_config = _get_filters(2, None) + rf = RFWithFilters(rf_path, ndim=2, filter_config=filter_config, output_channel=1) + pred = np.zeros(data.shape, dtype="float32") + for z in range(data.shape[0]): + pred[z] = rf(data[z]) + return pred[None] + + +def create_input_3d(input_path, rf_path): + with open_file(input_path, "r") as f: + data = f["raw"][:32, :256, :256] + assert os.path.exists(rf_path), rf_path + filter_config = _get_filters(3, None) rf = RFWithFilters(rf_path, ndim=3, filter_config=filter_config, output_channel=1) pred = rf(data) return pred[None] -def export_enhancer(input_, train_advanced, is3d): +def export_enhancer(input_, is3d, checkpoint=None, version=None, name=None): - checkpoint = "./checkpoints/shallow2deep-mitoem3d" if is3d else\ - "./checkpoints/shallow2deep-mitoem2d" - - if train_advanced: - checkpoint += "-advanced" - input_data = create_input_3d(input_, checkpoint) if is3d else create_input_2d(input_, checkpoint) + if checkpoint is None: + checkpoint = "./checkpoints/shallow2deep-mitoem3d" if is3d else\ + "./checkpoints/shallow2deep-mitoem2d" + out_folder = "./bio-models" + else: + assert version is not None + out_folder = f"./bio-models/v{version}" + + if is3d == "anisotropic": + rf_path = "/scratch/pape/s2d-mitochondria/rfs2d/mitoem/rf_0499.pkl" + input_data = create_input_anisotropic(input_, rf_path) + is3d = True + elif is3d: + assert False, "Currently don't have 3d rfs for mitos" + input_data = create_input_3d(input_, rf_path) + else: + rf_path = "/scratch/pape/s2d-mitochondria/rfs2d/mitoem/rf_0499.pkl" + input_data = create_input_2d(input_, rf_path) - name, description = _get_name_and_description(is3d) + name, description = _get_name_and_description(is3d, name) tags = ["unet", "mitochondria", "electron-microscopy", "instance-segmentation", "shallow2deep"] tags += ["3d"] if is3d else ["2d"] @@ -100,9 +125,8 @@ def export_enhancer(input_, train_advanced, is3d): doc = _get_doc(checkpoint, name, is3d) additional_formats = ["torchscript"] - out_folder = "./bio-models" os.makedirs(out_folder, exist_ok=True) - output = os.path.join(out_folder, f"{name}-advanced-traing" if train_advanced else name) + output = os.path.join(out_folder, name) if is3d: min_shape = [16, 128, 128] @@ -111,6 +135,7 @@ def export_enhancer(input_, train_advanced, is3d): min_shape = [256, 256] halo = [32, 32] + config = get_shallow2deep_config(rf_path) export_bioimageio_model( checkpoint, output, input_data=input_data, @@ -129,17 +154,52 @@ def export_enhancer(input_, train_advanced, is3d): maintainers=[{"github_user": "constantinpape"}], min_shape=min_shape, halo=halo, + config=config, ) add_weight_formats(output, additional_formats) +def export_version(args): + + def _get_ndim(x): + if x == "2d": + return False + elif x == "anisotropic": + return x + elif x == "3d": + return True + return None + + checkpoints = glob("./checkpoints/s2d-em-mitos-*") + for ckpt in checkpoints: + name = os.path.basename(ckpt) + parts = name.split("-") + is3d = _get_ndim(parts[-1]) + if is3d is None: + is3d = _get_ndim(parts[-2]) + else: + name += "-worst_points" + assert is3d is not None + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("Exporting:", ckpt) + print(name) + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + export_enhancer(args.input, is3d, checkpoint=ckpt, version=args.version, name=name) + + def main(): parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input") - parser.add_argument("-a", "--train_advanced", default=0) + parser.add_argument("-i", "--input", required=True) parser.add_argument("-d", "--is3d", default=0) + parser.add_argument("-v", "--version", default=None, type=int) args = parser.parse_args() - export_enhancer(args.input, args.train_advanced, args.is3d) + if args.version is None: + export_enhancer(args.input, args.is3d) + else: + # export all currently trained checkpoints as one version + export_version(args) if __name__ == "__main__": diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index b51e20e8..0c123c5b 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -329,6 +329,15 @@ def _prepare_shallow2deep( return ds, filters_and_sigmas +def _serialize_feature_config(filters_and_sigmas): + feature_config = [ + (filt.func.__name__ if isinstance(filt, partial) else filt.__name__, + sigma) + for filt, sigma in filters_and_sigmas + ] + return feature_config + + def prepare_shallow2deep( raw_paths, raw_key, @@ -356,6 +365,7 @@ def prepare_shallow2deep( raw_transform, label_transform, rois, is_seg_dataset, filter_config, sampler, ) + serialized_feature_config = _serialize_feature_config(filters_and_sigmas) def _train_rf(rf_id): # sample random patch with dataset @@ -367,6 +377,9 @@ def _train_rf(rf_id): features, labels = _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels) rf = RandomForestClassifier(**rf_kwargs) rf.fit(features, labels) + # monkey patch these so that we know the feature config and dimensionality + rf.feature_ndim = ndim + rf.feature_config = serialized_feature_config out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") with open(out_path, "wb") as f: pickle.dump(rf, f) @@ -539,6 +552,8 @@ def prepare_shallow2deep_advanced( raw_transform, label_transform, rois, is_seg_dataset, filter_config, sampler, ) + serialized_feature_config = _serialize_feature_config(filters_and_sigmas) + forests = [] n_stages = n_forests // forests_per_stage if n_forests % forests_per_stage == 0 else\ n_forests // forests_per_stage + 1 @@ -581,6 +596,9 @@ def _train_rf(rf_id): assert len(features) == len(labels) rf = RandomForestClassifier(**rf_kwargs) rf.fit(features, labels) + # monkey patch these so that we know the feature config and dimensionality + rf.feature_ndim = ndim + rf.feature_config = serialized_feature_config # save the random forest, update pbar, return it out_path = os.path.join(output_folder, f"rf_{rf_id:04d}.pkl") diff --git a/torch_em/util/modelzoo_configs.py b/torch_em/util/modelzoo_configs.py index 6b34e0cd..2bed3af7 100644 --- a/torch_em/util/modelzoo_configs.py +++ b/torch_em/util/modelzoo_configs.py @@ -1,3 +1,6 @@ +import os +import pickle +from glob import glob def get_mws_config(offsets, config=None): @@ -10,9 +13,16 @@ def get_mws_config(offsets, config=None): return config -def get_shallow2deep_config(config=None): - # TODO - shallow2deep_config = {} +def get_shallow2deep_config(rf_path, config=None): + if os.path.isdir(rf_path): + rf_path = glob(os.path.join(rf_path, "*.pkl"))[0] + assert os.path.exists(rf_path), rf_path + with open(rf_path, "rb") as f: + rf = pickle.load(f) + shallow2deep_config = { + "ndim": rf.feature_ndim, + "features": rf.feature_config, + } if config is None: config = {"shallow2deep": shallow2deep_config} else: From b3e7792de2bc09e54cb02b1e72e35830c8bc217a Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 19 Jul 2022 12:41:15 +0200 Subject: [PATCH 06/31] Update mito experiments --- .../shallow2deep/em-mitochondria/README.md | 24 +- .../em-mitochondria/evaluation.py | 392 ++++++++++-------- .../em-mitochondria/old_evaluation.py | 238 +++++++++++ .../em-mitochondria/train_mito_2d.py | 18 +- .../em-mitochondria/visualize_rfs.py | 30 +- torch_em/shallow2deep/shallow2deep_eval.py | 9 +- torch_em/util/debug.py | 6 +- 7 files changed, 521 insertions(+), 196 deletions(-) create mode 100644 experiments/shallow2deep/em-mitochondria/old_evaluation.py diff --git a/experiments/shallow2deep/em-mitochondria/README.md b/experiments/shallow2deep/em-mitochondria/README.md index d8e5b996..e5db5d5e 100644 --- a/experiments/shallow2deep/em-mitochondria/README.md +++ b/experiments/shallow2deep/em-mitochondria/README.md @@ -1,8 +1,28 @@ -# Shallow2Deep for mitochondria +# Shallow2Deep for Mitochondria in EM ## Evaluation -Evaluation of different shallow2deep setups on EM-Mitochondria. All scores are measured with a soft dice score. +Evaluation of different shallow2deep setups for mitochondria segmentation in EM. +The enhancers are (potentially) trained on multiple datasets, evaluation is always on the EPFL dataset (which is ofc not part of the training set). +All scores are measured with a soft dice score. + + +### V4 + +- 2d enhancer: trained on mito-em and vnc +- anisotropic enhancer: random forests are trained in 2d, enhancer trained in 3d, trained on mito-em +- direct-nets: 2d and 3d networks trained on mito-em +- different strategies for training the initial rfs: + - `vanilla`: random forests are trained on randomly sampled dense patches + - `worst_points`: initial stage of forests (25 forests) are trained on random samples, forests in the next stages add worst predictions from prev. stage to their training set + - `uncertain_worst_points`: same as `worst_points`, but points are selected based on linear combination of uncertainty and worst predictions + +a + + +## Old evaluation + +Evaluation of older set-ups. ### V1 diff --git a/experiments/shallow2deep/em-mitochondria/evaluation.py b/experiments/shallow2deep/em-mitochondria/evaluation.py index 55be6e23..094bc1a1 100644 --- a/experiments/shallow2deep/em-mitochondria/evaluation.py +++ b/experiments/shallow2deep/em-mitochondria/evaluation.py @@ -1,219 +1,251 @@ +import json import os +from glob import glob import bioimageio.core import numpy as np +import pandas as pd + from elf.io import open_file from elf.evaluation import dice_score -from sklearn.metrics import f1_score -from torch_em.shallow2deep import evaluate_enhancers - - -# make cut-outs from mito-em for ilastik training and evaluation -def prepare_eval_v1(): - out_folder = "/g/kreshuk/pape/Work/data/mito_em/data/crops" - os.makedirs(out_folder, exist_ok=True) - - train_bb = np.s_[:50, :1024, :1024] - test_bb = np.s_[50:, -1024:, -1024:] - - input_path = "/scratch/pape/mito-em/human_val.n5" - with open_file(input_path, "r") as f: - dsr = f["raw"] - dsr.n_threads = 8 - raw_train, raw_test = dsr[train_bb], dsr[test_bb] - - dsl = f["labels"] - dsl.n_threads = 8 - labels_train, labels_test = dsl[train_bb], dsl[test_bb] +from ilastik.experimental.api import from_project_file +from tqdm import trange, tqdm +from xarray import DataArray - with open_file(os.path.join(out_folder, "crop_train.h5"), "a") as f: - f.create_dataset("raw", data=raw_train, compression="gzip") - f.create_dataset("labels", data=labels_train, compression="gzip") - with open_file(os.path.join(out_folder, "crop_test.h5"), "a") as f: - f.create_dataset("raw", data=raw_test, compression="gzip") - f.create_dataset("labels", data=labels_test, compression="gzip") - - -def prepare_eval_v2(): - in_path = "/g/kreshuk/data/VNC/data_labeled_mito.h5" - out_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" - with open_file(in_path, "r") as f: +def prepare_eval_v4(): + import napari + path = "/g/kreshuk/data/epfl/testing.h5" + with open_file(path, "r") as f: raw = f["raw"][:] - labels = f["label"][:] - raw = raw.astype("float32") / 255.0 - with open_file(out_path, "a") as f: - f.create_dataset("raw", data=raw, compression="gzip") - f.create_dataset("labels", data=labels, compression="gzip") + label = f["label"][:] + v = napari.Viewer() + v.add_image(raw) + v.add_labels(label) + napari.run() def dice_metric(pred, label): - assert pred.shape[2:] == label.shape - return dice_score(pred[0, 0], label, threshold_seg=None) - - -def f1_metric(pred, label): - assert pred.shape[2:] == label.shape - return f1_score(label.ravel() > 0, pred[0, 0].ravel() > 0.5) - - -def _evaluation( - data_path, rfs, enhancers, rf_channel, save_path, metric=dice_metric, raw_key="raw", label_key="labels", is2d=True -): - with open_file(data_path, "r") as f: - raw = f[raw_key][:] - labels = f[label_key][:] - if is2d: - prediction_function = None + if pred.ndim == 4: + pred = pred[0] + assert pred.shape == label.shape + return dice_score(pred, label, threshold_seg=None) + + +def require_rfs(data_path, rfs, save_path): + # check if we need to run any of the predictions + with open_file(save_path, "a") as f_save: + if all(name in f_save for name in rfs): + return + + with open_file(data_path, "r") as f: + data = f["raw"][:] + data = DataArray(data, dims=tuple("zyx")) + + for name, ilp_path in rfs.items(): + if name in f_save: + continue + print("Run prediction for ILP", name, ":", ilp_path, "...") + assert os.path.exists(ilp_path) + ilp = from_project_file(ilp_path) + pred = ilp.predict(data).values[..., 1] + assert pred.shape == data.shape + f_save.create_dataset(name, data=pred, compression="gzip") + + +def require_enhancers_2d(rfs, enhancers, save_path): + with open_file(save_path, "a") as f: + rf_data = {} + for enhancer_name, enhancer_path in enhancers.items(): + save_names = [f"{enhancer_name}-{rf_name}" for rf_name in rfs] + if all(name in f for name in save_names): + continue + enhancer = bioimageio.core.load_resource_description(enhancer_path) + with bioimageio.core.create_prediction_pipeline(enhancer) as pp: + for rf_name in rfs: + save_name = f"{enhancer_name}-{rf_name}" + if save_name in f: + continue + if rf_name not in rf_data: + rf_data[rf_name] = f[rf_name][:] + rf_pred = rf_data[rf_name] + pred = np.zeros((2,) + rf_pred.shape, dtype="float32") + for z in trange(rf_pred.shape[0], desc=f"Run prediction for {enhancer_name}-{rf_name}"): + inp = DataArray(rf_pred[z][None, None], dims=tuple("bcyx")) + predz = pp(inp)[0].values[0] + pred[:, z] = predz + f.create_dataset(save_name, data=pred, compression="gzip") + + +def require_enhancers_3d(rfs, enhancers, save_path): + tiling = { + "tile": {"z": 32, "y": 256, "x": 256}, + "halo": {"z": 4, "y": 32, "x": 32} + } + with open_file(save_path, "a") as f: + rf_data = {} + for enhancer_name, enhancer_path in enhancers.items(): + save_names = [f"{enhancer_name}-{rf_name}" for rf_name in rfs] + if all(name in f for name in save_names): + continue + enhancer = bioimageio.core.load_resource_description(enhancer_path) + with bioimageio.core.create_prediction_pipeline(enhancer) as pp: + for rf_name in rfs: + save_name = f"{enhancer_name}-{rf_name}" + if save_name in f: + continue + if rf_name not in rf_data: + rf_data[rf_name] = f[rf_name][:] + rf_pred = rf_data[rf_name] + inp = DataArray(rf_pred[None, None], dims=tuple("bczyx")) + pred = bioimageio.core.predict_with_tiling(pp, inp, tiling=tiling, verbose=True)[0].values[0] + f.create_dataset(save_name, data=pred, compression="gzip") + + +def require_net_2d(data_path, model_path, model_name, save_path): + with open_file(save_path, "a") as f_save: + if model_name in f_save: + return + model = bioimageio.core.load_resource_description(model_path) + with open_file(data_path, "r") as f: + raw = f["raw"][:] + + pred = np.zeros((2,) + raw.shape, dtype="float32") + with bioimageio.core.create_prediction_pipeline(model) as pp: + for z in trange(raw.shape[0], desc=f"Run prediction for model {model_name}"): + inp = DataArray(raw[z][None, None], dims=tuple("bcyx")) + pred[:, z] = pp(inp)[0].values[0] + f_save.create_dataset(model_name, data=pred, compression="gzip") + + +def require_net_3d(data_path, model_path, model_name, save_path): + tiling = { + "tile": {"z": 32, "y": 256, "x": 256}, + "halo": {"z": 4, "y": 32, "x": 32} + } + with open_file(save_path, "a") as f_save: + if model_name in f_save: + return + model = bioimageio.core.load_resource_description(model_path) + with open_file(data_path, "r") as f: + raw = f["raw"][:] + + pred = np.zeros((2,) + raw.shape, dtype="float32") + with bioimageio.core.create_prediction_pipeline(model) as pp: + inp = DataArray(raw[None, None], dims=tuple("bczyx")) + pred = bioimageio.core.predict_with_tiling(pp, inp, tiling=tiling, verbose=True)[0].values[0] + f_save.create_dataset(model_name, data=pred, compression="gzip") + + +def get_enhancers(root): + names = [os.path.basename(path) for path in glob(os.path.join(root, "s2d-em*"))] + enhancers_2d, enhancers_anisotropic = {}, {} + for name in names: + parts = name.split("-") + dim = parts[-2] + rf = parts[-1] + path = os.path.join(root, name, f"{name}.zip") + assert os.path.exists(path) + if dim == "anisotropic": + enhancers_anisotropic[f"{dim}-{rf}"] = path + elif dim == "2d": + enhancers_2d[f"{dim}-{rf}"] = path + assert len(enhancers_2d) > 0 + assert len(enhancers_anisotropic) > 0 + return enhancers_2d, enhancers_anisotropic + + +def run_evaluation(data_path, save_path, eval_path): + if os.path.exists(eval_path): + with open(save_path, "r") as f: + scores = json.load(f) else: - prediction_function = bioimageio.core.predict_with_tiling - scores = evaluate_enhancers( - raw, labels, enhancers, rfs, - metric=metric, is2d=is2d, rf_channel=rf_channel, save_path=save_path, - prediction_function=prediction_function - ) - return scores - - -def _direct_evaluation(data_path, model_path, save_path, raw_key="raw", label_key="labels", metric=dice_metric): - import bioimageio.core - import xarray - from tqdm import trange + scores = {} - model = bioimageio.core.load_resource_description(model_path) with open_file(data_path, "r") as f: - raw, labels = f[raw_key][:], f[label_key][:] - scores = [] - - save_key = "direct_predictions" - with open_file(save_path, "a") as f: - if save_key in f: - pred = f[save_key][:] - else: - with bioimageio.core.create_prediction_pipeline(model) as pp: - pred = [] - for z in trange(raw.shape[0]): - inp = xarray.DataArray(raw[z][None, None], dims=tuple("bcyx")) - predz = pp(inp)[0].values - pred.append(predz[None]) - pred = np.concatenate(pred) - f.create_dataset(save_key, data=pred, compression="gzip") + labels = f["label"][:] - for z in range(raw.shape[0]): - scores.append(metric(pred[z], labels[z])) + with open_file(save_path, "r") as f: + for name, ds in tqdm(f.items(), total=len(f), desc="Run evaluation"): + if name in scores: + continue + pred = ds[:] + score = dice_metric(pred, labels) + scores[name] = float(score) + return scores - return np.mean(scores) +def to_table(scores): + pass -def _direct_evaluation3d(data_path, model_path, save_path, raw_key="raw", label_key="labels", metric=dice_metric): - import xarray - model = bioimageio.core.load_resource_description(model_path) - with open_file(data_path, "r") as f: - raw, labels = f[raw_key][:], f[label_key][:] +def evaluation_v4(): + data_path = "/g/kreshuk/pape/Work/data/group_data/epfl/testing.h5" + rf_folder = "/g/kreshuk/pape/Work/data/epfl/ilastik-projects" + save_path = "./bio-models/v4/prediction.h5" - save_key = "direct_predictions" - with open_file(save_path, "a") as f: - if save_key in f: - pred = f[save_key][:] - else: - with bioimageio.core.create_prediction_pipeline(model) as pp: - inp = xarray.DataArray(raw[None, None], dims=tuple("bczyx")) - pred = bioimageio.core.predict_with_tiling(pp, inp, verbose=True) - pred = pred[0].values - f.create_dataset(save_key, data=pred, compression="gzip") - - score = metric(pred, labels) - return score - - -def evaluation_v1(): - data_root = "/g/kreshuk/pape/Work/data/mito_em/data/crops" - data_path = os.path.join(data_root, "crop_test.h5") - rfs = { - "few-labels": os.path.join(data_root, "rfs", "rf1.ilp"), - "many-labels": os.path.join(data_root, "rfs", "rf3.ilp"), - } - enhancers = { - "vanilla-enhancer": "./bio-models/v1/EnhancerMitochondriaEM2D/EnhancerMitochondriaEM2D.zip", - "advanced-enhancer": "./bio-models/v1/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip", - } - save_path = "./bio-models/v1/prediction.h5" - scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path) - - model_path = "./bio-models/v1/DirectModel/mitchondriaemsegmentation2d_pytorch_state_dict.zip" - score_raw = _direct_evaluation(data_path, model_path, save_path) + # rfs = { + # "few-labels": os.path.join(rf_folder, "2d-1.ilp"), + # "medium-labels": os.path.join(rf_folder, "2d-2.ilp"), + # "many-labels": os.path.join(rf_folder, "2d-3.ilp"), + # } + # require_rfs(data_path, rfs, save_path) - enhancers = { - "direct-net": "./bio-models/v1/DirectModel/mitchondriaemsegmentation2d_pytorch_state_dict.zip", - } - save_path = "./bio-models/v2/prediction-direct.h5" - scores_direct = _evaluation(data_path, rfs, enhancers, rf_channel=0, save_path=save_path) - scores = scores.append(scores_direct.iloc[0]) - - print("Evaluation results:") - print(scores.to_markdown()) - print("Raw net evaluation:", score_raw) + # enhancers_2d, enhancers_anisotropic = get_enhancers("./bio-models/v4") + # require_enhancers_2d(rfs, enhancers_2d, save_path) + # require_enhancers_3d(rfs, enhancers_anisotropic, save_path) + # net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" + # require_net_2d(data_path, net2d, "direct2d", save_path) + # net3d = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" + # require_net_3d(data_path, net3d, "direct3d", save_path) -def evaluation_v2(): - data_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" - rf_folder = "/g/kreshuk/pape/Work/data/vnc/ilps" + # for debugging rfs = { - "few-labels": os.path.join(rf_folder, "vnc-mito1.ilp"), - "medium-labels": os.path.join(rf_folder, "vnc-mito3.ilp"), - "many-labels": os.path.join(rf_folder, "vnc-mito6.ilp"), - } - enhancers = { - "vanilla-enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D/EnhancerMitochondriaEM2D.zip", - "advanced-enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip", + "many-labels": os.path.join(rf_folder, "2d-3.ilp"), } - save_path = "./bio-models/v2/prediction.h5" - scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path) - - model_path = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" - score_raw = _direct_evaluation(data_path, model_path, save_path) + require_rfs(data_path, rfs, save_path) - enhancers = { - "direct-net": "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip", + enhancers_2d = { + "enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip" } - save_path = "./bio-models/v2/prediction-direct.h5" - scores_direct = _evaluation(data_path, rfs, enhancers, rf_channel=0, save_path=save_path) - scores = scores.append(scores_direct.iloc[0]) + require_enhancers_2d(rfs, enhancers_2d, save_path) + return + eval_path = "./bio-models/v4/eval.json" + scores = run_evaluation(data_path, save_path, eval_path) + scores = to_table(scores) print("Evaluation results:") print(scores.to_markdown()) - print("Raw net evaluation:", score_raw) -def evaluation_v3(): - data_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" - rf_folder = "/g/kreshuk/pape/Work/data/vnc/ilps3d" +def debug_v4(): + import napari + data_path = "/g/kreshuk/pape/Work/data/group_data/epfl/testing.h5" + save_path = "./bio-models/v4/prediction.h5" - rfs = { - "few-labels": os.path.join(rf_folder, "vnc-mito1.ilp"), - "medium-labels": os.path.join(rf_folder, "vnc-mito2.ilp"), - "many-labels": os.path.join(rf_folder, "vnc-mito3.ilp"), - } - enhancers = { - "vanilla-enhancer": "./bio-models/v3/EnhancerMitochondriaEM3D/EnhancerMitochondriaEM3D.zip", - "advanced-enhancer": "./bio-models/v3/EnhancerMitochondriaEM3D-advanced-traing/EnhancerMitochondriaEM3D.zip", - } - save_path = "./bio-models/v3/prediction.h5" - scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path, is2d=False) + print("Load data") + with open_file(data_path, "r") as f: + data = f["raw"][:] + labels = f["label"][:] - model_path = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" - score_raw = _direct_evaluation3d(data_path, model_path, save_path) + with open_file(save_path, "r") as f: + preds = {} + for name, ds in tqdm(f.items(), total=len(f)): + if ("labels" in name) and ("many" not in name): + continue + preds[name] = ds[:] - print("Evaluation results:") - print(scores.to_markdown()) - print("Raw net evaluation:", score_raw) + print("Start viewer") + v = napari.Viewer() + v.add_image(data) + v.add_labels(labels) + for name, pred in preds.items(): + v.add_image(pred, name=name) + napari.run() if __name__ == "__main__": - # prepare_eval_v1() - # prepare_eval_v2() - - # evaluation_v1() - # evaluation_v2() - evaluation_v3() + # prepare_eval_v4() + # evaluation_v4() + debug_v4() diff --git a/experiments/shallow2deep/em-mitochondria/old_evaluation.py b/experiments/shallow2deep/em-mitochondria/old_evaluation.py new file mode 100644 index 00000000..7d9572dc --- /dev/null +++ b/experiments/shallow2deep/em-mitochondria/old_evaluation.py @@ -0,0 +1,238 @@ +import os + +import bioimageio.core +import numpy as np +import pandas as pd +from elf.io import open_file +from elf.evaluation import dice_score +from sklearn.metrics import f1_score +from torch_em.shallow2deep import evaluate_enhancers + + +# make cut-outs from mito-em for ilastik training and evaluation +def prepare_eval_v1(): + out_folder = "/g/kreshuk/pape/Work/data/mito_em/data/crops" + os.makedirs(out_folder, exist_ok=True) + + train_bb = np.s_[:50, :1024, :1024] + test_bb = np.s_[50:, -1024:, -1024:] + + input_path = "/scratch/pape/mito-em/human_val.n5" + with open_file(input_path, "r") as f: + dsr = f["raw"] + dsr.n_threads = 8 + raw_train, raw_test = dsr[train_bb], dsr[test_bb] + + dsl = f["labels"] + dsl.n_threads = 8 + labels_train, labels_test = dsl[train_bb], dsl[test_bb] + + with open_file(os.path.join(out_folder, "crop_train.h5"), "a") as f: + f.create_dataset("raw", data=raw_train, compression="gzip") + f.create_dataset("labels", data=labels_train, compression="gzip") + + with open_file(os.path.join(out_folder, "crop_test.h5"), "a") as f: + f.create_dataset("raw", data=raw_test, compression="gzip") + f.create_dataset("labels", data=labels_test, compression="gzip") + + +def prepare_eval_v2(): + in_path = "/g/kreshuk/data/VNC/data_labeled_mito.h5" + out_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" + with open_file(in_path, "r") as f: + raw = f["raw"][:] + labels = f["label"][:] + raw = raw.astype("float32") / 255.0 + with open_file(out_path, "a") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels", data=labels, compression="gzip") + + +def dice_metric(pred, label): + assert pred.shape[2:] == label.shape + return dice_score(pred[0, 0], label, threshold_seg=None) + + +def f1_metric(pred, label): + assert pred.shape[2:] == label.shape + return f1_score(label.ravel() > 0, pred[0, 0].ravel() > 0.5) + + +def _evaluation( + data_path, rfs, enhancers, rf_channel, save_path, metric=dice_metric, raw_key="raw", label_key="labels", is2d=True +): + with open_file(data_path, "r") as f: + raw = f[raw_key][:] + labels = f[label_key][:] + if is2d: + prediction_function = None + else: + prediction_function = bioimageio.core.predict_with_tiling + scores = evaluate_enhancers( + raw, labels, enhancers, rfs, + metric=metric, is2d=is2d, rf_channel=rf_channel, save_path=save_path, + prediction_function=prediction_function + ) + return scores + + +def _direct_evaluation(data_path, model_path, save_path, raw_key="raw", label_key="labels", metric=dice_metric): + import bioimageio.core + import xarray + from tqdm import trange + + model = bioimageio.core.load_resource_description(model_path) + with open_file(data_path, "r") as f: + raw, labels = f[raw_key][:], f[label_key][:] + scores = [] + + save_key = "direct_predictions" + with open_file(save_path, "a") as f: + if save_key in f: + pred = f[save_key][:] + else: + with bioimageio.core.create_prediction_pipeline(model) as pp: + pred = [] + for z in trange(raw.shape[0]): + inp = xarray.DataArray(raw[z][None, None], dims=tuple("bcyx")) + predz = pp(inp)[0].values + pred.append(predz[None]) + pred = np.concatenate(pred) + f.create_dataset(save_key, data=pred, compression="gzip") + + for z in range(raw.shape[0]): + scores.append(metric(pred[z], labels[z])) + + return np.mean(scores) + + +def _direct_evaluation3d(data_path, model_path, save_path, raw_key="raw", label_key="labels", metric=dice_metric): + import xarray + + model = bioimageio.core.load_resource_description(model_path) + with open_file(data_path, "r") as f: + raw, labels = f[raw_key][:], f[label_key][:] + + save_key = "direct_predictions" + with open_file(save_path, "a") as f: + if save_key in f: + pred = f[save_key][:] + else: + with bioimageio.core.create_prediction_pipeline(model) as pp: + inp = xarray.DataArray(raw[None, None], dims=tuple("bczyx")) + pred = bioimageio.core.predict_with_tiling(pp, inp, verbose=True) + pred = pred[0].values + f.create_dataset(save_key, data=pred, compression="gzip") + + score = metric(pred, labels) + return score + + +def evaluation_v1(): + data_root = "/g/kreshuk/pape/Work/data/mito_em/data/crops" + data_path = os.path.join(data_root, "crop_test.h5") + rfs = { + "few-labels": os.path.join(data_root, "rfs", "rf1.ilp"), + "many-labels": os.path.join(data_root, "rfs", "rf3.ilp"), + } + enhancers = { + "vanilla-enhancer": "./bio-models/v1/EnhancerMitochondriaEM2D/EnhancerMitochondriaEM2D.zip", + "advanced-enhancer": "./bio-models/v1/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip", + } + save_path = "./bio-models/v1/prediction.h5" + scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path) + + model_path = "./bio-models/v1/DirectModel/mitchondriaemsegmentation2d_pytorch_state_dict.zip" + score_raw = _direct_evaluation(data_path, model_path, save_path) + + enhancers = { + "direct-net": "./bio-models/v1/DirectModel/mitchondriaemsegmentation2d_pytorch_state_dict.zip", + } + save_path = "./bio-models/v2/prediction-direct.h5" + scores_direct = _evaluation(data_path, rfs, enhancers, rf_channel=0, save_path=save_path) + scores = scores.append(scores_direct.iloc[0]) + + print("Evaluation results:") + print(scores.to_markdown()) + print("Raw net evaluation:", score_raw) + + +def evaluation_v2(): + data_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" + rf_folder = "/g/kreshuk/pape/Work/data/vnc/ilps" + rfs = { + "few-labels": os.path.join(rf_folder, "vnc-mito1.ilp"), + "medium-labels": os.path.join(rf_folder, "vnc-mito3.ilp"), + "many-labels": os.path.join(rf_folder, "vnc-mito6.ilp"), + } + enhancers = { + "vanilla-enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D/EnhancerMitochondriaEM2D.zip", + "advanced-enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip", + } + save_path = "./bio-models/v2/prediction.h5" + scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path) + + model_path = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" + score_raw = _direct_evaluation(data_path, model_path, save_path) + + enhancers = { + "direct-net": "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip", + } + save_path = "./bio-models/v2/prediction-direct.h5" + scores_direct = _evaluation(data_path, rfs, enhancers, rf_channel=0, save_path=save_path) + scores = scores.append(scores_direct.iloc[0]) + + print("Evaluation results:") + print(scores.to_markdown()) + print("Raw net evaluation:", score_raw) + + +def evaluation_v3(): + data_path = "/g/kreshuk/pape/Work/data/isbi/vnc-mitos.h5" + rf_folder = "/g/kreshuk/pape/Work/data/vnc/ilps3d" + + rfs = { + "few-labels": os.path.join(rf_folder, "vnc-mito1.ilp"), + "medium-labels": os.path.join(rf_folder, "vnc-mito2.ilp"), + "many-labels": os.path.join(rf_folder, "vnc-mito3.ilp"), + } + enhancers = { + "vanilla-enhancer": "./bio-models/v3/EnhancerMitochondriaEM3D/EnhancerMitochondriaEM3D.zip", + "advanced-enhancer": "./bio-models/v3/EnhancerMitochondriaEM3D-advanced-traing/EnhancerMitochondriaEM3D.zip", + } + save_path = "./bio-models/v3/prediction.h5" + scores = _evaluation(data_path, rfs, enhancers, rf_channel=1, save_path=save_path, is2d=False) + + model_path = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" + score_raw = _direct_evaluation3d(data_path, model_path, save_path) + + print("Evaluation results:") + print(scores.to_markdown()) + print("Raw net evaluation:", score_raw) + + +def get_enhancers(root): + names = os.listdir(root) + enhancers_2d, enhancers_anisotropic = {}, {} + for name in names: + parts = name.split("-") + dim = parts[-2] + rf = parts[-1] + path = os.path.join(root, name, f"{name}.zip") + assert os.path.exists(path) + if dim == "anisotropic": + enhancers_anisotropic[f"{dim}-{rf}"] = path + elif dim == "2d": + enhancers_2d[f"{dim}-{rf}"] = path + assert len(enhancers_2d) > 0 + assert len(enhancers_anisotropic) > 0 + return enhancers_2d, enhancers_anisotropic + + +if __name__ == "__main__": + # prepare_eval_v1() + # prepare_eval_v2() + + # evaluation_v1() + # evaluation_v2() + evaluation_v3() diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py index 37e3443b..6533d111 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py @@ -143,6 +143,19 @@ def train_shallow2deep(args): trainer.fit(args.n_iterations) +def check(args, train=True, val=True, n_images=2): + from torch_em.util.debug import check_loader + datasets = normalize_datasets(args.datasets) + if train: + print("Check train loader") + loader = get_loader(args, "train", datasets) + check_loader(loader, n_images) + if val: + print("Check val loader") + loader = get_loader(args, "val", datasets) + check_loader(loader, n_images) + + if __name__ == "__main__": parser = torch_em.util.parser_helper(require_input=False, default_batch_size=4) parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) @@ -150,4 +163,7 @@ def train_shallow2deep(args): parser.add_argument("--n_threads", type=int, default=32) parser.add_argument("--sampling_strategy", "-s", default=None) args = parser.parse_args() - train_shallow2deep(args) + if args.check: + check(args, n_images=3) + else: + train_shallow2deep(args) diff --git a/experiments/shallow2deep/em-mitochondria/visualize_rfs.py b/experiments/shallow2deep/em-mitochondria/visualize_rfs.py index 67654073..2e7b998e 100644 --- a/experiments/shallow2deep/em-mitochondria/visualize_rfs.py +++ b/experiments/shallow2deep/em-mitochondria/visualize_rfs.py @@ -1,21 +1,37 @@ import argparse -import h5py +import os + +from elf.io import open_file from torch_em.shallow2deep import visualize_pretrained_rfs from torch_em.transform.raw import normalize +ROOT = "/scratch/pape/s2d-mitochondria" + + def visualize_rfs(): parser = argparse.ArgumentParser() - parser.add_argument("-c", "--checkpoint", required=True) + parser.add_argument("-d", "--dataset", default="mitoem") + parser.add_argument("-n", "--rf_name", default="rfs2d") args = parser.parse_args() - n_forests = 24 - raw_path = "/scratch/pape/vnc/vnc_test.h5" - with h5py.File(raw_path, "r") as f: - raw = f["raw"][10] + dataset = args.dataset + assert dataset in ("mitoem", "vnc") + if dataset == "mitoem": + raw_path = os.path.join(ROOT, dataset, "human_test.n5") + else: + raw_path = os.path.join(ROOT, dataset, "vnc_test.n5") + assert os.path.exists(raw_path), raw_path + + rf_folder = os.path.join(ROOT, args.rf_name, dataset) + assert os.path.exists(rf_folder), rf_folder + + with open_file(raw_path, "r") as f: + raw = f["raw"][0, :1024, :1024] raw = normalize(raw) - visualize_pretrained_rfs(args.checkpoint, raw, n_forests, n_threads=8) + n_forests = 24 + visualize_pretrained_rfs(rf_folder, raw, n_forests, n_threads=8) if __name__ == "__main__": diff --git a/torch_em/shallow2deep/shallow2deep_eval.py b/torch_em/shallow2deep/shallow2deep_eval.py index b472c564..c76e85d8 100644 --- a/torch_em/shallow2deep/shallow2deep_eval.py +++ b/torch_em/shallow2deep/shallow2deep_eval.py @@ -27,9 +27,12 @@ def visualize_pretrained_rfs(checkpoint, raw, n_forests, """ import napari - rf_folder = os.path.join(checkpoint, "rfs") - assert os.path.exists(rf_folder), rf_folder - rf_paths = glob(os.path.join(rf_folder, "*.pkl")) + rf_paths = glob(os.path.join(checkpoint, "*.pkl")) + if len(rf_paths) == 0: + rf_folder = os.path.join(checkpoint, "rfs") + assert os.path.exists(rf_folder), rf_folder + rf_paths = glob(os.path.join(rf_folder, "*.pkl")) + assert len(rf_paths) > 0 rf_paths.sort() if sample_random: rf_paths = np.random.choice(rf_paths, size=n_forests) diff --git a/torch_em/util/debug.py b/torch_em/util/debug.py index 9b192572..61eda404 100644 --- a/torch_em/util/debug.py +++ b/torch_em/util/debug.py @@ -79,10 +79,10 @@ def _check_napari(loader, n_samples, instance_labels, model=None, device=None): pred = None else: pred = model(x if device is None else x.to(device)) - pred = ensure_array(pred).squeeze(0) + pred = ensure_array(pred)[0] - x = ensure_array(x).squeeze(0) - y = ensure_array(y).squeeze(0) + x = ensure_array(x)[0] + y = ensure_array(y)[0] v = napari.Viewer() v.add_image(x) From a9933e9ced9c00972e098ad6dfa20dbc16fcc9d6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 19 Jul 2022 13:09:14 +0200 Subject: [PATCH 07/31] Fix several issues in new mitochondria s2d training --- .../em-mitochondria/evaluation.py | 42 +++++++------------ .../em-mitochondria/train_mito_2d.py | 25 +++++------ .../em-mitochondria/train_mito_anisotropic.py | 37 ++++++++++------ 3 files changed, 51 insertions(+), 53 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/evaluation.py b/experiments/shallow2deep/em-mitochondria/evaluation.py index 094bc1a1..724717ca 100644 --- a/experiments/shallow2deep/em-mitochondria/evaluation.py +++ b/experiments/shallow2deep/em-mitochondria/evaluation.py @@ -142,14 +142,13 @@ def get_enhancers(root): enhancers_2d, enhancers_anisotropic = {}, {} for name in names: parts = name.split("-") - dim = parts[-2] - rf = parts[-1] + sampling_strategy, dim = parts[-1], parts[-2] path = os.path.join(root, name, f"{name}.zip") assert os.path.exists(path) if dim == "anisotropic": - enhancers_anisotropic[f"{dim}-{rf}"] = path + enhancers_anisotropic[f"{dim}-{sampling_strategy}"] = path elif dim == "2d": - enhancers_2d[f"{dim}-{rf}"] = path + enhancers_2d[f"{dim}-{sampling_strategy}"] = path assert len(enhancers_2d) > 0 assert len(enhancers_anisotropic) > 0 return enhancers_2d, enhancers_anisotropic @@ -175,6 +174,7 @@ def run_evaluation(data_path, save_path, eval_path): return scores +# TODO def to_table(scores): pass @@ -184,33 +184,21 @@ def evaluation_v4(): rf_folder = "/g/kreshuk/pape/Work/data/epfl/ilastik-projects" save_path = "./bio-models/v4/prediction.h5" - # rfs = { - # "few-labels": os.path.join(rf_folder, "2d-1.ilp"), - # "medium-labels": os.path.join(rf_folder, "2d-2.ilp"), - # "many-labels": os.path.join(rf_folder, "2d-3.ilp"), - # } - # require_rfs(data_path, rfs, save_path) - - # enhancers_2d, enhancers_anisotropic = get_enhancers("./bio-models/v4") - # require_enhancers_2d(rfs, enhancers_2d, save_path) - # require_enhancers_3d(rfs, enhancers_anisotropic, save_path) - - # net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" - # require_net_2d(data_path, net2d, "direct2d", save_path) - # net3d = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" - # require_net_3d(data_path, net3d, "direct3d", save_path) - - # for debugging rfs = { + "few-labels": os.path.join(rf_folder, "2d-1.ilp"), + "medium-labels": os.path.join(rf_folder, "2d-2.ilp"), "many-labels": os.path.join(rf_folder, "2d-3.ilp"), } require_rfs(data_path, rfs, save_path) - enhancers_2d = { - "enhancer": "./bio-models/v2/EnhancerMitochondriaEM2D-advanced-traing/EnhancerMitochondriaEM2D.zip" - } + enhancers_2d, enhancers_anisotropic = get_enhancers("./bio-models/v4") require_enhancers_2d(rfs, enhancers_2d, save_path) - return + require_enhancers_3d(rfs, enhancers_anisotropic, save_path) + + net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" + require_net_2d(data_path, net2d, "direct2d", save_path) + net3d = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" + require_net_3d(data_path, net3d, "direct3d", save_path) eval_path = "./bio-models/v4/eval.json" scores = run_evaluation(data_path, save_path, eval_path) @@ -247,5 +235,5 @@ def debug_v4(): if __name__ == "__main__": # prepare_eval_v4() - # evaluation_v4() - debug_v4() + evaluation_v4() + # debug_v4() diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py index 6533d111..d134a6c2 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py @@ -42,10 +42,7 @@ def require_ds(dataset): def require_rfs_ds(dataset, n_rfs, sampling_strategy): - if sampling_strategy is None: - out_folder = os.path.join(DATA_ROOT, "rfs2d", dataset) - else: - out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) + out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) os.makedirs(out_folder, exist_ok=True) if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: return @@ -68,7 +65,6 @@ def require_rfs_ds(dataset, n_rfs, sampling_strategy): is_seg_dataset=True, ) else: - sampling_strategy = "worst_points" if sampling_strategy is None else sampling_strategy shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, @@ -86,6 +82,7 @@ def require_rfs(datasets, n_rfs, sampling_strategy): def get_ds(file_pattern, rf_pattern, n_samples, label_key): + raw_transform = torch_em.transform.raw.normalize label_transform = torch_em.transform.BoundaryTransform(ndim=2, add_binary_target=True) patch_shape = (1, 512, 512) paths = glob(file_pattern) @@ -97,7 +94,9 @@ def get_ds(file_pattern, rf_pattern, n_samples, label_key): raw_key = "raw" return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( paths, raw_key, paths, label_key, rf_paths, - patch_shape=patch_shape, label_transform=label_transform, + patch_shape=patch_shape, + raw_transform=raw_transform, + label_transform=label_transform, n_samples=n_samples, ndim=2, ) @@ -108,12 +107,12 @@ def get_loader(args, split, dataset_names): if "mitoem" in dataset_names: ds_name = "mitoem" file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split}.n5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels")) if "vnc" in dataset_names and split == "train": ds_name = "vnc" - file_pattern = os.path.join(DATA_ROOT, ds_name, f"vnc_{split}.h5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split}.h5") + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels/mitochondria")) ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) if len(datasets) > 1 else datasets[0] loader = torch.utils.data.DataLoader( @@ -125,9 +124,7 @@ def get_loader(args, split, dataset_names): def train_shallow2deep(args): datasets = normalize_datasets(args.datasets) - name = f"s2d-em-mitos-{'_'.join(datasets)}-2d" - if args.sampling_strategy is not None: - name += f"-{args.sampling_strategy}" + name = f"s2d-em-mitos-{'_'.join(datasets)}-2d-{args.sampling_strategy}" require_rfs(datasets, args.n_rfs, args.sampling_strategy) model = UNet2d(in_channels=1, out_channels=2, final_activation="Sigmoid", @@ -161,9 +158,9 @@ def check(args, train=True, val=True, n_images=2): parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) - parser.add_argument("--sampling_strategy", "-s", default=None) + parser.add_argument("--sampling_strategy", "-s", default="worst_points") args = parser.parse_args() if args.check: - check(args, n_images=3) + check(args, n_images=5, val=False) else: train_shallow2deep(args) diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py index 9619f767..5a854544 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py @@ -37,10 +37,7 @@ def require_ds(dataset): def require_rfs_ds(dataset, n_rfs, sampling_strategy): - if sampling_strategy is None: - out_folder = os.path.join(DATA_ROOT, "rfs2d", dataset) - else: - out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) + out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) os.makedirs(out_folder, exist_ok=True) if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: return @@ -62,7 +59,6 @@ def require_rfs_ds(dataset, n_rfs, sampling_strategy): is_seg_dataset=True, ) else: - sampling_strategy = "worst_points" if sampling_strategy is None else sampling_strategy shallow2deep.prepare_shallow2deep_advanced( raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, @@ -80,6 +76,7 @@ def require_rfs(datasets, n_rfs, sampling_strategy): def get_ds(file_pattern, rf_pattern, n_samples, label_key): + raw_transform = torch_em.transform.raw.normalize label_transform = torch_em.transform.BoundaryTransform(ndim=3, add_binary_target=True) patch_shape = (32, 256, 256) paths = glob(file_pattern) @@ -91,7 +88,9 @@ def get_ds(file_pattern, rf_pattern, n_samples, label_key): raw_key = "raw" return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( paths, raw_key, paths, label_key, rf_paths, - patch_shape=patch_shape, label_transform=label_transform, + patch_shape=patch_shape, + raw_transform=raw_transform, + label_transform=label_transform, n_samples=n_samples, ndim="anisotropic", ) @@ -102,7 +101,7 @@ def get_loader(args, split, dataset_names): if "mitoem" in dataset_names: ds_name = "mitoem" file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split}.n5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") datasets.append(get_ds(file_pattern, rf_pattern, n_samples, label_key="labels")) ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) if len(datasets) > 1 else datasets[0] loader = torch.utils.data.DataLoader( @@ -114,9 +113,7 @@ def get_loader(args, split, dataset_names): def train_shallow2deep(args): datasets = normalize_datasets(args.datasets) - name = f"s2d-em-mitos-{'_'.join(datasets)}-anisotropic" - if args.sampling_strategy is not None: - name += f"-{args.sampling_strategy}" + name = f"s2d-em-mitos-{'_'.join(datasets)}-anisotropic-{args.sampling_strategy}" require_rfs(datasets, args.n_rfs, args.sampling_strategy) scale_factors = [[1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2]] @@ -133,11 +130,27 @@ def train_shallow2deep(args): trainer.fit(args.n_iterations) +def check(args, train=True, val=True, n_images=2): + from torch_em.util.debug import check_loader + datasets = normalize_datasets(args.datasets) + if train: + print("Check train loader") + loader = get_loader(args, "train", datasets) + check_loader(loader, n_images) + if val: + print("Check val loader") + loader = get_loader(args, "val", datasets) + check_loader(loader, n_images) + + if __name__ == "__main__": parser = torch_em.util.parser_helper(require_input=False) parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) - parser.add_argument("--sampling_strategy", "-s", default=None) + parser.add_argument("--sampling_strategy", "-s", default="worst_points") args = parser.parse_args() - train_shallow2deep(args) + if args.check: + check(args, n_images=5, val=False) + else: + train_shallow2deep(args) From 5d84059584d7d12cbccf20149e7eb474689f3fc0 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 19 Jul 2022 15:05:32 +0200 Subject: [PATCH 08/31] Fix score based sampling --- torch_em/shallow2deep/prepare_shallow2deep.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index 0c123c5b..a86be04a 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -410,7 +410,8 @@ def _score_based_points( n_samples = int(sample_fraction_per_stage * len(features)) n_samples_class = n_samples // nc for class_id in range(nc): - this_samples = np.argsort(score[labels == class_id])[::-1][:n_samples_class] + class_indices = np.where(labels == class_id)[0] + this_samples = class_indices[np.argsort(score[class_indices])[::-1][:n_samples_class]] samples.append(this_samples) samples = np.concatenate(samples) From 92ac2a5e6ada9110b2eba2cf876c7b9294db75b7 Mon Sep 17 00:00:00 2001 From: JonasHell Date: Wed, 20 Jul 2022 14:47:21 +0200 Subject: [PATCH 09/31] add worst_tiles sampling --- torch_em/shallow2deep/prepare_shallow2deep.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index a86be04a..d334f8e0 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -1,10 +1,12 @@ import os +import copy import pickle from concurrent import futures from glob import glob from functools import partial import numpy as np +import torch import torch_em from sklearn.ensemble import RandomForestClassifier from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets @@ -429,6 +431,7 @@ def worst_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, + sampling_kwargs={}, ): def score(pred, labels): # labels to one-hot encoding @@ -447,6 +450,7 @@ def uncertain_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, + sampling_kwargs={}, ): def score(pred, labels): assert pred.ndim == 2 @@ -465,6 +469,7 @@ def uncertain_worst_points( sample_fraction_per_stage, accumulate_samples=True, alpha=0.5, + sampling_kwargs={}, ): def score(pred, labels): assert pred.ndim == 2 @@ -504,11 +509,94 @@ def random_points( return features[samples], labels[samples] +def worst_tiles( + features, labels, rf_id, + forests, forests_per_stage, + sample_fraction_per_stage, + sampling_kwargs, + accumulate_samples=True, +): + # check inputs + img_shape = sampling_kwargs.get('img_shape', None) + assert len(img_shape) in [2, 3], img_shape + tiles_shape = sampling_kwargs.get('tiles_shape', None) + assert len(tiles_shape) in [2, 3], tiles_shape + + # get the corresponding random forest from the last stage + # and predict with it + last_forest = forests[rf_id - forests_per_stage] + pred = last_forest.predict_proba(features) + + # labels to one-hot encoding + unique, inverse = np.unique(labels, return_inverse=True) + onehot = np.eye(unique.shape[0])[inverse] + + # compute the difference between labels and prediction + diff = np.abs(onehot - pred).sum(axis=1) + assert len(diff) == len(features) + + # reshape diff to image shape and apply convolution with 1-kernel + diff_img = torch.Tensor(diff.reshape(img_shape)[None, None]) + kernel = torch.Tensor(np.ones(tiles_shape)[None, None]) + diff_img_smooth = torch.nn.functional.conv2d(diff_img, weight=kernel) if len(img_shape)==2 \ + else torch.nn.functional.conv3d(diff_img, weight=kernel) + diff_img_smooth = diff_img_smooth.detach().numpy().squeeze() + diff_img_shape = diff_img_smooth.shape + diff_img_smooth = diff_img_smooth.flatten() + + # define lambda functions for better readability + # need to add tiles_shape[i] // 2 to get indices in original img + start = lambda i: 0 # -(tiles_shape[i] // 2) + end = lambda i: tiles_shape[i] # tiles_shape[i] // 2 + 1 + + # get training samples based on tiles around maxima + # of the label-prediction diff + nc = len(np.unique(labels)) + n_samples_class = int(sample_fraction_per_stage * len(features)) // nc + + # sample in a class balanced way + samples_per_class = [[]]*nc + indices_per_class = [np.where(labels == class_id)[0] for class_id in range(nc)] + + # get maxima of the label-prediction diff + max_centers = np.argsort(diff_img_smooth)[::-1] + max_centers = np.unravel_index(max_centers, diff_img_shape) + max_centers = np.array(max_centers).swapaxes(0, 1) + for center in max_centers: + # get tile for each maximum + samples_in_tile = [(center[0]+y, center[1]+x) for y in range(start(0), end(0)) for x in range(start(1), end(1))] if len(center) == 2 \ + else [(center[0]+z, center[1]+y, center[2]+x) for z in range(start(0), end(0)) for y in range(start(1), end(1)) for x in range(start(2), end(2))] + samples_in_tile = np.array(samples_in_tile).swapaxes(0, 1) + samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) + + for samples_this_class, indices_this_class in zip(samples_per_class, indices_per_class): + # check if current class is already full + if len(samples_this_class) < n_samples_class: + intersect = np.intersect1d(indices_this_class, samples_in_tile) + # make sure to not use duplicates + samples_this_class.extend(intersect.tolist()) + samples_this_class = list(dict.fromkeys(samples_this_class)) + + # stop when there are enough samples in each class + if all([len(samples_this_class) >= n_samples_class for samples_this_class in samples_per_class]): + break + samples = np.concatenate(samples_per_class) + + # get the features and labels, add from previous rf if specified + features, labels = features[samples], labels[samples] + if accumulate_samples: + features = np.concatenate([last_forest.train_features, features], axis=0) + labels = np.concatenate([last_forest.train_labels, labels], axis=0) + + return features, labels + + SAMPLING_STRATEGIES = { "random_points": random_points, "uncertain_points": uncertain_points, "uncertain_worst_points": uncertain_worst_points, "worst_points": worst_points, + "worst_tiles": worst_tiles, } @@ -526,6 +614,7 @@ def prepare_shallow2deep_advanced( forests_per_stage, sample_fraction_per_stage, sampling_strategy="worst_points", + sampling_kwargs={}, raw_transform=None, label_transform=None, rois=None, @@ -576,6 +665,11 @@ def _train_rf(rf_id): raw, labels = raw.numpy().squeeze(), labels.numpy().astype("int8").squeeze() assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" + # monkey patch original shape to sampling_kwargs + # deepcopy needed due to multiprocessing + current_kwargs = copy.deepcopy(sampling_kwargs) + current_kwargs['img_shape'] = raw.shape + # only balance samples for the first (densely trained) rfs features, labels = _get_features_and_labels( raw, labels, filters_and_sigmas, balance_labels=False @@ -585,6 +679,7 @@ def _train_rf(rf_id): features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, + sampling_kwargs=current_kwargs, ) else: # sample randomly features, labels = random_points( From aa4e1db0b695852bddd13b0e88865fc1fa8358a4 Mon Sep 17 00:00:00 2001 From: JonasHell Date: Thu, 21 Jul 2022 15:26:38 +0200 Subject: [PATCH 10/31] make worst_tiles easier to read, use local maxima --- torch_em/shallow2deep/prepare_shallow2deep.py | 94 +++++++++---------- 1 file changed, 42 insertions(+), 52 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index d334f8e0..d44336f3 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -6,8 +6,9 @@ from functools import partial import numpy as np -import torch import torch_em +from scipy.ndimage import gaussian_filter, convolve +from skimage.feature import peak_local_max from sklearn.ensemble import RandomForestClassifier from torch_em.segmentation import check_paths, is_segmentation_dataset, samples_to_datasets from tqdm import tqdm @@ -431,7 +432,6 @@ def worst_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, - sampling_kwargs={}, ): def score(pred, labels): # labels to one-hot encoding @@ -450,7 +450,6 @@ def uncertain_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, - sampling_kwargs={}, ): def score(pred, labels): assert pred.ndim == 2 @@ -469,7 +468,6 @@ def uncertain_worst_points( sample_fraction_per_stage, accumulate_samples=True, alpha=0.5, - sampling_kwargs={}, ): def score(pred, labels): assert pred.ndim == 2 @@ -513,14 +511,15 @@ def worst_tiles( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, - sampling_kwargs, + img_shape, + tiles_shape=[51, 51], + smoothing_sigma=None, accumulate_samples=True, ): # check inputs - img_shape = sampling_kwargs.get('img_shape', None) - assert len(img_shape) in [2, 3], img_shape - tiles_shape = sampling_kwargs.get('tiles_shape', None) - assert len(tiles_shape) in [2, 3], tiles_shape + ndim = len(img_shape) + assert ndim in [2, 3], img_shape + assert len(tiles_shape) == ndim, tiles_shape # get the corresponding random forest from the last stage # and predict with it @@ -535,52 +534,43 @@ def worst_tiles( diff = np.abs(onehot - pred).sum(axis=1) assert len(diff) == len(features) - # reshape diff to image shape and apply convolution with 1-kernel - diff_img = torch.Tensor(diff.reshape(img_shape)[None, None]) - kernel = torch.Tensor(np.ones(tiles_shape)[None, None]) - diff_img_smooth = torch.nn.functional.conv2d(diff_img, weight=kernel) if len(img_shape)==2 \ - else torch.nn.functional.conv3d(diff_img, weight=kernel) - diff_img_smooth = diff_img_smooth.detach().numpy().squeeze() - diff_img_shape = diff_img_smooth.shape - diff_img_smooth = diff_img_smooth.flatten() - - # define lambda functions for better readability - # need to add tiles_shape[i] // 2 to get indices in original img - start = lambda i: 0 # -(tiles_shape[i] // 2) - end = lambda i: tiles_shape[i] # tiles_shape[i] // 2 + 1 - - # get training samples based on tiles around maxima - # of the label-prediction diff - nc = len(np.unique(labels)) - n_samples_class = int(sample_fraction_per_stage * len(features)) // nc + # reshape diff to image shape + diff_img = diff.reshape(img_shape) - # sample in a class balanced way - samples_per_class = [[]]*nc - indices_per_class = [np.where(labels == class_id)[0] for class_id in range(nc)] + # smooth either with gaussian or 1-kernel + if smoothing_sigma: + diff_img_smooth = gaussian_filter(diff_img, smoothing_sigma, mode='constant') + else: + kernel = np.ones(tiles_shape) + diff_img_smooth = convolve(diff_img, kernel, mode='constant') + + # get training samples based on tiles around maxima of the label-prediction diff + # get maxima of the label-prediction diff (they seem to be sorted already) + max_centers = peak_local_max( + diff_img_smooth, + min_distance=max(tiles_shape), + exclude_border=tuple([s // 2 for s in tiles_shape]) + ) - # get maxima of the label-prediction diff - max_centers = np.argsort(diff_img_smooth)[::-1] - max_centers = np.unravel_index(max_centers, diff_img_shape) - max_centers = np.array(max_centers).swapaxes(0, 1) + # get indices of tiles around maxima + tiles = [] for center in max_centers: - # get tile for each maximum - samples_in_tile = [(center[0]+y, center[1]+x) for y in range(start(0), end(0)) for x in range(start(1), end(1))] if len(center) == 2 \ - else [(center[0]+z, center[1]+y, center[2]+x) for z in range(start(0), end(0)) for y in range(start(1), end(1)) for x in range(start(2), end(2))] - samples_in_tile = np.array(samples_in_tile).swapaxes(0, 1) + tile_slice = tuple([slice(center[d]-tiles_shape[d]//2, + center[d]+tiles_shape[d]//2 + 1, None) for d in range(ndim)]) + grid = np.mgrid[tile_slice] + samples_in_tile = grid.reshape(ndim, -1) samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) + tiles.append(samples_in_tile) + tiles = np.concatenate(tiles) - for samples_this_class, indices_this_class in zip(samples_per_class, indices_per_class): - # check if current class is already full - if len(samples_this_class) < n_samples_class: - intersect = np.intersect1d(indices_this_class, samples_in_tile) - # make sure to not use duplicates - samples_this_class.extend(intersect.tolist()) - samples_this_class = list(dict.fromkeys(samples_this_class)) - - # stop when there are enough samples in each class - if all([len(samples_this_class) >= n_samples_class for samples_this_class in samples_per_class]): - break - samples = np.concatenate(samples_per_class) + # sample in a class balanced way + nc = len(np.unique(labels)) + n_samples_class = int(sample_fraction_per_stage * len(features)) // nc + samples = [] + for class_id in range(nc): + this_samples = tiles[labels[tiles] == class_id][:n_samples_class] + samples.append(this_samples) + samples = np.concatenate(samples) # get the features and labels, add from previous rf if specified features, labels = features[samples], labels[samples] @@ -666,7 +656,7 @@ def _train_rf(rf_id): assert raw.ndim == labels.ndim == ndim, f"{raw.ndim}, {labels.ndim}, {ndim}" # monkey patch original shape to sampling_kwargs - # deepcopy needed due to multiprocessing + # deepcopy needed due to multithreading current_kwargs = copy.deepcopy(sampling_kwargs) current_kwargs['img_shape'] = raw.shape @@ -679,7 +669,7 @@ def _train_rf(rf_id): features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, - sampling_kwargs=current_kwargs, + **current_kwargs, ) else: # sample randomly features, labels = random_points( From 6c6c2bd415d842d03c4dcb657cb00957d043cf02 Mon Sep 17 00:00:00 2001 From: JonasHell Date: Thu, 21 Jul 2022 17:50:32 +0200 Subject: [PATCH 11/31] find worst tiles per class --- torch_em/shallow2deep/prepare_shallow2deep.py | 60 ++++++++++--------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index d44336f3..671b13b7 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -512,7 +512,7 @@ def worst_tiles( forests, forests_per_stage, sample_fraction_per_stage, img_shape, - tiles_shape=[51, 51], + tiles_shape=[25, 25], smoothing_sigma=None, accumulate_samples=True, ): @@ -531,43 +531,45 @@ def worst_tiles( onehot = np.eye(unique.shape[0])[inverse] # compute the difference between labels and prediction - diff = np.abs(onehot - pred).sum(axis=1) + diff = np.abs(onehot - pred) assert len(diff) == len(features) # reshape diff to image shape - diff_img = diff.reshape(img_shape) - - # smooth either with gaussian or 1-kernel - if smoothing_sigma: - diff_img_smooth = gaussian_filter(diff_img, smoothing_sigma, mode='constant') - else: - kernel = np.ones(tiles_shape) - diff_img_smooth = convolve(diff_img, kernel, mode='constant') - - # get training samples based on tiles around maxima of the label-prediction diff - # get maxima of the label-prediction diff (they seem to be sorted already) - max_centers = peak_local_max( - diff_img_smooth, - min_distance=max(tiles_shape), - exclude_border=tuple([s // 2 for s in tiles_shape]) - ) - - # get indices of tiles around maxima - tiles = [] - for center in max_centers: - tile_slice = tuple([slice(center[d]-tiles_shape[d]//2, - center[d]+tiles_shape[d]//2 + 1, None) for d in range(ndim)]) - grid = np.mgrid[tile_slice] - samples_in_tile = grid.reshape(ndim, -1) - samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) - tiles.append(samples_in_tile) - tiles = np.concatenate(tiles) + diff_img = diff.reshape(img_shape + (-1,)) # sample in a class balanced way nc = len(np.unique(labels)) n_samples_class = int(sample_fraction_per_stage * len(features)) // nc samples = [] for class_id in range(nc): + # smooth either with gaussian or 1-kernel + if smoothing_sigma: + diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode='constant') + else: + kernel = np.ones(tiles_shape) + diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode='constant') + + # get training samples based on tiles around maxima of the label-prediction diff + # do this in a class-specific way to ensure that each class is sampled + # get maxima of the label-prediction diff (they seem to be sorted already) + max_centers = peak_local_max( + diff_img_smooth, + min_distance=max(tiles_shape), + exclude_border=tuple([s // 2 for s in tiles_shape]) + ) + + # get indices of tiles around maxima + tiles = [] + for center in max_centers: + tile_slice = tuple([slice(center[d]-tiles_shape[d]//2, + center[d]+tiles_shape[d]//2 + 1, None) for d in range(ndim)]) + grid = np.mgrid[tile_slice] + samples_in_tile = grid.reshape(ndim, -1) + samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) + tiles.append(samples_in_tile) + tiles = np.concatenate(tiles) + + # take samples that belong to the current class this_samples = tiles[labels[tiles] == class_id][:n_samples_class] samples.append(this_samples) samples = np.concatenate(samples) From 36de2c9b085fb0776a4c42cfa9ba2394a0d1514e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 21 Jul 2022 20:57:45 +0200 Subject: [PATCH 12/31] Add urocell dataset --- experiments/uro_cell/check_urocell_loader.py | 12 ++ torch_em/data/datasets/__init__.py | 1 + torch_em/data/datasets/uro_cell.py | 121 +++++++++++++++++++ torch_em/data/datasets/util.py | 5 +- 4 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 experiments/uro_cell/check_urocell_loader.py create mode 100644 torch_em/data/datasets/uro_cell.py diff --git a/experiments/uro_cell/check_urocell_loader.py b/experiments/uro_cell/check_urocell_loader.py new file mode 100644 index 00000000..fec60ae6 --- /dev/null +++ b/experiments/uro_cell/check_urocell_loader.py @@ -0,0 +1,12 @@ +from torch_em.data.datasets import get_uro_cell_loader +from torch_em.util.debug import check_loader + + +def check_uro_cell_loader(target): + loader = get_uro_cell_loader("./data", target=target, download=True, + batch_size=1, patch_shape=(32, 128, 128)) + check_loader(loader, n_samples=5, instance_labels=True) + + +if __name__ == "__main__": + check_uro_cell_loader(target="mito") diff --git a/torch_em/data/datasets/__init__.py b/torch_em/data/datasets/__init__.py index 59a6da36..9f17996b 100644 --- a/torch_em/data/datasets/__init__.py +++ b/torch_em/data/datasets/__init__.py @@ -14,3 +14,4 @@ from .snemi import get_snemi_loader from .util import get_bioimageio_dataset_id from .vnc import get_vnc_mito_loader +from .uro_cell import get_uro_cell_loader diff --git a/torch_em/data/datasets/uro_cell.py b/torch_em/data/datasets/uro_cell.py new file mode 100644 index 00000000..7d72c706 --- /dev/null +++ b/torch_em/data/datasets/uro_cell.py @@ -0,0 +1,121 @@ +import os +import warnings +from glob import glob +from shutil import rmtree + +import h5py +import torch_em +from .util import download_source, unzip, update_kwargs + + +URL = "https://github.com/MancaZerovnikMekuc/UroCell/archive/refs/heads/master.zip" +CHECKSUM = "1cfc83792c7ec2d201b95b6b919d119d594f453822de3ad24486019979387d1d" + + +def _require_urocell_data(path, download): + # download and unzip the data + if os.path.exists(path): + return path + + # add nifti file format support in elf by wrapping nibabel? + import nibabel as nib + + os.makedirs(path) + tmp_path = os.path.join(path, "uro_cell.zip") + download_source(tmp_path, URL, download, checksum=CHECKSUM) + unzip(tmp_path, path, remove=True) + + root = os.path.join(path, "UroCell-master") + + files = glob(os.path.join(root, "data", "*.nii.gz")) + files.sort() + for data_path in files: + fname = os.path.basename(data_path) + data = nib.load(data_path).get_fdata() + + out_path = os.path.join(path, fname.replace("nii.gz", "h5")) + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=data, compression="gzip") + + # check if we have any of the organelle labels for this volume + # and also copy them if yes + fv_path = os.path.join(root, "fv", "instance", fname) + if os.path.exists(fv_path): + fv = nib.load(fv_path).get_fdata().astype("uint32") + assert fv.shape == data.shape + f.create_dataset("labels/fv", data=fv, compression="gzip") + + golgi_path = os.path.join(root, "golgi", "precise", fname) + if os.path.exists(golgi_path): + golgi = nib.load(golgi_path).get_fdata().astype("uint32") + assert golgi.shape == data.shape + f.create_dataset("labels/golgi", data=golgi, compression="gzip") + + lyso_path = os.path.join(root, "lyso", "instance", fname) + if os.path.exists(lyso_path): + lyso = nib.load(lyso_path).get_fdata().astype("uint32") + assert lyso.shape == data.shape + f.create_dataset("labels/lyso", data=lyso, compression="gzip") + + mito_path = os.path.join(root, "mito", "instance", fname) + if os.path.exists(mito_path): + mito = nib.load(mito_path).get_fdata().astype("uint32") + assert mito.shape == data.shape + f.create_dataset("labels/mito", data=mito, compression="gzip") + + # clean up + rmtree(root) + + +def _get_paths(path, target): + label_key = f"labels/{target}" + all_paths = glob(os.path.join(path, "*.h5")) + all_paths.sort() + paths = [path for path in all_paths if label_key in h5py.File(path, "r")] + return paths, label_key + + +def get_uro_cell_loader( + path, + target, + download=False, + offsets=None, + boundaries=False, + binary=False, + ndim=3, + **kwargs +): + assert target in ("fv", "golgi", "lyso", "mito") + _require_urocell_data(path, download) + paths, label_key = _get_paths(path, target) + + assert sum((offsets is not None, boundaries, binary)) <= 1, f"{offsets}, {boundaries}, {binary}" + if offsets is not None: + if target in ("lyso", "golgi"): + warnings.warn( + f"{target} does not have instance labels, affinities will be computed based on binary segmentation." + ) + # we add a binary target channel for foreground background segmentation + label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets, + ignore_label=None, + add_binary_target=True, + add_mask=True) + msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." + kwargs = update_kwargs(kwargs, 'label_transform2', label_transform, msg=msg) + elif boundaries: + if target in ("lyso", "golgi"): + warnings.warn( + f"{target} does not have instance labels, boundaries will be computed based on binary segmentation." + ) + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." + kwargs = update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) + elif binary: + label_transform = torch_em.transform.label.labels_to_binary + msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." + kwargs = update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) + + raw_key = "raw" + return torch_em.default_segmentation_loader( + paths, raw_key, paths, label_key, ndim=ndim, is_seg_dataset=True, **kwargs + ) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 06ad006d..ae474ebb 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -17,11 +17,12 @@ "mitoem": "ilastik/mitoem_segmentation_challenge", "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018", "ovules": "", # not on bioimageio yet + "plantseg_root": "ilastik/plantseg_root", + "plantseg_ovules": "ilastik/plantseg_ovules", "platynereis": "ilastik/platynereis_em_training_data", "snemi": "", # not on bioimagegio yet + "uro_cell": "", # not on bioimageio yet: https://doi.org/10.1016/j.compbiomed.2020.103693 "vnc": "ilastik/vnc", - "plantseg_root": "ilastik/plantseg_root", - "plantseg_ovules": "ilastik/plantseg_ovules", } From 7762327c7b4342657296dd5dee839d1982403904 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 21 Jul 2022 21:00:02 +0200 Subject: [PATCH 13/31] Accumulate labels also in raw s2d rf sampling scheme --- torch_em/shallow2deep/prepare_shallow2deep.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index a86be04a..055a841a 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -488,6 +488,7 @@ def random_points( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, + accumulate_samples=True, ): samples = [] nc = len(np.unique(labels)) @@ -501,7 +502,14 @@ def random_points( ) samples.append(this_samples) samples = np.concatenate(samples) - return features[samples], labels[samples] + features, labels = features[samples], labels[samples] + + if accumulate_samples and rf_id >= forests_per_stage: + last_forest = forests[rf_id - forests_per_stage] + features = np.concatenate([last_forest.train_features, features], axis=0) + labels = np.concatenate([last_forest.train_labels, labels], axis=0) + + return features, labels SAMPLING_STRATEGIES = { From f1828ad41d24740013abd6331005c1019e640ef2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 21 Jul 2022 21:02:02 +0200 Subject: [PATCH 14/31] Fix small issues in mito s2d experiments --- .../shallow2deep/em-mitochondria/evaluation.py | 5 ++++- .../shallow2deep/em-mitochondria/export_enhancer.py | 12 ++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/evaluation.py b/experiments/shallow2deep/em-mitochondria/evaluation.py index 724717ca..088d255b 100644 --- a/experiments/shallow2deep/em-mitochondria/evaluation.py +++ b/experiments/shallow2deep/em-mitochondria/evaluation.py @@ -171,12 +171,15 @@ def run_evaluation(data_path, save_path, eval_path): pred = ds[:] score = dice_metric(pred, labels) scores[name] = float(score) + + with open(save_path, "w") as f: + json.dump(scores, f) return scores # TODO def to_table(scores): - pass + breakpoint() def evaluation_v4(): diff --git a/experiments/shallow2deep/em-mitochondria/export_enhancer.py b/experiments/shallow2deep/em-mitochondria/export_enhancer.py index 244d68ad..e4379fff 100644 --- a/experiments/shallow2deep/em-mitochondria/export_enhancer.py +++ b/experiments/shallow2deep/em-mitochondria/export_enhancer.py @@ -102,14 +102,14 @@ def export_enhancer(input_, is3d, checkpoint=None, version=None, name=None): out_folder = f"./bio-models/v{version}" if is3d == "anisotropic": - rf_path = "/scratch/pape/s2d-mitochondria/rfs2d/mitoem/rf_0499.pkl" + rf_path = "/scratch/pape/s2d-mitochondria/rfs2d-worst_points/mitoem/rf_0499.pkl" input_data = create_input_anisotropic(input_, rf_path) is3d = True elif is3d: assert False, "Currently don't have 3d rfs for mitos" input_data = create_input_3d(input_, rf_path) else: - rf_path = "/scratch/pape/s2d-mitochondria/rfs2d/mitoem/rf_0499.pkl" + rf_path = "/scratch/pape/s2d-mitochondria/rfs2d-worst_points/mitoem/rf_0499.pkl" input_data = create_input_2d(input_, rf_path) name, description = _get_name_and_description(is3d, name) @@ -168,17 +168,13 @@ def _get_ndim(x): return x elif x == "3d": return True - return None + raise ValueError(x) checkpoints = glob("./checkpoints/s2d-em-mitos-*") for ckpt in checkpoints: name = os.path.basename(ckpt) parts = name.split("-") - is3d = _get_ndim(parts[-1]) - if is3d is None: - is3d = _get_ndim(parts[-2]) - else: - name += "-worst_points" + is3d = _get_ndim(parts[-2]) assert is3d is not None print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") From b4f17c117370e900193686b674c0c3a532e00faf Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 22 Jul 2022 12:51:27 +0200 Subject: [PATCH 15/31] Add more mito datasets WIP --- torch_em/data/datasets/__init__.py | 2 ++ torch_em/data/datasets/kasthuri.py | 6 ++++ torch_em/data/datasets/lucchi.py | 32 ++++++++++++++++++++++ torch_em/data/datasets/pnas_arabidopsis.py | 4 +++ torch_em/data/datasets/uro_cell.py | 2 +- 5 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 torch_em/data/datasets/kasthuri.py create mode 100644 torch_em/data/datasets/lucchi.py create mode 100644 torch_em/data/datasets/pnas_arabidopsis.py diff --git a/torch_em/data/datasets/__init__.py b/torch_em/data/datasets/__init__.py index 9f17996b..d6a33c41 100644 --- a/torch_em/data/datasets/__init__.py +++ b/torch_em/data/datasets/__init__.py @@ -4,7 +4,9 @@ from .dsb import get_dsb_loader from .hpa import get_hpa_segmentation_loader from .isbi2012 import get_isbi_loader +from .kasthuri import get_kasthuri_loader from .livecell import get_livecell_loader +from .lucchi import get_lucchi_loader from .mitoem import get_mitoem_loader from .monuseg import get_monuseg_loader from .mouse_embryo import get_mouse_embryo_loader diff --git a/torch_em/data/datasets/kasthuri.py b/torch_em/data/datasets/kasthuri.py new file mode 100644 index 00000000..8ddc5f65 --- /dev/null +++ b/torch_em/data/datasets/kasthuri.py @@ -0,0 +1,6 @@ +URL = "" +CHECKSUM = "" + + +def get_kasthuri_loader(): + pass diff --git a/torch_em/data/datasets/lucchi.py b/torch_em/data/datasets/lucchi.py new file mode 100644 index 00000000..6d998da5 --- /dev/null +++ b/torch_em/data/datasets/lucchi.py @@ -0,0 +1,32 @@ +import os +from shutil import rmtree + +from .util import download_source, unzip, update_kwargs + +# TODO find a source for this! +URL = "" +CHECKSUM = "" + + +def _require_lucchi_data(path, download): + expected_paths = [ + os.path.join(path, "epfl_train.h5"), + os.path.join(path, "epfl_val.h5"), + os.path.join(path, "epfl_test.h5"), + ] + # download and unzip the data + if os.path.exists(path): + assert all(os.path.exists(pp) for pp in expected_paths) + return path + + os.makedirs(path) + tmp_path = os.path.join(path, "epfl.zip") + download_source(tmp_path, URL, download, checksum=CHECKSUM) + unzip(tmp_path, path, remove=True) + rmtree(tmp_path) + + assert all(os.path.exists(pp) for pp in expected_paths) + + +def get_lucchi_loader(): + pass diff --git a/torch_em/data/datasets/pnas_arabidopsis.py b/torch_em/data/datasets/pnas_arabidopsis.py new file mode 100644 index 00000000..1b7fdf1d --- /dev/null +++ b/torch_em/data/datasets/pnas_arabidopsis.py @@ -0,0 +1,4 @@ +# TODO + +URL = "https://www.repository.cam.ac.uk/bitstream/handle/1810/262530/PNAS.zip?sequence=4&isAllowed=y" +CHECKSUM = "" diff --git a/torch_em/data/datasets/uro_cell.py b/torch_em/data/datasets/uro_cell.py index 7d72c706..6c9b55ba 100644 --- a/torch_em/data/datasets/uro_cell.py +++ b/torch_em/data/datasets/uro_cell.py @@ -13,13 +13,13 @@ def _require_urocell_data(path, download): - # download and unzip the data if os.path.exists(path): return path # add nifti file format support in elf by wrapping nibabel? import nibabel as nib + # download and unzip the data os.makedirs(path) tmp_path = os.path.join(path, "uro_cell.zip") download_source(tmp_path, URL, download, checksum=CHECKSUM) From 3a315a1f915ece4b8a8a6b24f23a5c827e853cf7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 22 Jul 2022 14:32:33 +0200 Subject: [PATCH 16/31] Refactor em mitochondria experiments, add lucchi and kasthuri datasets --- .../kasthuri/check_kasthuri_loader.py | 12 +++ .../lucchi/check_lucchi_loader.py | 12 +++ .../mito-em/.gitignore | 0 .../mito-em/README.md | 0 .../mito-em/challenge/README.md | 0 .../mito-em/challenge/check_result.py | 0 .../mito-em/challenge/checkpoint_to_mobie.py | 0 .../mito-em/challenge/create_mobie_project.py | 0 .../mito-em/challenge/embeddings/predict.py | 0 .../challenge/embeddings/train_embeddings.py | 0 .../mito-em/challenge/prepare_train_data.py | 0 .../mito-em/challenge/pretty_print_table.py | 0 .../mito-em/challenge/segment_and_submit.py | 0 .../mito-em/challenge/segment_and_validate.py | 0 .../mito-em/challenge/segmentation_impl.py | 0 .../mito-em/export_bioimageio_model.py | 0 .../mito-em/train_affinities.py | 0 .../mito-em/train_boundaries.py | 0 .../mito-em/validate_model.py | 0 .../uro_cell/check_urocell_loader.py | 0 torch_em/data/datasets/kasthuri.py | 92 ++++++++++++++++++- torch_em/data/datasets/lucchi.py | 85 ++++++++++++++--- torch_em/data/datasets/util.py | 2 + 23 files changed, 184 insertions(+), 19 deletions(-) create mode 100644 experiments/mitochondria-segmentation/kasthuri/check_kasthuri_loader.py create mode 100644 experiments/mitochondria-segmentation/lucchi/check_lucchi_loader.py rename experiments/{ => mitochondria-segmentation}/mito-em/.gitignore (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/README.md (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/README.md (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/check_result.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/checkpoint_to_mobie.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/create_mobie_project.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/embeddings/predict.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/embeddings/train_embeddings.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/prepare_train_data.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/pretty_print_table.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/segment_and_submit.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/segment_and_validate.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/challenge/segmentation_impl.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/export_bioimageio_model.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/train_affinities.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/train_boundaries.py (100%) rename experiments/{ => mitochondria-segmentation}/mito-em/validate_model.py (100%) rename experiments/{ => mitochondria-segmentation}/uro_cell/check_urocell_loader.py (100%) diff --git a/experiments/mitochondria-segmentation/kasthuri/check_kasthuri_loader.py b/experiments/mitochondria-segmentation/kasthuri/check_kasthuri_loader.py new file mode 100644 index 00000000..dba28268 --- /dev/null +++ b/experiments/mitochondria-segmentation/kasthuri/check_kasthuri_loader.py @@ -0,0 +1,12 @@ +from torch_em.data.datasets import get_kasthuri_loader +from torch_em.util.debug import check_loader + + +def check_kasthuri_loader(split): + loader = get_kasthuri_loader("./data", split=split, download=True, batch_size=1, patch_shape=(64, 256, 256)) + check_loader(loader, n_samples=4, instance_labels=True) + + +if __name__ == "__main__": + check_kasthuri_loader(split="train") + check_kasthuri_loader(split="test") diff --git a/experiments/mitochondria-segmentation/lucchi/check_lucchi_loader.py b/experiments/mitochondria-segmentation/lucchi/check_lucchi_loader.py new file mode 100644 index 00000000..eb527381 --- /dev/null +++ b/experiments/mitochondria-segmentation/lucchi/check_lucchi_loader.py @@ -0,0 +1,12 @@ +from torch_em.data.datasets import get_lucchi_loader +from torch_em.util.debug import check_loader + + +def check_lucchi_loader(split): + loader = get_lucchi_loader("./data", split=split, download=True, batch_size=1, patch_shape=(64, 256, 256)) + check_loader(loader, n_samples=4, instance_labels=True) + + +if __name__ == "__main__": + check_lucchi_loader(split="train") + check_lucchi_loader(split="test") diff --git a/experiments/mito-em/.gitignore b/experiments/mitochondria-segmentation/mito-em/.gitignore similarity index 100% rename from experiments/mito-em/.gitignore rename to experiments/mitochondria-segmentation/mito-em/.gitignore diff --git a/experiments/mito-em/README.md b/experiments/mitochondria-segmentation/mito-em/README.md similarity index 100% rename from experiments/mito-em/README.md rename to experiments/mitochondria-segmentation/mito-em/README.md diff --git a/experiments/mito-em/challenge/README.md b/experiments/mitochondria-segmentation/mito-em/challenge/README.md similarity index 100% rename from experiments/mito-em/challenge/README.md rename to experiments/mitochondria-segmentation/mito-em/challenge/README.md diff --git a/experiments/mito-em/challenge/check_result.py b/experiments/mitochondria-segmentation/mito-em/challenge/check_result.py similarity index 100% rename from experiments/mito-em/challenge/check_result.py rename to experiments/mitochondria-segmentation/mito-em/challenge/check_result.py diff --git a/experiments/mito-em/challenge/checkpoint_to_mobie.py b/experiments/mitochondria-segmentation/mito-em/challenge/checkpoint_to_mobie.py similarity index 100% rename from experiments/mito-em/challenge/checkpoint_to_mobie.py rename to experiments/mitochondria-segmentation/mito-em/challenge/checkpoint_to_mobie.py diff --git a/experiments/mito-em/challenge/create_mobie_project.py b/experiments/mitochondria-segmentation/mito-em/challenge/create_mobie_project.py similarity index 100% rename from experiments/mito-em/challenge/create_mobie_project.py rename to experiments/mitochondria-segmentation/mito-em/challenge/create_mobie_project.py diff --git a/experiments/mito-em/challenge/embeddings/predict.py b/experiments/mitochondria-segmentation/mito-em/challenge/embeddings/predict.py similarity index 100% rename from experiments/mito-em/challenge/embeddings/predict.py rename to experiments/mitochondria-segmentation/mito-em/challenge/embeddings/predict.py diff --git a/experiments/mito-em/challenge/embeddings/train_embeddings.py b/experiments/mitochondria-segmentation/mito-em/challenge/embeddings/train_embeddings.py similarity index 100% rename from experiments/mito-em/challenge/embeddings/train_embeddings.py rename to experiments/mitochondria-segmentation/mito-em/challenge/embeddings/train_embeddings.py diff --git a/experiments/mito-em/challenge/prepare_train_data.py b/experiments/mitochondria-segmentation/mito-em/challenge/prepare_train_data.py similarity index 100% rename from experiments/mito-em/challenge/prepare_train_data.py rename to experiments/mitochondria-segmentation/mito-em/challenge/prepare_train_data.py diff --git a/experiments/mito-em/challenge/pretty_print_table.py b/experiments/mitochondria-segmentation/mito-em/challenge/pretty_print_table.py similarity index 100% rename from experiments/mito-em/challenge/pretty_print_table.py rename to experiments/mitochondria-segmentation/mito-em/challenge/pretty_print_table.py diff --git a/experiments/mito-em/challenge/segment_and_submit.py b/experiments/mitochondria-segmentation/mito-em/challenge/segment_and_submit.py similarity index 100% rename from experiments/mito-em/challenge/segment_and_submit.py rename to experiments/mitochondria-segmentation/mito-em/challenge/segment_and_submit.py diff --git a/experiments/mito-em/challenge/segment_and_validate.py b/experiments/mitochondria-segmentation/mito-em/challenge/segment_and_validate.py similarity index 100% rename from experiments/mito-em/challenge/segment_and_validate.py rename to experiments/mitochondria-segmentation/mito-em/challenge/segment_and_validate.py diff --git a/experiments/mito-em/challenge/segmentation_impl.py b/experiments/mitochondria-segmentation/mito-em/challenge/segmentation_impl.py similarity index 100% rename from experiments/mito-em/challenge/segmentation_impl.py rename to experiments/mitochondria-segmentation/mito-em/challenge/segmentation_impl.py diff --git a/experiments/mito-em/export_bioimageio_model.py b/experiments/mitochondria-segmentation/mito-em/export_bioimageio_model.py similarity index 100% rename from experiments/mito-em/export_bioimageio_model.py rename to experiments/mitochondria-segmentation/mito-em/export_bioimageio_model.py diff --git a/experiments/mito-em/train_affinities.py b/experiments/mitochondria-segmentation/mito-em/train_affinities.py similarity index 100% rename from experiments/mito-em/train_affinities.py rename to experiments/mitochondria-segmentation/mito-em/train_affinities.py diff --git a/experiments/mito-em/train_boundaries.py b/experiments/mitochondria-segmentation/mito-em/train_boundaries.py similarity index 100% rename from experiments/mito-em/train_boundaries.py rename to experiments/mitochondria-segmentation/mito-em/train_boundaries.py diff --git a/experiments/mito-em/validate_model.py b/experiments/mitochondria-segmentation/mito-em/validate_model.py similarity index 100% rename from experiments/mito-em/validate_model.py rename to experiments/mitochondria-segmentation/mito-em/validate_model.py diff --git a/experiments/uro_cell/check_urocell_loader.py b/experiments/mitochondria-segmentation/uro_cell/check_urocell_loader.py similarity index 100% rename from experiments/uro_cell/check_urocell_loader.py rename to experiments/mitochondria-segmentation/uro_cell/check_urocell_loader.py diff --git a/torch_em/data/datasets/kasthuri.py b/torch_em/data/datasets/kasthuri.py index 8ddc5f65..d4971ceb 100644 --- a/torch_em/data/datasets/kasthuri.py +++ b/torch_em/data/datasets/kasthuri.py @@ -1,6 +1,90 @@ -URL = "" -CHECKSUM = "" +import os +from concurrent import futures +from glob import glob +from shutil import rmtree +import imageio +import h5py +import numpy as np +import torch_em +from tqdm import tqdm +from .util import download_source, unzip -def get_kasthuri_loader(): - pass +URL = "http://www.casser.io/files/kasthuri_pp.zip " +CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792" + +# data from: https://sites.google.com/view/connectomics/ +# TODO: add sampler for foreground (-1 is empty area) +# TODO and masking for the empty space + + +def _load_volume(path): + files = glob(os.path.join(path, "*.png")) + files.sort() + nz = len(files) + + im0 = imageio.imread(files[0]) + out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) + out[0] = im0 + + def _loadz(z): + im = imageio.imread(files[z]) + out[z] = im + + n_threads = 8 + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 + )) + + return out + + +def _create_data(root, inputs, out_path): + raw = _load_volume(os.path.join(root, inputs[0])) + labels_argb = _load_volume(os.path.join(root, inputs[1])) + assert labels_argb.ndim == 4 + labels = np.zeros(raw.shape, dtype="int8") + + fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1) + labels[fg_mask] = 1 + bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1) + labels[bg_mask] = -1 + assert (np.unique(labels) == np.array([-1, 0, 1])).all() + assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") + + +def _require_kasthuri_data(path, download): + # download and unzip the data + if os.path.exists(path): + return path + + os.makedirs(path) + tmp_path = os.path.join(path, "kasthuri.zip") + download_source(tmp_path, URL, download, checksum=CHECKSUM) + unzip(tmp_path, path, remove=True) + + root = os.path.join(path, "Kasthuri++") + assert os.path.exists(root), root + + inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] + outputs = ["kasthuri_train.h5", "kasthuri_test.h5"] + for inp, out in zip(inputs, outputs): + out_path = os.path.join(path, out) + _create_data(root, inp, out_path) + + rmtree(root) + + +def get_kasthuri_loader(path, split, download=False, ndim=3, **kwargs): + assert split in ("train", "test") + _require_kasthuri_data(path, download) + data_path = os.path.join(path, f"kasthuri_{split}.h5") + assert os.path.exists(data_path), data_path + raw_key, label_key = "raw", "labels" + return torch_em.default_segmentation_loader( + data_path, raw_key, data_path, label_key, ndim=ndim, **kwargs + ) diff --git a/torch_em/data/datasets/lucchi.py b/torch_em/data/datasets/lucchi.py index 6d998da5..baad6534 100644 --- a/torch_em/data/datasets/lucchi.py +++ b/torch_em/data/datasets/lucchi.py @@ -1,32 +1,87 @@ import os +from concurrent import futures +from glob import glob from shutil import rmtree -from .util import download_source, unzip, update_kwargs +import imageio +import h5py +import numpy as np +import torch_em +from tqdm import tqdm +from .util import download_source, unzip -# TODO find a source for this! -URL = "" -CHECKSUM = "" +URL = "http://www.casser.io/files/lucchi_pp.zip" +CHECKSUM = "770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d" + +# data from: https://sites.google.com/view/connectomics/ +# TODO: add sampler for foreground to avoid empty batches + + +def _load_volume(path, pattern): + nz = len(glob(os.path.join(path, "*.png"))) + im0 = imageio.imread(os.path.join(path, pattern % 0)) + out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) + out[0] = im0 + + def _loadz(z): + im = imageio.imread(os.path.join(path, pattern % z)) + out[z] = im + + n_threads = 8 + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 + )) + + return out + + +def _create_data(root, inputs, out_path): + raw = _load_volume(os.path.join(root, inputs[0]), pattern="mask%04i.png") + labels_argb = _load_volume(os.path.join(root, inputs[1]), pattern="%i.png") + if labels_argb.ndim == 4: + labels = np.zeros(raw.shape, dtype="uint8") + fg_mask = (labels_argb == np.array([255, 255, 255, 255])[None, None, None]).all(axis=-1) + labels[fg_mask] = 1 + else: + assert labels_argb.ndim == 3 + labels = labels_argb + labels[labels == 255] = 1 + assert (np.unique(labels) == np.array([0, 1])).all() + assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") def _require_lucchi_data(path, download): - expected_paths = [ - os.path.join(path, "epfl_train.h5"), - os.path.join(path, "epfl_val.h5"), - os.path.join(path, "epfl_test.h5"), - ] # download and unzip the data if os.path.exists(path): - assert all(os.path.exists(pp) for pp in expected_paths) return path os.makedirs(path) - tmp_path = os.path.join(path, "epfl.zip") + tmp_path = os.path.join(path, "lucchi.zip") download_source(tmp_path, URL, download, checksum=CHECKSUM) unzip(tmp_path, path, remove=True) - rmtree(tmp_path) - assert all(os.path.exists(pp) for pp in expected_paths) + root = os.path.join(path, "Lucchi++") + assert os.path.exists(root), root + + inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] + outputs = ["lucchi_train.h5", "lucchi_test.h5"] + for inp, out in zip(inputs, outputs): + out_path = os.path.join(path, out) + _create_data(root, inp, out_path) + + rmtree(root) -def get_lucchi_loader(): - pass +def get_lucchi_loader(path, split, download=False, ndim=3, **kwargs): + assert split in ("train", "test") + _require_lucchi_data(path, download) + data_path = os.path.join(path, f"lucchi_{split}.h5") + assert os.path.exists(data_path), data_path + raw_key, label_key = "raw", "labels" + return torch_em.default_segmentation_loader( + data_path, raw_key, data_path, label_key, ndim=ndim, **kwargs + ) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index ae474ebb..17a9cbdc 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -13,7 +13,9 @@ "dsb": "ilastik/stardist_dsb_training_data", "hpa": "", # not on bioimageio yet "isbi2012": "ilastik/isbi2012_neuron_segmentation_challenge", + "kasthuri": "", # not on bioimageio yet: "livecell": "ilastik/livecell_dataset", + "lucchi": "", # not on bioimageio yet: "mitoem": "ilastik/mitoem_segmentation_challenge", "monuseg": "deepimagej/monuseg_digital_pathology_miccai2018", "ovules": "", # not on bioimageio yet From 1ce36f7eff5a15ee61b6cbe61f7b5c054fac7acd Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 22 Jul 2022 14:49:32 +0200 Subject: [PATCH 17/31] Fix issue in prepare s2d --- torch_em/shallow2deep/prepare_shallow2deep.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index 2c9ed020..a21f3c8d 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -432,6 +432,7 @@ def worst_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, + **kwargs, ): def score(pred, labels): # labels to one-hot encoding @@ -450,6 +451,7 @@ def uncertain_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, + **kwargs, ): def score(pred, labels): assert pred.ndim == 2 @@ -468,6 +470,7 @@ def uncertain_worst_points( sample_fraction_per_stage, accumulate_samples=True, alpha=0.5, + **kwargs, ): def score(pred, labels): assert pred.ndim == 2 @@ -492,6 +495,7 @@ def random_points( forests, forests_per_stage, sample_fraction_per_stage, accumulate_samples=True, + **kwargs, ): samples = [] nc = len(np.unique(labels)) @@ -523,6 +527,7 @@ def worst_tiles( tiles_shape=[25, 25], smoothing_sigma=None, accumulate_samples=True, + **kwargs, ): # check inputs ndim = len(img_shape) @@ -569,8 +574,13 @@ def worst_tiles( # get indices of tiles around maxima tiles = [] for center in max_centers: - tile_slice = tuple([slice(center[d]-tiles_shape[d]//2, - center[d]+tiles_shape[d]//2 + 1, None) for d in range(ndim)]) + tile_slice = tuple( + slice( + center[d]-tiles_shape[d]//2, + center[d]+tiles_shape[d]//2 + 1, + None + ) for d in range(ndim) + ) grid = np.mgrid[tile_slice] samples_in_tile = grid.reshape(ndim, -1) samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) @@ -668,7 +678,7 @@ def _train_rf(rf_id): # monkey patch original shape to sampling_kwargs # deepcopy needed due to multithreading current_kwargs = copy.deepcopy(sampling_kwargs) - current_kwargs['img_shape'] = raw.shape + current_kwargs["img_shape"] = raw.shape # only balance samples for the first (densely trained) rfs features, labels = _get_features_and_labels( From 664467cd3e2f6dcfa2c7d2b9e6fcac0475cc5c32 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 22 Jul 2022 22:18:55 +0200 Subject: [PATCH 18/31] Fix issue in kasthuri data loader --- torch_em/data/datasets/kasthuri.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/data/datasets/kasthuri.py b/torch_em/data/datasets/kasthuri.py index d4971ceb..ba717b75 100644 --- a/torch_em/data/datasets/kasthuri.py +++ b/torch_em/data/datasets/kasthuri.py @@ -54,7 +54,7 @@ def _create_data(root, inputs, out_path): assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" with h5py.File(out_path, "w") as f: f.create_dataset("raw", data=raw, compression="gzip") - f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") + f.create_dataset("labels", data=labels, compression="gzip") def _require_kasthuri_data(path, download): From 4c47e80fae2aa02cbdfc288a3464e1e1b0c67dd7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:02:14 +0200 Subject: [PATCH 19/31] Update min foreground sampler to enable multiple background values --- .../shallow2deep/em-mitochondria/README.md | 31 ++- .../em-mitochondria/evaluation.py | 176 ++++++++++++------ .../em-mitochondria/export_enhancer.py | 3 + .../em-mitochondria/train_mito_2d.py | 2 +- .../em-mitochondria/train_mito_anisotropic.py | 2 +- torch_em/data/sampler.py | 7 +- 6 files changed, 161 insertions(+), 60 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/README.md b/experiments/shallow2deep/em-mitochondria/README.md index e5db5d5e..6769d394 100644 --- a/experiments/shallow2deep/em-mitochondria/README.md +++ b/experiments/shallow2deep/em-mitochondria/README.md @@ -3,21 +3,42 @@ ## Evaluation Evaluation of different shallow2deep setups for mitochondria segmentation in EM. -The enhancers are (potentially) trained on multiple datasets, evaluation is always on the EPFL dataset (which is ofc not part of the training set). +The enhancers are (potentially) trained on multiple datasets, evaluation is done on the Kasthuri dataset (which is not part of the training set except for one last version that will be the (for now) final one to be uploaded to bioimagei.io). All scores are measured with a soft dice score. +## Datasets + +- Mito-EM +- VNC +- Lucchi +- UroCell +- Kasthuri + ### V4 -- 2d enhancer: trained on mito-em and vnc -- anisotropic enhancer: random forests are trained in 2d, enhancer trained in 3d, trained on mito-em -- direct-nets: 2d and 3d networks trained on mito-em +- 2d enhancer: trained on Mito-EM and VNC +- anisotropic enhancer: random forests are trained in 2d, enhancer trained in 3d, trained on Mito-EM +- 3d enhancer: random forests trained in 3d, enhancer trained in 3d, trained on Kasthuri +- direct-nets: 2d and anisotropic networks trained on Mito-EM, 3d network trained on Kasthuri - different strategies for training the initial rfs: - `vanilla`: random forests are trained on randomly sampled dense patches - `worst_points`: initial stage of forests (25 forests) are trained on random samples, forests in the next stages add worst predictions from prev. stage to their training set - `uncertain_worst_points`: same as `worst_points`, but points are selected based on linear combination of uncertainty and worst predictions + - `random_points`: random points sampled in each stage, points are accumulated over the stages + - `worst_tiles`: training samples are taken from worst tile predictions + + +### V5 + +TODO: (only best sampling from V4) +- train 2d on Mito-EM, VNC, Kasthuri and UroCell +- train anisotropic on Mito-EM, Kasthuri and UroCell +- train 3d on Kasthuri and UroCell + +## V6 -a +TODO same as V5, but train everything on Lucchi as well and upload the one with best sampling strategy to bioimage.io ## Old evaluation diff --git a/experiments/shallow2deep/em-mitochondria/evaluation.py b/experiments/shallow2deep/em-mitochondria/evaluation.py index 088d255b..65bff4f2 100644 --- a/experiments/shallow2deep/em-mitochondria/evaluation.py +++ b/experiments/shallow2deep/em-mitochondria/evaluation.py @@ -14,21 +14,38 @@ def prepare_eval_v4(): - import napari - path = "/g/kreshuk/data/epfl/testing.h5" + path = "/g/kreshuk/pape/Work/data/kasthuri/kasthuri_test.h5" with open_file(path, "r") as f: - raw = f["raw"][:] - label = f["label"][:] - v = napari.Viewer() - v.add_image(raw) - v.add_labels(label) - napari.run() - - -def dice_metric(pred, label): + # raw = f["raw"][:] + label = f["labels"][:] + print(label.shape) + fg = np.concatenate( + [(label != 255).all(axis=0)[None]] * label.shape[0], + axis=0 + ) + + # import napari + # v = napari.Viewer() + # v.add_labels(label) + # v.add_labels(fg) + # napari.run() + + fg = np.where(fg) + fg_bb = tuple( + slice(int(gg.min()), int(gg.max()) + 1) for gg in fg + ) + label = label[fg_bb] + print(label.shape) + + +def dice_metric(pred, label, mask=None): if pred.ndim == 4: pred = pred[0] assert pred.shape == label.shape + # deal with potential ignore label + if mask is not None: + pred, label = pred[mask], label[mask] + assert pred.shape == label.shape return dice_score(pred, label, threshold_seg=None) @@ -72,7 +89,7 @@ def require_enhancers_2d(rfs, enhancers, save_path): pred = np.zeros((2,) + rf_pred.shape, dtype="float32") for z in trange(rf_pred.shape[0], desc=f"Run prediction for {enhancer_name}-{rf_name}"): inp = DataArray(rf_pred[z][None, None], dims=tuple("bcyx")) - predz = pp(inp)[0].values[0] + predz = bioimageio.core.predict_with_padding(pp, inp)[0].values[0] pred[:, z] = predz f.create_dataset(save_name, data=pred, compression="gzip") @@ -114,15 +131,11 @@ def require_net_2d(data_path, model_path, model_name, save_path): with bioimageio.core.create_prediction_pipeline(model) as pp: for z in trange(raw.shape[0], desc=f"Run prediction for model {model_name}"): inp = DataArray(raw[z][None, None], dims=tuple("bcyx")) - pred[:, z] = pp(inp)[0].values[0] + pred[:, z] = bioimageio.core.predict_with_padding(pp, inp)[0].values[0] f_save.create_dataset(model_name, data=pred, compression="gzip") -def require_net_3d(data_path, model_path, model_name, save_path): - tiling = { - "tile": {"z": 32, "y": 256, "x": 256}, - "halo": {"z": 4, "y": 32, "x": 32} - } +def require_net_3d(data_path, model_path, model_name, save_path, tiling): with open_file(save_path, "a") as f_save: if model_name in f_save: return @@ -154,89 +167,148 @@ def get_enhancers(root): return enhancers_2d, enhancers_anisotropic -def run_evaluation(data_path, save_path, eval_path): +def run_evaluation(data_path, save_path, eval_path, label_key="label"): if os.path.exists(eval_path): - with open(save_path, "r") as f: + with open(eval_path, "r") as f: scores = json.load(f) else: scores = {} - with open_file(data_path, "r") as f: - labels = f["label"][:] + def load_labels(): + with open_file(data_path, "r") as f: + labels = f[label_key][:] + if 255 in labels: + mask = labels != 255 + print("Have mask!!!!!") + # getting rid of boundary artifacts + print("Pix in mask before:", mask.sum()) + mask = np.concatenate( + [mask.all(axis=0)[None]] * mask.shape[0], + axis=0 + ) + print("Pix in mask afer:", mask.sum()) + else: + mask = None + return labels, mask with open_file(save_path, "r") as f: - for name, ds in tqdm(f.items(), total=len(f), desc="Run evaluation"): - if name in scores: - continue - pred = ds[:] - score = dice_metric(pred, labels) + missing_names = list( + set(f.keys()) - set(scores.keys()) + ) + if missing_names: + labels, mask = load_labels() + + for name in tqdm(missing_names, desc="Run evaluation"): + pred = f[name][:] + score = dice_metric(pred, labels, mask) scores[name] = float(score) - with open(save_path, "w") as f: + with open(eval_path, "w") as f: json.dump(scores, f) return scores -# TODO def to_table(scores): - breakpoint() + # sort the results into enhanncers / rfs with few, medium and many labels + cols = {"few-labels": {}, "medium-labels": {}, "many-labels": {}} + for name, score in scores.items(): + for col in cols: + is_enhancer = False + if col in name: + # TODO need to adapt this once we also have a 3d rf + save_name = "rf3d" if col == name else name.replace(f"-{col}", "") + cols[col][save_name] = score + is_enhancer = True + break + # direct prediction: don't fit into the categories here, we just put it in the first col (few) + if not is_enhancer: + cols["few-labels"][name] = score + + # sort descending after 2d, 3d, anisotropic (alphabetically) + name_col = list(cols["few-labels"].keys()) + name_col.sort() + data = [] + for ndim in ("2d", "anisotropic", "3d"): + for name in name_col: + if ndim not in name: + # print("Skipping", ndim, name) + continue + row = [name] + [col[name] if name in col else None for col in cols.values()] + if name == "rf3d": + data = [row] + data + else: + data.append(row) -def evaluation_v4(): - data_path = "/g/kreshuk/pape/Work/data/group_data/epfl/testing.h5" - rf_folder = "/g/kreshuk/pape/Work/data/epfl/ilastik-projects" - save_path = "./bio-models/v4/prediction.h5" + df = pd.DataFrame(data, columns=["method"] + list(cols.keys())) + return df + + +def evaluate_lucchi(version): + data_path = "/g/kreshuk/pape/Work/data/lucchi/lucchi_test.h5" + rf_folder = "/g/kreshuk/pape/Work/data/lucchi/ilp3d" + save_path = f"./bio-models/v{version}/prediction_lucchi.h5" rfs = { - "few-labels": os.path.join(rf_folder, "2d-1.ilp"), - "medium-labels": os.path.join(rf_folder, "2d-2.ilp"), - "many-labels": os.path.join(rf_folder, "2d-3.ilp"), + "few-labels": os.path.join(rf_folder, "1.ilp"), + "medium-labels": os.path.join(rf_folder, "2.ilp"), + "many-labels": os.path.join(rf_folder, "3.ilp"), } + enhancers_2d, enhancers_anisotropic = get_enhancers(f"./bio-models/v{version}") + net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" + net_aniso = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" + require_rfs(data_path, rfs, save_path) - enhancers_2d, enhancers_anisotropic = get_enhancers("./bio-models/v4") require_enhancers_2d(rfs, enhancers_2d, save_path) require_enhancers_3d(rfs, enhancers_anisotropic, save_path) + # TODO add the 3d enhancers - net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" - require_net_2d(data_path, net2d, "direct2d", save_path) - net3d = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" - require_net_3d(data_path, net3d, "direct3d", save_path) + require_net_2d(data_path, net2d, "direct_2d", save_path) + tiling_aniso = { + "tile": {"z": 32, "y": 256, "x": 256}, + "halo": {"z": 4, "y": 32, "x": 32} + } + require_net_3d(data_path, net_aniso, "direct_anisotropic", save_path, tiling_aniso) + # TODO train and add the 3d network - eval_path = "./bio-models/v4/eval.json" - scores = run_evaluation(data_path, save_path, eval_path) + eval_path = f"./bio-models/v{version}/lucchi.json" + scores = run_evaluation(data_path, save_path, eval_path, label_key="labels") scores = to_table(scores) print("Evaluation results:") - print(scores.to_markdown()) + print(scores.to_markdown(floatfmt=".03f")) -def debug_v4(): +def debug_v4(pred_filter=None): import napari - data_path = "/g/kreshuk/pape/Work/data/group_data/epfl/testing.h5" - save_path = "./bio-models/v4/prediction.h5" + data_path = "/g/kreshuk/pape/Work/data/kasthuri/kasthuri_test.h5" + save_path = "./bio-models/v4/prediction_kasthuri.h5" print("Load data") with open_file(data_path, "r") as f: data = f["raw"][:] - labels = f["label"][:] + labels = f["labels"][:].astype("uint32") with open_file(save_path, "r") as f: preds = {} for name, ds in tqdm(f.items(), total=len(f)): - if ("labels" in name) and ("many" not in name): + if pred_filter is not None and pred_filter not in name: continue preds[name] = ds[:] print("Start viewer") v = napari.Viewer() v.add_image(data) - v.add_labels(labels) for name, pred in preds.items(): v.add_image(pred, name=name) + v.add_labels(labels) napari.run() if __name__ == "__main__": # prepare_eval_v4() - evaluation_v4() + + # debug_v4(pred_filter="few-labels") # debug_v4() + + evaluate_lucchi(version=4) diff --git a/experiments/shallow2deep/em-mitochondria/export_enhancer.py b/experiments/shallow2deep/em-mitochondria/export_enhancer.py index e4379fff..6f03277c 100644 --- a/experiments/shallow2deep/em-mitochondria/export_enhancer.py +++ b/experiments/shallow2deep/em-mitochondria/export_enhancer.py @@ -171,8 +171,11 @@ def _get_ndim(x): raise ValueError(x) checkpoints = glob("./checkpoints/s2d-em-mitos-*") + out_folder = f"./bio-models/v{args.version}" for ckpt in checkpoints: name = os.path.basename(ckpt) + if(os.path.exists(os.path.join(out_folder, name))): + continue parts = name.split("-") is3d = _get_ndim(parts[-2]) assert is3d is not None diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py index d134a6c2..5c9de470 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_2d.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_2d.py @@ -158,7 +158,7 @@ def check(args, train=True, val=True, n_images=2): parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) - parser.add_argument("--sampling_strategy", "-s", default="worst_points") + parser.add_argument("--sampling_strategy", "-s", default="worst_tiles") args = parser.parse_args() if args.check: check(args, n_images=5, val=False) diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py index 5a854544..c8ce22bf 100644 --- a/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py +++ b/experiments/shallow2deep/em-mitochondria/train_mito_anisotropic.py @@ -148,7 +148,7 @@ def check(args, train=True, val=True, n_images=2): parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) parser.add_argument("--n_rfs", type=int, default=500) parser.add_argument("--n_threads", type=int, default=32) - parser.add_argument("--sampling_strategy", "-s", default="worst_points") + parser.add_argument("--sampling_strategy", "-s", default="worst_tiles") args = parser.parse_args() if args.check: check(args, n_images=5, val=False) diff --git a/torch_em/data/sampler.py b/torch_em/data/sampler.py index 2c991a74..9d7163dc 100644 --- a/torch_em/data/sampler.py +++ b/torch_em/data/sampler.py @@ -9,7 +9,12 @@ def __init__(self, min_fraction, background_id=0, p_reject=1.0): def __call__(self, x, y): size = float(y.size) - foreground_fraction = np.sum(y != self.background_id) / size + if isinstance(self.background_id, int): + foreground_fraction = np.sum(y != self.background_id) / size + else: + foreground_fraction = np.sum( + np.logical_not(np.isin(y, self.background_id)) + ) / size if foreground_fraction > self.min_fraction: return True else: From d488dca290e6609957a24573102fa3bc43e4a905 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:02:52 +0200 Subject: [PATCH 20/31] Fix issue in worst_tile sampling --- torch_em/shallow2deep/prepare_shallow2deep.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index a21f3c8d..c3e67d39 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -524,7 +524,7 @@ def worst_tiles( forests, forests_per_stage, sample_fraction_per_stage, img_shape, - tiles_shape=[25, 25], + tile_shape=[25, 25], smoothing_sigma=None, accumulate_samples=True, **kwargs, @@ -532,7 +532,7 @@ def worst_tiles( # check inputs ndim = len(img_shape) assert ndim in [2, 3], img_shape - assert len(tiles_shape) == ndim, tiles_shape + assert len(tile_shape) == ndim, tile_shape # get the corresponding random forest from the last stage # and predict with it @@ -559,7 +559,7 @@ def worst_tiles( if smoothing_sigma: diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode='constant') else: - kernel = np.ones(tiles_shape) + kernel = np.ones(tile_shape) diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode='constant') # get training samples based on tiles around maxima of the label-prediction diff @@ -567,8 +567,8 @@ def worst_tiles( # get maxima of the label-prediction diff (they seem to be sorted already) max_centers = peak_local_max( diff_img_smooth, - min_distance=max(tiles_shape), - exclude_border=tuple([s // 2 for s in tiles_shape]) + min_distance=max(tile_shape), + exclude_border=tuple([s // 2 for s in tile_shape]) ) # get indices of tiles around maxima @@ -576,8 +576,8 @@ def worst_tiles( for center in max_centers: tile_slice = tuple( slice( - center[d]-tiles_shape[d]//2, - center[d]+tiles_shape[d]//2 + 1, + center[d]-tile_shape[d]//2, + center[d]+tile_shape[d]//2 + 1, None ) for d in range(ndim) ) @@ -585,11 +585,16 @@ def worst_tiles( samples_in_tile = grid.reshape(ndim, -1) samples_in_tile = np.ravel_multi_index(samples_in_tile, img_shape) tiles.append(samples_in_tile) - tiles = np.concatenate(tiles) - # take samples that belong to the current class - this_samples = tiles[labels[tiles] == class_id][:n_samples_class] - samples.append(this_samples) + # this (very rarely) fails due to empty tile list. Since we usually + # accumulate the features this doesn't hurt much and we can continue + try: + tiles = np.concatenate(tiles) + # take samples that belong to the current class + this_samples = tiles[labels[tiles] == class_id][:n_samples_class] + samples.append(this_samples) + except ValueError: + pass samples = np.concatenate(samples) # get the features and labels, add from previous rf if specified From e550ac6f5e50bdb9b70585056581a8f0f313e8f1 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:10:03 +0200 Subject: [PATCH 21/31] Add binary target option to boundry transform with background --- torch_em/transform/label.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index b43f9216..877d8c72 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -34,7 +34,7 @@ def label_consecutive(labels, with_background=True): return seg -# TODO ignore label + mask, smoothing +# TODO smoothing class BoundaryTransform: def __init__(self, mode="thick", add_binary_target=False, ndim=None): self.mode = mode @@ -52,12 +52,14 @@ def __call__(self, labels): return target +# TODO smoothing class NoToBackgroundBoundaryTransform: - def __init__(self, bg_label=0, mask_label=-1, mode="thick", ndim=None): + def __init__(self, bg_label=0, mask_label=-1, mode="thick", add_binary_target=False, ndim=None): self.bg_label = bg_label self.mask_label = mask_label self.mode = mode self.ndim = ndim + self.add_binary_target = add_binary_target def __call__(self, labels): labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) @@ -71,7 +73,15 @@ def __call__(self, labels): # mask the to-background-boundaries boundaries = boundaries.astype(np.int8) boundaries[to_bg_boundaries] = self.mask_label - return boundaries + + if self.add_binary_target: + binary = labels_to_binary(labels, self.bg_label).astype(boundaries.dtype) + binary[labels == self.mask_label] = self.mask_label + target = np.concatenate([binary[None], boundaries], axis=0) + else: + target = boundaries + + return target # TODO affinity smoothing From cc7f39fbc54beed074de46b678b6832cb1058fe7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:10:51 +0200 Subject: [PATCH 22/31] Add 3d mito s2d training --- .../em-mitochondria/train_mito_3d.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 experiments/shallow2deep/em-mitochondria/train_mito_3d.py diff --git a/experiments/shallow2deep/em-mitochondria/train_mito_3d.py b/experiments/shallow2deep/em-mitochondria/train_mito_3d.py new file mode 100644 index 00000000..b4770ba0 --- /dev/null +++ b/experiments/shallow2deep/em-mitochondria/train_mito_3d.py @@ -0,0 +1,188 @@ +import os +from glob import glob + +import torch +import torch_em +import torch_em.shallow2deep as shallow2deep +from torch_em.model import UNet3d +from torch_em.data.datasets.lucchi import _require_lucchi_data +from torch_em.data.datasets.kasthuri import _require_kasthuri_data + + +DATA_ROOT = "/scratch/pape/s2d-mitochondria" +DATASETS = ["kasthuri", "lucchi"] + + +def normalize_datasets(datasets): + wrong_ds = list(set(datasets) - set(DATASETS)) + if wrong_ds: + raise ValueError(f"Unkown datasets: {wrong_ds}. Only {DATASETS} are supported") + datasets = list(sorted(datasets)) + return datasets + + +def require_ds(dataset): + os.makedirs(DATA_ROOT, exist_ok=True) + data_path = os.path.join(DATA_ROOT, dataset) + if dataset == "kasthuri": + _require_kasthuri_data(data_path, download=True) + paths = [ + os.path.join(data_path, "kasthuri_train.h5"), + ] + assert all(os.path.exists(pp) for pp in paths), f"{paths}" + raw_key, label_key = "raw", "labels" + elif dataset == "lucchi": + _require_lucchi_data(data_path, download=True) + paths = [ + os.path.join(data_path, "lucchi_train.h5"), + ] + assert all(os.path.exists(pp) for pp in paths), f"{paths}" + raw_key, label_key = "raw", "labels" + return paths, raw_key, label_key + + +def require_rfs_ds(dataset, n_rfs, sampling_strategy): + out_folder = os.path.join(DATA_ROOT, f"rfs3d-{sampling_strategy}", dataset) + os.makedirs(out_folder, exist_ok=True) + if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: + return + + patch_shape_min = [64, 64, 64] + patch_shape_max = [96, 128, 128] + + raw_transform = torch_em.transform.raw.normalize + label_transform = shallow2deep.ForegroundTransform(ndim=3) + + if dataset == "kasthuri": + sampler = torch_em.data.sampler.MinForegroundSampler(min_fraction=0.05, background_id=[-1, 0]) + else: + sampler = None + + paths, raw_key, label_key = require_ds(dataset) + if sampling_strategy == "vanilla": + shallow2deep.prepare_shallow2deep( + raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, + patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, + n_forests=args.n_rfs, n_threads=args.n_threads, + output_folder=out_folder, ndim=3, + raw_transform=raw_transform, label_transform=label_transform, + is_seg_dataset=True, sampler=sampler + ) + else: + sampling_kwargs = {} + if sampling_strategy == "worst_tiles": + sampling_kwargs["tile_shape"] = [16, 16, 16] + shallow2deep.prepare_shallow2deep_advanced( + raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, + patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, + n_forests=args.n_rfs, n_threads=args.n_threads, + forests_per_stage=25, sample_fraction_per_stage=0.05, + output_folder=out_folder, ndim=3, + raw_transform=raw_transform, label_transform=label_transform, + is_seg_dataset=True, sampling_strategy=sampling_strategy, + sampler=sampler, sampling_kwargs=sampling_kwargs, + ) + + +def require_rfs(datasets, n_rfs, sampling_strategy): + for ds in datasets: + require_rfs_ds(ds, n_rfs, sampling_strategy) + + +def get_ds(file_pattern, rf_pattern, n_samples, label_key="labels", with_ignore=False): + raw_transform = torch_em.transform.raw.normalize + label_transform = torch_em.transform.BoundaryTransform(ndim=3, add_binary_target=True) + if with_ignore: + sampler = torch_em.data.sampler.MinForegroundSampler(min_fraction=0.05, background_id=[-1, 0]) + else: + sampler = None + patch_shape = (64, 256, 256) + paths = glob(file_pattern) + paths.sort() + assert len(paths) > 0 + rf_paths = glob(rf_pattern) + rf_paths.sort() + assert len(rf_paths) > 0 + raw_key = "raw" + return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( + paths, raw_key, paths, label_key, rf_paths, + patch_shape=patch_shape, + raw_transform=raw_transform, + label_transform=label_transform, + n_samples=n_samples, ndim=3, + sampler=sampler, + ) + + +def get_loader(args, split, dataset_names): + datasets = [] + n_samples = 500 if split == "train" else 25 + if "kasthuri" in dataset_names: + ds_name = "kasthuri" + # we need to use the test split here, because val is too small in z + split_ = "test" if split == "val" else split + file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split_}.h5") + rf_pattern = os.path.join(DATA_ROOT, f"rfs3d-{args.sampling_strategy}", ds_name, "*.pkl") + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, with_ignore=True)) + if "lucchi" in dataset_names: + ds_name = "kasthuri" + # we need to use the test split here, because val is too small in z + split_ = "test" if split == "val" else split + file_pattern = os.path.join(DATA_ROOT, ds_name, f"*_{split_}.h5") + rf_pattern = os.path.join(DATA_ROOT, f"rfs3d-{args.sampling_strategy}", ds_name, "*.pkl") + datasets.append(get_ds(file_pattern, rf_pattern, n_samples)) + ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) if len(datasets) > 1 else datasets[0] + loader = torch.utils.data.DataLoader( + ds, shuffle=True, batch_size=args.batch_size, num_workers=12 + ) + loader.shuffle = True + return loader + + +def train_shallow2deep(args): + datasets = normalize_datasets(args.datasets) + name = f"s2d-em-mitos-{'_'.join(datasets)}-3d-{args.sampling_strategy}" + require_rfs(datasets, args.n_rfs, args.sampling_strategy) + + model = UNet3d(in_channels=1, out_channels=2, final_activation="Sigmoid", + depth=4, initial_features=32) + + train_loader = get_loader(args, "train", datasets) + val_loader = get_loader(args, "val", datasets) + loss = torch_em.loss.DiceLoss() + if "kasthuri" in datasets: + loss = torch_em.loss.wrapper.LossWrapper( + loss, torch_em.loss.wrapper.MaskIgnoreLabel() + ) + + trainer = torch_em.default_segmentation_trainer( + name, model, train_loader, val_loader, learning_rate=1.0e-4, + loss=loss, device=args.device, log_image_interval=50 + ) + trainer.fit(args.n_iterations) + + +def check(args, train=True, val=True, n_images=2): + from torch_em.util.debug import check_loader + datasets = normalize_datasets(args.datasets) + if train: + print("Check train loader") + loader = get_loader(args, "train", datasets) + check_loader(loader, n_images) + if val: + print("Check val loader") + loader = get_loader(args, "val", datasets) + check_loader(loader, n_images) + + +if __name__ == "__main__": + parser = torch_em.util.parser_helper(require_input=False) + parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) + parser.add_argument("--n_rfs", type=int, default=500) + parser.add_argument("--n_threads", type=int, default=32) + parser.add_argument("--sampling_strategy", "-s", default="worst_tiles") + args = parser.parse_args() + if args.check: + check(args, n_images=5, val=False) + else: + train_shallow2deep(args) From e3ee157cd5598aa6ad26bd1056c79fea605e3a8c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:11:32 +0200 Subject: [PATCH 23/31] Add kasthuri 3d model training --- .../kasthuri/train_3d.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 experiments/mitochondria-segmentation/kasthuri/train_3d.py diff --git a/experiments/mitochondria-segmentation/kasthuri/train_3d.py b/experiments/mitochondria-segmentation/kasthuri/train_3d.py new file mode 100644 index 00000000..4b6a0366 --- /dev/null +++ b/experiments/mitochondria-segmentation/kasthuri/train_3d.py @@ -0,0 +1,42 @@ +import torch_em +from torch_em.model import UNet3d +from torch_em.data.datasets import get_kasthuri_loader + + +def get_loader(args, split): + patch_shape = (64, 256, 256) + + n_samples = 500 if split == "train" else 25 + sampler = torch_em.data.sampler.MinForegroundSampler(min_fraction=0.05, background_id=[-1, 0]) + label_transform = torch_em.transform.label.NoToBackgroundBoundaryTransform(ndim=3, add_binary_target=True) + loader = get_kasthuri_loader( + args.input, split=split, label_transform=label_transform, + batch_size=args.batch_size, patch_shape=patch_shape, + n_samples=n_samples, ndim=3, shuffle=True, + num_workers=12, sampler=sampler + ) + return loader + + +def train_direct(args): + name = "kasthuri-mito-3d" + model = UNet3d(in_channels=1, out_channels=2, final_activation="Sigmoid", depth=4, initial_features=32) + + train_loader = get_loader(args, "train") + val_loader = get_loader(args, "test") + loss = torch_em.loss.DiceLoss() + loss = torch_em.loss.wrapper.LossWrapper( + loss, torch_em.loss.wrapper.MaskIgnoreLabel() + ) + + trainer = torch_em.default_segmentation_trainer( + name, model, train_loader, val_loader, + loss=loss, learning_rate=3.0e-4, device=args.device, log_image_interval=50 + ) + trainer.fit(args.n_iterations) + + +if __name__ == "__main__": + parser = torch_em.util.parser_helper() + args = parser.parse_args() + train_direct(args) From ea310243d5d60857705e474417f6410e19dfd5c2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 10:15:44 +0200 Subject: [PATCH 24/31] Fix more issues in worst_tile sampling --- torch_em/shallow2deep/prepare_shallow2deep.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index c3e67d39..9293cef6 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -1,6 +1,7 @@ import os import copy import pickle +import warnings from concurrent import futures from glob import glob from functools import partial @@ -595,13 +596,20 @@ def worst_tiles( samples.append(this_samples) except ValueError: pass - samples = np.concatenate(samples) - # get the features and labels, add from previous rf if specified - features, labels = features[samples], labels[samples] - if accumulate_samples: - features = np.concatenate([last_forest.train_features, features], axis=0) - labels = np.concatenate([last_forest.train_labels, labels], axis=0) + try: + samples = np.concatenate(samples) + features, labels = features[samples], labels[samples] + + # get the features and labels, add from previous rf if specified + if accumulate_samples: + features = np.concatenate([last_forest.train_features, features], axis=0) + labels = np.concatenate([last_forest.train_labels, labels], axis=0) + except ValueError: + features, labels = last_forest.train_features, last_forest.train_labels + warnings.warn( + f"No features were sampled for forest {rf_id} using features of forest {rf_id - forests_per_stage}" + ) return features, labels From 44e7597d1b18e7e8fbc64b53922a089beb3932e0 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 16:14:06 +0200 Subject: [PATCH 25/31] Add tests for prepare shallow2deep --- .../shallow2deep/test_prepare_shallow2deep.py | 106 ++++++++++++++++++ torch_em/shallow2deep/prepare_shallow2deep.py | 2 + 2 files changed, 108 insertions(+) create mode 100644 test/shallow2deep/test_prepare_shallow2deep.py diff --git a/test/shallow2deep/test_prepare_shallow2deep.py b/test/shallow2deep/test_prepare_shallow2deep.py new file mode 100644 index 00000000..9f0445e2 --- /dev/null +++ b/test/shallow2deep/test_prepare_shallow2deep.py @@ -0,0 +1,106 @@ +import os +import unittest +from glob import glob +from shutil import rmtree + +import imageio +import h5py +import numpy as np + + +class TestPrepareShallow2Deep(unittest.TestCase): + tmp_folder = "./tmp" + rf_folder = "./tmp/rfs" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + try: + rmtree(self.tmp_folder) + except OSError: + pass + + def _create_seg_dataset(self): + path = os.path.join(self.tmp_folder, "data.h5") + raw_key = "raw" + label_key = "label" + with h5py.File(path, "w") as f: + f.create_dataset(raw_key, data=np.random.rand(128, 128).astype("float32")) + f.create_dataset(label_key, data=(np.random.rand(128, 128) > 0.5).astype("uint8")) + return path, raw_key, label_key + + def test_prepare_shallow2deep_seg_dataset(self): + from torch_em.shallow2deep import prepare_shallow2deep + path, raw_key, label_key = self._create_seg_dataset() + patch_shape_min = (48, 48) + patch_shape_max = (64, 64) + n_forests = 12 + n_threads = 6 + prepare_shallow2deep( + path, raw_key, path, label_key, patch_shape_min, patch_shape_max, + n_forests, n_threads, self.rf_folder, ndim=2, is_seg_dataset=True + ) + self.assertTrue(os.path.exists(self.rf_folder)) + n_rfs = len(glob(os.path.join(self.rf_folder, "*.pkl"))) + self.assertEqual(n_rfs, n_forests) + + def _create_collection_dataset(self): + n_images = 4 + + im_folder = os.path.join(self.tmp_folder, "images") + os.makedirs(im_folder) + im_paths = [] + for i in range(n_images): + path = os.path.join(im_folder, f"{i}.png") + imageio.imwrite(path, np.random.randint(0, 255, size=(96, 96)).astype("uint8")) + im_paths.append(path) + + label_folder = os.path.join(self.tmp_folder, "labels") + os.makedirs(label_folder) + label_paths = [] + for i in range(n_images): + path = os.path.join(label_folder, f"{i}.png") + imageio.imwrite(path, (np.random.rand(96, 96) > 0.5).astype("uint8")) + label_paths.append(path) + + return im_folder, label_folder + + def test_prepare_shallow2deep_image_dataset(self): + from torch_em.shallow2deep import prepare_shallow2deep + im_folder, label_folder = self._create_collection_dataset() + patch_shape_min = (48, 48) + patch_shape_max = (64, 64) + n_forests = 12 + n_threads = 6 + prepare_shallow2deep( + im_folder, "*.png", label_folder, "*.png", patch_shape_min, patch_shape_max, + n_forests, n_threads, self.rf_folder, ndim=2, is_seg_dataset=False + ) + self.assertTrue(os.path.exists(self.rf_folder)) + n_rfs = len(glob(os.path.join(self.rf_folder, "*.pkl"))) + self.assertEqual(n_rfs, n_forests) + + def test_prepare_shallow2deep_advanced(self): + from torch_em.shallow2deep import prepare_shallow2deep_advanced + from torch_em.shallow2deep.prepare_shallow2deep import SAMPLING_STRATEGIES + path, raw_key, label_key = self._create_seg_dataset() + patch_shape_min = (48, 48) + patch_shape_max = (64, 64) + n_forests = 12 + n_threads = 6 + for sampling_strategy in SAMPLING_STRATEGIES: + rf_folder = os.path.join(self.tmp_folder, f"rfs-{sampling_strategy}") + prepare_shallow2deep_advanced( + path, raw_key, path, label_key, patch_shape_min, patch_shape_max, + n_forests, n_threads, rf_folder, + forests_per_stage=4, sample_fraction_per_stage=0.10, + ndim=2, is_seg_dataset=True + ) + self.assertTrue(os.path.exists(rf_folder)) + n_rfs = len(glob(os.path.join(rf_folder, "*.pkl"))) + self.assertEqual(n_rfs, n_forests) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index 9293cef6..6698e731 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -158,6 +158,8 @@ def _check_patch(patch_shape): if isinstance(raw_paths, str): raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) ds = RFImageCollectionDataset(raw_paths, label_paths, patch_shape=patch_shape_min, **kwargs) + ds.patch_shape_min = patch_shape_min + ds.patch_shape_max = patch_shape_max elif raw_key is None: assert label_key is None assert isinstance(raw_paths, (list, tuple)) and isinstance(label_paths, (list, tuple)) From 8c71a160d471beb60bd6b21fce697bbfcb86d475 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 Jul 2022 19:59:49 +0200 Subject: [PATCH 26/31] Add more s2d tests, implement s2d training with image collectiond dataset --- .../test_shallow2deep_training.py | 112 ++++++++++++++++++ torch_em/shallow2deep/shallow2deep_dataset.py | 83 ++++++++++++- 2 files changed, 190 insertions(+), 5 deletions(-) create mode 100644 test/shallow2deep/test_shallow2deep_training.py diff --git a/test/shallow2deep/test_shallow2deep_training.py b/test/shallow2deep/test_shallow2deep_training.py new file mode 100644 index 00000000..f3bf507a --- /dev/null +++ b/test/shallow2deep/test_shallow2deep_training.py @@ -0,0 +1,112 @@ +import os +import unittest +from glob import glob +from shutil import rmtree + +import imageio +import h5py +import numpy as np +import torch_em +from torch_em.model import UNet2d + + +class TestShallow2DeepTraining(unittest.TestCase): + tmp_folder = "./tmp" + + def tearDown(self): + try: + rmtree(self.tmp_folder) + except OSError: + pass + try: + rmtree("./checkpoints") + except OSError: + pass + try: + rmtree("./logs") + except OSError: + pass + + def _create_seg_dataset(self): + os.makedirs(self.tmp_folder, exist_ok=True) + path = os.path.join(self.tmp_folder, "data.h5") + raw_key = "raw" + label_key = "label" + with h5py.File(path, "w") as f: + f.create_dataset(raw_key, data=np.random.rand(128, 128).astype("float32")) + f.create_dataset(label_key, data=(np.random.rand(128, 128) > 0.5).astype("uint8")) + return path, raw_key, label_key + + def _create_collection_dataset(self): + n_images = 4 + + im_folder = os.path.join(self.tmp_folder, "images") + os.makedirs(im_folder) + im_paths = [] + for i in range(n_images): + path = os.path.join(im_folder, f"{i}.png") + imageio.imwrite(path, np.random.randint(0, 255, size=(96, 96)).astype("uint8")) + im_paths.append(path) + + label_folder = os.path.join(self.tmp_folder, "labels") + os.makedirs(label_folder) + label_paths = [] + for i in range(n_images): + path = os.path.join(label_folder, f"{i}.png") + imageio.imwrite(path, (np.random.rand(96, 96) > 0.5).astype("uint8")) + label_paths.append(path) + + return im_folder, label_folder + + def test_shallow2deep_training_seg_ds(self): + from torch_em.shallow2deep import prepare_shallow2deep, get_shallow2deep_loader + path, raw_key, label_key = self._create_seg_dataset() + name = "s2d-seg" + rf_folder = os.path.join(self.tmp_folder, "rfs") + prepare_shallow2deep(path, raw_key, path, label_key, + patch_shape_min=(48, 48), + patch_shape_max=(96, 96), + n_forests=12, n_threads=6, + output_folder=rf_folder, ndim=2) + rf_paths = glob(os.path.join(rf_folder, "*.pkl")) + loader = get_shallow2deep_loader(path, raw_key, path, label_key, + rf_paths, batch_size=1, patch_shape=(64, 64), + n_samples=20) + net = UNet2d( + in_channels=1, out_channels=1, initial_features=4, gain=2, depth=2, + final_activation="Sigmoid" + ) + trainer = torch_em.default_segmentation_trainer(name, net, loader, loader) + trainer.fit(40) + self.assertTrue(os.path.exists(os.path.join( + "./checkpoints", name, "latest.pt" + ))) + + def test_shallow2deep_training_image_ds(self): + from torch_em.shallow2deep import prepare_shallow2deep, get_shallow2deep_loader + im_folder, label_folder = self._create_collection_dataset() + name = "s2d-im" + rf_folder = os.path.join(self.tmp_folder, "rfs") + prepare_shallow2deep(im_folder, "*.png", label_folder, "*.png", + patch_shape_min=(48, 48), + patch_shape_max=(96, 96), + n_forests=12, n_threads=6, + output_folder=rf_folder, ndim=2, + is_seg_dataset=False) + rf_paths = glob(os.path.join(rf_folder, "*.pkl")) + loader = get_shallow2deep_loader(im_folder, "*.png", label_folder, "*.png", + rf_paths, batch_size=1, patch_shape=(64, 64), + n_samples=20, is_seg_dataset=False) + net = UNet2d( + in_channels=1, out_channels=1, initial_features=4, gain=2, depth=2, + final_activation="Sigmoid" + ) + trainer = torch_em.default_segmentation_trainer(name, net, loader, loader) + trainer.fit(40) + self.assertTrue(os.path.exists(os.path.join( + "./checkpoints", name, "latest.pt" + ))) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_em/shallow2deep/shallow2deep_dataset.py b/torch_em/shallow2deep/shallow2deep_dataset.py index d7a3e6e3..42a10681 100644 --- a/torch_em/shallow2deep/shallow2deep_dataset.py +++ b/torch_em/shallow2deep/shallow2deep_dataset.py @@ -1,17 +1,19 @@ +import os import pickle import warnings +from glob import glob import numpy as np import torch from torch_em.segmentation import (check_paths, is_segmentation_dataset, get_data_loader, get_raw_transform, samples_to_datasets, _get_default_transform) -from torch_em.data import ConcatDataset, SegmentationDataset +from torch_em.data import ConcatDataset, ImageCollectionDataset, SegmentationDataset from .prepare_shallow2deep import _get_filters, _apply_filters from ..util import ensure_tensor_with_channels, ensure_spatial_array -class Shallow2DeepDataset(SegmentationDataset): +class _Shallow2DeepBase: _rf_paths = None _filter_config = None @@ -87,6 +89,8 @@ def _predict_rf_anisotropic(self, raw): return prediction + +class Shallow2DeepDataset(SegmentationDataset, _Shallow2DeepBase): def __getitem__(self, index): assert self._rf_paths is not None raw, labels = self._get_sample(index) @@ -125,7 +129,44 @@ def __getitem__(self, index): return prediction, labels -def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs): +class Shallow2DeepImageCollectionDataset(ImageCollectionDataset, _Shallow2DeepBase): + def __getitem__(self, index): + raw, labels = self._get_sample(index) + initial_label_dtype = labels.dtype + + if self.raw_transform is not None: + raw = self.raw_transform(raw) + + if self.label_transform is not None: + labels = self.label_transform(labels) + + if self.transform is not None: + raw, labels = self.transform(raw, labels) + + # support enlarging bounding box here as well (for affinity transform) ? + if self.label_transform2 is not None: + labels = ensure_spatial_array(labels, self.ndim, dtype=initial_label_dtype) + labels = self.label_transform2(labels) + + if isinstance(raw, (list, tuple)): # this can be a list or tuple due to transforms + assert len(raw) == 1 + raw = raw[0] + raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype) + if raw.shape[0] > 1: + raise NotImplementedError( + f"Shallow2Deep training not implemented for multi-channel input yet; got {raw.shape[0]} channels" + ) + + # NOTE we assume single channel raw data here; this needs to be changed for multi-channel + prediction = self._predict_rf(raw[0].numpy()) + prediction = ensure_tensor_with_channels(prediction, ndim=self._ndim, dtype=self.dtype) + labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype) + return prediction, labels + + +def _load_shallow2deep_segmentation_dataset( + raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, ndim, **kwargs +): rois = kwargs.pop("rois", None) filter_config = kwargs.pop("filter_config", None) if ndim == "anisotropic": @@ -167,6 +208,32 @@ def _load_shallow2deep_dataset(raw_paths, raw_key, label_paths, label_key, rf_pa return ds +def _load_shallow2deep_image_collection_dataset( + raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, **kwargs +): + if isinstance(raw_paths, str): + assert isinstance(label_paths, str) + raw_file_paths = glob(os.path.join(raw_paths, raw_key)) + raw_file_paths.sort() + label_file_paths = glob(os.path.join(label_paths, label_key)) + label_file_paths.sort() + ds = Shallow2DeepImageCollectionDataset(raw_file_paths, label_file_paths, patch_shape, **kwargs) + elif isinstance(raw_paths, list) and raw_key is None: + assert isinstance(label_paths, list) + assert label_key is None + assert all(os.path.exists(pp) for pp in raw_paths) + assert all(os.path.exists(pp) for pp in label_paths) + ds = Shallow2DeepDataset(raw_paths, label_paths, patch_shape, **kwargs) + else: + raise NotImplementedError + + filter_config = kwargs.pop("filter_config", None) + ds.rf_paths = rf_paths + ds.filter_config = filter_config + ds.rf_channels = rf_channels + return ds + + def get_shallow2deep_dataset( raw_paths, raw_key, @@ -204,7 +271,7 @@ def get_shallow2deep_dataset( ) if is_seg_dataset: - ds = _load_shallow2deep_dataset( + ds = _load_shallow2deep_segmentation_dataset( raw_paths, raw_key, label_paths, @@ -224,7 +291,13 @@ def get_shallow2deep_dataset( rf_channels=rf_channels, ) else: - raise NotImplementedError("Image collection dataset for shallow2deep not implemented yet.") + if rois is not None: + raise NotImplementedError + ds = _load_shallow2deep_image_collection_dataset( + raw_paths, raw_key, label_paths, label_key, rf_paths, rf_channels, patch_shape, + raw_transform=raw_transform, label_transform=label_transform, + transform=transform, dtype=dtype, n_samples=n_samples, + ) return ds From 2bbafa899e2adeeccbdfdf0fdc962dd890501e2e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 24 Jul 2022 13:33:03 +0200 Subject: [PATCH 27/31] Update training scripts for 2d lm membranes --- .../lm-membrane/check_rf_loaders.py | 81 +++++++++++++ .../lm-membrane/train_membrane_2d.py | 107 +++++++++++++----- 2 files changed, 160 insertions(+), 28 deletions(-) create mode 100644 experiments/shallow2deep/lm-membrane/check_rf_loaders.py diff --git a/experiments/shallow2deep/lm-membrane/check_rf_loaders.py b/experiments/shallow2deep/lm-membrane/check_rf_loaders.py new file mode 100644 index 00000000..e3fcf955 --- /dev/null +++ b/experiments/shallow2deep/lm-membrane/check_rf_loaders.py @@ -0,0 +1,81 @@ +import argparse +import os +from glob import glob + +import torch_em +import torch_em.shallow2deep as shallow2deep +from torch_em.data.datasets.mouse_embryo import _require_embryo_data +from torch_em.data.datasets.plantseg import _require_plantseg_data +from torch_em.shallow2deep.prepare_shallow2deep import _prepare_shallow2deep + +DATA_ROOT = "/scratch/pape/s2d-lm-boundaries" + + +# check the rf loader to see if samplers for the complicated 3d datasets work +# - mouse-embryo: +# - 2d: default sampler is fine +# - 3d: TODO +# - ovules: +# - 2d: default sampler is fine +# - 3d: TODO +# - root: +# - 2d: default sampler is fine +# - 3d: TODO +# + + +def require_ds(dataset): + os.makedirs(DATA_ROOT, exist_ok=True) + data_path = os.path.join(DATA_ROOT, dataset) + if dataset == "mouse-embryo": + _require_embryo_data(data_path, True) + paths = glob(os.path.join(data_path, "Membrane", "train", "*.h5")) + raw_key, label_key = "raw", "label" + elif dataset == "ovules": + _require_plantseg_data(data_path, True, "ovules", "train") + paths = glob(os.path.join(data_path, "ovules_train", "*.h5")) + raw_key, label_key = "raw", "label" + elif dataset == "root": + _require_plantseg_data(data_path, True, "root", "train") + paths = glob(os.path.join(data_path, "root_train", "*.h5")) + raw_key, label_key = "raw", "label" + return paths, raw_key, label_key + + +def check_rf_loader(dataset, ndim): + assert dataset in ("mouse-embryo", "ovules", "root") + paths, raw_key, label_key = require_ds(dataset) + n_images = 16 + if ndim == 2: + patch_shape_min = [1, 248, 248] + patch_shape_max = [1, 256, 256] + else: + pass # TODO + # TODO sampler + raw_transform = torch_em.transform.raw.normalize + label_transform = shallow2deep.BoundaryTransform(ndim=ndim) + ds, _ = _prepare_shallow2deep( + paths, raw_key, paths, label_key, + patch_shape_min, patch_shape_max, + n_forests=n_images, ndim=ndim, + raw_transform=raw_transform, label_transform=label_transform, + rois=None, filter_config=None, sampler=None, + is_seg_dataset=True, + ) + + print("Start viewer") + import napari + for i in range(len(ds)): + x, y = ds[i] + v = napari.Viewer() + v.add_image(x.numpy().squeeze(), name="data") + v.add_image(y.numpy().squeeze(), name="target") + napari.run() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", "-d", required=True) + parser.add_argument("-n", "--ndim", default=2, type=int) + args = parser.parse_args() + check_rf_loader(args.dataset, args.ndim) diff --git a/experiments/shallow2deep/lm-membrane/train_membrane_2d.py b/experiments/shallow2deep/lm-membrane/train_membrane_2d.py index 8f6210a7..810d36ad 100644 --- a/experiments/shallow2deep/lm-membrane/train_membrane_2d.py +++ b/experiments/shallow2deep/lm-membrane/train_membrane_2d.py @@ -9,6 +9,7 @@ from torch_em.data.datasets.covid_if import _download_covid_if from torch_em.data.datasets.mouse_embryo import _require_embryo_data from torch_em.data.datasets.plantseg import _require_plantseg_data +from torch_em.data.datasets.livecell import _download_livecell_images, _download_livecell_annotations # any more publicly available datasets? @@ -24,10 +25,10 @@ def normalize_datasets(datasets): return datasets -# def _require_livecell_data(data_path, split): -# _download_livecell_images(data_path, True) -# image_paths, label_paths = _download_livecell_annotations(daa_path, split, True) -# return image_paths, label_paths +def _require_livecell_data(data_path, split): + _download_livecell_images(data_path, True) + image_paths, label_paths = _download_livecell_annotations(data_path, split, True) + return image_paths, label_paths def require_ds(dataset): @@ -38,26 +39,31 @@ def require_ds(dataset): paths = glob(os.path.join(data_path, "*.h5")) paths.sort() paths = paths[:-5] + label_paths = paths raw_key, label_key = "raw/serum_IgG/s0", "labels/cells/s0" elif dataset == "livecell": - raise NotImplementedError + paths, label_paths = _require_livecell_data(data_path, "train") + raw_key, label_key = None, None elif dataset == "mouse-embryo": _require_embryo_data(data_path, True) paths = glob(os.path.join(data_path, "Membrane", "train", "*.h5")) + label_paths = paths raw_key, label_key = "raw", "label" elif dataset == "ovules": _require_plantseg_data(data_path, True, "ovules", "train") paths = glob(os.path.join(data_path, "ovules_train", "*.h5")) + label_paths = paths raw_key, label_key = "raw", "label" elif dataset == "root": _require_plantseg_data(data_path, True, "root", "train") paths = glob(os.path.join(data_path, "root_train", "*.h5")) + label_paths = paths raw_key, label_key = "raw", "label" - return paths, raw_key, label_key + return paths, label_paths, raw_key, label_key -def require_rfs_ds(dataset, n_rfs): - out_folder = os.path.join(DATA_ROOT, "rfs2d", dataset) +def require_rfs_ds(dataset, n_rfs, sampling_strategy): + out_folder = os.path.join(DATA_ROOT, f"rfs2d-{sampling_strategy}", dataset) os.makedirs(out_folder, exist_ok=True) if len(glob(os.path.join(out_folder, "*.pkl"))) == n_rfs: return @@ -73,26 +79,33 @@ def require_rfs_ds(dataset, n_rfs): raw_transform = torch_em.transform.raw.normalize label_transform = shallow2deep.BoundaryTransform(ndim=2) - paths, raw_key, label_key = require_ds(dataset) + paths, label_paths, raw_key, label_key = require_ds(dataset) + is_seg_dataset = True + if dataset == "livecell": + is_seg_dataset = False sampler = torch_em.data.MinForegroundSampler(min_fraction=0.05, background_id=1 if dataset == "root" else 0) shallow2deep.prepare_shallow2deep_advanced( - raw_paths=paths, raw_key=raw_key, label_paths=paths, label_key=label_key, + raw_paths=paths, raw_key=raw_key, label_paths=label_paths, label_key=label_key, patch_shape_min=patch_shape_min, patch_shape_max=patch_shape_max, n_forests=args.n_rfs, n_threads=args.n_threads, forests_per_stage=25, sample_fraction_per_stage=0.10, output_folder=out_folder, ndim=2, raw_transform=raw_transform, label_transform=label_transform, - is_seg_dataset=True, sampler=sampler + is_seg_dataset=is_seg_dataset, sampler=sampler, sampling_strategy=sampling_strategy, ) -def require_rfs(datasets, n_rfs): +def require_rfs(datasets, n_rfs, sampling_strategy): for ds in datasets: - require_rfs_ds(ds, n_rfs) + require_rfs_ds(ds, n_rfs, sampling_strategy) -def get_ds(file_pattern, rf_pattern, n_samples, is3d_data=True, path_selection=None, raw_key="raw", label_key="label"): +def get_ds( + file_pattern, rf_pattern, n_samples, is3d_data=True, path_selection=None, + raw_key="raw", label_key="label", sampler=None +): + raw_transform = torch_em.transform.raw.normalize label_transform = torch_em.transform.BoundaryTransform(ndim=2, add_binary_target=False) patch_shape = [1, 256, 256] if is3d_data else [256, 256] paths = glob(file_pattern) @@ -105,8 +118,11 @@ def get_ds(file_pattern, rf_pattern, n_samples, is3d_data=True, path_selection=N assert len(rf_paths) > 0 return shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( paths, raw_key, paths, label_key, rf_paths, - patch_shape=patch_shape, label_transform=label_transform, - n_samples=n_samples, ndim=2 + patch_shape=patch_shape, + raw_transform=raw_transform, + label_transform=label_transform, + n_samples=n_samples, ndim=2, + sampler=sampler, ) @@ -116,31 +132,43 @@ def get_loader(args, split, dataset_names): if "covid-if" in dataset_names: ds_name = "covid-if" file_pattern = os.path.join(DATA_ROOT, ds_name, "*.h5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") raw_key, label_key = "raw/serum_IgG/s0", "labels/cells/s0" datasets.append(get_ds( file_pattern, rf_pattern, n_samples, is3d_data=False, raw_key=raw_key, label_key=label_key, path_selection=np.s_[:-5] if split == "train" else np.s_[-5:] )) if "livecell" in dataset_names: - raise NotImplementedError + image_paths, label_paths = _require_livecell_data(os.path.join(DATA_ROOT, "livecell"), split) + rf_paths = glob(os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", "livecell", "*.pkl")) + ds = shallow2deep.shallow2deep_dataset.get_shallow2deep_dataset( + image_paths, None, label_paths, None, rf_paths, + patch_shape=[256, 256], is_seg_dataset=False, + label_transform=torch_em.transform.BoundaryTransform(ndim=2), + raw_transform=torch_em.transform.raw.normalize, + ) + datasets.append(ds) if "mouse-embryo" in dataset_names: ds_name = "mouse-embryo" file_pattern = os.path.join(DATA_ROOT, ds_name, "Membrane", split, "*.h5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") - datasets.append(get_ds(file_pattern, rf_pattern, n_samples)) + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") + sampler = torch_em.data.MinForegroundSampler(min_fraction=0.4, background_id=0) + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, sampler=sampler)) if "ovules" in dataset_names: ds_name = "ovules" _require_plantseg_data(os.path.join(DATA_ROOT, ds_name), True, ds_name, split) file_pattern = os.path.join(DATA_ROOT, ds_name, f"ovules_{split}", "*.h5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") - datasets.append(get_ds(file_pattern, rf_pattern, n_samples)) + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") + sampler = torch_em.data.MinForegroundSampler(min_fraction=0.4, background_id=0) + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, sampler=sampler)) if "root" in dataset_names: ds_name = "root" _require_plantseg_data(os.path.join(DATA_ROOT, ds_name), True, ds_name, split) file_pattern = os.path.join(DATA_ROOT, ds_name, f"root_{split}", "*.h5") - rf_pattern = os.path.join(DATA_ROOT, "rfs2d", ds_name, "*.pkl") - datasets.append(get_ds(file_pattern, rf_pattern, n_samples)) + rf_pattern = os.path.join(DATA_ROOT, f"rfs2d-{args.sampling_strategy}", ds_name, "*.pkl") + sampler = torch_em.data.MinForegroundSampler(min_fraction=0.4, background_id=1) + datasets.append(get_ds(file_pattern, rf_pattern, n_samples, sampler=sampler)) + assert len(datasets) > 0 ds = torch_em.data.concat_dataset.ConcatDataset(*datasets) loader = torch.utils.data.DataLoader( ds, shuffle=True, batch_size=args.batch_size, num_workers=12 @@ -151,8 +179,8 @@ def get_loader(args, split, dataset_names): def train_shallow2deep(args): datasets = normalize_datasets(args.datasets) - name = f"s2d-lm-membrane-{'_'.join(datasets)}-2d" - require_rfs(datasets, args.n_rfs) + name = f"s2d-lm-membrane-{'_'.join(datasets)}-2d-{args.sampling_strategy}" + require_rfs(datasets, args.n_rfs, args.sampling_strategy) train_loader = get_loader(args, "train", datasets) val_loader = get_loader(args, "val", datasets) @@ -166,10 +194,33 @@ def train_shallow2deep(args): trainer.fit(args.n_iterations) +# after looking at samples from initial random forests: +# - covid-if: random forests look horrible (worst_tiles)! -> can't use it yet +# - livecell: random forests only partially ok -> don't use it yet +# - mouse-embryo: looks ok (worst_tiles) +# - ovules: looks ok (worst_tiles) +# - root: looks ok (worst_tiles) +def check(args, train=True, val=True, n_images=2): + from torch_em.util.debug import check_loader + datasets = normalize_datasets(args.datasets) + if train: + print("Check train loader") + loader = get_loader(args, "train", datasets) + check_loader(loader, n_images) + if val: + print("Check val loader") + loader = get_loader(args, "val", datasets) + check_loader(loader, n_images) + + if __name__ == "__main__": parser = torch_em.util.parser_helper(require_input=False, default_batch_size=4) parser.add_argument("--datasets", "-d", nargs="+", default=DATASETS) - parser.add_argument("--n_rfs", type=int, default=150, help="Number of foersts per dataset") + parser.add_argument("--n_rfs", type=int, default=500, help="Number of forests per dataset") parser.add_argument("--n_threads", type=int, default=32) + parser.add_argument("--sampling_strategy", "-s", default="worst_tiles") args = parser.parse_args() - train_shallow2deep(args) + if args.check: + check(args, n_images=8) + else: + train_shallow2deep(args) From 4a96fa030b32db8925911b246aea7cb9d9d0d30e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 24 Jul 2022 13:33:57 +0200 Subject: [PATCH 28/31] Fix minor issues in prefab datasets --- torch_em/data/datasets/livecell.py | 3 +++ torch_em/data/datasets/mouse_embryo.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_em/data/datasets/livecell.py b/torch_em/data/datasets/livecell.py index ad8ca3d2..e605cf4b 100644 --- a/torch_em/data/datasets/livecell.py +++ b/torch_em/data/datasets/livecell.py @@ -64,6 +64,9 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo file_name = image_metadata["file_name"] sub_folder = file_name.split("_")[0] image_path = os.path.join(image_folder, sub_folder, file_name) + # something changed in the image layout? we keep the old version around in case this chagnes back... + if not os.path.exists(image_path): + image_path = os.path.join(image_folder, file_name) assert os.path.exists(image_path), image_path image_paths.append(image_path) diff --git a/torch_em/data/datasets/mouse_embryo.py b/torch_em/data/datasets/mouse_embryo.py index 09f4bbcb..1b113723 100644 --- a/torch_em/data/datasets/mouse_embryo.py +++ b/torch_em/data/datasets/mouse_embryo.py @@ -16,7 +16,7 @@ def _require_embryo_data(path, download): download_source(tmp_path, URL, download, CHECKSUM) unzip(tmp_path, path, remove=True) # remove empty volume - os.remove(os.path.join(path, "Membrane", "fused_paral_stack0_chan2_tp00073_raw_crop_bg_noise.h5")) + os.remove(os.path.join(path, "Membrane", "train", "fused_paral_stack0_chan2_tp00073_raw_crop_bg_noise.h5")) def get_mouse_embryo_loader( From 03fe64068a9fd58026adcd31fc6451335482e7e6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 24 Jul 2022 13:34:33 +0200 Subject: [PATCH 29/31] Enable worst_tiles sampling with ignore label --- torch_em/shallow2deep/prepare_shallow2deep.py | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/torch_em/shallow2deep/prepare_shallow2deep.py b/torch_em/shallow2deep/prepare_shallow2deep.py index 6698e731..637c1e62 100644 --- a/torch_em/shallow2deep/prepare_shallow2deep.py +++ b/torch_em/shallow2deep/prepare_shallow2deep.py @@ -281,7 +281,7 @@ def _balance_labels(labels, mask): return mask -def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels): +def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels, return_mask=False): # find the mask for where we compute filters and labels # by default we exclude everything that has label -1 assert labels.shape == raw.shape @@ -293,7 +293,10 @@ def _get_features_and_labels(raw, labels, filters_and_sigmas, balance_labels): features = _apply_filters_with_mask(raw, filters_and_sigmas, mask) assert features.ndim == 2 assert len(features) == len(labels) - return features, labels + if return_mask: + return features, labels, mask + else: + return features, labels def _prepare_shallow2deep( @@ -527,6 +530,7 @@ def worst_tiles( forests, forests_per_stage, sample_fraction_per_stage, img_shape, + mask, tile_shape=[25, 25], smoothing_sigma=None, accumulate_samples=True, @@ -551,19 +555,37 @@ def worst_tiles( assert len(diff) == len(features) # reshape diff to image shape - diff_img = diff.reshape(img_shape + (-1,)) + # we need to also take into account the mask here, and if we apply any masking + # because we can't directly reshape if we have it + if mask.sum() != mask.size: + # get the diff image + diff_img = np.zeros(img_shape + diff.shape[-1:], dtype=diff.dtype) + diff_img[mask] = diff + # inflate the features + full_features = np.zeros((mask.size,) + features.shape[-1:], dtype=features.dtype) + full_features[mask.ravel()] = features + features = full_features + # inflate the labels (with -1 so this will not be sampled) + full_labels = np.full(mask.size, -1, dtype="int8") + full_labels[mask.ravel()] = labels + labels = full_labels + else: + diff_img = diff.reshape(img_shape + (-1,)) + + # get the number of classes (not counting ignore label) + class_ids = np.unique(labels) + nc = len(class_ids) - 1 if -1 in class_ids else len(class_ids) # sample in a class balanced way - nc = len(np.unique(labels)) n_samples_class = int(sample_fraction_per_stage * len(features)) // nc samples = [] for class_id in range(nc): # smooth either with gaussian or 1-kernel if smoothing_sigma: - diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode='constant') + diff_img_smooth = gaussian_filter(diff_img[..., class_id], smoothing_sigma, mode="constant") else: kernel = np.ones(tile_shape) - diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode='constant') + diff_img_smooth = convolve(diff_img[..., class_id], kernel, mode="constant") # get training samples based on tiles around maxima of the label-prediction diff # do this in a class-specific way to ensure that each class is sampled @@ -696,14 +718,15 @@ def _train_rf(rf_id): current_kwargs["img_shape"] = raw.shape # only balance samples for the first (densely trained) rfs - features, labels = _get_features_and_labels( - raw, labels, filters_and_sigmas, balance_labels=False + features, labels, mask = _get_features_and_labels( + raw, labels, filters_and_sigmas, balance_labels=False, return_mask=True ) if forests: # we have forests: apply the sampling strategy features, labels = sampling_strategy( features, labels, rf_id, forests, forests_per_stage, sample_fraction_per_stage, + mask=mask, **current_kwargs, ) else: # sample randomly From c14e5b214a0cb4985a1ccfb2570fbebe43a8dcaf Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 24 Jul 2022 13:34:57 +0200 Subject: [PATCH 30/31] Fix typo --- torch_em/shallow2deep/shallow2deep_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/shallow2deep/shallow2deep_dataset.py b/torch_em/shallow2deep/shallow2deep_dataset.py index 42a10681..f165b2e5 100644 --- a/torch_em/shallow2deep/shallow2deep_dataset.py +++ b/torch_em/shallow2deep/shallow2deep_dataset.py @@ -223,7 +223,7 @@ def _load_shallow2deep_image_collection_dataset( assert label_key is None assert all(os.path.exists(pp) for pp in raw_paths) assert all(os.path.exists(pp) for pp in label_paths) - ds = Shallow2DeepDataset(raw_paths, label_paths, patch_shape, **kwargs) + ds = Shallow2DeepImageCollectionDataset(raw_paths, label_paths, patch_shape, **kwargs) else: raise NotImplementedError From 4b6598119c618a9b0d88162b7cf381aa6bddbb0b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 24 Jul 2022 21:54:28 +0200 Subject: [PATCH 31/31] Update s2d mito experiments --- .../shallow2deep/em-mitochondria/README.md | 18 ++++++++++++++++++ .../shallow2deep/em-mitochondria/evaluation.py | 11 +++++++---- .../em-mitochondria/export_enhancer.py | 3 ++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/experiments/shallow2deep/em-mitochondria/README.md b/experiments/shallow2deep/em-mitochondria/README.md index 6769d394..02ca1617 100644 --- a/experiments/shallow2deep/em-mitochondria/README.md +++ b/experiments/shallow2deep/em-mitochondria/README.md @@ -28,6 +28,24 @@ All scores are measured with a soft dice score. - `random_points`: random points sampled in each stage, points are accumulated over the stages - `worst_tiles`: training samples are taken from worst tile predictions +| method | few-labels | medium-labels | many-labels | +|:-----------------------------------|-------------:|----------------:|--------------:| +| rf3d | 0.326 | 0.328 | 0.385 | +| 2d-random_points | 0.593 | 0.693 | 0.782 | +| 2d-uncertain_worst_points | 0.613 | 0.777 | 0.794 | +| 2d-vanilla | 0.639 | 0.717 | 0.764 | +| 2d-worst_points | 0.549 | 0.711 | 0.730 | +| 2d-worst_tiles | 0.661 | 0.796 | 0.828 | +| direct_2d | 0.849 | nan | nan | +| anisotropic-random_points | 0.521 | 0.566 | 0.671 | +| anisotropic-uncertain_worst_points | 0.530 | 0.616 | 0.711 | +| anisotropic-vanilla | 0.576 | 0.660 | 0.749 | +| anisotropic-worst_points | 0.458 | 0.568 | 0.600 | +| anisotropic-worst_tiles | 0.614 | 0.728 | 0.788 | +| direct_anisotropic | 0.467 | nan | nan | +| 3d-random_points | 0.344 | 0.381 | 0.353 | +| 3d-worst_tiles | 0.385 | 0.472 | 0.504 | + ### V5 diff --git a/experiments/shallow2deep/em-mitochondria/evaluation.py b/experiments/shallow2deep/em-mitochondria/evaluation.py index 65bff4f2..1b5bf4b1 100644 --- a/experiments/shallow2deep/em-mitochondria/evaluation.py +++ b/experiments/shallow2deep/em-mitochondria/evaluation.py @@ -152,7 +152,7 @@ def require_net_3d(data_path, model_path, model_name, save_path, tiling): def get_enhancers(root): names = [os.path.basename(path) for path in glob(os.path.join(root, "s2d-em*"))] - enhancers_2d, enhancers_anisotropic = {}, {} + enhancers_2d, enhancers_anisotropic, enhancers_3d = {}, {}, {} for name in names: parts = name.split("-") sampling_strategy, dim = parts[-1], parts[-2] @@ -162,9 +162,12 @@ def get_enhancers(root): enhancers_anisotropic[f"{dim}-{sampling_strategy}"] = path elif dim == "2d": enhancers_2d[f"{dim}-{sampling_strategy}"] = path + elif dim == "3d": + enhancers_3d[f"{dim}-{sampling_strategy}"] = path assert len(enhancers_2d) > 0 assert len(enhancers_anisotropic) > 0 - return enhancers_2d, enhancers_anisotropic + assert len(enhancers_3d) > 0 + return enhancers_2d, enhancers_anisotropic, enhancers_3d def run_evaluation(data_path, save_path, eval_path, label_key="label"): @@ -254,7 +257,7 @@ def evaluate_lucchi(version): "medium-labels": os.path.join(rf_folder, "2.ilp"), "many-labels": os.path.join(rf_folder, "3.ilp"), } - enhancers_2d, enhancers_anisotropic = get_enhancers(f"./bio-models/v{version}") + enhancers_2d, enhancers_anisotropic, enhancers_3d = get_enhancers(f"./bio-models/v{version}") net2d = "./bio-models/v2/DirectModel/MitchondriaEMSegmentation2D.zip" net_aniso = "./bio-models/v3/DirectModel/mitochondriaemsegmentationboundarymodel_pytorch_state_dict.zip" @@ -262,7 +265,7 @@ def evaluate_lucchi(version): require_enhancers_2d(rfs, enhancers_2d, save_path) require_enhancers_3d(rfs, enhancers_anisotropic, save_path) - # TODO add the 3d enhancers + require_enhancers_3d(rfs, enhancers_3d, save_path) require_net_2d(data_path, net2d, "direct_2d", save_path) tiling_aniso = { diff --git a/experiments/shallow2deep/em-mitochondria/export_enhancer.py b/experiments/shallow2deep/em-mitochondria/export_enhancer.py index 6f03277c..5e7a254c 100644 --- a/experiments/shallow2deep/em-mitochondria/export_enhancer.py +++ b/experiments/shallow2deep/em-mitochondria/export_enhancer.py @@ -106,7 +106,7 @@ def export_enhancer(input_, is3d, checkpoint=None, version=None, name=None): input_data = create_input_anisotropic(input_, rf_path) is3d = True elif is3d: - assert False, "Currently don't have 3d rfs for mitos" + rf_path = "/scratch/pape/s2d-mitochondria/rfs3d-worst_tiles/kasthuri/rf_0499.pkl" input_data = create_input_3d(input_, rf_path) else: rf_path = "/scratch/pape/s2d-mitochondria/rfs2d-worst_points/mitoem/rf_0499.pkl" @@ -175,6 +175,7 @@ def _get_ndim(x): for ckpt in checkpoints: name = os.path.basename(ckpt) if(os.path.exists(os.path.join(out_folder, name))): + print("Already exported::", name) continue parts = name.split("-") is3d = _get_ndim(parts[-2])