In [None]:
import numpy as np

from aspire.homog_utils import vec_unit

from scipy.spatial import KDTree  

def get_cube( sideLen, homog ):
    """ Return a 3x36 array of cube vertices, transformed by `homog` """
    hl = sideLen/2.0
    #                Front
    vrts = np.array([[-hl,-hl, hl, 1.0],
                     [+hl,-hl, hl, 1.0],
                     [+hl,+hl, hl, 1.0],
         
                     [+hl,+hl, hl, 1.0],
                     [-hl,+hl, hl, 1.0],
                     [-hl,-hl, hl, 1.0],
         
                     # Back
                     [+hl,-hl,-hl, 1.0],
                     [-hl,-hl,-hl, 1.0],
                     [-hl,+hl,-hl, 1.0],
         
                     [-hl,+hl,-hl, 1.0],
                     [+hl,+hl,-hl, 1.0],
                     [+hl,-hl,-hl, 1.0],
         
                     # Right
                     [+hl,-hl,+hl, 1.0],
                     [+hl,-hl,-hl, 1.0],
                     [+hl,+hl,-hl, 1.0],
         
                     [+hl,+hl,-hl, 1.0],
                     [+hl,+hl,+hl, 1.0],
                     [+hl,-hl,+hl, 1.0],
         
                     # Left
                     [-hl,-hl,-hl, 1.0],
                     [-hl,-hl,+hl, 1.0],
                     [-hl,+hl,+hl, 1.0],
         
                     [-hl,+hl,+hl, 1.0],
                     [-hl,+hl,-hl, 1.0],
                     [-hl,-hl,-hl, 1.0],
         
                     # Top
                     [-hl,+hl,+hl, 1.0],
                     [+hl,+hl,+hl, 1.0],
                     [+hl,+hl,-hl, 1.0],
         
                     [+hl,+hl,-hl, 1.0],
                     [-hl,+hl,-hl, 1.0],
                     [-hl,+hl,+hl, 1.0],
         
                     # Bottom
                     [-hl,-hl,-hl, 1.0],
                     [+hl,-hl,-hl, 1.0],
                     [+hl,-hl,+hl, 1.0],
         
                     [+hl,-hl,+hl, 1.0],
                     [-hl,-hl,+hl, 1.0],
                     [-hl,-hl,-hl, 1.0],])
    return np.dot( homog, vrts.transpose() )[:3,:].transpose()


def get_normals_diameters_and_centers( vertTriples : np.ndarray ):
    """ Return the normal (and "diameter") at each vertex, assuming every three vertices form a triangle """
    # NOTE: Diameters will be used a distance heuristic
    # NOTE: This function outputs some distances 3 times in order to make the Numpy operations nice
    rtnNrm = np.ones( vertTriples.shape )
    rtnCen = np.ones( vertTriples.shape )
    rtnDia = np.ones( len( vertTriples ) )
    for i in range( 0, len( vertTriples ), 3 ):
        p0 = vertTriples[i  ,:3]
        p1 = vertTriples[i+1,:3]
        p2 = vertTriples[i+2,:3]
        v0 = np.subtract( p1, p0 )
        v1 = np.subtract( p2, p1 )
        v2 = np.subtract( p0, p2 )
        mx = np.max( [np.linalg.norm( v0 ), np.linalg.norm( v1 ), np.linalg.norm( v2 ),] )
        ni = vec_unit( np.cross( v0, v1 ) )
        ci = ( p0 + p1 + p2 ) / 3.0
        rtnNrm[ i:i+3, :3 ] = [ni, ni, ni,]
        rtnCen[ i:i+3, :3 ] = [ci, ci, ci,]
        rtnDia[ i:i+3 ] = [mx, mx, mx,]
    return rtnNrm, rtnDia, rtnCen


def min_dist_to_mesh( q, verts, norms, diams, cntrs ):
    """ Given a list of triangle verts, Return the least distance from `q` to the mesh, HACK: We don't actually care if it's very accurate! """
    # HACK: This function uses a distance heuristic to determine if the point distance or plane distance is correct
    # HACK: This function does NOT take into account the distance to the triangle edge in the case that it is the least distance
    # NOTE: This function computes some distances 3 times in order to make the Numpy operations nice
    factr = 1.5 # Bubble factor
    Npnts = len( verts )
    diffs = np.subtract( verts, q )
    cenDf = np.subtract( cntrs, q )
    cenDs = np.linalg.norm( cenDf, axis = 1 )
    plnDs = np.sum( norms * diffs, axis = 1, keepdims = True ) # https://stackoverflow.com/q/62500584
    pntDs = np.linalg.norm( diffs, axis = 1 ) # https://stackoverflow.com/a/7741976
    dMin  = 1e9
    for i in range( Npnts ):
        # HACK: CHOOSE POINT DISTANCE IF OUTSIDE OF BUBBLE, CHOOSE PLANE DISTANCE IF INSIDE BUBBLE
        if (cenDs[i,0] > (diams[i]*factr)):
            dMin = min( dMin, pntDs[i,0] )
        else:
            dMin = min( dMin, abs( plnDs[i,0] ) )
    return dMin


