In [7]:
import SimpleITK as sitk
from pathlib import Path
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

### Define all necessary paths

In [2]:
base_path = Path().resolve()
base_path = base_path.parent
data_path = base_path / 'data'

### Generate brain masks

In [None]:
# For every partition
for partition in ['train', 'validation', 'test']:
    # Define the right path
    partition_path = data_path / f'{partition}_set'
    for case_path in tqdm(partition_path.iterdir()):
        # Load the case
        vol_path = case_path / f'{case_path.name}.nii.gz'
        vol = sitk.ReadImage(str(vol_path), sitk.sitkUInt8)

        # Take advantage of the zero background and get the brain mask
        brain_mask = sitk.GetArrayFromImage(vol)
        brain_mask = np.where(brain_mask != 0, 1, 0).astype('uint8')

        # Store the image correctly
        brain_mask = sitk.GetImageFromArray(brain_mask)
        brain_mask.SetDirection(vol.GetDirection())
        brain_mask.SetOrigin(vol.GetOrigin())
        brain_mask.SetSpacing(vol.GetSpacing())
        brain_mask_path = case_path / f'{case_path.name}_brain_mask.nii.gz'
        sitk.WriteImage(brain_mask, str(brain_mask_path))

### Perform N4 bias field correction

In [None]:
# Define some  configuration parameters for N4 method
shrink_factor = 1
number_fitting_levels = 4
maximum_number_of_iterations = None

# For all the images
for partition in ['train', 'validation', 'test']:
    partition_path = data_path / f'{partition}_set'
    for case_path in tqdm(partition_path.iterdir(), total=len(list(partition_path.iterdir()))):

        # Define all necessary images paths
        vol_path = case_path / f'{case_path.name}.nii.gz'
        mask_path = case_path / f'{case_path.name}_brain_mask.nii.gz'
        n4_path = case_path / f'{case_path.name}_n4.nii.gz'
        n4_path_ = case_path / f'{case_path.name}_n4_fr.nii.gz'
        bias_field_path = case_path / f'{case_path.name}_bias_field.nii.gz'

        # Read the image
        input_image = sitk.ReadImage(str(vol_path), sitk.sitkFloat32)
        image = input_image

        # Read the brain mask
        if mask_path.exists():
            mask_image = sitk.ReadImage(str(mask_path), sitk.sitkUInt8)

        # Perform shrinking if indicated
        if shrink_factor > 1:
            image = sitk.Shrink(input_image, [shrink_factor] * input_image.GetDimension())
            mask_image = sitk.Shrink(mask_image, [shrink_factor] * input_image.GetDimension())

        # Define the N4 bias field corrector
        corrector = sitk.N4BiasFieldCorrectionImageFilter()

        # Limit the maximum number of iterations if needed
        if maximum_number_of_iterations is not None:
            corrector.SetMaximumNumberOfIterations(
                [int(maximum_number_of_iterations)] * number_fitting_levels)

        # Run N4, get the corrected image and the bias field
        corrected_image = corrector.Execute(image, mask_image)
        log_bias_field = corrector.GetLogBiasFieldAsImage(input_image)
        corrected_image_full_resolution = input_image / sitk.Exp(log_bias_field)
        
        # Store the results
        sitk.WriteImage(corrected_image_full_resolution, str(n4_path))
        sitk.WriteImage(log_bias_field, str(bias_field_path))
        if shrink_factor > 1:
            sitk.WriteImage(corrected_image, str(n4_path_))