Skip to content

Commit

Permalink
Merge branch 'master' into kk/integrate_mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
kanishk16 committed Jul 4, 2022
2 parents 7207db9 + a588984 commit 131e5e8
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 39 deletions.
19 changes: 18 additions & 1 deletion ivadomed/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,21 @@ class MetadataParamsKW:
@dataclass
class MetadataChoiceKW:
MRI_PARAMS = "mri_params"
CONTRASTS = "contrasts"
CONTRASTS = "contrasts"

@dataclass
class SegmentationDatasetKW:
X_MIN: str = 'x_min'
X_MAX: str = 'x_max'
Y_MIN: str = 'y_min'
Y_MAX: str = 'y_max'
Z_MIN: str = 'z_min'
Z_MAX: str = 'z_max'
HANDLER_INDEX: str = 'handler_index'

@dataclass
class SegmentationPairKW:
GT_METADATA = "gt_metadata"
INPUT_METADATA = "input_metadata"
GT = "gt"
INPUT = "input"
98 changes: 62 additions & 36 deletions ivadomed/loader/mri3d_subvolume_segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ivadomed.loader.utils import dropout_input, create_temp_directory, get_obj_size
from ivadomed.loader.segmentation_pair import SegmentationPair
from ivadomed.object_detection import utils as imed_obj_detect
from ivadomed.keywords import MetadataKW
from ivadomed.keywords import MetadataKW, SegmentationDatasetKW, SegmentationPairKW
from ivadomed.utils import get_timestamp, get_system_memory


Expand Down Expand Up @@ -148,86 +148,112 @@ def __len__(self):
"""Return the dataset size. The number of subvolumes."""
return len(self.indexes)

def __getitem__(self, index):
def __getitem__(self, subvolume_index: int):
"""Return the specific index pair subvolume (input, ground truth).
Args:
index (int): Subvolume index.
subvolume_index (int): Subvolume index.
"""

# copy.deepcopy is used to have different coordinates for reconstruction for a given handler,
# to allow a different rater at each iteration of training, and to clean transforms params from previous
# transforms i.e. remove params from previous iterations so that the coming transforms are different
coord = self.indexes[index]

tuple_seg_roi_pair: tuple = self.handlers[coord['handler_index']]
# Get the tuple that defines the boundaries for the subsample
coord: dict = self.indexes[subvolume_index]
x_min = coord.get(SegmentationDatasetKW.X_MIN)
x_max = coord.get(SegmentationDatasetKW.X_MAX)
y_min = coord.get(SegmentationDatasetKW.Y_MIN)
y_max = coord.get(SegmentationDatasetKW.Y_MAX)
z_min = coord.get(SegmentationDatasetKW.Z_MIN)
z_max = coord.get(SegmentationDatasetKW.Z_MAX)

# Obtain tuple reference to the pairs of file references
tuple_seg_roi_pair: tuple = self.handlers[coord.get(SegmentationDatasetKW.HANDLER_INDEX)]

# Disk Cache handling, either, load the seg_pair, not using ROI pair here.
if self.disk_cache:
with tuple_seg_roi_pair[0].open(mode='rb') as f:
seg_pair = pickle.load(f)
else:
seg_pair, _ = copy.deepcopy(tuple_seg_roi_pair)

# In case multiple raters
if seg_pair['gt'] and isinstance(seg_pair['gt'][0], list):
if seg_pair[SegmentationPairKW.GT] and isinstance(seg_pair[SegmentationPairKW.GT][0], list):
# Randomly pick a rater
idx_rater = random.randint(0, len(seg_pair['gt'][0]) - 1)
idx_rater = random.randint(0, len(seg_pair[SegmentationPairKW.GT][0]) - 1)
# Use it as ground truth for this iteration
# Note: in case of multi-class: the same rater is used across classes
for idx_class in range(len(seg_pair['gt'])):
seg_pair['gt'][idx_class] = seg_pair['gt'][idx_class][idx_rater]
seg_pair['gt_metadata'][idx_class] = seg_pair['gt_metadata'][idx_class][idx_rater]
for idx_class in range(len(seg_pair[SegmentationPairKW.GT])):
seg_pair[SegmentationPairKW.GT][idx_class] = seg_pair[SegmentationPairKW.GT][idx_class][idx_rater]
seg_pair[SegmentationPairKW.GT_METADATA][idx_class] = seg_pair[SegmentationPairKW.GT_METADATA][idx_class][idx_rater]

if seg_pair[SegmentationPairKW.INPUT_METADATA]:
metadata_input = seg_pair[SegmentationPairKW.INPUT_METADATA]
else:
metadata_input = []