def ransac(points, sideLen, max_iterations=1000, inlier_threshold=0.05, seed=None):
    """
    RANSAC algorithm to estimate the pose of a cube given noisy point cloud data.
    Args:
        points (np.ndarray): Nx3 array of point cloud data.
        sideLen (float): Length of the cube's sides.
        max_iterations (int): Number of RANSAC iterations.
        inlier_threshold (float): Distance threshold to consider a point as an inlier.
        seed (int, optional): Random seed for reproducibility.
    Returns:
        best_pose (np.ndarray): 4x4 transformation matrix of the estimated cube pose.
    """
    best_pose = None
    max_inliers = 0

    if seed is not None:
        np.random.seed(seed)

    # Precompute the canonical cube model points (e.g., the 8 corners)
    canonical_cube_points = get_cube(sideLen, np.eye(4)).T  # 8x3 array

    for _ in range(max_iterations):
        # Randomly sample 3 non-colinear points from the data
        sample_indices = np.random.choice(len(points), size=3, replace=False)
        sampled_points = points[sample_indices]

        # Check for colinearity
        if np.linalg.matrix_rank(sampled_points - sampled_points[0]) < 3:
            continue  # Skip iteration if points are colinear

        # Corresponding points on the cube model (e.g., select 3 non-colinear corners)
        model_indices = np.random.choice(len(canonical_cube_points), size=3, replace=False)
        model_points = canonical_cube_points[model_indices]


        # Estimate pose using the Kabsch algorithm
        try:
            R, t = estimate_rigid_transform(model_points, sampled_points)
        except np.linalg.LinAlgError:
            continue  # Skip if SVD fails

        # Construct the transformation matrix
        pose = np.eye(4)
        pose[:3, :3] = R
        pose[:3, 3] = t

        # Transform the canonical cube model points to the estimated pose
        transformed_model_points = (R @ canonical_cube_points.T + t[:, np.newaxis]).T

        # Compute distances from all data points to the transformed cube surface
        distances = compute_point_to_model_distances(points, transformed_model_points, sideLen)

        # Count inliers
        inliers = np.sum(distances < inlier_threshold)

        # Update best pose if current pose has more inliers
        if inliers > max_inliers:
            max_inliers = inliers
            best_pose = pose

    return best_pose

def icp(points, sideLen, initial_pose=None, max_iterations=50, convergence_threshold=1e-4):
    """
    ICP algorithm to refine the pose of a cube given an initial estimate.
    Args:
        points (np.ndarray): Nx3 array of point cloud data.
        sideLen (float): Length of the cube's sides.
        initial_pose (np.ndarray): 4x4 initial transformation matrix.
        max_iterations (int): Maximum number of ICP iterations.
        convergence_threshold (float): Threshold for convergence.
    Returns:
        pose (np.ndarray): Refined 4x4 transformation matrix of the cube pose.
    """
    # Initialize the pose
    pose = np.eye(4) if initial_pose is None else initial_pose

    # Generate dense points on the cube surface for better correspondences
    model_points = generate_cube_surface_points(sideLen, num_points=1000)  # Nx3 array

    for iteration in range(max_iterations):
        # Transform model points to the current pose
        transformed_model_points = (pose[:3, :3] @ model_points.T + pose[:3, 3][:, np.newaxis]).T

        # Build KDTree for the data points
        kdtree = KDTree(points)

        # Find closest data point for each model point
        distances, indices = kdtree.query(transformed_model_points)
        closest_data_points = points[indices]

        # Apply outlier rejection (e.g., keep correspondences within a threshold)
        valid_indices = distances < convergence_threshold * 10  # Arbitrary factor, maybe we can use diff outlier rejection method, this is just placeholder
        if np.sum(valid_indices) == 0:
            break  # No valid correspondences

        # Select valid correspondences
        src_points = transformed_model_points[valid_indices]
        dst_points = closest_data_points[valid_indices]

        # Estimate transformation between src_points and dst_points
        R, t = estimate_rigid_transform(src_points, dst_points)

        # Update pose
        delta_pose = np.eye(4)
        delta_pose[:3, :3] = R
        delta_pose[:3, 3] = t
        pose = delta_pose @ pose

        # Check convergence
        mean_error = np.mean(distances[valid_indices])
        if mean_error < convergence_threshold:
            break

    return pose

