In [None]:
#| default_exp metrics

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations

import torch

## Image similarity metrics

Compute the similarity between a fixed X-ray $\mathbf I$ and a moving X-ray $\mathbf{\hat I}$, where $\mathbf{\hat I}$ is rendered from an estimated camera pose (registration) or volume (reconstruction).

We implement patchwise variants of the following metrics:

- Normalized Cross Correlation (NCC)
- Multiscale Normalized Cross Correlation (mNCC)
- Gradient Normalized Cross Correlation (gNCC)

In [None]:
#| exporti
from einops import rearrange


def to_patches(x, patch_size):
    x = x.unfold(2, patch_size, step=1).unfold(3, patch_size, step=1).contiguous()
    return rearrange(x, "b c p1 p2 h w -> b (c p1 p2) h w")

In [None]:
#| export
class NormalizedCrossCorrelation2d(torch.nn.Module):
    """Compute Normalized Cross Correlation between two batches of images."""

    def __init__(self, patch_size=None, eps=1e-5):
        super().__init__()
        self.patch_size = patch_size
        self.eps = eps

    def forward(self, x1, x2):
        if self.patch_size is not None:
            x1 = to_patches(x1, self.patch_size)
            x2 = to_patches(x2, self.patch_size)
        assert x1.shape == x2.shape, "Input images must be the same size"
        _, c, h, w = x1.shape
        x1, x2 = self.norm(x1), self.norm(x2)
        score = torch.einsum("b...,b...->b", x1, x2)
        score /= c * h * w
        return score

    def norm(self, x):
        mu = x.mean(dim=[-1, -2], keepdim=True)
        var = x.var(dim=[-1, -2], keepdim=True, correction=0) + self.eps
        std = var.sqrt()
        return (x - mu) / std

In [None]:
#| export
class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):
    """Compute Normalized Cross Correlation between two batches of images at multiple scales."""

    def __init__(self, patch_sizes=[None], patch_weights=[1.0], eps=1e-5):
        super().__init__()

        assert len(patch_sizes) == len(patch_weights), "Each scale must have a weight"
        self.nccs = [
            NormalizedCrossCorrelation2d(patch_size) for patch_size in patch_sizes
        ]
        self.patch_weights = patch_weights

    def forward(self, x1, x2):
        scores = []
        for weight, ncc in zip(self.patch_weights, self.nccs):
            scores.append(weight * ncc(x1, x2))
        return torch.stack(scores, dim=0).sum(dim=0)

In [None]:
#| exporti
from torchvision.transforms.functional import gaussian_blur


class Sobel(torch.nn.Module):
    def __init__(self, sigma):
        super().__init__()
        self.sigma = sigma
        self.filter = torch.nn.Conv2d(
            in_channels=1,
            out_channels=2,  # X- and Y-gradients
            kernel_size=3,
            stride=1,
            padding=1,  # Return images of the same size as inputs
            bias=False,
        )

        Gx = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(torch.float32)
        Gy = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(torch.float32)
        G = torch.stack([Gx, Gy]).unsqueeze(1)
        self.filter.weight = torch.nn.Parameter(G, requires_grad=False)

    def forward(self, img):
        x = gaussian_blur(img, 5, self.sigma)
        x = self.filter(img)
        return x

In [None]:
#| export
class GradientNormalizedCrossCorrelation2d(NormalizedCrossCorrelation2d):
    """Compute Normalized Cross Correlation between the image gradients of two batches of images."""

    def __init__(self, patch_size=None, sigma=1.0, **kwargs):
        super().__init__(patch_size, **kwargs)
        self.sobel = Sobel(sigma)

    def forward(self, x1, x2):
        return super().forward(self.sobel(x1), self.sobel(x2))

In [None]:
#| hide
x1 = torch.randn(8, 1, 128, 128)
x2 = torch.randn(8, 1, 128, 128)

ncc = NormalizedCrossCorrelation2d()
ncc(x1, x2)

ncc = NormalizedCrossCorrelation2d(eps=1e-1)
ncc(x1, x2)

ncc = NormalizedCrossCorrelation2d(patch_size=9)
ncc(x1, x2)

msncc = MultiscaleNormalizedCrossCorrelation2d(
    patch_sizes=[9, None], patch_weights=[0.5, 0.5]
)
msncc(x1, x2)

gncc = GradientNormalizedCrossCorrelation2d()
gncc(x1, x2)

gncc = GradientNormalizedCrossCorrelation2d(patch_size=9)
gncc(x1, x2)

tensor([-0.0019, -0.0004,  0.0035, -0.0198, -0.0078, -0.0175,  0.0171,  0.0019])

## Geodesic distances for SE(3)

