From 82d2899d1f32c0bf53723f26f418b81fffa58007 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 27 Nov 2023 20:33:50 +0100 Subject: [PATCH 1/5] Add nnUNet raw transform --- torch_em/transform/raw.py | 80 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index cd43c379..9e6c85da 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -1,4 +1,6 @@ +import json import numpy as np + import torch from torchvision import transforms @@ -227,3 +229,81 @@ def get_default_mean_teacher_augmentations( augmentation1=aug1, augmentation2=aug2 ) + + +class nnUNetRawTransformBase: + """nnUNetRawTransformBase is an interface to implement specific raw transforms for nnUNet. + + Adapted from: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization + """ + def __init__( + self, + plans_file: str, + expected_dtype: type = np.float32, + tolerance: float = 1e-8 + ): + self.expected_dtype = expected_dtype + self.tolerance = tolerance + + self.intensity_properties = self.load_json(plans_file) + self.intensity_properties = self.intensity_properties["foreground_intensity_properties_per_channel"] + + def load_json(self, _file: str): + # credits: `batchgenerators.utilities.file_and_folder_operations` + with open(_file, 'r') as f: + a = json.load(f) + return a + + def __call__( + self, + raw: np.ndarray, + modality: str + ) -> np.ndarray: # the transformed raw inputs + """Returns the raw inputs after applying the pre-processing from nnUNet. + + Args: + raw: The raw array inputs + Expectd a float array of shape H * W * C + + Returns: + The transformed raw inputs (the same shape as inputs) + """ + raise NotImplementedError("It's a class template for raw transforms from nnUNet. \ + Use a child class that implements the expected raw transform instead") + + +class nnUNet_CT_RawTransform(nnUNetRawTransformBase): + """Apply transformation on the raw inputs (adapted from nnUNetv2's `CTNormalization`) + + You can use this class to apply the necessary raw transformations on CT and PET volume channels. + + Here's an example for how to use this class: + ```python + # Initialize the raw transform. + raw_transform = nnUNet_CT_RawTransform(plans_file="...nnUNetplans.json") + + # Apply transformation on the inputs. + ct_raw = raw_transform(ct_volume) + pet_raw = raw_transform(pet_volume) + ``` + """ + def __call__( + self, + raw: np.ndarray, + modality_index: str + ) -> np.ndarray: + assert self.intensity_properties is not None, \ + "Intensity properties are required here. Please make sure that you pass the `nnUNetplans.json correctly." + + raw = raw.astype(self.expected_dtype) + + # intensity properties for the respective modality + props = self.intensity_properties[modality_index] + + mean = props['mean'] + std = props['std'] + lower_bound = props['percentile_00_5'] + upper_bound = props['percentile_99_5'] + raw = np.clip(raw, lower_bound, upper_bound) + raw = (raw - mean) / max(std, self.tolerance) + return raw From 52d15c8ade00fb60e1c5115c9a4c6cbdca1a4f79 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 28 Nov 2023 15:13:50 +0100 Subject: [PATCH 2/5] Update raw trafo for expected inputs --- torch_em/transform/nnunet_raw.py | 88 ++++++++++++++++++++++++++++++++ torch_em/transform/raw.py | 79 ---------------------------- 2 files changed, 88 insertions(+), 79 deletions(-) create mode 100644 torch_em/transform/nnunet_raw.py diff --git a/torch_em/transform/nnunet_raw.py b/torch_em/transform/nnunet_raw.py new file mode 100644 index 00000000..721f031a --- /dev/null +++ b/torch_em/transform/nnunet_raw.py @@ -0,0 +1,88 @@ +import json +import numpy as np + + +class nnUNetRawTransformBase: + """nnUNetRawTransformBase is an interface to implement specific raw transforms for nnUNet. + + Adapted from: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization + """ + def __init__( + self, + plans_file: str, + expected_dtype: type = np.float32, + tolerance: float = 1e-8 + ): + self.expected_dtype = expected_dtype + self.tolerance = tolerance + + self.intensity_properties = self.load_json(plans_file) + self.intensity_properties = self.intensity_properties["foreground_intensity_properties_per_channel"] + + def load_json(self, _file: str): + # credits: `batchgenerators.utilities.file_and_folder_operations` + with open(_file, 'r') as f: + a = json.load(f) + return a + + def __call__( + self, + raw: np.ndarray + ) -> np.ndarray: # the transformed raw inputs + """Returns the raw inputs after applying the pre-processing from nnUNet. + + Args: + raw: The raw array inputs + Expectd a float array of shape M * (H * W * D) (where, M is the number of modalities) + Returns: + The transformed raw inputs (the same shape as inputs) + """ + raise NotImplementedError("It's a class template for raw transforms from nnUNet. \ + Use a child class that implements the expected raw transform instead") + + +class nnUNetCTRawTransform(nnUNetRawTransformBase): + """Apply transformation on the raw inputs for CT + PET channels (adapted from nnUNetv2's `CTNormalization`) + + You can use this class to apply the necessary raw transformations on CT and PET volume channels. + Expectation: The inputs should be of dimension 2 * (H * W * D). + - The first channel should be CT volume + - The second channel should be PET volume + + Here's an example for how to use this class: + ```python + # Initialize the raw transform. + raw_transform = nnUNetCTRawTransform(plans_file="...nnUNetplans.json") + + # Apply transformation on the inputs. + patient_vol = np.concatenate(ct_vol, pet_vol) + patient_transformed = raw_transform(patient_vol) + ``` + """ + def __call__( + self, + raw: np.ndarray + ) -> np.ndarray: + assert raw.shape[0] == 2, "The current expectation is channels (modality) first. The fn currently supports for two modalities, namely CT and PET-CT (in the mentioned order)" + + assert self.intensity_properties is not None, \ + "Intensity properties are required here. Please make sure that you pass the `nnUNetplans.json correctly." + + raw = raw.astype(self.expected_dtype) + + transformed_raw = [] + # intensity properties for the respective modalities + for idx in range(raw.shape[0]): + props = self.intensity_properties[str(idx)] + + mean = props['mean'] + std = props['std'] + lower_bound = props['percentile_00_5'] + upper_bound = props['percentile_99_5'] + + modality = np.clip(raw[idx, ...], lower_bound, upper_bound) + modality = (modality - mean) / max(std, self.tolerance) + transformed_raw.append(modality) + + transformed_raw = np.stack(transformed_raw) + return transformed_raw diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index 9e6c85da..7fcd96db 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -1,4 +1,3 @@ -import json import numpy as np import torch @@ -229,81 +228,3 @@ def get_default_mean_teacher_augmentations( augmentation1=aug1, augmentation2=aug2 ) - - -class nnUNetRawTransformBase: - """nnUNetRawTransformBase is an interface to implement specific raw transforms for nnUNet. - - Adapted from: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization - """ - def __init__( - self, - plans_file: str, - expected_dtype: type = np.float32, - tolerance: float = 1e-8 - ): - self.expected_dtype = expected_dtype - self.tolerance = tolerance - - self.intensity_properties = self.load_json(plans_file) - self.intensity_properties = self.intensity_properties["foreground_intensity_properties_per_channel"] - - def load_json(self, _file: str): - # credits: `batchgenerators.utilities.file_and_folder_operations` - with open(_file, 'r') as f: - a = json.load(f) - return a - - def __call__( - self, - raw: np.ndarray, - modality: str - ) -> np.ndarray: # the transformed raw inputs - """Returns the raw inputs after applying the pre-processing from nnUNet. - - Args: - raw: The raw array inputs - Expectd a float array of shape H * W * C - - Returns: - The transformed raw inputs (the same shape as inputs) - """ - raise NotImplementedError("It's a class template for raw transforms from nnUNet. \ - Use a child class that implements the expected raw transform instead") - - -class nnUNet_CT_RawTransform(nnUNetRawTransformBase): - """Apply transformation on the raw inputs (adapted from nnUNetv2's `CTNormalization`) - - You can use this class to apply the necessary raw transformations on CT and PET volume channels. - - Here's an example for how to use this class: - ```python - # Initialize the raw transform. - raw_transform = nnUNet_CT_RawTransform(plans_file="...nnUNetplans.json") - - # Apply transformation on the inputs. - ct_raw = raw_transform(ct_volume) - pet_raw = raw_transform(pet_volume) - ``` - """ - def __call__( - self, - raw: np.ndarray, - modality_index: str - ) -> np.ndarray: - assert self.intensity_properties is not None, \ - "Intensity properties are required here. Please make sure that you pass the `nnUNetplans.json correctly." - - raw = raw.astype(self.expected_dtype) - - # intensity properties for the respective modality - props = self.intensity_properties[modality_index] - - mean = props['mean'] - std = props['std'] - lower_bound = props['percentile_00_5'] - upper_bound = props['percentile_99_5'] - raw = np.clip(raw, lower_bound, upper_bound) - raw = (raw - mean) / max(std, self.tolerance) - return raw From dab5e3d6b215357ec05e194a9157bb7ba221ca2c Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 28 Nov 2023 15:14:42 +0100 Subject: [PATCH 3/5] Restore raw.py --- torch_em/transform/raw.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index 7fcd96db..cd43c379 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -1,5 +1,4 @@ import numpy as np - import torch from torchvision import transforms From c503cb8e919b6929ba944aaf78e14a04ceae4f00 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 28 Nov 2023 18:38:06 +0100 Subject: [PATCH 4/5] Update nnunet raw trafo - single class --- torch_em/transform/nnunet_raw.py | 102 +++++++++++++++---------------- 1 file changed, 50 insertions(+), 52 deletions(-) diff --git a/torch_em/transform/nnunet_raw.py b/torch_em/transform/nnunet_raw.py index 721f031a..ecdae1a4 100644 --- a/torch_em/transform/nnunet_raw.py +++ b/torch_em/transform/nnunet_raw.py @@ -2,29 +2,57 @@ import numpy as np -class nnUNetRawTransformBase: - """nnUNetRawTransformBase is an interface to implement specific raw transforms for nnUNet. +class nnUNetRawTransform: + """Apply transformation on the raw inputs. + Adapted from nnUNetv2's `ImageNormalization`: + - https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization - Adapted from: https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/preprocessing/normalization + You can use this class to apply the necessary raw transformations on input modalities. + + (Current Support - CT and PET): The inputs should be of dimension 2 * (H * W * D). + - The first channel should be CT volume + - The second channel should be PET volume + + Here's an example for how to use this class: + ```python + # Initialize the raw transform. + raw_transform = nnUNetRawTransform(plans_file=".../nnUNetPlans.json") + + # Apply transformation on the inputs. + patient_vol = np.concatenate(ct_vol, pet_vol) + patient_transformed = raw_transform(patient_vol) + ``` """ def __init__( self, plans_file: str, expected_dtype: type = np.float32, - tolerance: float = 1e-8 + tolerance: float = 1e-8, + model_name: str = "3d_fullres" ): self.expected_dtype = expected_dtype self.tolerance = tolerance - self.intensity_properties = self.load_json(plans_file) - self.intensity_properties = self.intensity_properties["foreground_intensity_properties_per_channel"] + json_file = self.load_json(plans_file) + self.intensity_properties = json_file["foreground_intensity_properties_per_channel"] + self.per_channel_scheme = json_file["configurations"][model_name]["normalization_schemes"] def load_json(self, _file: str): - # credits: `batchgenerators.utilities.file_and_folder_operations` + # source: `batchgenerators.utilities.file_and_folder_operations` with open(_file, 'r') as f: a = json.load(f) return a + def ct_transform(self, channel, properties): + mean = properties['mean'] + std = properties['std'] + lower_bound = properties['percentile_00_5'] + upper_bound = properties['percentile_99_5'] + + transformed_channel = np.clip(channel, lower_bound, upper_bound) + transformed_channel = (transformed_channel - mean) / max(std, self.tolerance) + return transformed_channel + def __call__( self, raw: np.ndarray @@ -37,52 +65,22 @@ def __call__( Returns: The transformed raw inputs (the same shape as inputs) """ - raise NotImplementedError("It's a class template for raw transforms from nnUNet. \ - Use a child class that implements the expected raw transform instead") - - -class nnUNetCTRawTransform(nnUNetRawTransformBase): - """Apply transformation on the raw inputs for CT + PET channels (adapted from nnUNetv2's `CTNormalization`) - - You can use this class to apply the necessary raw transformations on CT and PET volume channels. - Expectation: The inputs should be of dimension 2 * (H * W * D). - - The first channel should be CT volume - - The second channel should be PET volume - - Here's an example for how to use this class: - ```python - # Initialize the raw transform. - raw_transform = nnUNetCTRawTransform(plans_file="...nnUNetplans.json") - - # Apply transformation on the inputs. - patient_vol = np.concatenate(ct_vol, pet_vol) - patient_transformed = raw_transform(patient_vol) - ``` - """ - def __call__( - self, - raw: np.ndarray - ) -> np.ndarray: - assert raw.shape[0] == 2, "The current expectation is channels (modality) first. The fn currently supports for two modalities, namely CT and PET-CT (in the mentioned order)" - - assert self.intensity_properties is not None, \ - "Intensity properties are required here. Please make sure that you pass the `nnUNetplans.json correctly." - - raw = raw.astype(self.expected_dtype) + assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" - transformed_raw = [] - # intensity properties for the respective modalities - for idx in range(raw.shape[0]): - props = self.intensity_properties[str(idx)] + normalized_channels = [] + for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): + properties = self.intensity_properties[str(idxx)] - mean = props['mean'] - std = props['std'] - lower_bound = props['percentile_00_5'] - upper_bound = props['percentile_99_5'] + # get the correct transformation function, this can for example be a method of this class + if channel_transform == "CTNormalization": + channel = self.ct_transform(channel, properties) + elif channel_transform in [ + "ZScoreNormalization", "NoNormalization", "RescaleTo01Normalization", "RGBTo01Normalization" + ]: + raise NotImplementedError(f"{channel_transform} is not supported by nnUNetRawTransform yet.") + else: + raise ValueError(f"Transform is not known: {channel_transform}.") - modality = np.clip(raw[idx, ...], lower_bound, upper_bound) - modality = (modality - mean) / max(std, self.tolerance) - transformed_raw.append(modality) + normalized_channels.append(channel) - transformed_raw = np.stack(transformed_raw) - return transformed_raw + return np.stack(normalized_channels) From 83a908605ac4b83f17197dad5c0461b3e824525f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 28 Nov 2023 18:47:01 +0100 Subject: [PATCH 5/5] Fix casting inputs --- torch_em/transform/nnunet_raw.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_em/transform/nnunet_raw.py b/torch_em/transform/nnunet_raw.py index ecdae1a4..eaf60b06 100644 --- a/torch_em/transform/nnunet_raw.py +++ b/torch_em/transform/nnunet_raw.py @@ -67,6 +67,8 @@ def __call__( """ assert raw.shape[0] == len(self.per_channel_scheme), "Number of channels & transforms from data plan must match" + raw = raw.astype(self.expected_dtype) + normalized_channels = [] for idxx, (channel_transform, channel) in enumerate(zip(self.per_channel_scheme, raw)): properties = self.intensity_properties[str(idxx)]