Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc changes WIP #110

Merged
merged 11 commits into from
Mar 11, 2023
1 change: 0 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
---
name: test

on: [push, pull_request]
Expand Down
66 changes: 66 additions & 0 deletions experiments/livecell/check_cell_type_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
from glob import glob

import numpy as np
import pandas as pd
try:
import imageio.v2 as imageio
except ImportError:
import imageio

from tqdm import tqdm
from xarray import DataArray
from elf.evaluation import dice_score


def run_prediction(input_folder, output_folder):
import bioimageio.core
os.makedirs(output_folder, exist_ok=True)

inputs = glob(os.path.join(input_folder, "*.tif"))
model = bioimageio.core.load_resource_description("10.5281/zenodo.5869899")

with bioimageio.core.create_prediction_pipeline(model) as pp:
for inp in tqdm(inputs):
fname = os.path.basename(inp)
out_path = os.path.join(output_folder, fname)
image = imageio.v2.imread(inp)
input_ = DataArray(image[None, None], dims=tuple("bcyx"))
pred = bioimageio.core.predict_with_padding(pp, input_)[0].values.squeeze()
imageio.volwrite(out_path, pred)


def evaluate(label_folder, output_folder):
cell_types = ["A172", "BT474", "BV2", "Huh7",
"MCF7", "SHSY5Y", "SkBr3", "SKOV3"]
grid = pd.DataFrame(columns=["Cell_types"] + cell_types)
row = ["all"]
for i in cell_types:
label_files = glob(os.path.join(label_folder, i, "*.tif"))
this_scores = []
for label_file in label_files:
fname = os.path.basename(label_file)
pred_file = os.path.join(output_folder, fname)
label = imageio.imread(label_file)
pred = imageio.volread(pred_file)[0]
score = dice_score(pred, label != 0, threshold_gt=None, threshold_seg=None)

this_scores.append(score)
row.append(np.mean(this_scores))

grid.loc[len(grid)] = row

print("Cell type results:")
print(grid)


def main():
# input_folder = "/home/pape/Work/data/incu_cyte/livecell/images/livecell_test_images"
output_folder = "./predictions"
# run_prediction(input_folder, output_folder)
label_folder = "/home/pape/Work/data/incu_cyte/livecell/annotations/livecell_test_images"
evaluate(label_folder, output_folder)


if __name__ == "__main__":
main()
46 changes: 38 additions & 8 deletions experiments/livecell/train_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@ def train_boundaries(args):
patch_shape = (512, 512)
train_loader = get_livecell_loader(
args.input, patch_shape, "train",
download=True, boundaries=True, batch_size=args.batch_size
download=True, boundaries=True, batch_size=args.batch_size,
cell_types=None if args.cell_type is None else [args.cell_type]
)
val_loader = get_livecell_loader(
args.input, patch_shape, "val",
boundaries=True, batch_size=args.batch_size
boundaries=True, batch_size=args.batch_size,
cell_types=None if args.cell_type is None else [args.cell_type]
)
loss = torch_em.loss.DiceLoss()

cell_type = args.cell_type
name = "livecell-boundary-model"
if cell_type is not None:
name = f"{name}-{cell_type}"
trainer = torch_em.default_segmentation_trainer(
name="livecell-boundary-model",
name=name,
model=model,
train_loader=train_loader,
val_loader=val_loader,
Expand All @@ -30,14 +36,38 @@ def train_boundaries(args):
learning_rate=1e-4,
device=torch.device("cuda"),
mixed_precision=True,
log_image_interval=50
log_image_interval=50,
)
trainer.fit(iterations=args.n_iterations)


def check_loader(args, train=True, val=True, n_images=5):
from torch_em.util.debug import check_loader
patch_shape = (512, 512)
if train:
print("Check train loader")
loader = get_livecell_loader(
args.input, patch_shape, "train",
download=True, boundaries=True, batch_size=1,
cell_types=None if args.cell_type is None else [args.cell_type]
)
check_loader(loader, n_images)
if val:
print("Check val loader")
loader = get_livecell_loader(
args.input, patch_shape, "val",
download=True, boundaries=True, batch_size=1,
cell_types=None if args.cell_type is None else [args.cell_type]
)
check_loader(loader, n_images)


