Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to call specific post processing algorithms after reverse one-hot encoding #494

Merged
merged 22 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def validate_network(
print("== Current subject:", subject["subject_id"], flush=True)

# ensure spacing is always present in params and is always subject-specific
params["subject_spacing"] = None
if "spacing" in subject:
params["subject_spacing"] = subject["spacing"]
else:
params["subject_spacing"] = None

# constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide
# constructing a new dict because torchio.GridSampler requires torchio.Subject,
# which requires torchio.Image to be present in initial dict, which the loader does not provide
subject_dict = {}
label_ground_truth = None
label_present = False
Expand Down Expand Up @@ -306,26 +306,35 @@ def validate_network(
tensor=subject["1"]["data"].squeeze(0),
affine=subject["1"]["affine"].squeeze(0),
).as_sitk()
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
jpg_detected = False
if ext in [".jpg", ".jpeg"]:
jpg_detected = True
pred_mask = output_prediction.numpy()
# perform postprocessing before reverse one-hot encoding here
for postprocessor in params["data_postprocessing"]:
for _class in range(0, params["model"]["num_classes"]):
pred_mask[0, _class, ...] = global_postprocessing_dict[
postprocessor
](pred_mask[0, _class, ...], params)
# '0' because validation/testing dataloader always has batch size of '1'
pred_mask = reverse_one_hot(
pred_mask[0], params["model"]["class_list"]
)
pred_mask = np.swapaxes(pred_mask, 0, 2)

# perform numpy-specific postprocessing here
for postprocessor in params["data_postprocessing"]:
# perform postprocessing after reverse one-hot encoding here
for postprocessor in params[
"data_postprocessing_after_reverse_one_hot_encoding"
]:
pred_mask = global_postprocessing_dict[postprocessor](
pred_mask, params
).numpy()
if jpg_detected:
)

# if jpg detected, convert to 8-bit arrays
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
if ext in [
".jpg",
".jpeg",
".png",
]:
pred_mask = pred_mask.astype(np.uint8)
else:
pred_mask = pred_mask.astype(np.uint16)

## special case for 2D
if image.shape[-1] > 1:
Expand Down
3 changes: 2 additions & 1 deletion GANDLF/data/post_process/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
- Define a new submodule under `GANDLF.data.post_process`, or update an existing one.
- Add the algorithm's identifier to `GANDLF.data.post_process.__init__.global_postprocessing_dict` as appropriate.
- Call the new algorithm from the config using the `data_postprocessing` key.
- Care should be taken that the post-processing steps should only be called during the `"save_output"` routine of `GANDLF.compute.forward_pass`, so that validation results do not get tainted by any post-processing.
- Care should be taken that the post-processing steps should only be called during the `"save_output"` routine of `GANDLF.compute.forward_pass`, so that validation results do not get tainted by any post-processing.
- If the new algorithm is to be applied after reverse one-hot encoding, then append the key to `GANDLF.data.post_process.__init__.postprocessing_after_reverse_one_hot_encoding`. Ensure that these algorithms return `numpy.ndarray`.
9 changes: 8 additions & 1 deletion GANDLF/data/post_process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,11 @@
from .tensor import get_mapped_label


global_postprocessing_dict = {"fill_holes": fill_holes, "mapping": get_mapped_label}
global_postprocessing_dict = {
"fill_holes": fill_holes,
"mapping": get_mapped_label,
"morphology": torch_morphological,
}

# append post_processing functions that are to be be applied after reverse one-hot encoding
postprocessing_after_reverse_one_hot_encoding = ["mapping"]
18 changes: 9 additions & 9 deletions GANDLF/data/post_process/tensor.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import torch
import numpy as np
from GANDLF.utils.generic import get_array_from_image_or_tensor


def get_mapped_label(input_tensor, params):
"""
This function maps the input label to the output label.
Args:
input_tensor (torch.Tensor): The input label.
input_tensor (Union[torch.Tensor, sitk.Image]): The input label.
params (dict): The parameters dict.

Returns:
torch.Tensor: The output image after morphological operations.
np.ndarray: The output image after morphological operations.
"""
input_image_array = get_array_from_image_or_tensor(input_tensor)
if "data_postprocessing" not in params:
return input_tensor
return input_image_array
if "mapping" not in params["data_postprocessing"]:
return input_tensor
return input_image_array

mapping = params["data_postprocessing"]["mapping"]
output = np.zeros(input_image_array.shape)

output = torch.zeros(input_tensor.shape)

for key, value in mapping.items():
for key, value in params["data_postprocessing"]["mapping"].items():
output[input_tensor == key] = value

return output
14 changes: 14 additions & 0 deletions GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy

from .utils import version_check
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding

## dictionary to define defaults for appropriate options, which are evaluated
parameter_defaults = {
Expand Down Expand Up @@ -478,6 +479,19 @@ def parseConfig(config_file_path, version_check_flag=True):
"target": "adaptive"
}

# this is NOT a required parameter - a user should be able to train with NO built-in post-processing
params = initialize_key(params, "data_postprocessing", {})
params = initialize_key(
params, "data_postprocessing_after_reverse_one_hot_encoding", {}
)
temp_dict = deepcopy(params["data_postprocessing"])
for key in temp_dict:
if key in postprocessing_after_reverse_one_hot_encoding:
params["data_postprocessing_after_reverse_one_hot_encoding"][key] = params[
"data_postprocessing"
][key]
params["data_postprocessing"].pop(key)

if "model" in params:

if not (isinstance(params["model"], dict)):
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def reverse_one_hot(predmask_tensor, class_list):
if case in _class: # check if any of the special cases are present
special_case_detected = True

final_mask = np.zeros(predmask_array[0, ...].shape).astype(np.int8)
final_mask = np.zeros(predmask_array[0, ...].shape).astype(np.int16)
predmask_array_bool = predmask_array >= 0.5

# in case special case is detected, if 0 is absent from
Expand Down
76 changes: 37 additions & 39 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,26 @@ def test_generic_constructTrainingCSV():
i += 1


# these are helper functions to be used in other tests
def sanitize_outputDir():
print("02: Sanitizing outputDir")
print("02_1: Sanitizing outputDir")
if os.path.isdir(outputDir):
shutil.rmtree(outputDir) # overwrite previous results
Path(outputDir).mkdir(parents=True, exist_ok=True)


def get_temp_config_path():
print("02_2: Creating path for temporary config file")
temp_config_path = os.path.join(outputDir, "config_temp.yaml")
# if found in previous run, discard.
if os.path.exists(temp_config_path):
os.remove(temp_config_path)
return temp_config_path


# these are helper functions to be used in other tests


def test_train_segmentation_rad_2d(device):
print("03: Starting 2D Rad segmentation tests")
# read and parse csv
Expand Down Expand Up @@ -814,7 +827,7 @@ def test_train_scheduler_classification_rad_2d(device):
parameters["nested_training"]["validation"] = -5
sanitize_outputDir()
## ensure parameters are parsed every single time
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
file_config_temp = get_temp_config_path()

with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
Expand Down Expand Up @@ -958,20 +971,27 @@ def test_train_metrics_segmentation_rad_2d(device):
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
parameters["data_postprocessing"] = {"mapping": {0: 0, 255: 1}}
parameters["model"]["amp"] = True
parameters["save_output"] = True
parameters["model"]["num_channels"] = 3
parameters["metrics"] = ["dice", "hausdorff", "hausdorff95"]
parameters["model"]["architecture"] = "resunet"
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
file_config_temp = get_temp_config_path()

with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)

parameters = parseConfig(file_config_temp, version_check_flag=False)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters = populate_header_in_parameters(parameters, parameters["headers"])
sanitize_outputDir()
TrainingManager(
Expand Down Expand Up @@ -1060,12 +1080,7 @@ def test_train_losses_segmentation_rad_2d(device):

def test_generic_config_read():
print("24: Starting testing reading configuration")
# read and parse csv
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)

file_config_temp = get_temp_config_path()
parameters = parseConfig(
os.path.join(baseConfigDir, "config_all_options.yaml"),
version_check_flag=False,
Expand All @@ -1075,6 +1090,7 @@ def test_generic_config_read():
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)

# read and parse csv
parameters = parseConfig(file_config_temp, version_check_flag=True)

training_data, parameters["headers"] = parseTrainingCSV(
Expand Down Expand Up @@ -1149,7 +1165,7 @@ def test_generic_cli_function_preprocess():
print("25: Starting testing cli function preprocess")
file_config = os.path.join(testingDir, "config_segmentation.yaml")
sanitize_outputDir()
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
file_config_temp = get_temp_config_path()
file_data = os.path.join(inputDir, "train_2d_rad_segmentation.csv")

parameters = parseConfig(file_config)
Expand Down Expand Up @@ -1189,10 +1205,7 @@ def test_generic_cli_function_mainrun(device):
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
Expand Down Expand Up @@ -1610,17 +1623,17 @@ def test_generic_one_hot_logic():
parameters = {"data_postprocessing": {}}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

parameters = {}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

parameters = {"data_postprocessing": {"mapping": {0: 0, 1: 1, 2: 5}}}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

for key, value in parameters["data_postprocessing"]["mapping"].items():
comparison = (img_tensor_oh_rev_array == key) == (mapped_output == value)
Expand Down Expand Up @@ -1712,12 +1725,7 @@ def test_train_inference_segmentation_histology_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters_patch = {}
# extracting minimal number of patches to ensure that the test does not take too long
Expand Down Expand Up @@ -1790,12 +1798,7 @@ def test_train_inference_classification_histology_large_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

for sub in ["1", "2"]:
file_to_check = os.path.join(
Expand Down Expand Up @@ -1867,7 +1870,7 @@ def test_train_inference_classification_histology_large_2d(device):
)
parameters["modality"] = "histo"
parameters["patch_size"] = parameters_patch["patch_size"][0]
file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml")
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
parameters = parseConfig(file_config_temp, version_check_flag=False)
Expand Down Expand Up @@ -1936,12 +1939,7 @@ def test_train_inference_classification_histology_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters_patch = {}
# extracting minimal number of patches to ensure that the test does not take too long
Expand All @@ -1968,7 +1966,7 @@ def test_train_inference_classification_histology_2d(device):
)
parameters["modality"] = "histo"
parameters["patch_size"] = 128
file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml")
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
parameters = parseConfig(file_config_temp, version_check_flag=False)
Expand Down