In [1]:
import cv2
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import numpy as np
import matplotlib.pyplot as plt
import random
from copy import deepcopy

In [2]:
def plot_img(img, kpts, sph=False, color=(255, 0, 255), radius=1, thickness=5):
    """
    Draws keypoints on an image and displays it.

    Args:
        image (np.ndarray): The input image in BGR format (as read by cv2).
        kpts_spherical (np.ndarray): Keypoints in spherical coordinates (phi, theta).
        color (tuple): BGR color for the keypoints.
        radius (int): Radius of the circles representing keypoints.
        thickness (int): Thickness of the circle outline.
    """
    h, w = img.shape[:2]
    # Convert spherical keypoints to pixel coordinates in the image
    if sph:
        pixel_coords = standard_spherical_to_pixel(kpts, w, h)
    else:
        pixel_coords = kpts

    # Draw each keypoint as a circle on the image
    for point in pixel_coords:
        # Get integer coordinates for drawing
        x, y = int(round(point[0])), int(round(point[1]))
        cv2.circle(img, (x, y), radius, color, thickness)

    # --- Display the image using Matplotlib ---
    # Convert the image from BGR (OpenCV's default) to RGB for correct color display
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Create a plot to show the image
    plt.figure(figsize=(16, 8))
    plt.imshow(img_rgb)
    plt.title("Image with Plotted Keypoints")
    plt.axis('off')  # Hide the axes for a cleaner look
    plt.show()

In [3]:
def read_depth(p):
    d = cv2.imread(str(p), cv2.IMREAD_UNCHANGED)
    if d.ndim == 3: d = d[...,0]
    return d.astype(np.float32)

In [4]:
def load_pose_and_get_c2w_matrix(filepath):
    """Reads a pose from a specific .dat file format.

    Each file contains the rotation matrix R stored as a 3 x 3 matrix that, following convention, 
    encodes the transformation from world to camera coordinate system. However, here the vector t 
    does not follow convention and represents the camera position in world coordinate system. This 
    is convenient to compute distances between cameras and determine neighborhood.

    1. R is a 3x3 matrix from world-to-camera. This means P_camera = R @ P_world.
    2. t (as stored in the file) represents the camera position in the world coordinate system. 
    This means t is what we typically call the camera center, C.

    Args:
        filepath: Path to the .dat file.

    Returns:
        A 4x4 Camera-to-World (C2W) matrix.
    """

    # Load the 3x3 rotation matrix, skipping the first 2 header lines
    # and reading only the 3 subsequent lines.
    R_w2c = np.loadtxt(filepath, skiprows=2, max_rows=3)
    
    # Load the 3x1 translation vector, skipping all lines before it.
    # (2 header lines + 3 matrix rows + 't' marker = 6 lines)
    C_world = np.loadtxt(filepath, skiprows=6)

    # --- Convert [R_w2c | C] to a 4x4 C2W matrix ---
    # The rotation part of a C2W matrix is the inverse (transpose) of a W2C rotation.
    R_c2w = R_w2c.T
    
    # The translation part of a C2W matrix is simply the camera center C.
    t_c2w = C_world
    
    # Assemble the 4x4 matrix
    pose_c2w = np.eye(4)
    pose_c2w[:3, :3] = R_c2w
    pose_c2w[:3, 3] = t_c2w

    pose_w2c = np.eye(4)
    pose_w2c[:3, :3] = R_w2c
    pose_w2c[:3, 3] = -R_w2c.dot(t_c2w)

    return pose_c2w, pose_w2c

In [5]:
W = 2048
H = 1024

def standard_spherical_to_pixel(kpts_sph_np, W, H):
    """
    Converts standard spherical coordinates to pixel coordinates.
    phi: longitude [-pi, pi] -> x [0, W]
    theta: latitude [-pi/2, pi/2] -> y [0, H]
    """
    phi = kpts_sph_np[:, 0]
    theta = kpts_sph_np[:, 1]

    # Normalize phi to [0, 1] and theta to [0, 1] and Scale to pixel coordinates
    px = (phi / (2 * np.pi) + 0.5) * (W - 1) - 0.5
    py = (-theta / np.pi + 0.5) * (H - 1) - 0.5
    
    return np.stack([px, py], axis=-1)