if __name__ == '__main__':
parser = torch_em.util.parser_helper(
default_batch_size=8
)
parser = torch_em.util.parser_helper(default_batch_size=8)
parser.add_argument("--cell_type", default=None)
args = parser.parse_args()
train_boundaries(args)

if args.check:
check_loader(args)
else:
train_boundaries(args)
2 changes: 1 addition & 1 deletion test/util/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_predict_with_padding_and_channels(self):
shapes = [(3, 128, 128), (3, 133, 33), (3, 64, 49), (3, 27, 97)]
for shape in shapes:
input_ = np.random.rand(*shape).astype("float32")
out = predict_with_padding(model, input_, min_divisible=(1, 8, 8), device="cpu", with_channels=True)
out = predict_with_padding(model, input_, min_divisible=(8, 8), device="cpu", with_channels=True)
self.assertEqual(out.shape[1:], shape)


Expand Down
18 changes: 10 additions & 8 deletions torch_em/data/datasets/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo
# get the path for the image data and make sure the corresponding image exists
image_metadata = coco.loadImgs(image_id)[0]
file_name = image_metadata["file_name"]

# if cell_type names are given we only select file names that match a cell_type
if cell_types is not None and (not any([cell_type in file_name for cell_type in cell_types])):
continue

sub_folder = file_name.split("_")[0]
image_path = os.path.join(image_folder, sub_folder, file_name)
# something changed in the image layout? we keep the old version around in case this chagnes back...
Expand Down Expand Up @@ -97,7 +97,8 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo
imageio.imwrite(seg_path, seg)

assert len(image_paths) == len(seg_paths)
assert len(image_paths) > 0, f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?"
assert len(image_paths) > 0,\
f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?"
return image_paths, seg_paths


Expand All @@ -107,9 +108,10 @@ def _download_livecell_annotations(path, split, download, cell_types, label_path
split_name = "livecell_test_images"
else:
split_name = "livecell_train_val_images"

image_folder = os.path.join(path, "images", split_name)
seg_folder = os.path.join(path, "annotations", split_name) if label_path is None else os.path.join(label_path, "annotations", split_name)
seg_folder = os.path.join(path, "annotations", split_name) if label_path is None\
else os.path.join(label_path, "annotations", split_name)

assert os.path.exists(image_folder), image_folder

Expand Down Expand Up @@ -145,7 +147,7 @@ def _livecell_segmentation_loader(
label_dtype=label_dtype,
transform=transform,
n_samples=n_samples)

return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)


Expand All @@ -156,7 +158,7 @@ def get_livecell_loader(path, patch_shape, split, download=False,
if cell_types is not None:
assert isinstance(cell_types, (list, tuple)),\
f"cell_types must be passed as a list or tuple instead of {cell_types}"

_download_livecell_images(path, download)
image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path)

Expand All @@ -181,4 +183,4 @@ def get_livecell_loader(path, patch_shape, split, download=False,
label_dtype = torch.float32

kwargs.update({"patch_shape": patch_shape})
return _livecell_segmentation_loader(image_paths, seg_paths, label_dtype=label_dtype, **kwargs)
return _livecell_segmentation_loader(image_paths, seg_paths, label_dtype=label_dtype, **kwargs)
3 changes: 2 additions & 1 deletion torch_em/loss/dice.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn


Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, alpha=1., beta=1., channelwise=True, eps=1e-7):

