In [1]:
import torch
print(torch.__version__)

from src import read_dicom

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


1.12.0


In [2]:
class Detector:
    def __init__(self, source, center, height, width, delx, dely, device):
        self.source = torch.tensor(source, device=device, requires_grad=True)
        self.center = torch.tensor(center, device=device, requires_grad=True)
        self.height = height
        self.width = width
        self.delx = delx
        self.dely = dely
        self.device = device

    def make_xrays(self):

        # Get the detector plane normal vector
        normal = self.source - self.center
        normal = normal / torch.norm(normal)
        u, v, w = _get_basis(normal)
        assert torch.allclose(w, normal)

        # Construt the detector plane
        t = (torch.arange(-self.height // 2, self.height // 2, device=self.device) + 1) * self.delx
        s = (torch.arange(-self.width // 2, self.width // 2, device=self.device) + 1) * self.dely
        coefs = torch.cartesian_prod(t, s).reshape(self.height, self.width, 2)
        targets = coefs @ torch.stack([u, v])
        targets += self.center
        return targets


def _get_basis(w):
    t = _get_noncollinear_vector(w)
    t = t / torch.norm(t)
    u = torch.cross(t, w)
    v = torch.cross(u, w)
    return u, v, w


def _get_noncollinear_vector(w):
    t = w.clone()
    i = torch.argmin(torch.abs(w))
    t[i] = 1
    return t

In [3]:
class Siddon:
    def __init__(self, spacing, isocenter, volume, device=device):
        self.spacing = torch.tensor(spacing, device=device)
        self.isocenter = torch.tensor(isocenter, device=device)
        self.dims = torch.tensor(volume.shape, device=device) + 1.0
        self.volume = torch.tensor(volume, device=device)
        self.device = device

    def get_alpha(self, planes, source, target):
        return (self.isocenter + planes * self.spacing - source) / (target - source)

    def get_alpha_minmax(self, source, target):
        planes = torch.tensor([0, 0, 0], device=self.device)
        alpha0 = (self.isocenter + planes * self.spacing - source) / (target - source)
        planes = self.dims - 1
        alpha1 = (self.isocenter + planes * self.spacing - source) / (target - source)
        alphas = torch.stack([alpha0, alpha1])

        minis = torch.min(alphas, dim=0).values
        maxis = torch.max(alphas, dim=0).values
        alphamin = torch.max(minis, dim=-1).values
        alphamax = torch.min(maxis, dim=-1).values
        return alphamin, alphamax, minis, maxis

    def get_coords(self, alpha, source, target):
        pxyz = source + alpha * (target - source)
        return (pxyz - self.isocenter) / self.spacing

    def initialize(self, source, target):
        alphamin, alphamax, minis, maxis = self.get_alpha_minmax(source, target)
        alphamin = alphamin.expand(3, -1, -1).permute(1, 2, 0)
        alphamax = alphamax.expand(3, -1, -1).permute(1, 2, 0)
        idxmin = self.get_coords(alphamin, source, target)
        idxmax = self.get_coords(alphamax, source, target)

        # source < target
        # get minidx
        a = (alphamin == minis) * torch.ones(3, device=self.device)
        b = (alphamin != minis) * (idxmin + 1).trunc()
        # get maxidx
        c = (alphamax == maxis) * (self.dims - 1)
        d = (alphamax != maxis) * idxmax.trunc()
        # source > target
        # get minidx
        e = (alphamax == maxis) * torch.ones(3, device=self.device)
        f = (alphamax != maxis) * (idxmax + 1).trunc()
        # get maxidx
        g = (alphamin == minis) * (self.dims - 2)
        h = (alphamin != minis) * idxmin.trunc()

        minidx = (source < target) * (a + b) + (source >= target) * (e + f)
        maxidx = (source < target) * (c + d) + (source >= target) * (g + h)

        return alphamin, alphamax, minidx, maxidx

    def get_voxel_idx(self, alpha, source, target):
        idxs = self.get_coords(alpha.expand(3, -1, -1).permute(1, 2, 0), source, target).trunc().long()
        idxs = idxs[:, :, 0] + idxs[:, :, 1] * self.volume.shape[1] + idxs[:, :, 2] * self.volume.shape[2]
        return idxs
    
    def get_voxel(self, voxel_idxs):
        return torch.take(self.volume, voxel_idxs)        

    def raytrace(self, source, target):

        # Get the update conditions
        ones = torch.ones(3, device=self.device, dtype=int)
        update_idxs = (source < target) * ones - (source >= target) * ones
        update_alpha = self.spacing / torch.abs(target - source)

        # Initialize the loop
        alphamin, alphamax, minidx, maxidx = self.initialize(source, target)
        alphacurr = alphamin[:, :, 0].clone()

        # Get the potential next steps in the xyz planes
        steps = self.get_alpha(minidx, source, target)
        idx = steps.argmin(dim=-1)  # Find the smallest step
        alphanext = steps[idx]  # I.e., the next plane

        alphamid = (alphacurr + alphanext) / 2
        voxel = self.get_voxel_idx(alphamid, source, target)

        step_length = alphanext - alphacurr
        d12 = step_length * self.volume[voxel]
        alphacurr = alphanext.clone()

        # Loop over all voxels that the ray passes through
        while alphacurr < alphamax and not torch.isclose(alphacurr, alphamax):
            voxel[idx] += update_idxs[idx]
            steps[idx] += update_alpha[idx]
            idx = steps.argmin()
            alphanext = steps[idx]
            step_length = alphanext - alphacurr
            d12 += step_length * self.volume[voxel]
            alphacurr = alphanext.clone()

        return d12

In [4]:
volume, spacing = read_dicom("../data/cxr/")
isocenter = [0.0, 0.0, 0.0]

source = [-10., -10., -15.]
center = [400., 375., 350.]

detector = Detector(
    source=source,
    center=center,
    height=5,
    width=5,
    delx=5.,
    dely=5.,
    device=device
)

source = detector.source
rays = detector.make_xrays()

siddon = Siddon(spacing, isocenter, volume)

In [53]:
# Get the update conditions
ones = torch.ones(3, dtype=int, device=device)
update_idxs = (source < rays) * ones - (source >= rays) * ones
update_alpha = siddon.spacing / torch.abs(rays - source)

# Initialize the loop
alphamin, alphamax, minidx, maxidx = siddon.initialize(source, rays)
alphamax = alphamax[:, :, 0].clone()
alphacurr = alphamin[:, :, 0].clone()

# Get the potneital next steps in the xyz planes
steps = siddon.get_alpha(minidx, source, rays)
alphanext, idxs = steps.min(dim=-1)
idxs = idxs.unsqueeze(2)

alphamids = (alphacurr + alphanext) / 2
voxelidxs = siddon.get_voxel_idx(alphamids, source, rays)

drr = (alphanext - alphacurr) * siddon.get_voxel(voxelidxs)
alphacurr = alphanext.clone()

In [55]:
steps

tensor([[[0.0409, 0.0402, 0.0469],
         [0.0407, 0.0418, 0.0474],
         [0.0422, 0.0416, 0.0479],
         [0.0420, 0.0432, 0.0485],
         [0.0436, 0.0430, 0.0491]],

        [[0.0406, 0.0406, 0.0469],
         [0.0421, 0.0422, 0.0474],
         [0.0419, 0.0420, 0.0479],
         [0.0417, 0.0418, 0.0485],
         [0.0432, 0.0434, 0.0491]],

        [[0.0402, 0.0410, 0.0469],
         [0.0417, 0.0408, 0.0474],
         [0.0415, 0.0424, 0.0479],
         [0.0430, 0.0422, 0.0485],
         [0.0428, 0.0438, 0.0491]],

        [[0.0416, 0.0414, 0.0469],
         [0.0414, 0.0412, 0.0474],
         [0.0412, 0.0428, 0.0479],
         [0.0427, 0.0426, 0.0485],
         [0.0425, 0.0424, 0.0491]],

        [[0.0413, 0.0418, 0.0469],
         [0.0411, 0.0416, 0.0474],
         [0.0425, 0.0414, 0.0479],
         [0.0423, 0.0430, 0.0485],
         [0.0422, 0.0428, 0.0491]]], device='cuda:0', dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [None]:
torch.rand()