In [None]:
from pathlib import Path

import imageio.v3 as iio
import numpy as np
import pyvista
import torch
from IPython.display import IFrame
from tqdm import tqdm

from diffdrr.data import read
from diffdrr.visualization import _make_camera_frustum_mesh, drr_to_mesh

pyvista.start_xvfb()

In [None]:
def load(
    datapath,
    proj_rows,
    proj_cols,
    subsample,
    orbits_to_recon=[1, 2, 3],
    geometry_filename="scan_geom_corrected.geom",
    dark_filename="di000000.tif",
    flat_filenames=["io000000.tif", "io000001.tif"],
):
    """Load and preprocess raw projection data."""

    # Create a numpy array to geometry projection data
    projs = np.zeros((proj_rows, 0, proj_cols), dtype=np.float32)

    # And create a numpy array to projection geometry
    vecs = np.zeros((0, 12), dtype=np.float32)
    orbit = range(0, 1200, subsample)
    n_projs_orbit = len(orbit)

    # Projection file indices, reversed due to portrait mode acquisition
    projs_idx = range(1200, 0, -subsample)

    # Read the images and geometry from each acquisition
    for orbit_id in orbits_to_recon:

        # Load the scan geometry
        orbit_datapath = datapath / f"tubeV{orbit_id}"
        vecs_orbit = np.loadtxt(orbit_datapath / f"{geometry_filename}")
        vecs = np.concatenate((vecs, vecs_orbit[orbit]), axis=0)

        # Load flat-field and dark-fields
        dark = trafo(iio.imread(orbit_datapath / dark_filename))
        flat = np.zeros((2, proj_rows, proj_cols), dtype=np.float32)
        for idx, fn in enumerate(flat_filenames):
            flat[idx] = trafo(iio.imread(orbit_datapath / fn))
        flat = np.mean(flat, axis=0)

        # Load projection data directly on the big projection array
        projs_orbit = np.zeros((n_projs_orbit, proj_rows, proj_cols), dtype=np.float32)
        for idx, fn in enumerate(
            tqdm(projs_idx, desc=f"Loading images (tube {orbit_id})")
        ):
            projs_orbit[idx] = trafo(iio.imread(orbit_datapath / f"scan_{fn:06}.tif"))

        # Preprocess the projection data
        projs_orbit -= dark
        projs_orbit /= flat - dark
        np.log(projs_orbit, out=projs_orbit)
        np.negative(projs_orbit, out=projs_orbit)

        # Permute data to ASTRA convention
        projs_orbit = np.transpose(projs_orbit, (1, 0, 2))
        projs = np.concatenate((projs, projs_orbit), axis=1)
        del projs_orbit

    projs = np.ascontiguousarray(projs)
    return projs, vecs


def get_source_target_vec(vecs: np.ndarray):
    projs_rows = 972  # Image height
    projs_cols = 768  # Image width

    sources = []
    targets = []
    for idx in range(len(vecs)):
        src = vecs[idx, :3]  # X-ray source
        det = vecs[idx, 3:6]  # Center of the detector plane
        u = vecs[idx, 6:9]  # Basis vector one of the detector plane
        v = vecs[idx, 9:12]  # Basis vector two of the detector plane

        src = torch.from_numpy(src).to(torch.float32)
        det = torch.from_numpy(det).to(torch.float32)
        u = torch.from_numpy(u).to(torch.float32)
        v = torch.from_numpy(v).to(torch.float32)

        # Create a canonical basis for the detector plane
        rows = (
            torch.arange(-projs_rows // 2, projs_rows // 2) + 0.5
            if projs_rows % 2 == 0
            else 1.0
        )
        cols = (
            torch.arange(-projs_cols // 2, projs_cols // 2) + 0.5
            if projs_cols % 2 == 0
            else 1.0
        )

        # Change of basis to u and v from the dataset
        i, j = torch.meshgrid(rows, cols, indexing="ij")
        x = torch.einsum("ij, n -> ijn", j, -u)
        y = torch.einsum("ij, n -> ijn", i, v)

        # Move the center of the detector plane to `det`
        source = src
        target = det + x + y
        source = source.expand(target.shape)
        sources.append(source.flip([1, 2]))
        targets.append(target.flip([1, 2]))

    return sources, targets


def trafo(x):
    # x = (x - dark) / (flat - dark)
    # x = np.clip(x, 0, 1)
    # x = -np.log(x)
    x = np.flipud(x)
    x = np.transpose(x)
    return x

In [None]:
subject = read("../data/Walnut1/gt.nii.gz")
mesh = drr_to_mesh(subject, "marching_cubes", threshold=2.5e-2)
mesh = mesh.extract_geometry().triangulate()
mesh.save("walnut.ply")

In [None]:
datapath = Path("../data/Walnut1/Projections")
projs, vecs = load(datapath, 972, 768, subsample=100, orbits_to_recon=[2])
mesh = pyvista.read("walnut.ply")

<TiffTag.fromfile> raised TiffFileError('<tifffile.TiffTag 320 @1493217> invalid value offset 0')
<TiffTag.fromfile> raised TiffFileError('<tifffile.TiffTag 320 @1493217> invalid value offset 0')
<TiffTag.fromfile> raised TiffFileError('<tifffile.TiffTag 320 @1493217> invalid value offset 0')
Loading images (tube 2): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 156.44it/s]


In [None]:
def load(idx, projs, vecs):
    img = projs[:, idx]
    vec = vecs[idx : idx + 1]

    # Make the texture
    img = (img - img.min()) / (img.max() - img.min())
    img = (255.0 * img).astype(np.uint8)
    texture = pyvista.numpy_to_texture(img)

    # Get the source and target
    s, t = get_source_target_vec(vec)
    s = torch.stack(s)
    t = torch.stack(t)
    s = s.view(1, -1, 3)
    t = t.view(1, -1, 3)
    source = s[0, 0].squeeze().cpu().detach().numpy()
    target = t.reshape(972, 768, 3).cpu().detach().numpy()
    principal_ray = pyvista.Line(source, target.mean(axis=0).mean(axis=0))
    camera = _make_camera_frustum_mesh(source, target, size=0.125)
    
    # Make a mesh for the detector plane
    detector = pyvista.StructuredGrid(
        target[..., 0],
        target[..., 1],
        target[..., 2],
    )
    detector.add_field_data([972], "height")
    detector.add_field_data([768], "width")
    detector.texture_map_to_plane(
        origin=target[-1, 0],
        point_u=target[-1, -1],
        point_v=target[0, 0],
        inplace=True,
    )

    return camera, detector, texture, principal_ray

In [None]:
pl = pyvista.Plotter()
pl.add_mesh(mesh)

for idx, projection in enumerate([0, 2, -2]):
    camera, detector, texture, principal_ray = load(projection, projs, vecs)
    if idx == 0:
        opacity = 1.0
    else:
        opacity = 0.75
    pl.add_mesh(camera, show_edges=True, line_width=7.5, opacity=opacity)
    pl.add_mesh(detector, texture=texture, opacity=opacity)
    pl.add_mesh(principal_ray, color="hotpink", line_width=5, opacity=opacity)

pl.export_html("render.html")