# Aligning

In [None]:
"""Image orientation alignment based on keypoint distribution."""

import io
from typing import List, Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA


class OrientationAligner:
    """
    Class for aligning image orientation based on keypoint distribution.

    This class detects keypoints and rotates the image so that the principal
    axis of keypoint distribution aligns with the X-axis (horizontal).
    """

    def __init__(self, detector_type: str = "SIFT", max_features: int = 5000):
        """
        Initialize the orientation aligner.

        Args:
            detector_type (str): Feature detector type ("SIFT", "ORB", "AKAZE").
            max_features (int): Maximum number of features to detect.
        """
        self.detector_type = detector_type
        self.max_features = max_features
        self.detector = self._create_detector()

    def _create_detector(self):
        """Create feature detector based on specified type."""
        if self.detector_type == "SIFT":
            return cv2.SIFT_create(nfeatures=self.max_features)
        elif self.detector_type == "ORB":
            return cv2.ORB_create(nfeatures=self.max_features)
        elif self.detector_type == "AKAZE":
            return cv2.AKAZE_create()
        else:
            raise ValueError(f"Unsupported detector type: {self.detector_type}")

    def detect_keypoints(self, image: np.ndarray) -> Tuple[List, np.ndarray]:
        """
        Detect keypoints in an image.

        Args:
            image (np.ndarray): Input image (grayscale or color).

        Returns:
            Tuple[List, np.ndarray]: Keypoints and their coordinates as numpy array.
        """
        # Convert to grayscale if needed
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        # Detect keypoints
        keypoints = self.detector.detect(gray, None)
        # Extract coordinates
        if len(keypoints) == 0:
            return keypoints, np.array([])
        coords = np.array([[kp.pt[0], kp.pt[1]] for kp in keypoints])
        return keypoints, coords

    def compute_principal_axis_pca(self, points: np.ndarray) -> Tuple[float, np.ndarray, np.ndarray]:
        """
        Compute principal axis of point distribution using PCA.

        Args:
            points (np.ndarray): Array of 2D points with shape (N, 2).

        Returns:
            Tuple[float, np.ndarray, np.ndarray]: Rotation angle in degrees,
                                                  principal components, centroid.
        """
        if len(points) < 2:
            return (
                0.0,
                np.eye(2),
                np.mean(points, axis=0) if len(points) > 0 else np.array([0, 0]),
            )
        # Center the points
        centroid = np.mean(points, axis=0)
        centered_points = points - centroid
        # Apply PCA
        pca = PCA(n_components=2)
        pca.fit(centered_points)
        # Get principal component (first eigenvector)
        principal_component = pca.components_[0]
        # Calculate rotation angle to align with X-axis
        angle_rad = np.arctan2(principal_component[1], principal_component[0])
        angle_deg = np.degrees(angle_rad)
        return angle_deg, pca.components_, centroid

    def compute_principal_axis_moments(self, points: np.ndarray) -> Tuple[float, np.ndarray]:
        """
        Compute principal axis using image moments (alternative method).

        Args:
            points (np.ndarray): Array of 2D points with shape (N, 2).

        Returns:
            Tuple[float, np.ndarray]: Rotation angle in degrees, centroid.
        """
        if len(points) < 2:
            return 0.0, np.mean(points, axis=0) if len(points) > 0 else np.array([0, 0])
        # Calculate centroid
        centroid = np.mean(points, axis=0)
        # Center points
        x = points[:, 0] - centroid[0]
        y = points[:, 1] - centroid[1]
        # Calculate second moments
        mu20 = np.sum(x**2) / len(points)
        mu02 = np.sum(y**2) / len(points)
        mu11 = np.sum(x * y) / len(points)
        # Calculate orientation angle
        if mu20 == mu02:
            angle_rad = 0.0
        else:
            angle_rad = 0.5 * np.arctan2(2 * mu11, mu20 - mu02)
        angle_deg = np.degrees(angle_rad)
        return angle_deg, centroid

    def compute_bounding_box_orientation(self, points: np.ndarray) -> Tuple[float, np.ndarray]:
        """
        Compute orientation based on minimum area bounding box.

        Args:
            points (np.ndarray): Array of 2D points with shape (N, 2).

        Returns:
            Tuple[float, np.ndarray]: Rotation angle in degrees, center point.
        """
        if len(points) < 3:
            return 0.0, np.mean(points, axis=0) if len(points) > 0 else np.array([0, 0])
        # Convert to format expected by cv2.minAreaRect
        points_int = points.astype(np.int32)
        # Find minimum area rectangle
        rect = cv2.minAreaRect(points_int)
        # Extract angle and center
        center, (width, height), angle = rect
        # Adjust angle to align longer side with X-axis
        if width < height:
            angle += 90
        return angle, np.array(center)

    def rotate_image(
        self, image: np.ndarray, angle: float, center: Optional[np.ndarray] = None
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Rotate image by specified angle around center point.

        Args:
            image (np.ndarray): Input image.
            angle (float): Rotation angle in degrees (positive = counterclockwise).
            center (Optional[np.ndarray]): Rotation center. If None, uses image center.

        Returns:
            Tuple[np.ndarray, np.ndarray]: Rotated image and transformation matrix.
        """
        if center is None:
            center = np.array([image.shape[1] / 2, image.shape[0] / 2])
        # Get rotation matrix
        rotation_matrix = cv2.getRotationMatrix2D(tuple(center), angle, 1.0)
        # Calculate new image dimensions to fit rotated image
        cos_angle = abs(rotation_matrix[0, 0])
        sin_angle = abs(rotation_matrix[0, 1])
        new_width = int((image.shape[0] * sin_angle) + (image.shape[1] * cos_angle))
        new_height = int((image.shape[0] * cos_angle) + (image.shape[1] * sin_angle))
        # Adjust translation to center the rotated image
        rotation_matrix[0, 2] += (new_width / 2) - center[0]
        rotation_matrix[1, 2] += (new_height / 2) - center[1]
        # Apply rotation
        rotated_image = cv2.warpAffine(
            image,
            rotation_matrix,
            (new_width, new_height),
            flags=cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_CONSTANT,
            borderValue=0,
        )
        return rotated_image, rotation_matrix

    def align_image_orientation(
        self, image: np.ndarray, method: str = "pca", visualize: bool = False
    ) -> Tuple[np.ndarray, float, dict]:
        """
        Align image orientation based on keypoint distribution.

        Args:
            image (np.ndarray): Input image.
            method (str): Method for computing orientation ("pca", "moments", "bbox").
            visualize (bool): Whether to create visualization.

        Returns:
            Tuple[np.ndarray, float, dict]: Aligned image, rotation angle, and info dict.
        """
        # Detect keypoints
        keypoints, coords = self.detect_keypoints(image)
        if len(coords) == 0:
            print("Warning: No keypoints detected")
            return image.copy(), 0.0, {"keypoints": [], "coords": coords}
        print(f"Detected {len(keypoints)} keypoints")
        # Compute principal axis based on method
        if method == "pca":
            angle, components, center = self.compute_principal_axis_pca(coords)
            info_extra = {"components": components, "center": center}
        elif method == "moments":
            angle, center = self.compute_principal_axis_moments(coords)
            info_extra = {"center": center}
        elif method == "bbox":
            angle, center = self.compute_bounding_box_orientation(coords)
            info_extra = {"center": center}
        else:
            raise ValueError(f"Unsupported method: {method}")
        print(f"Computed rotation angle: {angle:.2f} degrees")
        # Rotate image to align principal axis with X-axis
        aligned_image, rotation_matrix = self.rotate_image(image, angle, None)  # center)
        # Prepare info dictionary
        info = {
            "keypoints": keypoints,
            "coords": coords,
            "rotation_angle": angle,
            "rotation_matrix": rotation_matrix,
            "method": method,
            **info_extra,
        }
        if visualize:
            info["visualization"] = self._create_visualization(image, aligned_image, coords, angle, center, method)
        return aligned_image, angle, info

    def _create_visualization(
        self,
        original: np.ndarray,
        aligned: np.ndarray,
        coords: np.ndarray,
        angle: float,
        center: np.ndarray,
        method: str,
    ) -> np.ndarray:
        """Create visualization of the alignment process."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        # Original image with keypoints
        axes[0, 0].imshow(cv2.cvtColor(original, cv2.COLOR_BGR2RGB))
        if len(coords) > 0:
            axes[0, 0].scatter(coords[:, 0], coords[:, 1], c="red", s=2, alpha=0.7)
            axes[0, 0].plot(center[0], center[1], "bo", markersize=8, label="Center")
        axes[0, 0].set_title("Original Image with Keypoints")
        axes[0, 0].legend()
        axes[0, 0].axis("off")
        # Keypoint distribution
        if len(coords) > 0:
            axes[0, 1].scatter(coords[:, 0], coords[:, 1], c="red", s=10, alpha=0.7)
            axes[0, 1].plot(center[0], center[1], "bo", markersize=8)
            # Draw principal axis
            if method == "pca":
                # Draw both principal components
                scale = 100
                pc1_end = center + scale * np.array([np.cos(np.radians(angle)), np.sin(np.radians(angle))])
                pc2_end = center + scale * np.array([np.cos(np.radians(angle + 90)), np.sin(np.radians(angle + 90))])
                axes[0, 1].arrow(
                    center[0],
                    center[1],
                    pc1_end[0] - center[0],
                    pc1_end[1] - center[1],
                    head_width=10,
                    head_length=15,
                    fc="blue",
                    ec="blue",
                    label="1st Principal Component",
                )
                axes[0, 1].arrow(
                    center[0],
                    center[1],
                    pc2_end[0] - center[0],
                    pc2_end[1] - center[1],
                    head_width=10,
                    head_length=15,
                    fc="green",
                    ec="green",
                    label="2nd Principal Component",
                )
        axes[0, 1].set_title(f"Keypoint Distribution\nRotation: {angle:.1f}°")
        axes[0, 1].legend()
        axes[0, 1].set_aspect("equal")
        axes[0, 1].grid(True, alpha=0.3)
        # Aligned image
        axes[1, 0].imshow(cv2.cvtColor(aligned, cv2.COLOR_BGR2RGB))
        axes[1, 0].set_title("Aligned Image")
        axes[1, 0].axis("off")
        # Histogram of keypoint coordinates
        if len(coords) > 0:
            axes[1, 1].hist(coords[:, 0], bins=30, alpha=0.7, label="X coordinates", color="red")
            axes[1, 1].hist(coords[:, 1], bins=30, alpha=0.7, label="Y coordinates", color="blue")
            axes[1, 1].set_title("Keypoint Coordinate Distribution")
            axes[1, 1].legend()
            axes[1, 1].grid(True, alpha=0.3)
        plt.tight_layout()
        # --- Robust conversion of Matplotlib figure to numpy image via PNG buffer ---
        buf = io.BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0.1)
        buf.seek(0)
        pil_img = Image.open(buf).convert("RGB")
        vis_image = np.array(pil_img)
        buf.close()
        plt.close(fig)
        return vis_image

In [None]:
aligner = OrientationAligner("SIFT", 100)

image = Image.open("/home/akobylin/datasets/lct_2025/7_shernitsa/DSCN0314.JPG").convert("RGB")
w, h = image.size
x0 = w // 6
y0 = h // 6
x1 = x0 + w * 4 // 6
y1 = y0 + h * 4 // 6
image = image.crop((x0, y0, x1, y1))
aligned_image, angle, info = aligner.align_image_orientation(np.array(image), method="pca", visualize=True)
Image.fromarray(info["visualization"])