diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6e9a8614..7a1c8b2e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,4 +1,3 @@ ---- name: test on: [push, pull_request] diff --git a/experiments/livecell/check_cell_type_performance.py b/experiments/livecell/check_cell_type_performance.py new file mode 100644 index 00000000..3b7a1233 --- /dev/null +++ b/experiments/livecell/check_cell_type_performance.py @@ -0,0 +1,66 @@ +import os +from glob import glob + +import numpy as np +import pandas as pd +try: + import imageio.v2 as imageio +except ImportError: + import imageio + +from tqdm import tqdm +from xarray import DataArray +from elf.evaluation import dice_score + + +def run_prediction(input_folder, output_folder): + import bioimageio.core + os.makedirs(output_folder, exist_ok=True) + + inputs = glob(os.path.join(input_folder, "*.tif")) + model = bioimageio.core.load_resource_description("10.5281/zenodo.5869899") + + with bioimageio.core.create_prediction_pipeline(model) as pp: + for inp in tqdm(inputs): + fname = os.path.basename(inp) + out_path = os.path.join(output_folder, fname) + image = imageio.v2.imread(inp) + input_ = DataArray(image[None, None], dims=tuple("bcyx")) + pred = bioimageio.core.predict_with_padding(pp, input_)[0].values.squeeze() + imageio.volwrite(out_path, pred) + + +def evaluate(label_folder, output_folder): + cell_types = ["A172", "BT474", "BV2", "Huh7", + "MCF7", "SHSY5Y", "SkBr3", "SKOV3"] + grid = pd.DataFrame(columns=["Cell_types"] + cell_types) + row = ["all"] + for i in cell_types: + label_files = glob(os.path.join(label_folder, i, "*.tif")) + this_scores = [] + for label_file in label_files: + fname = os.path.basename(label_file) + pred_file = os.path.join(output_folder, fname) + label = imageio.imread(label_file) + pred = imageio.volread(pred_file)[0] + score = dice_score(pred, label != 0, threshold_gt=None, threshold_seg=None) + + this_scores.append(score) + row.append(np.mean(this_scores)) + + grid.loc[len(grid)] = row + + print("Cell type results:") + print(grid) + + +def main(): + # input_folder = "/home/pape/Work/data/incu_cyte/livecell/images/livecell_test_images" + output_folder = "./predictions" + # run_prediction(input_folder, output_folder) + label_folder = "/home/pape/Work/data/incu_cyte/livecell/annotations/livecell_test_images" + evaluate(label_folder, output_folder) + + +if __name__ == "__main__": + main() diff --git a/experiments/livecell/train_boundaries.py b/experiments/livecell/train_boundaries.py index d57c072b..ab5d2427 100644 --- a/experiments/livecell/train_boundaries.py +++ b/experiments/livecell/train_boundaries.py @@ -12,16 +12,22 @@ def train_boundaries(args): patch_shape = (512, 512) train_loader = get_livecell_loader( args.input, patch_shape, "train", - download=True, boundaries=True, batch_size=args.batch_size + download=True, boundaries=True, batch_size=args.batch_size, + cell_types=None if args.cell_type is None else [args.cell_type] ) val_loader = get_livecell_loader( args.input, patch_shape, "val", - boundaries=True, batch_size=args.batch_size + boundaries=True, batch_size=args.batch_size, + cell_types=None if args.cell_type is None else [args.cell_type] ) loss = torch_em.loss.DiceLoss() + cell_type = args.cell_type + name = "livecell-boundary-model" + if cell_type is not None: + name = f"{name}-{cell_type}" trainer = torch_em.default_segmentation_trainer( - name="livecell-boundary-model", + name=name, model=model, train_loader=train_loader, val_loader=val_loader, @@ -30,14 +36,38 @@ def train_boundaries(args): learning_rate=1e-4, device=torch.device("cuda"), mixed_precision=True, - log_image_interval=50 + log_image_interval=50, ) trainer.fit(iterations=args.n_iterations) +def check_loader(args, train=True, val=True, n_images=5): + from torch_em.util.debug import check_loader + patch_shape = (512, 512) + if train: + print("Check train loader") + loader = get_livecell_loader( + args.input, patch_shape, "train", + download=True, boundaries=True, batch_size=1, + cell_types=None if args.cell_type is None else [args.cell_type] + ) + check_loader(loader, n_images) + if val: + print("Check val loader") + loader = get_livecell_loader( + args.input, patch_shape, "val", + download=True, boundaries=True, batch_size=1, + cell_types=None if args.cell_type is None else [args.cell_type] + ) + check_loader(loader, n_images) + + if __name__ == '__main__': - parser = torch_em.util.parser_helper( - default_batch_size=8 - ) + parser = torch_em.util.parser_helper(default_batch_size=8) + parser.add_argument("--cell_type", default=None) args = parser.parse_args() - train_boundaries(args) + + if args.check: + check_loader(args) + else: + train_boundaries(args) diff --git a/test/util/test_prediction.py b/test/util/test_prediction.py index afee71ff..69747148 100644 --- a/test/util/test_prediction.py +++ b/test/util/test_prediction.py @@ -82,7 +82,7 @@ def test_predict_with_padding_and_channels(self): shapes = [(3, 128, 128), (3, 133, 33), (3, 64, 49), (3, 27, 97)] for shape in shapes: input_ = np.random.rand(*shape).astype("float32") - out = predict_with_padding(model, input_, min_divisible=(1, 8, 8), device="cpu", with_channels=True) + out = predict_with_padding(model, input_, min_divisible=(8, 8), device="cpu", with_channels=True) self.assertEqual(out.shape[1:], shape) diff --git a/torch_em/data/datasets/livecell.py b/torch_em/data/datasets/livecell.py index ef8b1fe0..9249b6f6 100644 --- a/torch_em/data/datasets/livecell.py +++ b/torch_em/data/datasets/livecell.py @@ -62,11 +62,11 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo # get the path for the image data and make sure the corresponding image exists image_metadata = coco.loadImgs(image_id)[0] file_name = image_metadata["file_name"] - + # if cell_type names are given we only select file names that match a cell_type if cell_types is not None and (not any([cell_type in file_name for cell_type in cell_types])): continue - + sub_folder = file_name.split("_")[0] image_path = os.path.join(image_folder, sub_folder, file_name) # something changed in the image layout? we keep the old version around in case this chagnes back... @@ -97,7 +97,8 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo imageio.imwrite(seg_path, seg) assert len(image_paths) == len(seg_paths) - assert len(image_paths) > 0, f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?" + assert len(image_paths) > 0,\ + f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?" return image_paths, seg_paths @@ -107,9 +108,10 @@ def _download_livecell_annotations(path, split, download, cell_types, label_path split_name = "livecell_test_images" else: split_name = "livecell_train_val_images" - + image_folder = os.path.join(path, "images", split_name) - seg_folder = os.path.join(path, "annotations", split_name) if label_path is None else os.path.join(label_path, "annotations", split_name) + seg_folder = os.path.join(path, "annotations", split_name) if label_path is None\ + else os.path.join(label_path, "annotations", split_name) assert os.path.exists(image_folder), image_folder @@ -145,7 +147,7 @@ def _livecell_segmentation_loader( label_dtype=label_dtype, transform=transform, n_samples=n_samples) - + return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs) @@ -156,7 +158,7 @@ def get_livecell_loader(path, patch_shape, split, download=False, if cell_types is not None: assert isinstance(cell_types, (list, tuple)),\ f"cell_types must be passed as a list or tuple instead of {cell_types}" - + _download_livecell_images(path, download) image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path) @@ -181,4 +183,4 @@ def get_livecell_loader(path, patch_shape, split, download=False, label_dtype = torch.float32 kwargs.update({"patch_shape": patch_shape}) - return _livecell_segmentation_loader(image_paths, seg_paths, label_dtype=label_dtype, **kwargs) \ No newline at end of file + return _livecell_segmentation_loader(image_paths, seg_paths, label_dtype=label_dtype, **kwargs) diff --git a/torch_em/loss/dice.py b/torch_em/loss/dice.py index 3ec2346b..f7ccbe78 100644 --- a/torch_em/loss/dice.py +++ b/torch_em/loss/dice.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn @@ -98,7 +99,7 @@ def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7): def forward(self, input_, target): loss_dice = dice_score( - nn.functional.sigmoid(input_), + torch.sigmoid(input_), target, invert=True, channelwise=self.channelwise, diff --git a/torch_em/model/resnet3d.py b/torch_em/model/resnet3d.py index 1618755e..6f8feda8 100644 --- a/torch_em/model/resnet3d.py +++ b/torch_em/model/resnet3d.py @@ -7,9 +7,9 @@ import torch.nn as nn from torch import Tensor -from torchvision.models._api import WeightsEnum -from torchvision.models._utils import _ovewrite_named_param -from torchvision.utils import _log_api_usage_once +# from torchvision.models._api import WeightsEnum +# from torchvision.models._utils import _ovewrite_named_param +# from torchvision.utils import _log_api_usage_once __all__ = [ @@ -165,9 +165,10 @@ def __init__( width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, + stride_conv1: bool = True, ) -> None: super().__init__() - _log_api_usage_once(self) + # _log_api_usage_once(self) if norm_layer is None: norm_layer = nn.BatchNorm3d self._norm_layer = norm_layer @@ -188,7 +189,9 @@ def __init__( ) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv3d( + in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False + ) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) @@ -282,7 +285,7 @@ def forward(self, x: Tensor) -> Tensor: def _resnet( block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], - weights: Optional[WeightsEnum], + weights: Any, progress: bool, **kwargs: Any, ) -> ResNet3d: diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 143af453..42de2d7e 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -85,6 +85,37 @@ def __call__(self, labels): return target +# TODO smoothing +class BoundaryTransformWithIgnoreLabel: + def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None): + self.ignore_label = ignore_label + self.mode = mode + self.ndim = ndim + self.add_binary_target = add_binary_target + + def __call__(self, labels): + labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim) + # calculate the normal boundaries + boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None] + + # calculate the boundaries for the ignore label + labels_ignore = (labels == self.ignore_label) + to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None] + + # mask the to-background-boundaries + boundaries = boundaries.astype(np.int8) + boundaries[to_ignore_boundaries] = self.ignore_label + + if self.add_binary_target: + binary = labels_to_binary(labels).astype(boundaries.dtype) + binary[labels == self.ignore_label] = self.ignore_label + target = np.concatenate([binary[None], boundaries], axis=0) + else: + target = boundaries + + return target + + # TODO affinity smoothing class AffinityTransform: def __init__(self, offsets, diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index 3446eb95..e20344c1 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -44,10 +44,10 @@ def cast(inpt, typestring): def _normalize_torch(tensor, minval=None, maxval=None, axis=None, eps=1e-7): if axis: # torch returns torch.return_types.min or torch.return_types.max - minval = tensor.min(dim=axis, keepdim=True).values if minval is None else minval + minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval tensor -= minval - maxval = tensor.max(dim=axis, keepdim=True).values if maxval is None else maxval + maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval tensor /= (maxval + eps) return tensor @@ -83,7 +83,6 @@ def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-7): return normalize(raw, v_lower, v_upper, eps=eps) -# TODO # # intensity augmentations / noise augmentations # @@ -173,7 +172,10 @@ def __call__(self, img): kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1 # switch boundaries to make sure 0 is excluded from sampling sigma = np.random.uniform(self.sigma[1], self.sigma[0]) - return transforms.GaussianBlur(kernel_size, sigma=sigma)(img) + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img) + return out # @@ -205,13 +207,17 @@ def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2= # The default values are made for an image with pixel values in # range [0, 1]. That the image is in this range is ensured by an # initial normalizations step. -def get_default_mean_teacher_augmentations(p=0.3): - norm = normalize +def get_default_mean_teacher_augmentations( + p=0.3, norm=None, + blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None +): + if norm is None: + norm = normalize aug1 = transforms.Compose([ - normalize, - transforms.RandomApply([GaussianBlur()], p=p), - transforms.RandomApply([PoissonNoise()], p=p/2), - transforms.RandomApply([AdditiveGaussianNoise()], p=p/2), + norm, + transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p), + transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2), + transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2), ]) aug2 = transforms.RandomApply( [RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p diff --git a/torch_em/util/prediction.py b/torch_em/util/prediction.py index b66cb90c..b473af3c 100644 --- a/torch_em/util/prediction.py +++ b/torch_em/util/prediction.py @@ -117,6 +117,7 @@ def predict_with_halo( postprocess=None, with_channels=False, skip_block=None, + mask=None, disable_tqdm=False, tqdm_desc="predict with halo", prediction_function=None, @@ -137,6 +138,7 @@ def predict_with_halo( postprocess [callable] - function to postprocess the network predictions (default: None) with_channels [bool] - whether the input has a channel axis (default: False) skip_block [callable] - function to evaluate wheter a given input block should be skipped (default: None) + mask [arraylike] - elements outside the mask will be ignored in the prediction (default: None) disable_tqdm [bool] - flag that allows to disable tqdm output (e.g. if function is called multiple times) tqdm_desc [str] - description shown by the tqdm output prediction_function [callable] - A wrapper function for prediction to enable custom prediction procedures @@ -167,6 +169,14 @@ def predict_block(block_id): with torch.no_grad(): block = blocking.getBlock(block_id) offset = [beg for beg in block.begin] + inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) + + if mask is not None: + mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) + mask_block = mask_block[inner_bb] + if mask_block.sum() == 0: + return + inp, _ = _load_block(input_, offset, block_shape, halo, with_channels=with_channels) if skip_block is not None and skip_block(inp): @@ -190,11 +200,15 @@ def predict_block(block_id): if postprocess is not None: prediction = postprocess(prediction) - inner_bb = tuple(slice(ha, ha + bs) for ha, bs in zip(halo, block.shape)) if prediction.ndim == ndim + 1: inner_bb = (slice(None),) + inner_bb prediction = prediction[inner_bb] + if mask is not None: + if prediction.ndim == ndim + 1: + mask_block = np.concatenate(prediction.shape[0] * [mask_block[None]], axis=0) + prediction[~mask_block] = 0 + bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) if isinstance(output, list): # we have multiple outputs and split the prediction channels for out, channel_slice in output: