Skip to content

Commit

Permalink
Merge pull request #494 from sarthakpati/493_specialized_post_process
Browse files Browse the repository at this point in the history
Added ability to call specific post processing algorithms after reverse one-hot encoding
  • Loading branch information
sarthakpati committed Sep 12, 2022
2 parents be60609 + 08cca1a commit 3d00791
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 64 deletions.
35 changes: 22 additions & 13 deletions GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def validate_network(
print("== Current subject:", subject["subject_id"], flush=True)

# ensure spacing is always present in params and is always subject-specific
params["subject_spacing"] = None
if "spacing" in subject:
params["subject_spacing"] = subject["spacing"]
else:
params["subject_spacing"] = None

# constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide
# constructing a new dict because torchio.GridSampler requires torchio.Subject,
# which requires torchio.Image to be present in initial dict, which the loader does not provide
subject_dict = {}
label_ground_truth = None
label_present = False
Expand Down Expand Up @@ -306,26 +306,35 @@ def validate_network(
tensor=subject["1"]["data"].squeeze(0),
affine=subject["1"]["affine"].squeeze(0),
).as_sitk()
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
jpg_detected = False
if ext in [".jpg", ".jpeg"]:
jpg_detected = True
pred_mask = output_prediction.numpy()
# perform postprocessing before reverse one-hot encoding here
for postprocessor in params["data_postprocessing"]:
for _class in range(0, params["model"]["num_classes"]):
pred_mask[0, _class, ...] = global_postprocessing_dict[
postprocessor
](pred_mask[0, _class, ...], params)
# '0' because validation/testing dataloader always has batch size of '1'
pred_mask = reverse_one_hot(
pred_mask[0], params["model"]["class_list"]
)
pred_mask = np.swapaxes(pred_mask, 0, 2)

# perform numpy-specific postprocessing here
for postprocessor in params["data_postprocessing"]:
# perform postprocessing after reverse one-hot encoding here
for postprocessor in params[
"data_postprocessing_after_reverse_one_hot_encoding"
]:
pred_mask = global_postprocessing_dict[postprocessor](
pred_mask, params
).numpy()
if jpg_detected:
)

# if jpg detected, convert to 8-bit arrays
ext = get_filename_extension_sanitized(subject["1"]["path"][0])
if ext in [
".jpg",
".jpeg",
".png",
]:
pred_mask = pred_mask.astype(np.uint8)
else:
pred_mask = pred_mask.astype(np.uint16)

## special case for 2D
if image.shape[-1] > 1:
Expand Down
3 changes: 2 additions & 1 deletion GANDLF/data/post_process/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
- Define a new submodule under `GANDLF.data.post_process`, or update an existing one.
- Add the algorithm's identifier to `GANDLF.data.post_process.__init__.global_postprocessing_dict` as appropriate.
- Call the new algorithm from the config using the `data_postprocessing` key.
- Care should be taken that the post-processing steps should only be called during the `"save_output"` routine of `GANDLF.compute.forward_pass`, so that validation results do not get tainted by any post-processing.
- Care should be taken that the post-processing steps should only be called during the `"save_output"` routine of `GANDLF.compute.forward_pass`, so that validation results do not get tainted by any post-processing.
- If the new algorithm is to be applied after reverse one-hot encoding, then append the key to `GANDLF.data.post_process.__init__.postprocessing_after_reverse_one_hot_encoding`. Ensure that these algorithms return `numpy.ndarray`.
9 changes: 8 additions & 1 deletion GANDLF/data/post_process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,11 @@
from .tensor import get_mapped_label


global_postprocessing_dict = {"fill_holes": fill_holes, "mapping": get_mapped_label}
global_postprocessing_dict = {
"fill_holes": fill_holes,
"mapping": get_mapped_label,
"morphology": torch_morphological,
}

# append post_processing functions that are to be be applied after reverse one-hot encoding
postprocessing_after_reverse_one_hot_encoding = ["mapping"]
18 changes: 9 additions & 9 deletions GANDLF/data/post_process/tensor.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import torch
import numpy as np
from GANDLF.utils.generic import get_array_from_image_or_tensor


