In [None]:
#| default_exp registration

In [None]:
#| hide
from nbdev.showdoc import *

## Registration

The `Registration` module uses the `DRR` module to perform differentiable 2D-to-3D registration. Initial guesses for the pose parameters are as stored as `nn.Parameters` of the module. This allows the pose parameters to be optimized with any PyTorch optimizer. Furthermore, this design choice allows `DRR` to be used purely as a differentiable renderer.

In [None]:
#| export
import torch
import torch.nn as nn

from diffdrr.drr import DRR
from diffdrr.pose import convert


class Registration(nn.Module):
    """Perform automatic 2D-to-3D registration using differentiable rendering."""

    def __init__(
        self,
        drr: DRR,  # Preinitialized DRR module
        rotation: torch.Tensor,  # Initial guess for rotations
        translation: torch.Tensor,  # Initial guess for translations
        parameterization: str,  # Specifies the representation of the rotation
        convention: str = None,  # If `parameterization` is `euler_angles`, specify convention
    ):
        super().__init__()
        self.drr = drr
        self._rotation = nn.Parameter(rotation)
        self._translation = nn.Parameter(translation)
        self.parameterization = parameterization
        self.convention = convention

    def forward(self, **kwargs):
        return self.drr(self.pose, **kwargs)

    @property
    def pose(self):
        return convert(
            self._rotation,
            self._translation,
            parameterization=self.parameterization,
            convention=self.convention,
        )

    @property
    def rotation(self):
        return self._rotation

    @property
    def translation(self):
        return self._translation

## Pose Regressor

We perform patient-specific X-ray to CT registration by pre-training an encoder/decoder architecture. The encoder, `PoseRegressor`, is comprised of two networks:

1. A pretrained backbone (i.e., convolutional or transformer network) that extracts features from an input X-ray image.
2. A set of two linear layers that decodes these features into camera pose parameters (a rotation and a translation).

The decoder is `diffdrr.drr.DRR`, which renders a simulated X-ray from the predicted pose parameters. Because our renderer is differentiable, a loss metric on the simulated X-ray and the input X-ray can be backpropogated to the encoder.

In [None]:
#| export
import timm

from diffdrr.pose import RigidTransform


class PoseRegressor(torch.nn.Module):
    """
    A PoseRegressor is comprised of a pretrained backbone model that extracts features
    from an input X-ray and two linear layers that decode these features into rotational
    and translational camera pose parameters, respectively.
    """

    def __init__(
        self,
        model_name,
        parameterization,
        convention=None,
        pretrained=False,
        height=256,
        **kwargs,
    ):
        super().__init__()

        self.parameterization = parameterization
        self.convention = convention
        n_angular_components = N_ANGULAR_COMPONENTS[parameterization]

        # Get the size of the output from the backbone
        self.backbone = timm.create_model(
            model_name,
            pretrained,
            num_classes=0,
            in_chans=1,
            **kwargs,
        )
        output = self.backbone(torch.randn(1, 1, height, height)).shape[-1]
        self.xyz_regression = torch.nn.Linear(output, 3)
        self.rot_regression = torch.nn.Linear(output, n_angular_components)

    def forward(self, x):
        x = self.backbone(x)
        rot = self.rot_regression(x)
        xyz = self.xyz_regression(x)
        return rot, xyz

In [None]:
#| exporti
N_ANGULAR_COMPONENTS = {
    "axis_angle": 3,
    "euler_angles": 3,
    "se3_log_map": 3,
    "quaternion": 4,
    "rotation_6d": 6,
    "rotation_9d": 9,
    "rotation_10d": 10,
    "quaternion_adjugate": 10,
}

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()