Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test set for CTC dataset #216

Merged
merged 5 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions scripts/datasets/check_ctc.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from torch_em.data.datasets.ctc import get_ctc_segmentation_loader, CTC_URLS
from torch_em.data.datasets.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS
from torch_em.util.debug import check_loader
from torch_em.data.sampler import MinInstanceSampler

ROOT = "/scratch/projects/nim00007/sam/data/ctc/"
ROOT = "/home/anwai/data/ctc/"


# Some of the datasets have partial sparse labels:
# - Fluo-N2DH-GOWT1
# - Fluo-N2DL-HeLa
# Maybe depends on the split?!
def check_ctc_segmentation():
for name in CTC_URLS.keys():
def check_ctc_segmentation(split):
ctc_dataset_names = list(CTC_CHECKSUMS["train"].keys())
for name in ctc_dataset_names:
print("Checking dataset", name)
loader = get_ctc_segmentation_loader(
ROOT, name, (1, 512, 512), 1, download=True,
path=ROOT,
dataset_name=name,
patch_shape=(1, 512, 512),
batch_size=1,
download=True,
split=split,
sampler=MinInstanceSampler()
)
check_loader(loader, 8, plt=True, save_path="ctc.png")
check_loader(loader, 8, plt=True)


if __name__ == "__main__":
check_ctc_segmentation()
check_ctc_segmentation("train")
109 changes: 63 additions & 46 deletions torch_em/data/datasets/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,66 +6,78 @@
from . import util


CTC_URLS = {
"BF-C2DL-HSC": "http://data.celltrackingchallenge.net/training-datasets/BF-C2DL-HSC.zip",
"BF-C2DL-MuSC": "http://data.celltrackingchallenge.net/training-datasets/BF-C2DL-MuSC.zip",
"DIC-C2DH-HeLa": "http://data.celltrackingchallenge.net/training-datasets/DIC-C2DH-HeLa.zip",
"Fluo-C2DL-Huh7": "http://data.celltrackingchallenge.net/training-datasets/Fluo-C2DL-Huh7.zip",
"Fluo-C2DL-MSC": "http://data.celltrackingchallenge.net/training-datasets/Fluo-C2DL-MSC.zip",
"Fluo-N2DH-GOWT1": "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DH-GOWT1.zip",
"Fluo-N2DH-SIM+": "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DH-SIM+.zip",
"Fluo-N2DL-HeLa": "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip",
"PhC-C2DH-U373": "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DH-U373.zip",
"PhC-C2DL-PSC": "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip",
}
CTC_CHECKSUMS = {
"BF-C2DL-HSC": "0aa68ec37a9b06e72a5dfa07d809f56e1775157fb674bb75ff904936149657b1",
"BF-C2DL-MuSC": "ca72b59042809120578a198ba236e5ed3504dd6a122ef969428b7c64f0a5e67d",
"DIC-C2DH-HeLa": "832fed2d05bb7488cf9c51a2994b75f8f3f53b3c3098856211f2d39023c34e1a",
"Fluo-C2DL-Huh7": "1912658c1b3d8b38b314eb658b559e7b39c256917150e9b3dd8bfdc77347617d",
"Fluo-C2DL-MSC": "a083521f0cb673ae02d4957c5e6580c2e021943ef88101f6a2f61b944d671af2",
"Fluo-N2DH-GOWT1": "1a7bd9a7d1d10c4122c7782427b437246fb69cc3322a975485c04e206f64fc2c",
"Fluo-N2DH-SIM+": "3e809148c87ace80c72f563b56c35e0d9448dcdeb461a09c83f61e93f5e40ec8",
"Fluo-N2DL-HeLa": "35dd99d58e071aba0b03880128d920bd1c063783cc280f9531fbdc5be614c82e",
"PhC-C2DH-U373": "b18185c18fce54e8eeb93e4bbb9b201d757add9409bbf2283b8114185a11bc9e",
"PhC-C2DL-PSC": "9d54bb8febc8798934a21bf92e05d92f5e8557c87e28834b2832591cdda78422",

"train": {
"BF-C2DL-HSC": "0aa68ec37a9b06e72a5dfa07d809f56e1775157fb674bb75ff904936149657b1",
"BF-C2DL-MuSC": "ca72b59042809120578a198ba236e5ed3504dd6a122ef969428b7c64f0a5e67d",
"DIC-C2DH-HeLa": "832fed2d05bb7488cf9c51a2994b75f8f3f53b3c3098856211f2d39023c34e1a",
"Fluo-C2DL-Huh7": "1912658c1b3d8b38b314eb658b559e7b39c256917150e9b3dd8bfdc77347617d",
"Fluo-C2DL-MSC": "a083521f0cb673ae02d4957c5e6580c2e021943ef88101f6a2f61b944d671af2",
"Fluo-N2DH-GOWT1": "1a7bd9a7d1d10c4122c7782427b437246fb69cc3322a975485c04e206f64fc2c",
"Fluo-N2DH-SIM+": "3e809148c87ace80c72f563b56c35e0d9448dcdeb461a09c83f61e93f5e40ec8",
"Fluo-N2DL-HeLa": "35dd99d58e071aba0b03880128d920bd1c063783cc280f9531fbdc5be614c82e",
"PhC-C2DH-U373": "b18185c18fce54e8eeb93e4bbb9b201d757add9409bbf2283b8114185a11bc9e",
"PhC-C2DL-PSC": "9d54bb8febc8798934a21bf92e05d92f5e8557c87e28834b2832591cdda78422",
},
"test": {
"BF-C2DL-HSC": "fd1c05ec625fd0526c8369d1139babe137e885457eee98c10d957da578d0d5bc",
"BF-C2DL-MuSC": "c5cae259e6090e82a2596967fb54c8a768717c1772398f8546ad1c8df0820450",
"DIC-C2DH-HeLa": "5e5d5f2aa90aef99d750cf03f5c12d799d50b892f98c86950e07a2c5955ac01f",
"Fluo-C2DL-Huh7": "cc7359f8fb6b0c43995365e83ce0116d32f477ac644b2ca02b98bc253e2bcbbe",
"Fluo-C2DL-MSC": "c90b13e603dde52f17801d4f0cadde04ed7f21cc05296b1f0957d92dbfc8ffa6",
"Fluo-N2DH-GOWT1": "c6893ec2d63459de49d4dc21009b04275573403c62cc02e6ee8d0cb1a5068add",
"Fluo-N2DH-SIM+": "c4f257add739b284d02176057814de345dee2ac1a7438e360ccd2df73618db68",
"Fluo-N2DL-HeLa": "45cf3daf05e8495aa2ce0febacca4cf0928fab808c0b14ed2eb7289a819e6bb8",
"PhC-C2DH-U373": "7aa3162e4363a416b259149adc13c9b09cb8aecfe8165eb1428dd534b66bec8a",
"PhC-C2DL-PSC": "8c98ac6203e7490157ceb6aa1131d60a3863001b61fb75e784bc49d47ee264d5",
}
}