def get_mapped_label(input_tensor, params):
"""
This function maps the input label to the output label.
Args:
input_tensor (torch.Tensor): The input label.
input_tensor (Union[torch.Tensor, sitk.Image]): The input label.
params (dict): The parameters dict.
Returns:
torch.Tensor: The output image after morphological operations.
np.ndarray: The output image after morphological operations.
"""
input_image_array = get_array_from_image_or_tensor(input_tensor)
if "data_postprocessing" not in params:
return input_tensor
return input_image_array
if "mapping" not in params["data_postprocessing"]:
return input_tensor
return input_image_array

mapping = params["data_postprocessing"]["mapping"]
output = np.zeros(input_image_array.shape)

output = torch.zeros(input_tensor.shape)

for key, value in mapping.items():
for key, value in params["data_postprocessing"]["mapping"].items():
output[input_tensor == key] = value

return output
14 changes: 14 additions & 0 deletions GANDLF/parseConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from copy import deepcopy

from .utils import version_check
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding

## dictionary to define defaults for appropriate options, which are evaluated
parameter_defaults = {
Expand Down Expand Up @@ -478,6 +479,19 @@ def parseConfig(config_file_path, version_check_flag=True):
"target": "adaptive"
}

# this is NOT a required parameter - a user should be able to train with NO built-in post-processing
params = initialize_key(params, "data_postprocessing", {})
params = initialize_key(
params, "data_postprocessing_after_reverse_one_hot_encoding", {}
)
temp_dict = deepcopy(params["data_postprocessing"])
for key in temp_dict:
if key in postprocessing_after_reverse_one_hot_encoding:
params["data_postprocessing_after_reverse_one_hot_encoding"][key] = params[
"data_postprocessing"
][key]
params["data_postprocessing"].pop(key)

if "model" in params:

if not (isinstance(params["model"], dict)):
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def reverse_one_hot(predmask_tensor, class_list):
if case in _class: # check if any of the special cases are present
special_case_detected = True

final_mask = np.zeros(predmask_array[0, ...].shape).astype(np.int8)
final_mask = np.zeros(predmask_array[0, ...].shape).astype(np.int16)
predmask_array_bool = predmask_array >= 0.5

# in case special case is detected, if 0 is absent from
Expand Down
76 changes: 37 additions & 39 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,26 @@ def test_generic_constructTrainingCSV():
i += 1


# these are helper functions to be used in other tests
def sanitize_outputDir():
print("02: Sanitizing outputDir")
print("02_1: Sanitizing outputDir")
if os.path.isdir(outputDir):
shutil.rmtree(outputDir) # overwrite previous results
Path(outputDir).mkdir(parents=True, exist_ok=True)


def get_temp_config_path():
print("02_2: Creating path for temporary config file")
temp_config_path = os.path.join(outputDir, "config_temp.yaml")
# if found in previous run, discard.
if os.path.exists(temp_config_path):
os.remove(temp_config_path)
return temp_config_path


# these are helper functions to be used in other tests


def test_train_segmentation_rad_2d(device):
print("03: Starting 2D Rad segmentation tests")
# read and parse csv
Expand Down Expand Up @@ -814,7 +827,7 @@ def test_train_scheduler_classification_rad_2d(device):
parameters["nested_training"]["validation"] = -5
sanitize_outputDir()
## ensure parameters are parsed every single time
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
file_config_temp = get_temp_config_path()

with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
Expand Down Expand Up @@ -958,20 +971,27 @@ def test_train_metrics_segmentation_rad_2d(device):
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
parameters["model"]["dimension"] = 2
parameters["model"]["class_list"] = [0, 255]
parameters["data_postprocessing"] = {"mapping": {0: 0, 255: 1}}
parameters["model"]["amp"] = True
parameters["save_output"] = True
parameters["model"]["num_channels"] = 3
parameters["metrics"] = ["dice", "hausdorff", "hausdorff95"]
parameters["model"]["architecture"] = "resunet"
parameters["model"]["onnx_export"] = False
parameters["model"]["print_summary"] = False
file_config_temp = get_temp_config_path()

with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)

