In [7]:
import numpy as np
import scipy.ndimage as ndimage
from scipy import optimize
import matplotlib.pyplot as plt
import cv2

In [8]:
def mse(arr1, arr2):
    """
    Compute the mean squared error between two arrays.

    Args:
    arr1: First input array.
    arr2: Second input array.

    Returns:
    Mean squared error.
    """
    # Ensure the arrays are of the same shape
    x = min(arr1.shape[0], arr2.shape[0])
    y = min(arr1.shape[1], arr2.shape[1])

    arr1_cropped = arr1[:x, :y]
    arr2_cropped = arr2[:x, :y]

    # Compute the mean squared error
    mse_value = np.mean((arr1_cropped - arr2_cropped) ** 2)

    return mse_value


In [9]:
def apply_rotation(angle, img):
    # Rotate the image around its center
    rotated_img = ndimage.rotate(img, angle, reshape=False, order=1, mode='constant', cval=0.0)
    return rotated_img


In [10]:
def cost_mse(param, reference_image, target_image):
    transformed = apply_rotation(param, target_image)
    cost = mse(reference_image, transformed)  # We want to minimize MSE
    print(f"Param: {param}, Cost: {cost}")
    return cost


In [11]:
# Load images using OpenCV
static = cv2.imread('brain.png', cv2.IMREAD_GRAYSCALE)
moving = cv2.imread('distorted.png', cv2.IMREAD_GRAYSCALE)

# Ensure the images are loaded properly
if static is None or moving is None:
    raise ValueError("One or both images could not be loaded.")

In [12]:
initial_angle = np.array([0])  # Initial guess for the rotation angle in degrees

def trying_params(params):
    """ Callback function """
    print(f"Trying params: {params}")

best_angle = optimize.fmin_powell(cost_mse, initial_angle, args=(static, moving), callback=trying_params)
print(f"Best angle: {best_angle}")


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)

In [None]:
# Apply the best rotation angle to the moving image
transformed_image = apply_rotation(best_angle, moving)

# Function to plot three images side by side
def plot_three_images(img1, img2, img3, title1='Static Image', title2='Original Moving Image', title3='Transformed Image'):
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(img1, cmap='gray')
    plt.title(title1)
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(img2, cmap='gray')
    plt.title(title2)
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(img3, cmap='gray')
    plt.title(title3)
    plt.axis('off')
    plt.show()

# Plot the static, original moving, and transformed images
plot_three_images(static, moving, transformed_image, title1='Static Image', title2='Original Moving Image', title3='Transformed Image')