Skip to content

Commit

Permalink
Add support for 3D tomographic projection with astra (#427)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael McCann <mccann@lanl.gov>
Co-authored-by: Li-Ta (Ollie) Lo <ollie@lanl.gov>
Co-authored-by: Brendt Wohlberg <brendt@ieee.org>
  • Loading branch information
4 people committed Jul 11, 2023
1 parent 523d662 commit a07809c
Show file tree
Hide file tree
Showing 13 changed files with 322 additions and 66 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
# Run example test
- name: Run example test
run: |
${GITHUB_WORKSPACE}/examples/scriptcheck.sh -e -d -t
${GITHUB_WORKSPACE}/examples/scriptcheck.sh -e -d -t -g
11 changes: 6 additions & 5 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@ SCICO Release Notes
Version 0.0.4 (unreleased)
----------------------------

New `Function` class for representing array-to-array mappings with more than
Add new `Function` class for representing array-to-array mappings with more than
one input.
New methods and a function for computing Jacobian-vector products for `Operator`
Add new methods and a function for computing Jacobian-vector products for `Operator`
objects.
New proximal ADMM solvers.
New ADMM subproblem solvers for problems involving a sum-of-convolutions
Add new proximal ADMM solvers.
Add new ADMM subproblem solvers for problems involving a sum-of-convolutions
operator.
• Extend support for other ML models including UNet, ODP and MoDL.
• Add functionality for training Flax-based ML models and for data generation.
• Enable diagnostics for ML training loops.
• Change required packages and version numbers, including more recent version
for `flax`.
New methods and a function for computing Jacobian-vector products for
Add new methods and a function for computing Jacobian-vector products for
`Operator` objects.
• Drop support for Python 3.7.
• Add support for 3D tomographic projection with the ASTRA Toolbox.



Expand Down
2 changes: 1 addition & 1 deletion data
3 changes: 3 additions & 0 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Computed Tomography
examples/ct_abel_tv_admm
examples/ct_abel_tv_admm_tune
examples/ct_astra_noreg_pcg
examples/ct_astra_3d_tv_admm
examples/ct_astra_tv_admm
examples/ct_astra_weighted_tv_admm
examples/ct_svmbir_tv_multi
Expand Down Expand Up @@ -126,6 +127,7 @@ Total Variation
examples/ct_abel_tv_admm
examples/ct_abel_tv_admm_tune
examples/ct_astra_tv_admm
examples/ct_astra_3d_tv_admm
examples/ct_astra_weighted_tv_admm
examples/ct_svmbir_tv_multi
examples/deconv_circ_tv_admm
Expand Down Expand Up @@ -193,6 +195,7 @@ ADMM
examples/ct_abel_tv_admm
examples/ct_abel_tv_admm_tune
examples/ct_astra_tv_admm
examples/ct_astra_3d_tv_admm
examples/ct_astra_weighted_tv_admm
examples/ct_svmbir_tv_multi
examples/ct_svmbir_ppp_bm3d_admm_cg
Expand Down
7 changes: 4 additions & 3 deletions docs/source/team.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ Emeritus Developers
Contributors
------------

- `Weijie Gan <https://github.com/wjgancn>`_ (Non-blind variant of DnCNN)
- `Oleg Korobkin <https://github.com/korobkin>`_ (BlockArray improvements)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Andrew Leong <https://scholar.google.com/citations?user=-2wRWbcAAAAJ&hl=en>`_ (Improvements to optics module documentation)
- `Weijie Gan <https://github.com/wjgancn>`_ (Non-blind variant of DnCNN)
- `Saurav Maheshkar <https://github.com/SauravMaheshkar>`_ (Improvements to pre-commit configuration)
- `Yanpeng Yuan <https://github.com/yanpeng7>`_ (ASTRA interface improvements)
- `Li-Ta (Ollie) Lo <https://github.com/ollielo>`_ (ASTRA interface improvements)
35 changes: 21 additions & 14 deletions examples/scriptcheck.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@ Usage: $SCRIPT [-h] [-d]
[-e] Display excerpt of error message on failure
[-d] Skip tests involving additional data downloads
[-t] Skip tests related to learned model training
[-g] Skip tests that need a GPU
EOF
)

