Skip to content

Commit

Permalink
Merge pull request #189 from anwai98/u-deepbacs
Browse files Browse the repository at this point in the history
Update DeepBacs Loader
  • Loading branch information
constantinpape committed Dec 29, 2023
2 parents 020f3fb + a14a3d1 commit a2e530f
Showing 1 changed file with 53 additions and 7 deletions.
60 changes: 53 additions & 7 deletions torch_em/data/datasets/deepbacs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import shutil
import numpy as np
from glob import glob

import torch_em
from . import util
Expand All @@ -25,18 +28,61 @@ def _require_deebacs_dataset(path, bac_type, download):
util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type])
util.unzip(zip_path, os.path.join(path, bac_type))

# let's get a val split for the expected bacteria type
_assort_val_set(path, bac_type)


def _assort_val_set(path, bac_type):
image_paths = glob(os.path.join(path, bac_type, "training", "source", "*"))
image_paths = [os.path.split(_path)[-1] for _path in image_paths]

val_partition = 0.2
# let's get a balanced set of bacterias, if bac_type is mixed
if bac_type == "mixed":
_sort_1, _sort_2, _sort_3 = [], [], []
for _path in image_paths:
if _path.startswith("JE2"):
_sort_1.append(_path)
elif _path.startswith("pos"):
_sort_2.append(_path)
elif _path.startswith("train_"):
_sort_3.append(_path)

val_image_paths = [
*np.random.choice(_sort_1, size=int(val_partition * len(_sort_1)), replace=False),
*np.random.choice(_sort_2, size=int(val_partition * len(_sort_2)), replace=False),
*np.random.choice(_sort_3, size=int(val_partition * len(_sort_3)), replace=False)
]
else:
val_image_paths = np.random.choice(image_paths, size=int(val_partition * len(image_paths)), replace=False)

val_image_dir = os.path.join(path, bac_type, "val", "source")
val_label_dir = os.path.join(path, bac_type, "val", "target")
os.makedirs(val_image_dir, exist_ok=True)
os.makedirs(val_label_dir, exist_ok=True)

for sample_id in val_image_paths:
src_val_image_path = os.path.join(path, bac_type, "training", "source", sample_id)
dst_val_image_path = os.path.join(val_image_dir, sample_id)
shutil.move(src_val_image_path, dst_val_image_path)

src_val_label_path = os.path.join(path, bac_type, "training", "target", sample_id)
dst_val_label_path = os.path.join(val_label_dir, sample_id)
shutil.move(src_val_label_path, dst_val_label_path)


def _get_paths(path, bac_type, split):
# the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet
# mixed is the combination of all other types
if split == "train":
dir_choice = "training"
else:
dir_choice = split

if bac_type != "mixed":
raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}")
image_folder = os.path.join(
path, bac_type, "training" if split == "train" else "test", "source"
)
label_folder = os.path.join(
path, bac_type, "training" if split == "train" else "test", "target"
)
image_folder = os.path.join(path, bac_type, dir_choice, "source")
label_folder = os.path.join(path, bac_type, dir_choice, "target")
return image_folder, label_folder


Expand All @@ -48,7 +94,7 @@ def get_deepbacs_dataset(
This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z.
Please cite it if you use this dataset for a publication.
"""
assert split in ("train", "test")
assert split in ("train", "val", "test")
bac_types = list(URLS.keys())
assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}"

Expand Down

0 comments on commit a2e530f

Please sign in to comment.