In [6]:
def unproject_spherical(uv, d, w, h):
    u = uv[:, 0].astype(float)
    v = uv[:, 1].astype(float)
    r = d.squeeze().astype(float)
    phi = (v + 0.5) * np.pi / h
    theta = (1.0 - (u + 0.5) / w) * (2.0 * np.pi)
    x = r * np.cos(theta) * np.sin(phi)
    y = r * np.sin(theta) * np.sin(phi)
    z = r * np.cos(phi)
    return np.stack([x, y, z], axis=1)


def project_spherical(pts, w, h):
    x, y, z = pts[:, 0], pts[:, 1], pts[:, 2]

    d = np.linalg.norm(pts, axis=-1)
    d_safe = np.clip(d, a_min=1e-8, a_max=None)  # Prevent division by zero

    phi = np.arccos(np.clip(z/d_safe, a_min=-1.0, a_max=1.0)) 
    theta = np.arctan2(y, x) # Shift to [0, 2*pi]
    theta = np.remainder(theta, 2.0 * np.pi)

    u = (1.0 - theta / (2.0 * np.pi)) * w - 0.5
    v = (phi / np.pi) * h - 0.5
    
    uv = np.stack([u, v], axis=1)
    return uv, d

In [7]:
img1 = cv2.imread("/data/code/glue-factory/datasets/spherecraft_data/barbershop/images/00000000.jpg")
img2 = cv2.imread("/data/code/glue-factory/datasets/spherecraft_data/barbershop/images/00000001.jpg")
depth1 = read_depth("/data/code/glue-factory/datasets/spherecraft_data/barbershop/depthmaps/00000000.exr")
depth2 = read_depth("/data/code/glue-factory/datasets/spherecraft_data/barbershop/depthmaps/00000001.exr")

kpts1_sph = np.load("/data/code/glue-factory/datasets/spherecraft_data/barbershop/features_xfeat_spherical/00000000.npz")['keypoints']
kpts2_sph = np.load("/data/code/glue-factory/datasets/spherecraft_data/barbershop/features_xfeat_spherical/00000001.npz")['keypoints']

p0_c2w, p0_w2c = load_pose_and_get_c2w_matrix("/data/code/glue-factory/datasets/spherecraft_data/barbershop/extr/00000000.dat")
p1_c2w, p1_w2c = load_pose_and_get_c2w_matrix("/data/code/glue-factory/datasets/spherecraft_data/barbershop/extr/00000001.dat")

In [8]:
kpts1 = standard_spherical_to_pixel(kpts1_sph, W, H)

In [9]:
len(kpts1)

2048

In [13]:
sample_N = min(100, len(kpts1))
ids_rt = random.sample(range(len(kpts1)), sample_N)
uv_rt = kpts1[ids_rt]

In [14]:
d_rt = depth1[uv_rt[:,1].astype(int), uv_rt[:,0].astype(int)]
pts3d_rt = unproject_spherical(uv_rt, d_rt, W, H)
uv_reproj, _ = project_spherical(pts3d_rt, W, H)

In [15]:
uv_rt[:5]

array([[586.9989 , 483.16373],
       [820.103  , 910.27014],
       [803.877  , 599.0131 ],
       [292.90872, 533.43787],
       [279.7636 , 857.49567]], dtype=float32)

In [16]:
uv_reproj[:5]

array([[586.99890137, 483.16372681],
       [820.10302734, 910.2701416 ],
       [803.87701416, 599.01312256],
       [292.90872192, 533.43786621],
       [279.76361084, 857.4956665 ]])

In [17]:
class EmptyTensorError(Exception):
    pass

