In [1]:
import torch

from src import read_dicom

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class Siddon:
    def __init__(self, spacing, isocenter, volume, device=device):
        self.spacing = torch.tensor(spacing, device=device)
        self.isocenter = torch.tensor(isocenter, device=device)
        self.dims = torch.tensor(volume.shape, device=device) + 1.0
        self.volume = torch.tensor(volume, device=device)
        self.device = device

    def get_alpha(self, planes, source, target):
        return (self.isocenter + planes * self.spacing - source) / (target - source)

    def get_alpha_minmax(self, source, target):
        planes = torch.tensor([[1, 1, 1], self.dims], device=device) - 1
        alphas = self.get_alpha(planes, source, target)
        minis = torch.min(alphas, dim=0).values
        maxis = torch.max(alphas, dim=0).values
        alphamin = torch.max(minis)
        alphamax = torch.min(maxis)
        return alphamin, alphamax, minis, maxis

    def get_coords(self, alpha, source, target):
        pxyz = source + alpha * (target - source)
        return (pxyz - self.isocenter) / self.spacing

    def initialize(self, source, target):
        alphamin, alphamax, minis, maxis = self.get_alpha_minmax(source, target)
        idxmin = self.get_coords(alphamin, source, target)
        idxmax = self.get_coords(alphamax, source, target)

        # source < target
        # get minidx
        a = (alphamin == minis) * torch.ones(3)
        b = (alphamin != minis) * (idxmin + 1).trunc()
        # get maxidx
        c = (alphamax == maxis) * (self.dims - 1)
        d = (alphamax != maxis) * idxmax.trunc()
        # source > target
        # get minidx
        e = (alphamax == maxis) * torch.ones(3)
        f = (alphamax != maxis) * (idxmax + 1).trunc()
        # get maxidx
        g = (alphamin == minis) * (self.dims - 2)
        h = (alphamin != minis) * idxmin.trunc()

        minidx = (source < target) * (a + b) + (source >= target) * (e + f)
        maxidx = (source < target) * (c + d) + (source >= target) * (g + h)

        return alphamin, alphamax, minidx, maxidx

    def get_voxel_idx(self, alpha, source, target):
        idxs = self.get_coords(alpha, source, target).trunc().int()
        return list(idxs)

    def raytrace(self, source, target):

        # Get the update conditions
        ones = torch.ones(3, dtype=int)
        update_idxs = (source < target) * ones - (source >= target) * ones
        update_alpha = self.spacing / torch.abs(target - source)

        # Initialize the loop
        alphamin, alphamax, minidx, maxidx = self.initialize(source, target)
        alphacurr = alphamin

        # Get the potential next steps in the xyz planes
        steps = self.get_alpha(minidx, source, target)
        idx = steps.argmin()  # Find the smallest step
        alphanext = steps[idx]  # I.e., the next plane

        alphamid = (alphacurr + alphanext) / 2
        voxel = self.get_voxel_idx(alphamid, source, target)

        step_length = alphanext - alphacurr
        d12 = step_length * self.volume[voxel]
        alphacurr = alphanext.clone()

        # Loop over all voxels that the ray passes through
        while alphacurr < alphamax and not torch.isclose(alphacurr, alphamax):
            voxel[idx] += update_idxs[idx]
            steps[idx] += update_alpha[idx]
            idx = steps.argmin()
            alphanext = steps[idx]
            step_length = alphanext - alphacurr
            d12 += step_length * self.volume[voxel]
            alphacurr = alphanext.clone()

        return d12

In [57]:
class Detector:
    def __init__(self, source, center, height, width, delx, dely, device):
        self.source = torch.tensor(source, device=device, requires_grad=True)
        self.center = torch.tensor(center, device=device, requires_grad=True)
        self.height = height
        self.width = width
        self.delx = delx
        self.dely = dely
        self.device = device

    def make_xrays(self):

        # Get the detector plane normal vector
        normal = self.source - self.center
        normal = normal / torch.norm(normal)
        u, v, w = _get_basis(normal)
        assert torch.allclose(w, normal)

        # Construt the detector plane
        t = (torch.arange(-self.height // 2, self.height // 2, device=self.device) + 1) * self.delx
        s = (torch.arange(-self.width // 2, self.width // 2, device=self.device) + 1) * self.dely
        coefs = torch.cartesian_prod(t, s).reshape(self.height, self.width, 2)
        targets = coefs @ torch.stack([u, v])
        targets += self.center
        return targets


def _get_basis(w):
    t = _get_noncollinear_vector(w)
    t = t / torch.norm(t)
    u = torch.cross(t, w)
    v = torch.cross(u, w)
    return u, v, w


def _get_noncollinear_vector(w):
    t = w.clone()
    i = torch.argmin(torch.abs(w))
    t[i] = 1
    return t

In [58]:
volume, spacing = read_dicom("../data/cxr/")
isocenter=[0.0, 0.0, 0.0]

siddon = Siddon(spacing, isocenter, volume)

In [59]:
source = [-10., -10., -15.]
center = [400., 375., 350.]

detector = Detector(
    source=source,
    center=center,
    height=25,
    width=25,
    delx=1.,
    dely=1.,
    device=device
)

rays = detector.make_xrays()

tensor([ 6.7942e-01, -7.2354e-01, -8.5771e-09], device='cuda:0',
       grad_fn=<CrossBackward>) tensor([ 0.3939,  0.3699, -0.8326], device='cuda:0', grad_fn=<CrossBackward>) tensor([-0.6115, -0.5742, -0.5444], device='cuda:0', grad_fn=<DivBackward0>)


In [61]:
rays.shape

torch.Size([3, 25, 25])

In [53]:
t = torch.arange(-detector.height // 2, detector.height // 2, device=detector.device)
t = (t + 1) * detector.delx

s = torch.arange(-detector.width // 2, detector.width // 2, device=detector.device)
s = (s + 1) * detector.dely

coef = torch.cartesian_prod(t, s).reshape(detector.height, detector.width, 2)

In [51]:
a = 