In [None]:
#| default_exp renderers

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch.nn.functional import grid_sample

## Siddon's Method

DRRs are generated by modeling the geometry of an idealized projectional radiography system.
Let $\mathbf s \in \mathbb R^3$ be the X-ray source and $\mathbf p \in \mathbb R^3$ be a target pixel on the detector plane.
Then, $R(\alpha) = \mathbf s + \alpha (\mathbf p - \mathbf s)$ is a ray that originates from $\mathbf s$ ($\alpha=0$), passes through the imaged volume, and hits the detector plane at $\mathbf p$ ($\alpha=1$).
The proportion of energy attenuation experienced by the X-ray at the time it reaches pixel $\mathbf p$ is given by the following line integral:

\begin{equation}
    E(R) = \|\mathbf p - \mathbf s\|_2 \int_0^1 \mathbf V \left( \mathbf s + \alpha (\mathbf p - \mathbf s) \right) \, \mathrm d\alpha \,,
\end{equation}

where $\mathbf V : \mathbb R^3 \mapsto \mathbb R$ is the imaged volume.
The units term $\|\mathbf p - \mathbf s\|_2$ serves to cancel out the units of $\mathbf V(\cdot)$, reciprocal length, such that the final proportion $E$ is unitless.
For DRR synthesis, $\mathbf V$ is approximated by a discrete 3D CT volume, and the first equation becomes

\begin{equation}
    E(R) = \|\mathbf p - \mathbf s\|_2 \sum_{m=1}^{M-1} (\alpha_{m+1} - \alpha_m) \mathbf V \left[ \mathbf s + \frac{\alpha_{m+1} + \alpha_m}{2} (\mathbf p - \mathbf s) \right] \,,
\end{equation}

where $\alpha_m$ parameterizes the locations where ray $R$ intersects one of the orthogonal planes comprising the CT volume, and $M$ is the number of such intersections.

Siddon's method provides a parametric method to identify the plane intersections $\{\alpha_m\}_{m=1}^M$.
Let $\Delta X$ be the CT voxel size in the $x$-direction and $b_x$ be the location of the $0$-th plane in this direction.
Then the intersection of ray $R$ with the $i$-th plane in the $x$-direction is given by
\begin{equation}
    \alpha_x(i) = \frac{b_x + i \Delta X - \mathbf s_x}{\mathbf p_x - \mathbf s_x} ,
\end{equation}
with analogous expressions for $\alpha_y(\cdot)$ and $\alpha_z(\cdot)$.

We can use this equation to compute the values $\mathbf \alpha_x$ for all the intersections between $R$ and the planes in the $x$-direction:
\begin{equation}
    \mathbf\alpha_x = \{ \alpha_x(i_{\min}), \dots, \alpha_x(i_{\max}) \} ,
\end{equation}
where $i_{\min}$ and $i_{\max}$ denote the first and last intersections of $R$ with the $x$-direction planes.

Defining $\mathbf\alpha_y$ and $\mathbf\alpha_z$ analogously, we construct the array
\begin{equation}
    \mathbf\alpha = \mathrm{sort}(\mathbf\alpha_x, \mathbf\alpha_y, \mathbf\alpha_z) ,
\end{equation}
which contains $M$ values of $\alpha$ parameterizing the intersections between $R$ and the orthogonal $x$-, $y$-, and $z$-directional planes. 
We substitute values in the sorted set $\mathbf\alpha$ into the first equation to evaluate $E(R)$, which corresponds to the intensity of pixel $\mathbf p$ in the synthesized DRR.

In [None]:
#| export
class Siddon(torch.nn.Module):
    """Differentiable X-ray renderer implemented with Siddon's method for exact raytracing."""

    def __init__(
        self,
        mode: str = "nearest",  # Interpolation mode for grid_sample
        stop_gradients_through_grid_sample: bool = False,  # Apply torch.no_grad when calling grid_sample
        filter_intersections_outside_volume: bool = True,  # Use alphamin/max to filter the intersections
        eps: float = 1e-8,  # Small constant to avoid div by zero errors
    ):
        super().__init__()
        self.mode = mode
        self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample
        self.filter_intersections_outside_volume = filter_intersections_outside_volume
        self.eps = eps

    def dims(self, volume):
        return torch.tensor(volume.shape).to(volume)

    def forward(
        self,
        volume,
        source,
        target,
        align_corners=False,
        mask=None,
    ):
        dims = self.dims(volume)

        # Calculate the intersections of each ray with the planes comprising the CT volume
        alphas = _get_alphas(
            source,
            target,
            dims,
            self.eps,
            self.filter_intersections_outside_volume,
        )

        # Calculate the midpoint of every pair of adjacent intersections
        # These midpoints lie exclusively in a single voxel
        alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2

        # Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3)
        xyzs = _get_xyzs(alphamid, source, target, dims, self.eps)

        # Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel
        if self.stop_gradients_through_grid_sample:
            with torch.no_grad():
                img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
        else:
            img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)

        # Weight each intersected voxel by the length of the ray's intersection with the voxel
        intersection_length = torch.diff(alphas, dim=-1)
        img = img * intersection_length

        # Handle optional masking
        if mask is None:
            img = img.sum(dim=-1)
            img = img.unsqueeze(1)
        else:
            # Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
            # https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614
            B, D, _ = img.shape
            C = int(mask.max().item() + 1)
            channels = _get_voxel(
                mask, xyzs, self.mode, align_corners=align_corners
            ).long()
            img = (
                torch.zeros(B, C, D)
                .to(img)
                .scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
            )

        return img

