Skip to content

Commit

Permalink
Add Toothfairy dataset (#313)
Browse files Browse the repository at this point in the history
* Add scripts to check toothfairy1 dataset

* Finalize toothfairy dataset
  • Loading branch information
anwai98 committed Jun 28, 2024
1 parent 499bfd2 commit aae57ce
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
21 changes: 21 additions & 0 deletions scripts/datasets/medical/check_toothfairy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.data.datasets.medical import get_toothfairy_loader


ROOT = "/scratch/share/cidas/cca/data/toothfairy/"


def check_toothfairy():
loader = get_toothfairy_loader(
path=ROOT,
patch_shape=(1, 512, 512),
ndim=2,
batch_size=2,
sampler=MinInstanceSampler()
)

check_loader(loader, 8, plt=True, save_path="./toothfairy.png")


check_toothfairy()
1 change: 1 addition & 0 deletions torch_em/data/datasets/medical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader
from .sega import get_sega_dataset, get_sega_loader
from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader
from .toothfairy import get_toothfairy_dataset, get_toothfairy_loader
from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader
88 changes: 88 additions & 0 deletions torch_em/data/datasets/medical/toothfairy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
from glob import glob
from tqdm import tqdm
from natsort import natsorted

import numpy as np
import nibabel as nib

import torch_em

from .. import util


def get_toothfairy_data(path, download):
"""Automatic download is not possible.
"""
if download:
raise NotImplementedError

data_dir = os.path.join(path, "ToothFairy_Dataset", "Dataset")
return data_dir


def _get_toothfairy_paths(path, download):
data_dir = get_toothfairy_data(path, download)

images_dir = os.path.join(path, "data", "images")
gt_dir = os.path.join(path, "data", "dense_labels")
if os.path.exists(images_dir) and os.path.exists(gt_dir):
return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz")))

os.makedirs(images_dir, exist_ok=True)
os.makedirs(gt_dir, exist_ok=True)

image_paths, gt_paths = [], []
for patient_dir in tqdm(glob(os.path.join(data_dir, "P*"))):
patient_id = os.path.split(patient_dir)[-1]

dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy")
if not os.path.exists(dense_anns_path):
continue

image_path = os.path.join(patient_dir, "data.npy")

image = np.load(image_path)
gt = np.load(dense_anns_path)

image_nifti = nib.Nifti2Image(image, np.eye(4))
gt_nifti = nib.Nifti2Image(gt, np.eye(4))

trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz")
trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz")

nib.save(image_nifti, trg_image_path)
nib.save(gt_nifti, trg_gt_path)

image_paths.append(trg_image_path)
gt_paths.append(trg_gt_path)

return image_paths, gt_paths


def get_toothfairy_dataset(path, patch_shape, download=False, **kwargs):
"""Canal segmentation in CBCT
https://toothfairy.grand-challenge.org/
"""
image_paths, gt_paths = _get_toothfairy_paths(path, download)

dataset = torch_em.default_segmentation_dataset(
raw_paths=image_paths,
raw_key="data",
label_paths=gt_paths,
label_key="data",
is_seg_dataset=True,
patch_shape=patch_shape,
**kwargs
)

return dataset


def get_toothfairy_loader(path, patch_shape, batch_size, download=False, **kwargs):
"""
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_toothfairy_dataset(path, patch_shape, download, **ds_kwargs)
loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
return loader

0 comments on commit aae57ce

Please sign in to comment.