def _require_ctc_dataset(path, dataset_name, download):
dataset_names = list(CTC_URLS.keys())
def get_ctc_url_and_checksum(dataset_name, split):
if split == "train":
_link_to_split = "training-datasets"
else:
_link_to_split = "test-datasets"

url = f"http://data.celltrackingchallenge.net/{_link_to_split}/{dataset_name}.zip"
checksum = CTC_CHECKSUMS[split][dataset_name]
return url, checksum


def _require_ctc_dataset(path, dataset_name, download, split):
dataset_names = list(CTC_CHECKSUMS["train"].keys())
if dataset_name not in dataset_names:
raise ValueError(f"Inalid dataset: {dataset_name}, choose one of {dataset_names}.")

data_path = os.path.join(path, dataset_name)
data_path = os.path.join(path, split, dataset_name)

if os.path.exists(data_path):
return data_path

os.makedirs(data_path)
url, checksum = CTC_URLS[dataset_name], CTC_CHECKSUMS[dataset_name]
url, checksum = get_ctc_url_and_checksum(dataset_name, split)
zip_path = os.path.join(path, f"{dataset_name}.zip")
util.download_source(zip_path, url, download, checksum=checksum)
util.unzip(zip_path, path, remove=True)
util.unzip(zip_path, os.path.join(path, split), remove=True)

return data_path


def _require_gt_images(data_path, splits):
def _require_gt_images(data_path, vol_ids):
image_paths, label_paths = [], []

if isinstance(splits, str):
splits = [splits]
if isinstance(vol_ids, str):
vol_ids = [vol_ids]

for split in splits:
image_folder = os.path.join(data_path, split)
assert os.path.join(image_folder), f"Cannot find split, {split} in {data_path}."
for vol_id in vol_ids:
image_folder = os.path.join(data_path, vol_id)
assert os.path.join(image_folder), f"Cannot find volume id, {vol_id} in {data_path}."

label_folder = os.path.join(data_path, f"{split}_GT", "SEG")
label_folder = os.path.join(data_path, f"{vol_id}_GT", "SEG")

# copy over the images corresponding to the labeled frames
label_image_folder = os.path.join(data_path, f"{split}_GT", "IM")
label_image_folder = os.path.join(data_path, f"{vol_id}_GT", "IM")
os.makedirs(label_image_folder, exist_ok=True)

this_label_paths = glob(os.path.join(label_folder, "*.tif"))
Expand All @@ -88,7 +100,8 @@ def get_ctc_segmentation_dataset(
path,
dataset_name,
patch_shape,
split=None,
split,
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
vol_id=None,
download=False,
**kwargs,
):
Expand All @@ -98,16 +111,18 @@ def get_ctc_segmentation_dataset(
cell tracking challenge. If you use this data in your research please cite
https://doi.org/10.1038/nmeth.4473
"""
data_path = _require_ctc_dataset(path, dataset_name, download)
assert split in ["train"]

if split is None:
splits = glob(os.path.join(data_path, "*_GT"))
splits = [os.path.basename(split) for split in splits]
splits = [split.rstrip("_GT") for split in splits]
data_path = _require_ctc_dataset(path, dataset_name, download, split)

if vol_id is None:
vol_ids = glob(os.path.join(data_path, "*_GT"))
vol_ids = [os.path.basename(vol_id) for vol_id in vol_ids]
vol_ids = [vol_id.rstrip("_GT") for vol_id in vol_ids]
else:
splits = split
vol_ids = vol_id

image_path, label_path = _require_gt_images(data_path, splits)
image_path, label_path = _require_gt_images(data_path, vol_ids)

kwargs = util.update_kwargs(kwargs, "ndim", 2)
return torch_em.default_segmentation_dataset(
Expand All @@ -120,7 +135,8 @@ def get_ctc_segmentation_loader(
dataset_name,
patch_shape,
batch_size,
split=None,
split,
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
vol_id=None,
download=False,
**kwargs,
):
Expand All @@ -131,7 +147,8 @@ def get_ctc_segmentation_loader(
torch_em.default_segmentation_dataset, **kwargs
)
dataset = get_ctc_segmentation_dataset(
path, dataset_name, patch_shape, split=split, download=download, **ds_kwargs,
path, dataset_name, patch_shape, split=split, vol_id=vol_id, download=download, **ds_kwargs,
)

loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
return loader