Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make livecell label extraction more robust against duplicate per obje… #127

Merged
merged 1 commit into from
May 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions torch_em/data/datasets/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
import imageio
import numpy as np
import requests
import vigra
from tqdm import tqdm

import torch_em
import torch.utils.data
from .util import download_source, unzip, update_kwargs

try:
from pycocotools.coco import COCO
except ImportError:
COCO = None

URLS = {
"images": "http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/images.zip",
"train": ("http://livecell-dataset.s3.eu-central-1.amazonaws.com/LIVECell_dataset_2021/annotations/"
Expand Down Expand Up @@ -49,9 +55,39 @@ def _download_annotation_file(path, split, download):
return annotation_file


def _annotations_to_instances(coco, image_metadata, category_ids):
# create and save the segmentation
annotation_ids = coco.getAnnIds(imgIds=image_metadata["id"], catIds=category_ids)
annotations = coco.loadAnns(annotation_ids)
assert len(annotations) <= np.iinfo("uint16").max
shape = (image_metadata["height"], image_metadata["width"])
seg = np.zeros(shape, dtype="uint32")

# sort annotations by size, except for iscrowd which go first
# we do this to minimize small noise from overlapping multi annotations
# (see below)
sizes = [ann["area"] if ann["iscrowd"] == 0 else 1 for ann in annotations]
sorting = np.argsort(sizes)
annotations = [annotations[i] for i in sorting]

for seg_id, annotation in enumerate(annotations, 1):
mask = coco.annToMask(annotation).astype("bool")
assert mask.shape == seg.shape
seg[mask] = seg_id

# some images have multiple masks per object with slightly different foreground
# this causes small noise objects we need to filter
min_size = 50
seg_ids, sizes = np.unique(seg, return_counts=True)
seg[np.isin(seg, seg_ids[sizes < min_size])] = 0

vigra.analysis.relabelConsecutive(seg, out=seg)

return seg.astype("uint16")


def _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types):
# TODO try except to explain how to install this
from pycocotools.coco import COCO
assert COCO is not None, "pycocotools is required for processing the LiveCELL ground-truth."

coco = COCO(annotation_file)
category_ids = coco.getCatIds(catNms=["cell"])
Expand All @@ -69,7 +105,7 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo

sub_folder = file_name.split("_")[0]
image_path = os.path.join(image_folder, sub_folder, file_name)
# something changed in the image layout? we keep the old version around in case this chagnes back...
# something changed in the image layout? we keep the old version around in case this changes back...
if not os.path.exists(image_path):
image_path = os.path.join(image_folder, file_name)
assert os.path.exists(image_path), image_path
Expand All @@ -83,17 +119,7 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo
if os.path.exists(seg_path):
continue

# create and save the segmentation
annotation_ids = coco.getAnnIds(imgIds=image_metadata["id"], catIds=category_ids)
annotations = coco.loadAnns(annotation_ids)
assert len(annotations) <= np.iinfo("uint16").max
shape = (image_metadata["height"], image_metadata["width"])
seg = np.zeros(shape, dtype="uint16")
for seg_id, annotation in enumerate(annotations, 1):
mask = coco.annToMask(annotation).astype("bool")
assert mask.shape == seg.shape
seg[mask] = seg_id

seg = _annotations_to_instances(coco, image_metadata, category_ids)
imageio.imwrite(seg_path, seg)

assert len(image_paths) == len(seg_paths)
Expand Down
Loading