In [None]:
#| export
def _get_alphas(source, target, dims, eps, filter_intersections_outside_volume):
    """Calculates the parametric intersections of each ray with the planes of the CT volume."""
    # Parameterize the parallel XYZ planes that comprise the CT volumes
    alphax = torch.arange(dims[0] + 1).to(source)
    alphay = torch.arange(dims[1] + 1).to(source)
    alphaz = torch.arange(dims[2] + 1).to(source)

    # Calculate the parametric intersection of each ray with every plane
    sx, sy, sz = source[..., 0:1], source[..., 1:2], source[..., 2:3]
    tx, ty, tz = target[..., 0:1], target[..., 1:2], target[..., 2:3]
    alphax = (alphax.expand(len(source), 1, -1) - sx) / (tx - sx + eps)
    alphay = (alphay.expand(len(source), 1, -1) - sy) / (ty - sy + eps)
    alphaz = (alphaz.expand(len(source), 1, -1) - sz) / (tz - sz + eps)
    alphas = torch.cat([alphax, alphay, alphaz], dim=-1)

    # Sort the intersections
    alphas = torch.sort(alphas, dim=-1).values
    if filter_intersections_outside_volume:
        alphas = _filter_intersections_outside_volume(alphas, source, target, dims, eps)
    return alphas


def _filter_intersections_outside_volume(alphas, source, target, dims, eps):
    """Remove interesections that are outside of the volume for all rays."""
    alphamin, alphamax = _get_alpha_minmax(source, target, dims, eps)
    good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)
    alphas = alphas[..., good_idxs.any(dim=[0, 1])]
    return alphas


def _get_alpha_minmax(source, target, dims, eps):
    """Calculate the first and last intersections of each ray with the volume."""
    sdd = target - source + eps

    alpha0 = (torch.zeros(3).to(source) - source) / sdd
    alpha1 = ((dims + 1).to(source) - source) / sdd
    alphas = torch.stack([alpha0, alpha1])

    alphamin = alphas.min(dim=0).values.max(dim=-1).values.unsqueeze(-1)
    alphamax = alphas.max(dim=0).values.min(dim=-1).values.unsqueeze(-1)

    alphamin = torch.where(alphamin < 0.0, 0.0, alphamin)
    alphamax = torch.where(alphamax > 1.0, 1.0, alphamax)
    return alphamin, alphamax


def _get_xyzs(alpha, source, target, dims, eps):
    """Given a set of rays and parametric coordinates, calculates the XYZ coordinates."""
    # Get the world coordinates of every point parameterized by alpha
    xyzs = (
        source.unsqueeze(-2)
        + alpha.unsqueeze(-1) * (target - source + eps).unsqueeze(2)
    ).unsqueeze(1)

    # Normalize coordinates to be in [-1, +1] for grid_sample
    xyzs = 2 * xyzs / dims - 1
    return xyzs


def _get_voxel(volume, xyzs, mode, align_corners):
    """Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates."""
    batch_size = len(xyzs)
    voxels = grid_sample(
        input=volume.permute(2, 1, 0)[None, None].expand(batch_size, -1, -1, -1, -1),
        grid=xyzs,
        mode=mode,
        align_corners=align_corners,
    )[:, 0, 0]
    return voxels

## Trilinear interpolation

Instead of computing the exact line integral over the voxel grid (i.e., Siddon's method), we can sample colors at points along the each ray using trilinear interpolation.

Now, the rendering equation is
\begin{equation}
    E(R) = \|\mathbf p - \mathbf s\|_2\frac{\alpha_{\max} - \alpha_{\min}}{M-1} \sum_{m=1}^{M} \mathbf V \left[ \mathbf s + \alpha_m (\mathbf p - \mathbf s) \right] \,,
\end{equation}
where $\mathbf V[\cdot]$ is the trilinear interpolation function and $M$ is the number of points sampled per ray.

In [None]:
#| export
class Trilinear(torch.nn.Module):
    """Differentiable X-ray renderer implemented with trilinear interpolation."""

    def __init__(
        self,
        mode: str = "bilinear",  # Interpolation mode for grid_sample
        eps: float = 1e-8,  # Small constant to avoid div by zero errors
    ):
        super().__init__()
        self.mode = mode
        self.eps = eps

    def dims(self, volume):
        return torch.tensor(volume.shape).to(volume)

    def forward(
        self,
        volume,
        source,
        target,
        n_points=500,
        align_corners=False,
        mask=None,
        alphamin=None,
        alphamax=None,
    ):
        dims = self.dims(volume)

        # Sample points along the rays and rescale to [-1, 1]
        if alphamin is None or alphamax is None:
            alphamin, alphamax = _get_alpha_minmax(source, target, dims, self.eps)
            alphamin = alphamin.min()
            alphamax = alphamax.max()
        alphas = torch.linspace(0, 1, n_points)[None, None].to(volume)
        alphas = alphas * (alphamax - alphamin) + alphamin

        # Render the DRR
        # Get the XYZ coordinate of each alpha, normalized for grid_sample
        xyzs = _get_xyzs(alphas, source, target, dims, self.eps)

        # Sample the volume with trilinear interpolation
        img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
        
        # Multiply by the step size to compute the rectangular rule for integration
        step_size = (alphamax - alphamin) / (n_points - 1)
        img = img * step_size

        # Handle optional masking
        if mask is None:
            img = img.sum(dim=-1).unsqueeze(1)
        else:
            B, D, _ = img.shape
            C = int(mask.max().item() + 1)
            channels = _get_voxel(
                mask, xyzs, self.mode, align_corners=align_corners
            ).long()
            img = (
                torch.zeros(B, C, D)
                .to(img)
                .scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
            )

        return img

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()