In [1]:
import matplotlib.pyplot as plt
import torch
from tqdm.autonotebook import tqdm

from src import read_dicom, Detector, Siddon

In [2]:
volume, spacing = read_dicom("../data/cxr/")
volume = torch.tensor(volume)
siddon = Siddon(spacing=spacing, isocenter=[0.0, 0.0, 0.0], volume=volume)

In [3]:
source = [-10., -12., -15.]
target = [400., 375., 350.]

detector = Detector(
    source=source,
    center=target,
    height=5,
    width=5,
    delx=1.,
    dely=1.,
)

targets = detector.make_xrays()
targets.shape  # shape = (detector.height, detector.weight, 3)

torch.Size([5, 5, 3])

In [4]:
# Maybe don't need
def make_big(tensor, height=detector.height, width=detector.width):
    """Copy a 3-vector to a tensor of shape (3, height, width)"""
    return tensor.repeat(height, width, 1).transpose(0, 2)


# Make sources big!
# sources = make_big(detector.source)
# sources

In [5]:
# Source-to-detector distance
sdd = targets - detector.source
sdd.shape

torch.Size([5, 5, 3])

In [6]:
### GET ALPHA MINMAX

# Get intersection with first planes
planes = torch.tensor([0, 0, 0])
alpha0 = (siddon.isocenter + planes * siddon.spacing - detector.source) / sdd

# Get intersection with last planes
planes = siddon.dims - 1
alpha1 = (siddon.isocenter + planes * siddon.spacing - detector.source) / sdd

# Merge intersections
alphas = torch.stack([alpha0, alpha1])
print(alphas.shape)

minis = alphas.min(dim=0).values
maxis = alphas.max(dim=0).values
print(maxis.shape)

alphamin = minis.max(dim=-1, keepdim=True).values
alphamax = maxis.min(dim=-1, keepdim=True).values
print(alphamax.shape)

torch.Size([2, 5, 5, 3])
torch.Size([5, 5, 3])
torch.Size([5, 5, 1])


In [7]:
### GET IDX MINMAX

# Get idx minmax
idxmin = ((alphamin * sdd) + detector.source - siddon.isocenter) / siddon.spacing
idxmax = ((alphamax * sdd) + detector.source - siddon.isocenter) / siddon.spacing

# source < target
# get minidx
a = (alphamin == minis) * torch.ones(3)
b = (alphamin != minis) * (idxmin + 1).trunc()
# get maxidx
c = (alphamax == maxis) * (siddon.dims - 1)
d = (alphamax != maxis) * idxmax.trunc()
# source > target
# get minidx
e = (alphamax == maxis) * torch.ones(3)
f = (alphamax != maxis) * (idxmax + 1).trunc()
# get maxidx
g = (alphamin == minis) * (siddon.dims - 2)
h = (alphamin != minis) * idxmin.trunc()

minidx = (detector.source < targets) * (a + b) + (detector.source >= targets) * (e + f)
maxidx = (detector.source < targets) * (c + d) + (detector.source >= targets) * (g + h)
print(minidx.shape)

torch.Size([5, 5, 3])


In [8]:
# Get the alphas for the entire volume
nx, ny, nz = (siddon.dims - 1).to(int).tolist()
idxx = torch.arange(0, nx)
idxy = torch.arange(0, ny)
idxz = torch.arange(0, nz)
idxs = torch.cartesian_prod(idxx, idxy, idxz).reshape(nx, ny, nz, 3)
# idxs = idxs.permute((3, 0, 1, 2))
print(idxs.shape)

alphas = (idxs - detector.source).expand((5, 5, -1, -1, -1, -1)).permute(2, 3, 4, 0, 1, 5) / sdd
print(alphas.shape)

torch.Size([512, 512, 133, 3])
torch.Size([512, 512, 133, 5, 5, 3])


In [28]:
minidx.shape

torch.Size([5, 5, 3])

In [26]:
minidx

tensor([[[10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.]],

        [[10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.]],

        [[10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.]],

        [[10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.]],

        [[10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.],
         [10.,  6.,  1.]]], dtype=torch.float64, grad_fn=<AddBackward0>)

In [27]:
maxidx

tensor([[[512., 483., 127.],
         [512., 483., 126.],
         [512., 483., 126.],
         [512., 483., 125.],
         [512., 483., 125.]],

        [[512., 481., 126.],
         [512., 481., 126.],
         [512., 481., 125.],
         [512., 481., 125.],
         [512., 481., 125.]],

        [[512., 479., 126.],
         [512., 479., 126.],
         [512., 479., 125.],
         [512., 479., 125.],
         [512., 479., 124.]],

        [[512., 477., 126.],
         [512., 477., 125.],
         [512., 477., 125.],
         [512., 477., 125.],
         [512., 477., 124.]],

        [[512., 476., 126.],
         [512., 476., 125.],
         [512., 476., 125.],
         [512., 476., 124.],
         [512., 476., 124.]]], dtype=torch.float64, grad_fn=<AddBackward0>)