def forward(self, input_, target):
loss_dice = dice_score(
nn.functional.sigmoid(input_),
torch.sigmoid(input_),
target,
invert=True,
channelwise=self.channelwise,
Expand Down
15 changes: 9 additions & 6 deletions torch_em/model/resnet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch.nn as nn
from torch import Tensor

from torchvision.models._api import WeightsEnum
from torchvision.models._utils import _ovewrite_named_param
from torchvision.utils import _log_api_usage_once
# from torchvision.models._api import WeightsEnum
# from torchvision.models._utils import _ovewrite_named_param
# from torchvision.utils import _log_api_usage_once


__all__ = [
Expand Down Expand Up @@ -165,9 +165,10 @@ def __init__(
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
stride_conv1: bool = True,
) -> None:
super().__init__()
_log_api_usage_once(self)
# _log_api_usage_once(self)
if norm_layer is None:
norm_layer = nn.BatchNorm3d
self._norm_layer = norm_layer
Expand All @@ -188,7 +189,9 @@ def __init__(
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv3d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv3d(
in_channels, self.inplanes, kernel_size=7, stride=2 if stride_conv1 else 1, padding=3, bias=False
)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
Expand Down Expand Up @@ -282,7 +285,7 @@ def forward(self, x: Tensor) -> Tensor:
def _resnet(
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
weights: Optional[WeightsEnum],
weights: Any,
progress: bool,
**kwargs: Any,
) -> ResNet3d:
Expand Down
31 changes: 31 additions & 0 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,37 @@ def __call__(self, labels):
return target


# TODO smoothing
class BoundaryTransformWithIgnoreLabel:
def __init__(self, ignore_label=-1, mode="thick", add_binary_target=False, ndim=None):
self.ignore_label = ignore_label
self.mode = mode
self.ndim = ndim
self.add_binary_target = add_binary_target

def __call__(self, labels):
labels = ensure_array(labels) if self.ndim is None else ensure_spatial_array(labels, self.ndim)
# calculate the normal boundaries
boundaries = skimage.segmentation.find_boundaries(labels, mode=self.mode)[None]

# calculate the boundaries for the ignore label
labels_ignore = (labels == self.ignore_label)
to_ignore_boundaries = skimage.segmentation.find_boundaries(labels_ignore, mode=self.mode)[None]

# mask the to-background-boundaries
boundaries = boundaries.astype(np.int8)
boundaries[to_ignore_boundaries] = self.ignore_label

if self.add_binary_target:
binary = labels_to_binary(labels).astype(boundaries.dtype)
binary[labels == self.ignore_label] = self.ignore_label
target = np.concatenate([binary[None], boundaries], axis=0)
else:
target = boundaries

return target


# TODO affinity smoothing
class AffinityTransform:
def __init__(self, offsets,
Expand Down
26 changes: 16 additions & 10 deletions torch_em/transform/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def cast(inpt, typestring):

def _normalize_torch(tensor, minval=None, maxval=None, axis=None, eps=1e-7):
if axis: # torch returns torch.return_types.min or torch.return_types.max
minval = tensor.min(dim=axis, keepdim=True).values if minval is None else minval
minval = torch.amin(tensor, dim=axis, keepdim=True) if minval is None else minval
tensor -= minval

maxval = tensor.max(dim=axis, keepdim=True).values if maxval is None else maxval
maxval = torch.amax(tensor, dim=axis, keepdim=True) if maxval is None else maxval
tensor /= (maxval + eps)

return tensor
Expand Down Expand Up @@ -83,7 +83,6 @@ def normalize_percentile(raw, lower=1.0, upper=99.0, axis=None, eps=1e-7):
return normalize(raw, v_lower, v_upper, eps=eps)


# TODO
#
# intensity augmentations / noise augmentations
#
Expand Down Expand Up @@ -173,7 +172,10 @@ def __call__(self, img):
kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1
# switch boundaries to make sure 0 is excluded from sampling
sigma = np.random.uniform(self.sigma[1], self.sigma[0])
return transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
return out


#
Expand Down Expand Up @@ -205,13 +207,17 @@ def get_raw_transform(normalizer=standardize, augmentation1=None, augmentation2=
# The default values are made for an image with pixel values in
# range [0, 1]. That the image is in this range is ensured by an
# initial normalizations step.
def get_default_mean_teacher_augmentations(p=0.3):
norm = normalize
def get_default_mean_teacher_augmentations(
p=0.3, norm=None,
blur_kwargs=None, poisson_kwargs=None, gaussian_kwargs=None
):
if norm is None:
norm = normalize
aug1 = transforms.Compose([
normalize,
transforms.RandomApply([GaussianBlur()], p=p),
transforms.RandomApply([PoissonNoise()], p=p/2),
transforms.RandomApply([AdditiveGaussianNoise()], p=p/2),
norm,
transforms.RandomApply([GaussianBlur(**({} if blur_kwargs is None else blur_kwargs))], p=p),
transforms.RandomApply([PoissonNoise(**({} if poisson_kwargs is None else poisson_kwargs))], p=p/2),
transforms.RandomApply([AdditiveGaussianNoise(**({} if gaussian_kwargs is None else gaussian_kwargs))], p=p/2),
])
aug2 = transforms.RandomApply(
[RandomContrast(clip_kwargs={"a_min": 0, "a_max": 1})], p=p
Expand Down
Loading