metadata_input = seg_pair['input_metadata'] if seg_pair['input_metadata'] is not None else []
metadata_gt = seg_pair['gt_metadata'] if seg_pair['gt_metadata'] is not None else []
if seg_pair[SegmentationPairKW.GT_METADATA]:
metadata_gt = seg_pair[SegmentationPairKW.GT_METADATA]
else:
metadata_gt = []

# Extract image and gt slices or patches from coordinates
stack_input = np.asarray(seg_pair[SegmentationPairKW.INPUT])[
:,
x_min:x_max,
y_min:y_max,
z_min:z_max
]

if seg_pair[SegmentationPairKW.GT]:
stack_gt = np.asarray(seg_pair[SegmentationPairKW.GT])[
:,
x_min:x_max,
y_min:y_max,
z_min:z_max
]
else:
stack_gt = []

# Run transforms on images
stack_input, metadata_input = self.transform(sample=seg_pair['input'],
# Run transforms on image slices
stack_input, metadata_input = self.transform(sample=stack_input,
metadata=metadata_input,
data_type="im")
# Update metadata_gt with metadata_input
metadata_gt = imed_loader_utils.update_metadata(metadata_input, metadata_gt)

# Run transforms on images
stack_gt, metadata_gt = self.transform(sample=seg_pair['gt'],
# Run transforms on gt slices
stack_gt, metadata_gt = self.transform(sample=stack_gt,
metadata=metadata_gt,
data_type="gt")
# Make sure stack_gt is binarized
if stack_gt is not None and not self.soft_gt:
stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5).astype(np.uint8)

shape_x = coord["x_max"] - coord["x_min"]
shape_y = coord["y_max"] - coord["y_min"]
shape_z = coord["z_max"] - coord["z_min"]
shape_x = x_max - x_min
shape_y = y_max - y_min
shape_z = z_max - z_min

# add coordinates to metadata to reconstruct volume
for metadata in metadata_input:
metadata[MetadataKW.COORD] = [coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"], coord["z_min"],
coord["z_max"]]
metadata[MetadataKW.COORD] = [
x_min, x_max,
y_min, y_max,
z_min, z_max,
]

subvolumes = {
'input': torch.zeros(stack_input.shape[0], shape_x, shape_y, shape_z),
'gt': torch.zeros(stack_gt.shape[0], shape_x, shape_y, shape_z) if stack_gt is not None else None,
SegmentationPairKW.INPUT: stack_input,
SegmentationPairKW.GT: stack_gt,
MetadataKW.INPUT_METADATA: metadata_input,
MetadataKW.GT_METADATA: metadata_gt
}

for _ in range(len(stack_input)):
subvolumes['input'] = stack_input[:,
coord['x_min']:coord['x_max'],
coord['y_min']:coord['y_max'],
coord['z_min']:coord['z_max']]

# Input-level dropout to train with missing modalities
if self.is_input_dropout:
subvolumes = dropout_input(subvolumes)

if stack_gt is not None:
for _ in range(len(stack_gt)):
subvolumes['gt'] = stack_gt[:,
coord['x_min']:coord['x_max'],
coord['y_min']:coord['y_max'],
coord['z_min']:coord['z_max']]

return subvolumes

def determine_cache_need(self, seg_pair: dict, roi_pair: dict):
Expand Down
2 changes: 1 addition & 1 deletion ivadomed/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def check_multiple_raters(is_train, loader_params):
if not is_train:
logger.error(
"Please provide only one annotation per class in 'target_suffix' when not training a model.\n")
exit()
sys.exit()


def film_normalize_data(context, model_params, ds_train, ds_valid, path_output):
Expand Down
2 changes: 1 addition & 1 deletion ivadomed/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def run_inference(test_loader, model, model_params, testing_params, ofolder, cud
# save the completely processed file as a NifTI file
if ofolder:
fname_pred = str(Path(ofolder, Path(fname_ref).name))
fname_pred = fname_pred.rsplit("_", 1)[0] + '_pred.nii.gz'
fname_pred = fname_pred.split(testing_params['target_suffix'][0])[0] + '_pred.nii.gz'
# If Uncertainty running, then we save each simulation result
if testing_params['uncertainty']['applied']:
fname_pred = fname_pred.split('.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(2) + '.nii.gz'
Expand Down
14 changes: 14 additions & 0 deletions testing/unit_tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from ivadomed.main import check_multiple_raters

@pytest.mark.parametrize(
'is_train, loader_params', [
(False, {"target_suffix":
[["_seg-axon-manual1", "_seg-axon-manual2"],
["_seg-myelin-manual1", "_seg-myelin-manual2"]]
})
])
def test_check_multiple_raters(is_train, loader_params):
with pytest.raises(SystemExit):
check_multiple_raters(is_train, loader_params)
95 changes: 95 additions & 0 deletions testing/unit_tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,100 @@ def test_inference_2d_microscopy(download_data_testing_test_files, transforms_di
assert len([x for x in __output_dir__.iterdir() if x.name.endswith(".png")]) == 2*len(test_lst)


@pytest.mark.parametrize('transforms_dict', [{
"CenterCrop": {
"size": [128, 128]
},
"NormalizeInstance": {"applied_to": ["im"]}
}])
@pytest.mark.parametrize('test_lst',
[['sub-rat3_ses-01_sample-data9_SEM.png', 'sub-rat3_ses-02_sample-data10_SEM.png']])
@pytest.mark.parametrize('target_lst', [["_seg-axon_manual", "_seg-myelin_manual"]])
@pytest.mark.parametrize('roi_params', [{"suffix": None, "slice_filter_roi": None}])
@pytest.mark.parametrize('testing_params', [{
"binarize_maxpooling": {},
"uncertainty": {
"applied": False,
"epistemic": False,
"aleatoric": False,
"n_it": 0
}}])
def test_inference_target_suffix(download_data_testing_test_files, transforms_dict, test_lst, target_lst, roi_params,
testing_params):
"""
This test checks if the filename(s) of the prediction(s) saved as NifTI file(s) in the pred_masks
dir conform to the target_suffix or not. Thus, independent of underscore(s) in the target_suffix. As a result,
_seg-axon-manual or _seg-axon_manual should yield the same filename(s).
(c.f: https://github.com/ivadomed/ivadomed/issues/1135)
"""
cuda_available, device = imed_utils.define_device(GPU_ID)

model_params = {"name": "Unet", "is_2d": True, "out_channel": 3}
loader_params = {
"transforms_params": transforms_dict,
"data_list": test_lst,
"dataset_type": "testing",
"requires_undo": True,
"contrast_params": {"contrast_lst": ['SEM'], "balance": {}},
"path_data": [str(Path(__data_testing_dir__, "microscopy_png"))],
"bids_config": f"{path_repo_root}/ivadomed/config/config_bids.json",
"target_suffix": target_lst,
"extensions": [".png"],
"roi_params": roi_params,
"slice_filter_params": {"filter_empty_mask": False, "filter_empty_input": True},
"patch_filter_params": {"filter_empty_mask": False, "filter_empty_input": False},
"slice_axis": SLICE_AXIS,
"multichannel": False
}
loader_params.update({"model_params": model_params})

# restructuring the dataset
gt_path = f'{loader_params["path_data"][0]}/derivatives/labels/'
for file_path in Path(gt_path).rglob('*.png'):
src_filename = file_path.resolve()
dst_filename = '_'.join(str(src_filename).rsplit('-', 1))
src_filename.rename(Path(dst_filename))

bids_df = BidsDataframe(loader_params, __tmp_dir__, derivatives=True)

ds_test = imed_loader.load_dataset(bids_df, **loader_params)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE,
shuffle=False, pin_memory=True,
collate_fn=imed_loader_utils.imed_collate,
num_workers=0)

# Undo transform
val_undo_transform = imed_transforms.UndoCompose(imed_transforms.Compose(transforms_dict))

# Update testing_params
testing_params.update({
"slice_axis": loader_params["slice_axis"],
"target_suffix": loader_params["target_suffix"],
"undo_transforms": val_undo_transform
})

# Model
model = imed_models.Unet(out_channel=model_params['out_channel'])

if cuda_available:
model.cuda()
model.eval()

if not __output_dir__.is_dir():
__output_dir__.mkdir(parents=True, exist_ok=True)

preds_npy, gt_npy = imed_testing.run_inference(test_loader=test_loader,
model=model,
model_params=model_params,
testing_params=testing_params,
ofolder=str(__output_dir__),
cuda_available=cuda_available)

for x in __output_dir__.iterdir():
if x.name.endswith('_pred.nii.gz'):
assert x.name.rsplit('_', 1)[0].endswith(loader_params['contrast_params']['contrast_lst'][-1]), (
'Incompatible filename(s) of the prediction(s) saved as NifTI file(s)!'
)

def teardown_function():
remove_tmp_dir()

0 comments on commit 131e5e8

Please sign in to comment.