In [1]:
import numpy as np


def create_image(size=(200, 200), pattern="random_shapes"):
    """Create a synthetic RGB image with highly asymmetric patterns."""
    width, height = size
    image = np.zeros((height, width, 3))

    if pattern == "random_shapes":
        num_shapes = np.random.randint(5, 15)
        for _ in range(num_shapes):
            shape_type = np.random.choice(["circle", "rectangle", "triangle"])
            color = np.random.rand(3)
            if shape_type == "circle":
                radius = np.random.randint(5, min(width, height) // 5)
                center_x = np.random.randint(radius, width - radius)
                center_y = np.random.randint(radius, height - radius)
                rr, cc = np.ogrid[:height, :width]
                mask = (rr - center_y) ** 2 + (cc - center_x) ** 2 <= radius**2
                image[mask] = color
            elif shape_type == "rectangle":
                rect_width = np.random.randint(5, width // 5)
                rect_height = np.random.randint(5, height // 5)
                start_x = np.random.randint(0, width - rect_width)
                start_y = np.random.randint(0, height - rect_height)
                image[start_y : start_y + rect_height, start_x : start_x + rect_width] = color
            elif shape_type == "triangle":
                vertices = np.random.randint(0, min(width, height), (3, 2))
                rr, cc = np.ogrid[:height, :width]
                mask = np.zeros((height, width), dtype=bool)
                for i in range(height):
                    for j in range(width):
                        v0 = vertices[0]
                        v1 = vertices[1]
                        v2 = vertices[2]
                        a = 0.5 * (-v1[1] * v2[0] + v0[1] * (-v1[0] + v2[0]) + v0[0] * (v1[1] - v2[1]) + v1[0] * v2[1])
                        sign = -1 if a < 0 else 1
                        s = (v0[1] * v2[0] - v0[0] * v2[1] + (v2[1] - v0[1]) * i + (v0[0] - v2[0]) * j) * sign
                        t = (v0[0] * v1[1] - v0[1] * v1[0] + (v0[1] - v1[1]) * i + (v1[0] - v0[0]) * j) * sign
                        if s >= 0 and t >= 0 and (s + t) <= 2 * a * sign:
                            mask[i, j] = True
                image[mask] = color
    else:
        raise ValueError("Unsupported pattern type. Choose 'random_shapes'.")

    # Add padding to ensure the image is fully visible when rotated
    pad_width = width // 2
    pad_height = height // 2
    padded_image = np.pad(
        image,
        pad_width=((pad_height, pad_height), (pad_width, pad_width), (0, 0)),
        mode="constant",
        constant_values=0,
    )
    return padded_image

In [2]:
# VERY IMPROTANT IS TO READ https://scikit-image.org/docs/stable/auto_examples/registration/plot_register_rotation.html#sphx-glr-auto-examples-registration-plot-register-rotation-py


import numpy as np
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
from skimage import transform
from skimage.color import rgb2gray


def process_image(angle, size: tuple[int, int], radius=None, num_angles: int = 720, grayscale: bool = False):
    """Process the image by rotating and aligning it."""
    w, h = size
    channel_axis = -1

    img_orig = create_image(size, pattern="random_shapes")
    img_rot = transform.rotate(img_orig, angle)

    if grayscale:
        img_orig = rgb2gray(img_orig)
        img_rot = rgb2gray(img_rot)
        channel_axis = None

    if radius is None:
        radius = np.sqrt(w**2 + h**2) / 2

    polar_shape = (num_angles, int(np.ceil(radius)))
    orig_polar = transform.warp_polar(img_orig, radius=radius, output_shape=polar_shape, channel_axis=channel_axis)
    rot_polar = transform.warp_polar(img_rot, radius=radius, output_shape=polar_shape, channel_axis=channel_axis)

    orig_fft = np.fft.fft(orig_polar, axis=0)
    rot_fft = np.fft.fft(rot_polar, axis=0)

    pred_shifts, _, _ = phase_cross_correlation(orig_fft, rot_fft, space="fourier")
    pred_deg = pred_shifts[0] / rot_polar.shape[0] * 360
    pred_deg = pred_deg % 360

    print(f"Expected value for counterclockwise rotation in degrees: {angle}")
    print(f"Recovered value for counterclockwise rotation: {pred_deg}")

    img_rot_fixed = transform.rotate(img_rot, -pred_deg)
    rot_polar_fixed = transform.warp_polar(
        img_rot_fixed,
        radius=radius,
        output_shape=polar_shape,
        channel_axis=channel_axis,
    )

    fixed_shifts, _, _ = phase_cross_correlation(orig_polar, rot_polar_fixed, space="real")
    fixed_pred_deg = fixed_shifts[0] / rot_polar.shape[0] * 360
    fixed_pred_deg = fixed_pred_deg % 360

    img_error = np.abs(img_rot_fixed - img_orig)
    img_error = np.clip(img_error * 10, 0, 1)

    print("Rotation Error after Fix: ", fixed_pred_deg)

    return img_orig, img_rot, img_rot_fixed, img_error

In [3]:
def plot_images(img_orig, img_rot, img_rot_fixed, img_pred_error):
    """Plot the original, rotated, fixed, and error images."""
    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
    axs[0].imshow(img_orig)
    axs[0].set_title("Original Image")
    axs[1].imshow(img_rot)
    axs[1].set_title("Rotated Image")
    axs[2].imshow(img_rot_fixed)
    axs[2].set_title("Fixed Image")
    axs[3].imshow(img_pred_error)
    axs[3].set_title("Error Image")
    plt.show()

In [None]:
N = 10

for _ in range(N):
    # randomize image stats
    width = np.random.randint(100, 300)
    height = np.random.randint(100, 300)
    rot_angle = np.random.uniform(0, 360)

    img_orig, img_rot, img_rot_fixed, img_error = process_image(rot_angle, size=(width, height), grayscale=False)
    plot_images(img_orig, img_rot, img_rot_fixed, img_error)