OPTIND=1
DISPLAY_ERROR=0
SKIP_DOWNLOAD=0
SKIP_TRAINING=0
while getopts ":hedt" opt; do
SKIP_GPU=0
while getopts ":hedtg" opt; do
case $opt in
h) echo "$USAGE"; exit 0;;
e) DISPLAY_ERROR=1;;
d) SKIP_DOWNLOAD=1;;
t) SKIP_TRAINING=1;;
\?) echo "Error: invalid option -$OPTARG" >&2
h) echo "$USAGE"; exit 0;;
e) DISPLAY_ERROR=1;;
d) SKIP_DOWNLOAD=1;;
t) SKIP_TRAINING=1;;
g) SKIP_GPU=1;;
\?) echo "Error: invalid option -$OPTARG" >&2
echo "$USAGE" >&2
exit 1
;;
Expand Down Expand Up @@ -74,15 +77,19 @@ for f in $SCRIPTPATH/scripts/*.py; do

# Skip problem cases.
if [ $SKIP_DOWNLOAD -eq 1 ] && grep -q '_microscopy' <<< $f; then
printf "%s\n" skipped
continue
printf "%s\n" skipped
continue
fi
if [ $SKIP_TRAINING -eq 1 ]; then
if grep -q '_datagen' <<< $f || grep -q '_train' <<< $f; then
printf "%s\n" skipped
continue
if grep -q '_datagen' <<< $f || grep -q '_train' <<< $f; then
printf "%s\n" skipped
continue
fi
fi
if [ $SKIP_GPU -eq 1 ] && grep -q '_astra_3d' <<< $f; then
printf "%s\n" skipped
continue
fi

# Create temporary copy of script with all algorithm maxiter values set
# to small number and final input statements commented out.
Expand All @@ -95,9 +102,9 @@ for f in $SCRIPTPATH/scripts/*.py; do
else
printf "%s\n" FAILED
retval=1
if [ $DISPLAY_ERROR -eq 1 ]; then
echo "$output" | tail -8 | sed -e 's/^/ /'
fi
if [ $DISPLAY_ERROR -eq 1 ]; then
echo "$output" | tail -8 | sed -e 's/^/ /'
fi
fi

# Remove temporary script.
Expand Down
6 changes: 6 additions & 0 deletions examples/scripts/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Computed Tomography
Parameter Tuning for TV-Regularized Abel Inversion
`ct_astra_noreg_pcg.py <ct_astra_noreg_pcg.py>`_
CT Reconstruction with CG and PCG
`ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_
3D TV-Regularized Sparse-View CT Reconstruction
`ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction
`ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_
Expand Down Expand Up @@ -153,6 +155,8 @@ Total Variation
Parameter Tuning for TV-Regularized Abel Inversion
`ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction
`ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_
3D TV-Regularized Sparse-View CT Reconstruction
`ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_
TV-Regularized Low-Dose CT Reconstruction
`ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_
Expand Down Expand Up @@ -244,6 +248,8 @@ ADMM
Parameter Tuning for TV-Regularized Abel Inversion
`ct_astra_tv_admm.py <ct_astra_tv_admm.py>`_
TV-Regularized Sparse-View CT Reconstruction
`ct_astra_3d_tv_admm.py <ct_astra_3d_tv_admm.py>`_
3D TV-Regularized Sparse-View CT Reconstruction
`ct_astra_weighted_tv_admm.py <ct_astra_weighted_tv_admm.py>`_
TV-Regularized Low-Dose CT Reconstruction
`ct_svmbir_tv_multi.py <ct_svmbir_tv_multi.py>`_
Expand Down
116 changes: 116 additions & 0 deletions examples/scripts/ct_astra_3d_tv_admm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the SCICO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

r"""
3D TV-Regularized Sparse-View CT Reconstruction
===============================================
This example demonstrates solution of a sparse-view, 3D CT
reconstruction problem with isotropic total variation (TV)
regularization
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is
a 3D finite difference operator, and $\mathbf{x}$ is the desired
image.
"""


import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
from scico.linop.radon_astra import TomographicProjector
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

"""
Create a ground truth image and projector.
"""

Nx = 128
Ny = 256
Nz = 64

tangle = create_tangle_phantom(Nx, Ny, Nz)
tangle = jax.device_put(tangle)

n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = TomographicProjector(
tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles
) # Radon transform operator
y = A @ tangle # sinogram


"""
Set up ADMM solver object.
"""
λ = 2e0 # L1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
cg_maxiter = 25 # maximum CG iterations per ADMM iteration

# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=tangle.shape, append=0)
g = λ * functional.L21Norm()

f = loss.SquaredL2Loss(y=y, A=A)

x0 = A.T(y)

solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 5},
)

"""
Run the solver.
"""
print(f"Solving on {device_info()}\n")
solver.solve()
hist = solver.itstat_object.history(transpose=True)
tangle_recon = solver.x

print(
"TV Restruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)

"""
Show the recovered image.
"""
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 5))
plot.imview(tangle[32], title="Ground truth (central slice)", cbar=None, fig=fig, ax=ax[0])

plot.imview(
tangle_recon[32],
title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
fig=fig,
ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.show()

input("\nWaiting for input to close figures and exit")
3 changes: 3 additions & 0 deletions examples/scripts/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Computed Tomography
- ct_abel_tv_admm.py
- ct_abel_tv_admm_tune.py
- ct_astra_noreg_pcg.py
- ct_astra_3d_tv_admm.py
- ct_astra_tv_admm.py
- ct_astra_weighted_tv_admm.py
- ct_svmbir_tv_multi.py
Expand Down Expand Up @@ -95,6 +96,7 @@ Total Variation
- ct_abel_tv_admm.py
- ct_abel_tv_admm_tune.py
- ct_astra_tv_admm.py
- ct_astra_3d_tv_admm.py
- ct_astra_weighted_tv_admm.py
- ct_svmbir_tv_multi.py
- deconv_circ_tv_admm.py
Expand Down Expand Up @@ -150,6 +152,7 @@ ADMM
- ct_abel_tv_admm.py
- ct_abel_tv_admm_tune.py
- ct_astra_tv_admm.py
- ct_astra_3d_tv_admm.py
- ct_astra_weighted_tv_admm.py
- ct_svmbir_tv_multi.py
- ct_svmbir_ppp_bm3d_admm_cg.py
Expand Down
34 changes: 34 additions & 0 deletions scico/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,40 @@ def create_conv_sparse_phantom(Nx: int, Nnz: int) -> Tuple[np.ndarray, np.ndarra
return h, x


def create_tangle_phantom(nx: int, ny: int, nz: int) -> snp.Array:
"""Construct a volume phantom.
Args:
nx: x-size of output.
ny: y-size of output.
nz: z-size of output.
Returns:
An array with shape (nz, ny, nx).
"""
xs = 1.0 * np.linspace(-1.0, 1.0, nx)
ys = 1.0 * np.linspace(-1.0, 1.0, ny)
zs = 1.0 * np.linspace(-1.0, 1.0, nz)

# default ordering for meshgrid is `xy`, this makes inputs of length
# M, N, P will create a mesh of N, M, P. Thus we want ys, zs and xs.
xx, yy, zz = np.meshgrid(ys, zs, xs, copy=True)
xx = 3.0 * xx
yy = 3.0 * yy
zz = 3.0 * zz
values = (
xx * xx * xx * xx
- 5.0 * xx * xx
+ yy * yy * yy * yy
- 5.0 * yy * yy
+ zz * zz * zz * zz
- 5.0 * zz * zz
+ 11.8
) * 0.2 + 0.5
return (values < 2.0).astype(float)


def spnoise(
img: Union[np.ndarray, snp.Array], nfrac: float, nmin: float = 0.0, nmax: float = 1.0
) -> Union[np.ndarray, snp.Array]:
Expand Down
Loading

0 comments on commit a07809c

Please sign in to comment.