Visualizing the 3D geometry of the X-ray detector in `DiffDRR` can be a helpful sanity check and is useful for debugging. We enable visualization of the `DiffDRR` set up using `PyVista`. The dependencies are `pyvista`, `trame`, and `vtk`.

The 3D visualization functions in `DiffDRR` perform the following:
- Extract a mesh from your CT volume
- Plot a pyramid frustum to visualize the camera pose
- Plot the detector plane with the DRR embedded as a texture
- Draw the principal ray from the X-ray source to the detector plane

Mesh extraction currently supports
- MarchingCubes
- [SurfaceNets](https://www.kitware.com/really-fast-isocontouring/)

In the future, we plan to merge [SurfaceNets](https://www.kitware.com/really-fast-isocontouring/) with [TotalSegmentator](https://github.com/wasserth/TotalSegmentator) such that CT meshes can be rendered from label maps.

In [None]:
import pyvista

pyvista.set_jupyter_backend("trame")

In [None]:
#| code-fold: true
import torch

from diffdrr.data import load_example_ct
from diffdrr.drr import DRR
from diffdrr.visualization import drr_to_mesh, img_to_mesh

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Read in the volume and get the isocenter
volume, spacing = load_example_ct()
bx, by, bz = torch.tensor(volume.shape) * torch.tensor(spacing) / 2
focal_len = 1020.0

# Initialize the DRR module for generating synthetic X-rays
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
drr = DRR(volume, spacing, sdr=focal_len / 2, height=200, delx=2.0).to(device)

In [None]:
plotter = pyvista.Plotter()

# Make a mesh from the CT volume
ct = drr_to_mesh(drr, "surface_nets", threshold=1150, verbose=False)
plotter.add_mesh(ct)

# Make a mesh from the camera and detector plane
rotations = torch.tensor([[torch.pi, 0.0, torch.pi / 2]], device=device)
translations = torch.tensor([[bx - focal_len / 3, by, bz]], device=device)
camera, detector, texture, principal_ray = img_to_mesh(
    drr, rotations, translations, "euler_angles", "ZYX"
)
plotter.add_mesh(camera, show_edges=True, line_width=1.5)
plotter.add_mesh(principal_ray, color="lime", line_width=3)
plotter.add_mesh(detector, texture=texture)

# Make the plot
plotter.add_axes()
plotter.add_bounding_box()

# Soon, this command will natively export a PyVista scene as HTML
# plotter.show(jupyter_backend="html")
plotter.export_html("render.html")

In [None]:
# For now, we manually export the HTML and render it in a new cell
from IPython.display import IFrame

IFrame("render.html", height=500, width=749)