In [None]:
#| default_exp metrics

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from diffdrr.metrics import (
    GradientNormalizedCrossCorrelation2d,
    MultiscaleNormalizedCrossCorrelation2d,
    NormalizedCrossCorrelation2d,
)
from torchmetrics import Metric

## Image similarity metrics

Used to quantify the similarity between ground truth X-rays ($\mathbf I$) and synthetic X-rays generated from estimated camera poses ($\hat{\mathbf I}$). If a metric is differentiable, it can be used to optimize camera poses with `DiffDRR`.

In [None]:
#| exporti
class CustomMetric(Metric):
    is_differentiable: True

    def __init__(self, LossClass, **kwargs):
        super().__init__()
        self.lossfn = LossClass(**kwargs)
        self.add_state("loss", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        self.loss += self.lossfn(preds, target).sum()
        self.count += len(preds)

    def compute(self):
        return self.loss.float() / self.count

`NCC` and `GradNCC` are originally implemented in [`diffdrr.metrics`](https://github.com/eigenvivek/DiffDRR/blob/main/notebooks/api/05_metrics.ipynb).
`DiffPose` provides `torchmetrics` wrappers for these functions.

In [None]:
#| export
class NormalizedCrossCorrelation(CustomMetric):
    """`torchmetric` wrapper for NCC."""

    higher_is_better: True

    def __init__(self, patch_size=None):
        super().__init__(NormalizedCrossCorrelation2d, patch_size=patch_size)


class MultiscaleNormalizedCrossCorrelation(CustomMetric):
    """`torchmetric` wrapper for Multiscale NCC."""

    higher_is_better: True

    def __init__(self, patch_sizes, patch_weights):
        super().__init__(
            MultiscaleNormalizedCrossCorrelation2d,
            patch_sizes=patch_sizes,
            patch_weights=patch_weights,
        )


class GradientNormalizedCrossCorrelation(CustomMetric):
    """`torchmetric` wrapper for GradNCC."""

    higher_is_better: True

    def __init__(self, patch_size=None):
        super().__init__(GradientNormalizedCrossCorrelation2d, patch_size=patch_size)

## Geodesic distances for SO(3) and SE(3)

One can define geodesic pseudo-distances on SO(3) and SE(3).[^1] This let's us measure registration error (in radians and millimeters, respectively) on poses, rather than needed to compute the projection of fiducials.

- **For SO(3)**, the geodesic distance between two rotation matrices $\mathbf R_A$ and $\mathbf R_B$ is
\begin{equation}
    d_\theta(\mathbf R_A, \mathbf R_B; r) = r \left| \arccos \left( \frac{\mathrm{trace}(\mathbf R_A^* \mathbf R_B) - 1}{2} \right ) \right| \,,
\end{equation}
where $r$, the source-to-detector radius, is used to convert radians to millimeters.

- **For SE(3)**, we decompose the transformation matrix into a rotation and a translation, i.e., $\mathbf T = (\mathbf R, \mathbf t)$.
Then, we compute the geodesic on translations (just Euclidean distance),
\begin{equation}
    d_t(\mathbf t_A, \mathbf t_B) = \| \mathbf t_A - \mathbf t_B \|_2 \,.
\end{equation}
Finally, we compute the *double geodesic* on the rotations and translations:
\begin{equation}
    d(\mathbf T_A, \mathbf T_B) = \sqrt{d_\theta(\mathbf R_A, \mathbf R_B)^2 + d_t(\mathbf t_A, \mathbf t_B)^2} \,.
\end{equation}

[^1]: [https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf](https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf)

In [None]:
#| export
import torch
from beartype import beartype
from diffdrr.pose import convert, so3_log_map
from jaxtyping import Float, jaxtyped

from diffpose.calibration import RigidTransform

In [None]:
#| export
class GeodesicSO3(torch.nn.Module):
    """Calculate the angular distance between two rotations in SO(3)."""

    def __init__(self):
        super().__init__()

    @jaxtyped(typechecker=beartype)
    def forward(
        self,
        pose_1: RigidTransform,
        pose_2: RigidTransform,
    ) -> Float[torch.Tensor, "b"]:
        r1 = pose_1.matrix[..., :3, :3]
        r2 = pose_2.matrix[..., :3, :3]
        rdiff = r1.transpose(-1, -2) @ r2
        return so3_log_map(rdiff).norm(dim=-1)


class GeodesicTranslation(torch.nn.Module):
    """Calculate the angular distance between two translations in R^3."""

    def __init__(self):
        super().__init__()

    @jaxtyped(typechecker=beartype)
    def forward(
        self,
        pose_1: RigidTransform,
        pose_2: RigidTransform,
    ) -> Float[torch.Tensor, "b"]:
        t1 = pose_1.matrix[..., :3, 3]
        t2 = pose_2.matrix[..., :3, 3]
        return (t1 - t2).norm(dim=1)

In [None]:
#| export
class GeodesicSE3(torch.nn.Module):
    """Calculate the distance between transforms in the log-space of SE(3)."""

    def __init__(self):
        super().__init__()

    @jaxtyped(typechecker=beartype)
    def forward(
        self,
        pose_1: RigidTransform,
        pose_2: RigidTransform,
    ) -> Float[torch.Tensor, "b"]:
        return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)

In [None]:
#| export
@beartype
class DoubleGeodesic(torch.nn.Module):
    """Calculate the angular and translational geodesics between two SE(3) transformation matrices."""

    def __init__(
        self,
        sdr: float,  # Source-to-detector radius
        eps: float = 1e-4,  # Avoid overflows in sqrt
    ):
        super().__init__()
        self.sdr = sdr
        self.eps = eps

        self.rotation = GeodesicSO3()
        self.translation = GeodesicTranslation()

    @jaxtyped(typechecker=beartype)
    def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
        angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)
        translation_geodesic = self.translation(pose_1, pose_2)
        double_geodesic = (
            (angular_geodesic).square() + translation_geodesic.square() + self.eps
        ).sqrt()
        return angular_geodesic, translation_geodesic, double_geodesic

In [None]:
# SO(3) distance
geodesic_so3 = GeodesicSO3()

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

print(geodesic_so3(pose_1, pose_2))  # Angular distance in radians

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.1, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

print(geodesic_so3(pose_1, pose_2))  # Angular distance in radians

tensor([0.])
tensor([0.1000])


In [None]:
# SE(3) distance
geodesic_se3 = GeodesicSE3()

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.1, torch.pi]]),
    torch.zeros(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

geodesic_se3(pose_1, pose_2)

tensor([1.7355])

In [None]:
# Angular distance and translational distance both in mm
double_geodesic = DoubleGeodesic(1020 / 2)

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.1, torch.pi]]),
    torch.zeros(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

double_geodesic(pose_1, pose_2)

(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()