One can define geodesic pseudo-distances on $\mathbf{SO}(3)$ and $\mathbf{SE}(3)$.[^1] This let's us measure registration error (in radians and millimeters, respectively) on poses, rather than needed to compute the projection of fiducials.

We implement two geodesics on $\mathbf{SE}(3)$:

- The logarithmic geodesic
- The double geodesic

[^1]: [https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf](https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf)

### Logarithmic Geodesic

Given two rotation matrices $\mathbf R_A, \mathbf R_B \in \mathbf{SO}(3)$, the angular distance between their axes of rotation is

$$
    d_\theta(\mathbf R_A, \mathbf R_B) 
    = \arccos \left( \frac{\mathrm{trace}(\mathbf R_A^T \mathbf R_B) - 1}{2} \right)
    = \| \log (\mathbf R_A^T \mathbf R_B) \| \,,
$$

where $\log(\cdot)$ is the logarithm map on $\mathbf{SO}(3)$.[^2]
Using the logarithm map on $\mathbf{SE}(3)$, this generalizes to a geodesic loss function on camera poses ${\mathbf T}_A, {\mathbf T}_B \in \mathbf{SE}(3)$:

$$
    \mathcal L_{\mathrm{log}}({\mathbf T}_A, {\mathbf T}_B) = \| \log({\mathbf T}_A^{-1} {\mathbf T}_B) \| \,.
$$

[^2]: [https://www.cs.cmu.edu/~cga/dynopt/readings/Rmetric.pdf](https://www.cs.cmu.edu/~cga/dynopt/readings/Rmetric.pdf)

In [None]:
#| export
from diffdrr.pose import RigidTransform, convert


class LogGeodesicSE3(torch.nn.Module):
    """
    Calculate the distance between transforms in the log-space of SE(3).
    """

    def __init__(self):
        super().__init__()

    def forward(
        self,
        pose_1: RigidTransform,
        pose_2: RigidTransform,
    ) -> Float[torch.Tensor, "b"]:
        return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)

In [None]:
# SE(3) distance
geodesic_se3 = LogGeodesicSE3()

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.1, torch.pi]]),
    torch.zeros(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

geodesic_se3(pose_1, pose_2)

tensor([1.7354])

## Double Geodesic

We can also formulate a geodesic distance on $\mathbf{SE}(3)$ with units of length. Using the camera's focal length $f$, we convert the angular distance to an arc length:

$$
    d_\theta(\mathbf R_A, \mathbf R_B; f) = \frac{f}{2} d_\theta(\mathbf R_A, \mathbf R_B) \,.
$$

When combined with the Euclidean distance on the translations $d_t(\mathbf t_A, \mathbf t_B) = \| \mathbf t_A - \mathbf t_B \|$, this yields the *double geodesic* loss on $\mathbf{SE}(3)$:[^3]

$$
    \mathcal L_{\mathrm{geo}}({\mathbf T}_A, {\mathbf T}_B; f) = \sqrt{d^2_\theta(\mathbf R_A, \mathbf R_B; f) + d^2_t(\mathbf t_A, \mathbf t_B)} \,.
$$

[^3]: [https://rpk.lcsr.jhu.edu/wp-content/uploads/2017/08/Partial-Bi-Invariance-of-SE3-Metrics1.pdf](https://rpk.lcsr.jhu.edu/wp-content/uploads/2017/08/Partial-Bi-Invariance-of-SE3-Metrics1.pdf)

In [None]:
#| export
from diffdrr.pose import so3_log_map


class DoubleGeodesicSE3(torch.nn.Module):
    """
    Calculate the angular and translational geodesics between two SE(3) transformation matrices.
    """

    def __init__(
        self,
        sdd: float,  # Source-to-detector distance
        eps: float = 1e-6,  # Avoid overflows in sqrt
    ):
        super().__init__()
        self.sdr = sdd / 2
        self.eps = eps

        self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(r1.transpose(-1, -2) @ r2).norm(dim=-1)
        self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)

    def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
        r1, t1 = pose_1.convert("matrix")
        r2, t2 = pose_2.convert("matrix")
        rot = self.rot_geo(r1, r2)
        xyz = self.xyz_geo(t1, t2)
        dou = (rot.square() + xyz.square() + self.eps).sqrt()
        return rot, xyz, dou

In [None]:
# Angular distance and translational distance both in mm
double_geodesic = DoubleGeodesicSE3(1020.0)

pose_1 = convert(
    torch.tensor([[0.1, 1.0, torch.pi]]),
    torch.ones(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)
pose_2 = convert(
    torch.tensor([[0.1, 1.1, torch.pi]]),
    torch.zeros(1, 3),
    parameterization="euler_angles",
    convention="ZYX",
)

double_geodesic(pose_1, pose_2)  # Angular, translational, double geodesics

(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()