def interpolate_depth(pos, depth):
    """Interpolates depth values for 2D points using bilinear interpolation."""
    # Ensure pos is 2xN and convert to integer indices
    pos = pos.T[[1, 0]]

    h, w = depth.shape
    
    i = pos[0, :].astype(float)
    j = pos[1, :].astype(float)

    # Valid corners and indices
    i_top_left = np.floor(i).astype(int)
    j_top_left = np.floor(j).astype(int)
    valid_top_left = np.logical_and(i_top_left >= 0, j_top_left >= 0)

    i_top_right = np.floor(i).astype(int)
    j_top_right = np.ceil(j).astype(int)
    valid_top_right = np.logical_and(i_top_right >= 0, j_top_right < w)

    i_bottom_left = np.ceil(i).astype(int)
    j_bottom_left = np.floor(j).astype(int)
    valid_bottom_left = np.logical_and(i_bottom_left < h, j_bottom_left >= 0)

    i_bottom_right = np.ceil(i).astype(int)
    j_bottom_right = np.ceil(j).astype(int)
    valid_bottom_right = np.logical_and(i_bottom_right < h, j_bottom_right < w)

    valid_corners = np.all(
        [valid_top_left, valid_top_right, valid_bottom_left, valid_bottom_right], axis=0)

    ids = np.arange(pos.shape[1])
    ids_valid_corners = ids[valid_corners]
    
    if ids_valid_corners.size == 0:
        raise EmptyTensorError

    i_top_left = i_top_left[valid_corners]
    j_top_left = j_top_left[valid_corners]
    i_top_right = i_top_right[valid_corners]
    j_top_right = j_top_right[valid_corners]
    i_bottom_left = i_bottom_left[valid_corners]
    j_bottom_left = j_bottom_left[valid_corners]
    i_bottom_right = i_bottom_right[valid_corners]
    j_bottom_right = j_bottom_right[valid_corners]
    
    # Check depth validity
    valid_depth = np.all(
        [
            depth[i_top_left, j_top_left] > 0,
            depth[i_top_right, j_top_right] > 0,
            depth[i_bottom_left, j_bottom_left] > 0,
            depth[i_bottom_right, j_bottom_right] > 0
        ],
        axis=0,
    )

    ids = ids_valid_corners[valid_depth]
    ids_valid_depth = deepcopy(ids)

    if ids.size == 0:
        raise EmptyTensorError

    i = i[ids]
    j = j[ids]

    i_top_left = i_top_left[valid_depth]
    j_top_left = j_top_left[valid_depth]
    i_top_right = i_top_right[valid_depth]
    j_top_right = j_top_right[valid_depth]
    i_bottom_left = i_bottom_left[valid_depth]
    j_bottom_left = j_bottom_left[valid_depth]
    i_bottom_right = i_bottom_right[valid_depth]
    j_bottom_right = j_bottom_right[valid_depth]

    # Interpolation
    dist_i = i - i_top_left
    dist_j = j - j_top_left
    w_top_left = (1 - dist_i) * (1 - dist_j)
    w_top_right = (1 - dist_i) * dist_j
    w_bottom_left = dist_i * (1 - dist_j)
    w_bottom_right = dist_i * dist_j

    interpolated_depth = (w_top_left * depth[i_top_left, j_top_left] +
                          w_top_right * depth[i_top_right, j_top_right] +
                          w_bottom_left * depth[i_bottom_left, j_bottom_left] +
                          w_bottom_right * depth[i_bottom_right, j_bottom_right])

    pos_valid = pos[:, ids]
    pos_valid = pos_valid[[1, 0]].T

    return [interpolated_depth, pos_valid, ids, ids_valid_corners, ids_valid_depth]

In [18]:
def warp_points3d(points3d0: np.ndarray, pose01: np.ndarray) -> np.ndarray:
    """Warps 3D points using a SE3 pose."""
    points3d0_homo = np.concatenate([points3d0, np.ones((points3d0.shape[0], 1))], axis=1)
    points3d01_homo = np.einsum('jk,nk->nj', pose01, points3d0_homo)
    return points3d01_homo[:, 0:3]

In [19]:
import logging

