An X-ray C-arm can be modeled as a pinhole camera with its own extrinsic and intrinsic matrices. 
This module provides utilities for parsing these matrices and working with rigid transforms.

In [None]:
#| default_exp calibration

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch

## Rigid transformations

We represent rigid transforms as $4 \times 4$ matrices (following the right-handed convention of `PyTorch3D`),

\begin{equation}
\begin{bmatrix}
    \mathbf R^T & \mathbf 0 \\
    \mathbf t^T & 1
\end{bmatrix}
\in \mathbf{SE}(3) \,,
\end{equation}

where $\mathbf R \in \mathbf{SO}(3)$ is a rotation matrix and $\mathbf t\in \mathbb R^3$ represents a translation.

Note that since rotation matrices are orthogonal, we have a simple closed-form equation for the inverse:
\begin{equation}
\begin{bmatrix}
    \mathbf R^T & \mathbf 0 \\
    \mathbf t^T & 1
\end{bmatrix}^{-1} =
\begin{bmatrix}
    \mathbf R & \mathbf 0 \\
    -\mathbf R \mathbf t & 1
\end{bmatrix} \,.
\end{equation}

For convenience, we add a wrapper of `pytorch3d.transforms.Transform3d` that can be construced from a (batched) rotation matrix and translation vector. This module also includes the closed-form inverse specific to rigid transforms.

In [None]:
#| export
from typing import Optional

from beartype import beartype
from diffdrr.utils import convert as convert_so3
from jaxtyping import Float, jaxtyped
from pytorch3d.transforms import Transform3d
from pytorchse3.se3 import se3_exp_map, se3_log_map

In [None]:
#| export
@beartype
class RigidTransform(Transform3d):
    """Wrapper of pytorch3d.transforms.Transform3d with extra functionalities."""

    @jaxtyped
    def __init__(
        self,
        R: Float[torch.Tensor, "..."],
        t: Float[torch.Tensor, "... 3"],
        parameterization: str = "matrix",
        convention: Optional[str] = None,
        device=None,
        dtype=torch.float32,
    ):
        if device is None and (R.device == t.device):
            device = R.device

        R = convert_so3(R, parameterization, "matrix", convention)
        if R.dim() == 2 and t.dim() == 1:
            R = R.unsqueeze(0)
            t = t.unsqueeze(0)
        assert (batch_size := len(R)) == len(t), "R and t need same batch size"

        matrix = torch.zeros(batch_size, 4, 4, device=device, dtype=dtype)
        matrix[..., :3, :3] = R.transpose(-1, -2)
        matrix[..., 3, :3] = t
        matrix[..., 3, 3] = 1

        super().__init__(matrix=matrix, device=device, dtype=dtype)

    def get_rotation(self, parameterization=None, convention=None):
        R = self.get_matrix()[..., :3, :3].transpose(-1, -2)
        if parameterization is not None:
            R = convert_so3(R, "matrix", parameterization, None, convention)
        return R

    def get_translation(self):
        return self.get_matrix()[..., 3, :3]

    def inverse(self):
        """Closed-form inverse for rigid transforms."""
        R = self.get_rotation().transpose(-1, -2)
        t = self.get_translation()
        t = -torch.einsum("bij,bj->bi", R, t)
        return RigidTransform(R, t, device=self.device, dtype=self.dtype)

    def compose(self, other):
        T = super().compose(other)
        R = T.get_matrix()[..., :3, :3].transpose(-1, -2)
        t = T.get_matrix()[..., 3, :3]
        return RigidTransform(R, t, device=self.device, dtype=self.dtype)

    def clone(self):
        R = self.get_matrix()[..., :3, :3].transpose(-1, -2).clone()
        t = self.get_matrix()[..., 3, :3].clone()
        return RigidTransform(R, t, device=self.device, dtype=self.dtype)

    def get_se3_log(self):
        return se3_log_map(self.get_matrix().transpose(-1, -2))

In [None]:
#| export
def convert(
    transform,
    input_parameterization,
    output_parameterization,
    input_convention=None,
    output_convention=None,
):
    """Convert between representations of SE(3)."""

    # Convert any input parameterization to a RigidTransform
    if input_parameterization == "se3_log_map":
        transform = torch.concat([*transform], axis=-1)
        matrix = se3_exp_map(transform)
        transform = RigidTransform(
            R=matrix[..., :3, :3].transpose(-1, -2),
            t=matrix[..., 3, :3],
            device=matrix.device,
            dtype=matrix.dtype,
        )
    elif input_parameterization == "se3_exp_map":
        pass
    else:
        transform = RigidTransform(
            R=transform[0],
            t=transform[1],
            parameterization=input_parameterization,
            convention=input_convention,
        )

    # Convert the RigidTransform to any output
    if output_parameterization == "se3_exp_map":
        return transform
    elif output_parameterization == "se3_log_map":
        se3_log = transform.get_se3_log()
        return se3_log[..., :3], se3_log[..., 3:]
    else:
        return (
            transform.get_rotation(output_parameterization, output_convention),
            transform.get_translation(),
        )

## Computing a perspective projection

Given an `extrinsic` and `intrinsic` camera matrix, we can compute the perspective projection of a batch of points.
This is used for computing where fiducials in world coordinates get mapped onto the image plane.

In [None]:
#| export
@beartype
@jaxtyped
def perspective_projection(
    extrinsic: RigidTransform,  # Extrinsic camera matrix (world to camera)
    intrinsic: Float[torch.Tensor, "3 3"],  # Intrinsic camera matrix (camera to image)
    x: Float[torch.Tensor, "b n 3"],  # World coordinates
) -> Float[torch.Tensor, "b n 2"]:
    x = extrinsic.transform_points(x)
    x = torch.einsum("ij, bnj -> bni", intrinsic, x)
    z = x[..., -1].unsqueeze(-1).clone()
    x = x / z
    return x[..., :2]

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()