parameters = parseConfig(file_config_temp, version_check_flag=False)
training_data, parameters["headers"] = parseTrainingCSV(
inputDir + "/train_2d_rad_segmentation.csv"
)
parameters = populate_header_in_parameters(parameters, parameters["headers"])
sanitize_outputDir()
TrainingManager(
Expand Down Expand Up @@ -1060,12 +1080,7 @@ def test_train_losses_segmentation_rad_2d(device):

def test_generic_config_read():
print("24: Starting testing reading configuration")
# read and parse csv
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)

file_config_temp = get_temp_config_path()
parameters = parseConfig(
os.path.join(baseConfigDir, "config_all_options.yaml"),
version_check_flag=False,
Expand All @@ -1075,6 +1090,7 @@ def test_generic_config_read():
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)

# read and parse csv
parameters = parseConfig(file_config_temp, version_check_flag=True)

training_data, parameters["headers"] = parseTrainingCSV(
Expand Down Expand Up @@ -1149,7 +1165,7 @@ def test_generic_cli_function_preprocess():
print("25: Starting testing cli function preprocess")
file_config = os.path.join(testingDir, "config_segmentation.yaml")
sanitize_outputDir()
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
file_config_temp = get_temp_config_path()
file_data = os.path.join(inputDir, "train_2d_rad_segmentation.csv")

parameters = parseConfig(file_config)
Expand Down Expand Up @@ -1189,10 +1205,7 @@ def test_generic_cli_function_mainrun(device):
parameters = parseConfig(
testingDir + "/config_segmentation.yaml", version_check_flag=False
)
file_config_temp = os.path.join(outputDir, "config_segmentation_temp.yaml")
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters["modality"] = "rad"
parameters["patch_size"] = patch_size["2D"]
Expand Down Expand Up @@ -1610,17 +1623,17 @@ def test_generic_one_hot_logic():
parameters = {"data_postprocessing": {}}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

parameters = {}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

parameters = {"data_postprocessing": {"mapping": {0: 0, 1: 1, 2: 5}}}
mapped_output = get_mapped_label(
torch.from_numpy(img_tensor_oh_rev_array), parameters
).numpy()
)

for key, value in parameters["data_postprocessing"]["mapping"].items():
comparison = (img_tensor_oh_rev_array == key) == (mapped_output == value)
Expand Down Expand Up @@ -1712,12 +1725,7 @@ def test_train_inference_segmentation_histology_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters_patch = {}
# extracting minimal number of patches to ensure that the test does not take too long
Expand Down Expand Up @@ -1790,12 +1798,7 @@ def test_train_inference_classification_histology_large_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

for sub in ["1", "2"]:
file_to_check = os.path.join(
Expand Down Expand Up @@ -1867,7 +1870,7 @@ def test_train_inference_classification_histology_large_2d(device):
)
parameters["modality"] = "histo"
parameters["patch_size"] = parameters_patch["patch_size"][0]
file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml")
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
parameters = parseConfig(file_config_temp, version_check_flag=False)
Expand Down Expand Up @@ -1936,12 +1939,7 @@ def test_train_inference_classification_histology_2d(device):
Path(output_dir_patches).mkdir(parents=True, exist_ok=True)
output_dir_patches_output = os.path.join(output_dir_patches, "histo_patches_output")
Path(output_dir_patches_output).mkdir(parents=True, exist_ok=True)
file_config_temp = os.path.join(
output_dir_patches, "config_patch-extraction_temp.yaml"
)
# if found in previous run, discard.
if os.path.exists(file_config_temp):
os.remove(file_config_temp)
file_config_temp = get_temp_config_path()

parameters_patch = {}
# extracting minimal number of patches to ensure that the test does not take too long
Expand All @@ -1968,7 +1966,7 @@ def test_train_inference_classification_histology_2d(device):
)
parameters["modality"] = "histo"
parameters["patch_size"] = 128
file_config_temp = os.path.join(outputDir, "config_classification_temp.yaml")
file_config_temp = get_temp_config_path()
with open(file_config_temp, "w") as file:
yaml.dump(parameters, file)
parameters = parseConfig(file_config_temp, version_check_flag=False)
Expand Down

0 comments on commit 3d00791

Please sign in to comment.