Skip to content

Commit

Permalink
Merge 34cd11f into e9739b7
Browse files Browse the repository at this point in the history
  • Loading branch information
uzaymacar committed Jan 15, 2023
2 parents e9739b7 + 34cd11f commit 42c6741
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 30 deletions.
25 changes: 17 additions & 8 deletions docs/source/configuration_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,11 @@ See details in both ``train_validation`` and ``test`` for the contrasts that are
{
"$schema": "http://json-schema.org/draft-04/schema#",
"title": "slice_filter_params",
"description": "Discard a slice from the dataset if it meets a condition, see below.",
"$$description": [
"Discard a slice from the dataset if it meets a condition, defined below.\n",
"A slice is an entire 2D image taken from a 3D volume (e.g. an image of size 128x128 taken from a volume of size 128x128x16).\n",
"Therefore, the parameter ``slice_filter_params`` is applicable for 2D models only.",
],
"type": "dict",
"options": {
"filter_empty_input": {
Expand Down Expand Up @@ -611,25 +615,30 @@ See details in both ``train_validation`` and ``test`` for the contrasts that are
"$schema": "http://json-schema.org/draft-04/schema#",
"title": "patch_filter_params",
"$$description": [
"Discard a 2D patch from the dataset if it meets a condition at training time, defined below.\n",
"Contrary to the field ``slice_filter_params`` which applies at training and testing time, ",
"this parameter only applies during training time."
"Discard a 2D or 3D patch from the dataset if it meets a condition at training time, defined below.\n",
"A 2D patch is a portion of a 2D image (e.g. a patch of size 32x32 taken inside an image of size 128x128).\n",
"A 3D patch is a portion of a 3D volume (e.g. a patch of size 32x32x16 from a volume of size 128x128x16).\n",
"Therefore, the parameter ``patch_filter_params`` is applicable for 2D or 3D models.\n",
"In addition, contrary to ``slice_filter_params`` which applies at training and testing time, ``patch_filter_params``\n",
"is applied only at training time. This is because the reconstruction algorithm for predictions from patches\n",
"needs to have the predictions for all patches at testing time."
],
"type": "dict",
"options": {
"filter_empty_input": {
"type": "boolean",
"description": "Discard 2D patches where all voxel intensities are zeros. Default: ``False``."
"description": "Discard 2D or 3D patches where all voxel intensities are zeros. Default: ``False``."
},
"filter_empty_mask": {
"type": "boolean",
"description": "Discard 2D patches where all voxel labels are zeros. Default: ``False``."
"description": "Discard 2D or 3D patches where all voxel labels are zeros. Default: ``False``."
},
"filter_absent_class": {
"type": "boolean",
"$$description": [
"Discard 2D patches where all voxel labels are zero for one or more classes\n",
"(this is most relevant for multi-class models that need GT for all classes). Default: ``False``."
"Discard 2D or 3D patches where all voxel labels are zero for one or more classes\n",
"(this is most relevant for multi-class models that need GT for all classes).\n",
"Default: ``False``."
]
}
}
Expand Down
10 changes: 10 additions & 0 deletions ivadomed/loader/bids3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if typing.TYPE_CHECKING:
from typing import List, Optional
from ivadomed.loader.bids_dataframe import BidsDataframe
from ivadomed.loader.patch_filter import PatchFilter


class Bids3DDataset(MRI3DSubVolumeSegmentationDataset):
Expand All @@ -27,9 +28,15 @@ class Bids3DDataset(MRI3DSubVolumeSegmentationDataset):
to apply during training (Compose).
metadata_choice: Choice between "mri_params", "contrasts", None or False, related to FiLM.
roi_params (dict): Dictionary containing parameters related to ROI image processing.
subvolume_filter_fn (PatchFilter): Class that filters subvolumes according to their content.
multichannel (bool): If True, the input contrasts are combined as input channels for the model. Otherwise, each
contrast is processed individually (ie different sample / tensor).
subvolume_filter_fn (PatchFilter): Class that filters subvolumes according to their content.
object_detection_params (dict): Object dection parameters.
task (str): Choice between segmentation or classification. If classification: GT is discrete values, \
If segmentation: GT is binary mask.
soft_gt (bool): If True, ground truths are not binarized before being fed to the network. Otherwise, ground
truths are thresholded (0.5) after the data augmentation operations.
is_input_dropout (bool): Return input with missing modalities.
"""

Expand All @@ -44,6 +51,7 @@ def __init__(self,
transform: List[Optional[Compose]] = None,
metadata_choice: str | bool = False,
roi_params: dict = None,
subvolume_filter_fn: PatchFilter = None,
multichannel: bool = False,
object_detection_params: dict = None,
task: str = "segmentation",
Expand All @@ -56,6 +64,7 @@ def __init__(self,
roi_params=roi_params,
contrast_params=contrast_params,
model_params=model_params,
patch_filter_fn=subvolume_filter_fn,
metadata_choice=metadata_choice,
slice_axis=slice_axis,
transform=transform,
Expand All @@ -68,6 +77,7 @@ def __init__(self,
stride=model_params[ModelParamsKW.STRIDE_3D],
transform=transform,
slice_axis=slice_axis,
subvolume_filter_fn=subvolume_filter_fn,
task=task,
soft_gt=soft_gt,
is_input_dropout=is_input_dropout)
2 changes: 1 addition & 1 deletion ivadomed/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def load_dataset(bids_df: BidsDataframe,
Note: For more details on the parameters transform_params, target_suffix, roi_params, contrast_params,
slice_filter_params, patch_filter_params and object_detection_params see :doc:`configuration_file`.
"""

# Compose transforms
tranform_lst, _ = imed_transforms.prepare_transforms(copy.deepcopy(transforms_params), requires_undo)

Expand All @@ -86,6 +85,7 @@ def load_dataset(bids_df: BidsDataframe,
slice_axis=imed_utils.AXIS_DCT[slice_axis],
transform=tranform_lst,
multichannel=multichannel,
subvolume_filter_fn=PatchFilter(**patch_filter_params, is_train=False if dataset_type == "testing" else True),
model_params=model_params,
object_detection_params=object_detection_params,
soft_gt=soft_gt,
Expand Down
41 changes: 31 additions & 10 deletions ivadomed/loader/mri3d_subvolume_segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ivadomed.loader.utils import dropout_input, create_temp_directory, get_obj_size
from ivadomed.loader.segmentation_pair import SegmentationPair
from ivadomed.object_detection import utils as imed_obj_detect
from ivadomed.loader.patch_filter import PatchFilter
from ivadomed.keywords import MetadataKW, SegmentationDatasetKW, SegmentationPairKW
from ivadomed.utils import get_timestamp, get_system_memory
from torchvision.transforms import Compose
Expand All @@ -31,10 +32,11 @@ class MRI3DSubVolumeSegmentationDataset(Dataset):
Args:
filename_pairs (list): A list of tuples in the format (input filename, ground truth filename).
transform (Compose): Transformations to apply.
length (tuple): Size of each dimensions of the subvolumes, length equals 3.
stride (tuple): Size of the overlapping per subvolume and dimensions, length equals 3.
slice_axis (int): Indicates the axis used to extract slices: "axial": 2, "sagittal": 0, "coronal": 1.
transform (Compose): Transformations to apply.
subvolume_filter_fn (PatchFilter): PatchFilter object containing subvolume filter parameters.
soft_gt (bool): If True, ground truths are not binarized before being fed to the network. Otherwise, ground
truths are thresholded (0.5) after the data augmentation operations.
is_input_dropout (bool): Return input with missing modalities.
Expand All @@ -48,10 +50,12 @@ def __init__(self,
length: tuple = (64, 64, 64),
stride: tuple = (0, 0, 0),
slice_axis: int = 0,
subvolume_filter_fn: PatchFilter = None,
task: str = "segmentation",
soft_gt: bool = False,
is_input_dropout: bool = False,
disk_cache: bool=True):

self.filename_pairs = filename_pairs

# could be a list of tuple of objects OR path objects to the actual disk equivalent.
Expand All @@ -63,6 +67,7 @@ def __init__(self,
self.stride = stride
self.prepro_transforms, self.transform = transform
self.slice_axis = slice_axis
self.subvolume_filter_fn = subvolume_filter_fn
self.has_bounding_box: bool = True
self.task = task
self.soft_gt = soft_gt
Expand Down Expand Up @@ -90,9 +95,11 @@ def _load_filenames(self) -> None:

self.has_bounding_box = imed_obj_detect.verify_metadata(seg_pair, self.has_bounding_box)
if self.has_bounding_box:
self.prepro_transforms = imed_obj_detect.adjust_transforms(self.prepro_transforms, seg_pair,
self.prepro_transforms = imed_obj_detect.adjust_transforms(self.prepro_transforms,
seg_pair,
length=self.length,
stride=self.stride)

seg_pair, roi_pair = imed_transforms.apply_preprocessing_transforms(self.prepro_transforms,
seg_pair=seg_pair)

Expand Down Expand Up @@ -128,7 +135,8 @@ def _prepare_indices(self):
else:
segpair = self.handlers[i][0]

input_img = segpair.get('input')
input_img, gt = segpair.get('input'), segpair.get('gt')

shape = input_img[0].shape

if ((shape[0] - self.length[0]) % self.stride[0]) != 0 or self.length[0] % 16 != 0 or shape[0] < \
Expand All @@ -143,14 +151,27 @@ def _prepare_indices(self):
for x in range(0, (shape[0] - self.length[0]) + 1, self.stride[0]):
for y in range(0, (shape[1] - self.length[1]) + 1, self.stride[1]):
for z in range(0, (shape[2] - self.length[2]) + 1, self.stride[2]):
x_min, x_max = x, x + self.length[0]
y_min, y_max = y, y + self.length[1]
z_min, z_max = z, z + self.length[2]

subvolume = {
'input': list(np.asarray(input_img)[:, x_min:x_max, y_min:y_max, z_min:z_max]),
'gt': list(np.asarray(gt)[:, x_min:x_max, y_min:y_max, z_min:z_max] if gt else []),
}

if self.subvolume_filter_fn and not self.subvolume_filter_fn(subvolume):
continue

self.indexes.append({
'x_min': x,
'x_max': x + self.length[0],
'y_min': y,
'y_max': y + self.length[1],
'z_min': z,
'z_max': z + self.length[2],
'handler_index': i})
'x_min': x_min,
'x_max': x_max,
'y_min': y_min,
'y_max': y_max,
'z_min': z_min,
'z_max': z_max,
'handler_index': i,
})

def __len__(self) -> int:
"""Return the dataset size. The number of subvolumes."""
Expand Down
20 changes: 10 additions & 10 deletions ivadomed/loader/patch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@


class PatchFilter(object):
"""Filter 2D patches from dataset.
"""Filter 2D or 3D patches from dataset.
If a patch does not meet certain conditions, it is discarded from the dataset at training time.
Args:
filter_empty_mask (bool): If True, 2D patches where all voxel labels are zeros are discarded at training time.
filter_absent_class (bool): If True, 2D patches where all voxel labels are zero for one or more classes are
filter_empty_mask (bool): If True, 2D or 3D patches where all voxel labels are zeros are discarded at training time.
filter_absent_class (bool): If True, 2D or 3D patches where all voxel labels are zero for one or more classes are
discarded at training time.
filter_empty_input (bool): If True, 2D patches where all voxel intensities are zeros are discarded
filter_empty_input (bool): If True, 2D or 3D patches where all voxel intensities are zeros are discarded
at training time.
is_train (bool): Indicates if at training time.
Attributes:
filter_empty_mask (bool): If True, 2D patches where all voxel labels are zeros are discarded at training time.
filter_empty_mask (bool): If True, 2D or 3D patches where all voxel labels are zeros are discarded at training time.
Default: False.
filter_absent_class (bool): If True, 2D patches where all voxel labels are zero for one or more classes are
filter_absent_class (bool): If True, 2D or 3D patches where all voxel labels are zero for one or more classes are
discarded at training time. Default: False.
filter_empty_input (bool): If True, 2D patches where all voxel intensities are zeros are discarded
filter_empty_input (bool): If True, 2D or 3D patches where all voxel intensities are zeros are discarded
at training time. Default: False.
is_train (bool): Indicates if at training time.
Expand All @@ -43,15 +43,15 @@ def __call__(self, sample: dict) -> bool:

if self.is_train:
if self.filter_empty_mask:
# Discard 2D patches that do not have ANY ground truth (i.e. all masks are empty) at training time
# Discard 2D or 3D patches that do not have ANY ground truth (i.e. all masks are empty) at training time
if not np.any(gt_data):
return False
if self.filter_absent_class:
# Discard 2D patches that have absent classes (i.e. one or more masks are empty) at training time
# Discard 2D or 3D patches that have absent classes (i.e. one or more masks are empty) at training time
if not np.all([np.any(mask) for mask in gt_data]):
return False
if self.filter_empty_input:
# Discard set of 2D patches if one of them is empty or filled with constant value
# Discard set of 2D or 3D patches if one of them is empty or filled with constant value
# (i.e. std == 0) at training time
if np.any([img.std() == 0 for img in input_data]):
return False
Expand Down
54 changes: 53 additions & 1 deletion testing/unit_tests/test_patch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _cmpt_slice(ds_loader):
{"filter_empty_mask": False, "filter_empty_input": True},
{"filter_empty_mask": True, "filter_empty_input": True}])
@pytest.mark.parametrize('dataset_type', ["training", "testing"])
def test_patch_filter(download_data_testing_test_files, transforms_dict, train_lst, target_lst, patch_filter_params,
def test_patch_filter_2d(download_data_testing_test_files, transforms_dict, train_lst, target_lst, patch_filter_params,
dataset_type):

cuda_available, device = imed_utils.define_device(GPU_ID)
Expand Down Expand Up @@ -86,6 +86,58 @@ def test_patch_filter(download_data_testing_test_files, transforms_dict, train_l
# We verify if there are still some negative patches (they are removed with our filter)
assert cmpt_neg != 0 and cmpt_pos != 0

@pytest.mark.parametrize('transforms_dict', [{"CenterCrop": {"size": [128, 128, 128], "applied_to": ["im", "gt"]}}])
@pytest.mark.parametrize('train_lst', [['sub-unf01_T2w.nii.gz']])
@pytest.mark.parametrize('target_lst', [["_seg-manual"]])
@pytest.mark.parametrize('patch_filter_params', [
{"filter_empty_mask": False, "filter_empty_input": True},
{"filter_empty_mask": True, "filter_empty_input": True}])
@pytest.mark.parametrize('dataset_type', ["training", "testing"])
def test_patch_filter_3d(download_data_testing_test_files, transforms_dict, train_lst, target_lst, patch_filter_params,
dataset_type):

cuda_available, device = imed_utils.define_device(GPU_ID)

loader_params = {
"transforms_params": transforms_dict,
"data_list": train_lst,
"dataset_type": dataset_type,
"requires_undo": False,
"contrast_params": {"contrast_lst": ['T2w'], "balance": {}},
"path_data": [os.path.join(__data_testing_dir__)],
"target_suffix": target_lst,
"extensions": [".nii.gz"],
"roi_params": {"suffix": None, "slice_filter_roi": None},
"model_params": {"name": "Unet", "is_2d": False, "length_3D": [32, 32, 32], "stride_3D": [32, 32, 32]},
"slice_filter_params": {"filter_empty_mask": False, "filter_empty_input": False},
"patch_filter_params": patch_filter_params,
"slice_axis": "axial",
"multichannel": False
}
# Get Training dataset
bids_df = BidsDataframe(loader_params, __tmp_dir__, derivatives=True)
ds = imed_loader.load_dataset(bids_df, **loader_params)

logger.info(f"\tNumber of loaded subvolumes: {len(ds)}")

loader = DataLoader(ds, batch_size=BATCH_SIZE,
shuffle=True, pin_memory=True,
collate_fn=imed_loader_utils.imed_collate,
num_workers=0)
logger.info("\tNumber of Neg/Pos subvolumes in GT.")
cmpt_neg, cmpt_pos = _cmpt_slice(loader)
if patch_filter_params["filter_empty_mask"]:
if dataset_type == "testing":
# Filters on patches are not applied at testing time
assert cmpt_neg + cmpt_pos == len(ds)
else:
# Filters on patches are applied at training time
assert cmpt_neg == 0
assert cmpt_pos != 0
else:
# We verify if there are still some negative patches (they are removed with our filter)
assert cmpt_neg != 0 and cmpt_pos != 0


def teardown_function():
remove_tmp_dir()

0 comments on commit 42c6741

Please sign in to comment.