def warp_se3_spherical(kpts0: np.ndarray, params: dict) -> tuple:
    """Warps 2D keypoints from one spherical image to another using 3D transformation and validation."""
    pose01 = params['pose01']
    depth0 = params['depth0'].squeeze()
    depth1 = params['depth1'].squeeze()
    W, H = params['width'], params['height']
    abs_tol = params.get('abs_tol', 0.05)
    rel_tol = params.get('rel_tol', 0.02)

    try:
        # 1) Get depth for keypoints
        z0, k0v, ids0, _, _ = interpolate_depth(kpts0, depth0)
    except EmptyTensorError:
        logging.warning("No valid keypoints after img0 depth check.")
        return kpts0, kpts0, np.empty(0, dtype=np.long), np.empty(0, dtype=np.long)

    # 2) Unproject -> warp -> project
    pts3d_0 = unproject_spherical(k0v, z0, W, H)
    pts3d_1 = warp_points3d(pts3d_0, pose01)
    uv1_pred, z1_proj = project_spherical(pts3d_1, W, H)

    try:
        # 3) Depth check img1
        z1i, k1v, ids1, _, _ = interpolate_depth(uv1_pred, depth1)
    except EmptyTensorError:
        logging.warning("All warped keypoints invalid in img1.")
        return kpts0, uv1_pred, ids0, ids1
        
    # 4) Occlusion check (depth consistency)
    abs_diff = np.abs(z1_proj[ids1] - z1i)
    rel_diff = abs_diff / np.clip(z1i, a_min=1e-6, a_max=None)
    mask = (abs_diff < abs_tol) & (rel_diff < rel_tol)
    
    # Filter points based on the mask
    final_ids = ids0[ids1][mask]
    k0_final = k0v[ids1][mask]
    k1_final = k1v[mask]
    
    # 5) Handle spherical wrap-around
    u0n = k0_final[:, 0] / (W - 1)
    u1n = k1_final[:, 0] / (W - 1)
    dn = np.remainder(u1n - u0n + 0.5, 1.0) - 0.5
    uc = np.remainder(u0n + dn, 1.0)
    k1_final[:, 0] = uc * (W - 1)
    k1_final[:, 1] = np.clip(k1_final[:, 1], a_min=0, a_max=H - 1)
    
    return k0_final, k1_final, final_ids, np.empty(0, dtype=np.int64)

In [20]:
pose01 = (p1_w2c.dot(p0_c2w)).astype(float)

params = {'pose01': pose01, 'depth0': depth1,
              'depth1': depth2, 'width': W, 'height': H,
              'abs_tol': 0.05, 'rel_tol': 0.02}
k0f, k1f, ids_w, _ = warp_se3_spherical(kpts1, params)

In [25]:
k0f

array([[ 279.28714,  377.64093],
       [1201.1182 ,  446.01025],
       [1468.7258 ,  496.65268],
       ...,
       [ 275.6683 ,  539.1186 ],
       [1412.4904 ,  875.2244 ],
       [1680.4025 ,  421.9269 ]], dtype=float32)

In [26]:
k1f

array([[1672.21916705,  306.81269   ],
       [ 460.1912684 ,  669.15722432],
       [ 780.0181296 ,  610.22466507],
       ...,
       [1624.44552266,  432.81668533],
       [1400.06074255,  854.0879722 ],
       [ 947.18843201,  433.47462863]])

In [30]:
ids_w

array([   0,    1,    3, ..., 2045, 2046, 2047])

In [21]:
img0 = cv2.resize(img1, (W, H), interpolation=cv2.INTER_LINEAR)
img1 = cv2.resize(img2, (W, H), interpolation=cv2.INTER_LINEAR)
canvas = 255 * np.ones((H, 2*W, 3), dtype=np.uint8)
canvas[:, :W] = img0; canvas[:, W:] = img1
draw_N = min(25, len(ids_w))
sel = random.sample(range(len(ids_w)), draw_N)
for i in sel:
    u0, v0 = k0f[i].tolist()
    u1, v1 = k1f[i].tolist()
    cv2.line(canvas, (int(u0), int(v0)), (int(u1)+W, int(v1)), (0, 0, 255), 2)
out_path = 'warp_se3_test.png'
cv2.imwrite(str(out_path), canvas)
logging.info(f"Saved warp visualization to {out_path}")