From aae57cee35a3123773b375e86931dd59019ac703 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:27:01 +0200 Subject: [PATCH] Add Toothfairy dataset (#313) * Add scripts to check toothfairy1 dataset * Finalize toothfairy dataset --- scripts/datasets/medical/check_toothfairy.py | 21 +++++ torch_em/data/datasets/medical/__init__.py | 1 + torch_em/data/datasets/medical/toothfairy.py | 88 ++++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 scripts/datasets/medical/check_toothfairy.py create mode 100644 torch_em/data/datasets/medical/toothfairy.py diff --git a/scripts/datasets/medical/check_toothfairy.py b/scripts/datasets/medical/check_toothfairy.py new file mode 100644 index 00000000..dc0fbcf3 --- /dev/null +++ b/scripts/datasets/medical/check_toothfairy.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_toothfairy_loader + + +ROOT = "/scratch/share/cidas/cca/data/toothfairy/" + + +def check_toothfairy(): + loader = get_toothfairy_loader( + path=ROOT, + patch_shape=(1, 512, 512), + ndim=2, + batch_size=2, + sampler=MinInstanceSampler() + ) + + check_loader(loader, 8, plt=True, save_path="./toothfairy.png") + + +check_toothfairy() diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 0af6b9ab..4871adb0 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -24,4 +24,5 @@ from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader from .sega import get_sega_dataset, get_sega_loader from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader +from .toothfairy import get_toothfairy_dataset, get_toothfairy_loader from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader diff --git a/torch_em/data/datasets/medical/toothfairy.py b/torch_em/data/datasets/medical/toothfairy.py new file mode 100644 index 00000000..5713f6b5 --- /dev/null +++ b/torch_em/data/datasets/medical/toothfairy.py @@ -0,0 +1,88 @@ +import os +from glob import glob +from tqdm import tqdm +from natsort import natsorted + +import numpy as np +import nibabel as nib + +import torch_em + +from .. import util + + +def get_toothfairy_data(path, download): + """Automatic download is not possible. + """ + if download: + raise NotImplementedError + + data_dir = os.path.join(path, "ToothFairy_Dataset", "Dataset") + return data_dir + + +def _get_toothfairy_paths(path, download): + data_dir = get_toothfairy_data(path, download) + + images_dir = os.path.join(path, "data", "images") + gt_dir = os.path.join(path, "data", "dense_labels") + if os.path.exists(images_dir) and os.path.exists(gt_dir): + return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz"))) + + os.makedirs(images_dir, exist_ok=True) + os.makedirs(gt_dir, exist_ok=True) + + image_paths, gt_paths = [], [] + for patient_dir in tqdm(glob(os.path.join(data_dir, "P*"))): + patient_id = os.path.split(patient_dir)[-1] + + dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy") + if not os.path.exists(dense_anns_path): + continue + + image_path = os.path.join(patient_dir, "data.npy") + + image = np.load(image_path) + gt = np.load(dense_anns_path) + + image_nifti = nib.Nifti2Image(image, np.eye(4)) + gt_nifti = nib.Nifti2Image(gt, np.eye(4)) + + trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz") + trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz") + + nib.save(image_nifti, trg_image_path) + nib.save(gt_nifti, trg_gt_path) + + image_paths.append(trg_image_path) + gt_paths.append(trg_gt_path) + + return image_paths, gt_paths + + +def get_toothfairy_dataset(path, patch_shape, download=False, **kwargs): + """Canal segmentation in CBCT + https://toothfairy.grand-challenge.org/ + """ + image_paths, gt_paths = _get_toothfairy_paths(path, download) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + is_seg_dataset=True, + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_toothfairy_loader(path, patch_shape, batch_size, download=False, **kwargs): + """ + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_toothfairy_dataset(path, patch_shape, download, **ds_kwargs) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader