In [5]:
import SimpleITK as sitk
import numpy as np
import nibabel as nib
from scipy.ndimage import affine_transform
from scipy import ndimage
from pathlib import Path

# Provide the file paths for the fixed (anchor) and moving (distorted) images
fixed_image_path = "../data/segthor_train/train/Patient_27/GT2.nii.gz"
moving_image_path = "../data/segthor_train/train/Patient_27/GT.nii.gz"

replace_index = 2

In [124]:
def register_images(
    fixed_img_file, moving_img_file, output_transform_file, replace_index=None
):
    """
    Register two images and save the transform.
    """
    # Read the images
    fixed_img = sitk.ReadImage(fixed_img_file, sitk.sitkFloat32)
    moving_img = sitk.ReadImage(moving_img_file, sitk.sitkFloat32)

    if replace_index:
        fixed_img = sitk.GetArrayFromImage(fixed_img)
        fixed_img = (fixed_img == replace_index).astype(np.float32)
        fixed_img = sitk.GetImageFromArray(fixed_img)

        moving_img = sitk.GetArrayFromImage(moving_img)
        moving_img = (moving_img == replace_index).astype(np.float32)
        moving_img = sitk.GetImageFromArray(moving_img)

    # Setup the registration
    R = sitk.ImageRegistrationMethod()
    R.SetMetricAsMattesMutualInformation()
    # R.SetMetricSamplingStrategy(R.RANDOM)
    # R.SetMetricSamplingPercentage(0.02)
    R.SetInterpolator(sitk.sitkNearestNeighbor)
    R.SetOptimizerScalesFromPhysicalShift()
    R.SetOptimizerAsGradientDescent(
        learningRate=2,
        numberOfIterations=200,
        convergenceMinimumValue=1e-8,
        convergenceWindowSize=1,
    )
    R.SetOptimizerScalesFromPhysicalShift()
    R.SetInitialTransform(
        sitk.CenteredTransformInitializer(
            fixed_img, moving_img, sitk.AffineTransform(3), sitk.CenteredTransformInitializerFilter.GEOMETRY
        )
    )
    # R.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    # R.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    # R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Run the registration
    transform = R.Execute(fixed_img, moving_img)

    # Save the transform
    sitk.WriteTransform(transform, output_transform_file)
    return transform

def apply_affine_transform(image: np.ndarray, transform):
    """
    Apply an affine transform to an image.
    """
    return sitk.GetArrayFromImage(
        sitk.Resample(
            sitk.GetImageFromArray(image),
            transform,
            sitk.sitkNearestNeighbor,
            0.0,
        )
    )

In [119]:
# Step 2: Apply affine transformation using SciPy
def apply_affine_transform(moving_image, transform_map = None):
    moving_image_data = np.array(moving_image)

    # Hardcoded transform parameters from the raw transform file
    transform_parameters = (
        1,
        0,
        0,
        0,
        1,
        0,
        0,
        0,
        1,
        -66.58184178354128,
        -38.56866728429671,
        -14.999585668903961,
    )
    affine_matrix = np.array(
        [
            [
                transform_parameters[0],
                transform_parameters[1],
                transform_parameters[2],
                transform_parameters[9],
            ],
            [
                transform_parameters[3],
                transform_parameters[4],
                transform_parameters[5],
                transform_parameters[10],
            ],
            [
                transform_parameters[6],
                transform_parameters[7],
                transform_parameters[8],
                transform_parameters[11],
            ],
            [0, 0, 0, 1],
        ]
    )

    rotation_scale = affine_matrix[:3, :3]
    translation = affine_matrix[:3, 3]
    transformed_image_data = affine_transform(moving_image_data, rotation_scale, offset=translation, order=0)

    object_center = ndimage.center_of_mass(transformed_image_data)
    volume_center = np.array(transformed_image_data.shape) // 2
    # shift = volume_center - np.array(object_center)
    # Hardcoded shift
    shift = np.array([-30.68731694,  21.74786194,  20.2083591 ])
    shifted_volume = ndimage.shift(transformed_image_data, shift, order=0)
    # Rotate 27 degrees in the xy-plane
    rotated_volume = ndimage.rotate(shifted_volume, 27, axes=(0,1), reshape=False, order=0)
    # Shift the rotated volume back to the original position
    transformed_image_data = ndimage.shift(rotated_volume, -shift, order=0)

    return transformed_image_data


# Step 3: Save the result using Nibabel
def save_image_with_nibabel(
    transformed_image_data, reference_image_path, output_image_path, replace_index=None
):
    reference_image = nib.load(reference_image_path)
    reference_affine = reference_image.affine
    reference_header = reference_image.header

    if replace_index:
        reference_image_data = reference_image.get_fdata()
        reference_image_data[reference_image_data == replace_index] = 0
        reference_image_data[transformed_image_data == 1] = replace_index
        transformed_image_data = reference_image_data

    transformed_image_nifti = nib.Nifti1Image(
        transformed_image_data.astype(np.uint8), reference_affine, reference_header
    )
    nib.save(transformed_image_nifti, output_image_path)

In [125]:
# Load the fixed and moving images
fixed_image = nib.load(fixed_image_path).get_fdata()
fixed_image = (fixed_image == replace_index).astype(np.float32)

moving_image = nib.load(moving_image_path).get_fdata()
moving_image = (moving_image == replace_index).astype(np.float32)

output_image_path = "transformed_image.nii.gz"

# Perform registration and get the transform map
# print("Registering images...")
# transform_params = register_images(fixed_image_path, moving_image_path, "transform.tfm")
# transform_params = sitk.ReadTransform("transform.tfm")

# Apply the affine transformation
print("Applying affine transformation...")
transform_params = None
transformed_image_data = apply_affine_transform(moving_image, transform_params)
# print(transform_params)

# Save the final transformed image
print("Saving the transformed image...")
save_image_with_nibabel(
    transformed_image_data.round(), moving_image_path, output_image_path, replace_index
)

print(
    f"Final overlap ratio: {(transformed_image_data * fixed_image).sum() / fixed_image.sum():.4f}"
)

Registering images...


In [None]:
# Get all the images in the dataset


for img in Path.cwd().parent.rglob("data/**/GT.nii.gz"):
    moving_image = nib.load(img).get_fdata()
    if replace_index:
        moving_image = (moving_image == replace_index).astype(np.float32)

    print("Applying affine transformation...")
    transformed_image_data = apply_affine_transform(moving_image, transform_params)
    print("Saving the transformed image...")
    # output_image_path = img.with_name("transformed.nii.gz")
    output_image_path = img
    save_image_with_nibabel(
        transformed_image_data.round(),
        img,
        output_image_path,
        replace_index=replace_index,
    )
    print(f"Saved to {output_image_path}")