def estimate_rigid_transform(A, B):
    """
    Estimate the optimal rotation and translation between two sets of points.
    Args:
        A (np.ndarray): Nx3 array of source points.
        B (np.ndarray): Nx3 array of destination points.
    Returns:
        R (np.ndarray): 3x3 rotation matrix.
        t (np.ndarray): 3x1 translation vector.
    """
    # Compute centroids
    centroid_A = np.mean(A, axis=0)
    centroid_B = np.mean(B, axis=0)

    # Subtract centroids
    AA = A - centroid_A
    BB = B - centroid_B

    # Compute covariance matrix
    H = AA.T @ BB

    # SVD decomposition
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # Correct reflection issue
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T

    # Compute translation
    t = centroid_B - R @ centroid_A

    return R, t

def compute_point_to_model_distances(points, model_points, sideLen):
    """
    Compute distances from data points to the cube model surface.
    Args:
        points (np.ndarray): Nx3 array of data points.
        model_points (np.ndarray): Mx3 array of transformed model points.
        sideLen (float): Length of the cube's sides.
    Returns:
        distances (np.ndarray): Nx1 array of distances.
    """
    # Build KDTree for the model points
    kdtree = KDTree(model_points)

    # Find closest model point for each data point
    distances, _ = kdtree.query(points)

    return distances

def generate_cube_surface_points(sideLen, num_points=1000):
    """
    Generate a dense set of points on the surface of a cube.
    Args:
        sideLen (float): Length of the cube's sides.
        num_points (int): Number of points to generate.
    Returns:
        points (np.ndarray): Nx3 array of points on the cube surface.
    """
    hl = sideLen / 2.0
    # Generate points on each face
    face_points = []
    for axis in range(3):
        for sign in [-1, 1]:
            coords = np.random.uniform(-hl, hl, size=(num_points // 6, 2))
            points = np.zeros((num_points // 6, 3))
            points[:, axis] = sign * hl
            other_axes = [i for i in range(3) if i != axis]
            points[:, other_axes] = coords
            face_points.append(points)
    return np.vstack(face_points)

def estimate_cube_pose(points, sideLen, method="RANSAC", **kwargs):
    """
    Estimate the cube pose using the specified method (RANSAC or ICP).
    Args:
        points (np.ndarray): Nx3 array of point cloud data.
        sideLen (float): Length of the cube's sides.
        method (str): Estimation method ('RANSAC' or 'ICP').
        kwargs: Additional arguments for the estimation methods.
    Returns:
        pose (np.ndarray): 4x4 transformation matrix of the cube pose.
    """
    if method.upper() == "RANSAC":
        return ransac(points, sideLen, **kwargs)
    elif method.upper() == "ICP":
        return icp(points, sideLen, **kwargs)
    else:
        raise ValueError(f"Unknown method '{method}'. Use 'RANSAC' or 'ICP'.")

# Example usage
if __name__ == "__main__":
    # Generate synthetic test points (e.g., a cube with noise)
    cube_side = 1.0
    true_pose = np.eye(4)
    true_pose[:3, 3] = [0.5, 0.5, 0.5]
    true_pose[:3, :3] = np.array([
        [0, -1, 0],
        [1,  0, 0],
        [0,  0, 1]
    ])  # Add some rotation

    # Generate dense points on the cube surface
    cube_points = generate_cube_surface_points(cube_side, num_points=2000)
    # Transform the cube points to the true pose
    cube_points = (true_pose[:3, :3] @ cube_points.T + true_pose[:3, 3][:, np.newaxis]).T
    # Add Gaussian noise
    noisy_points = cube_points + np.random.normal(0, 0.01, cube_points.shape)

    # Test RANSAC
    ransac_pose = estimate_cube_pose(noisy_points, cube_side, method="RANSAC", max_iterations=1000, inlier_threshold=0.05)

    # Refine with ICP using RANSAC result as initial pose
    icp_pose = estimate_cube_pose(noisy_points, cube_side, method="ICP", initial_pose=ransac_pose, max_iterations=50)

    print("True Pose:")
    print(true_pose)
    print("\nRANSAC Estimated Pose:")
    print(ransac_pose)
    print("\nICP Refined Pose:")
    print(icp_pose)