In [None]:
import numpy as np
import matplotlib.pyplot as plt
from dipy.align.imaffine import AffineRegistration
from dipy.align.transforms import TranslationTransform2D, RigidTransform2D, AffineTransform2D
from dipy.align.metrics import SSDMetric
import imageio

In [None]:
# Load the static (fixed) image
static_image = imageio.imread('brain.png', mode='L').astype(np.float32)

# Load the moving image
moving_image = imageio.imread('distorted.png', mode='L').astype(np.float32)

# Ensure both images have the same shape
if static_image.shape != moving_image.shape:
    raise ValueError("The static and moving images must have the same dimensions.")

In [None]:
def mse(arr1, arr2):
    """
    Compute the mean squared error between two arrays.
    
    Args:
    arr1: First input array.
    arr2: Second input array.
    
    Returns:
    Mean squared error.
    """
    # Ensure the arrays are of the same shape
    x = min(arr1.shape[0], arr2.shape[0])
    y = min(arr1.shape[1], arr2.shape[1])
    
    arr1_cropped = arr1[:x, :y]
    arr2_cropped = arr2[:x, :y]
    
    # Compute the mean squared error
    mse_value = np.mean((arr1_cropped - arr2_cropped) ** 2)
    
    return mse_value

In [None]:
def plot_three_images(img1, img2, img3, title1='Static Image', title2='Original Moving Image', title3='Transformed Image'):
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(img1, cmap='gray')
    plt.title(title1)
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(img2, cmap='gray')
    plt.title(title2)
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(img3, cmap='gray')
    plt.title(title3)
    plt.axis('off')
    plt.show()

In [None]:
# Define the registration parameters
level_iters = [1000, 100, 10]  # Number of iterations at each resolution level
sigmas = [3.0, 1.0, 0.0]  # Gaussian smoothing at each level
factors = [4, 2, 1]  # Pyramid decimation factors

# Initialize the affine registration object with the SSDMetric
metric = SSDMetric(2)  # 2 for 2D images
affreg = AffineRegistration(metric=metric, level_iters=level_iters, sigmas=sigmas, factors=factors)


In [None]:
# Perform rigid registration first
rigid = RigidTransform2D()
rigid_map = affreg.optimize(static_image, moving_image, rigid, params0=None)

# Then perform affine registration
affine = AffineTransform2D()
affine_map = affreg.optimize(static_image, moving_image, affine, params0=None, starting_affine=rigid_map.affine)

# Get the transformed image
transformed_image = affine_map.transform(moving_image)

In [None]:
# Print affine map matrix for debugging
print("Affine Transformation Matrix:")
print(affine_map.affine)

# Compute the MSE between the original moving and static images
initial_mse = mse(static_image, moving_image)
print(f"Initial Mean Squared Error: {initial_mse}")

# Compute the MSE between the static image and the transformed image
final_mse = mse(static_image, transformed_image)
print(f"Mean Squared Error after registration: {final_mse}")

In [None]:
# Plot the static, original moving, and transformed images
plot_three_images(static_image, moving_image, transformed_image, 
                  title1='Static Image', title2='Original Moving Image', title3='Transformed Image')