In [None]:
#| default_exp drr

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
from fastcore.basics import patch

from diffdrr.detector import Detector
from diffdrr.siddon import siddon_raycast
from diffdrr.utils import reshape_subsampled_drr

## DRR
`DRR` is a PyTorch module that compues differentiable digitally reconstructed radiographs.

### X-ray pose parameters

The viewing angle for the DRR (known generally in computer graphics as *pose parameters*) is parameterized by the following parameters:

- `sdr`   : Source-to-Detector radius (half of the source-to-detector distance)
- `theta` : Azimuthal angle
- `phi`   : Polar angle
- `gamma` : Plane rotation angle
- `bx`    : X-dir translation
- `by`    : Y-dir translation
- `bz`    : Z-dir translation

`(bx, by, bz)` are translational parameters and `(theta, phi, gamma)` are rotational parameters. The rotational parameters are detailed in [Spherical Coordiantes Tutorial](https://vivekg.dev/DiffDRR/tutorials/spherical.html).

In [None]:
#| export
class DRR(nn.Module):
    """PyTorch module that computes differentiable digitally reconstructed radiographs."""

    def __init__(
        self,
        volume: np.ndarray,  # CT volume
        spacing: np.ndarray,  # Dimensions of voxels in the CT volume
        sdr: float,  # Source-to-detector radius for the C-arm (half of the source-to-detector distance)
        height: int,  # Height of the rendered DRR
        delx: float,  # X-axis pixel size
        width: int | None = None,  # Width of the rendered DRR (if not provided, set to `height`)
        dely: float | None = None,  # Y-axis pixel size (if not provided, set to `delx`)
        p_subsample: float | None = None,  # Proportion of pixels to randomly subsample
        reshape: bool = True,  # Return DRR with shape (b, h, w)
        convention: str = "diffdrr",  # Either `diffdrr` or `deepdrr`, order of basis matrix multiplication
    ):
        super().__init__()

        # Initialize the X-ray detector
        width = height if width is None else width
        dely = delx if dely is None else dely
        self.detector = Detector(
            sdr,
            height,
            width,
            delx,
            dely,
            n_subsample=int(height * width * p_subsample)
            if p_subsample is not None
            else None,
            convention=convention,
        )

        # Initialize the volume
        self.register_buffer("spacing", torch.tensor(spacing))
        self.register_buffer("volume", torch.tensor(volume).flip([0]))
        self.reshape = reshape

        # Dummy tensor for device and dtype
        self.register_buffer("dummy", torch.tensor([0.0]))

    def reshape_transform(self, img, batch_size):
        if self.reshape:
            if self.detector.n_subsample is None:
                img = img.view(-1, 1, self.detector.height, self.detector.width)
            else:
                img = reshape_subsampled_drr(img, self.detector, batch_size)
        return img

The forward pass of the `DRR` module generated DRRs from the input CT volume. The pose parameters (i.e., viewing angles) from which the DRRs are generated are passed to the forward call.

In [None]:
#| export
@patch
def forward(self: DRR, rotations: torch.Tensor, translations: torch.Tensor):
    """Generate DRR with rotations and translations parameters."""
    assert len(rotations) == len(translations)
    batch_size = len(rotations)
    source, target = self.detector.make_xrays(
        rotations=rotations,
        translations=translations,
    )
    img = siddon_raycast(source, target, self.volume, self.spacing)
    return self.reshape_transform(img, batch_size=batch_size)

## Registration

The `Registration` module uses the `DRR` module to perform differentiable 2D-to-3D registration. Initial guesses for the pose parameters are as stored as `nn.Parameters` of the module. This allows the pose parameters to be optimized with any PyTorch optimizer. Furthermore, this design choice allows `DRR` to be used purely as a differentiable renderer.

In [None]:
#| export
class Registration(nn.Module):
    """Perform automatic 2D-to-3D registration using differentiable rendering."""

    def __init__(
        self,
        drr: DRR,
        rotations: torch.Tensor,
        translations: torch.Tensor,
    ):
        super().__init__()
        self.drr = drr
        self.rotations = nn.Parameter(rotations)
        self.translations = nn.Parameter(translations)

    def forward(self):
        return self.drr(self.rotations, self.translations)

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()