Ghost Tomography reconstruction example

In [None]:
from collections.abc import Sequence
import corrct as cct
import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

# local packages
import reconstruction_utils as rec_utils

%matplotlib widget

In [None]:
FoV_size = 101
compression_ratio = 5

phantom, background, resolutions = cct.testing.create_phantom_nuclei3d(FoV_size=FoV_size)
phantom -= background
phantom /= phantom.max() - phantom.min()

mask_shape = np.array(phantom.shape)[[-3, -1]]
rot_angles_deg = np.linspace(0, 180, 60, endpoint=False)
rot_angles_rad = np.deg2rad(rot_angles_deg)

vol_geom = cct.models.get_vol_geom_from_volume(phantom)
prj_geom = cct.models.get_prj_geom_parallel(rot_axis_shift_pix=1.0)

masks_gen = cct.struct_illum.MaskGeneratorHalfGaussian(mask_shape)
mc = masks_gen.generate_collection(buckets_fraction=1 / compression_ratio)

with cct.projectors.ProjectorUncorrected(vol_geom=vol_geom, angles_rot_rad=rot_angles_rad, prj_geom=prj_geom) as prj_tomo:
    prj_gt = cct.struct_illum.ProjectorGhostTomography(mc, prj_tomo)

    buckets_clean = prj_gt(phantom)

slice_ind = phantom.shape[0] // 2

fig, axs = plt.subplots(1, 2)
axs[0].imshow(phantom[slice_ind])
axs[1].plot(buckets_clean[0])
fig.tight_layout()

In [None]:
dwell_time = 1
mean_vox_emit_photon_flux = 1e3

ph_per_realization = dwell_time * mean_vox_emit_photon_flux

add_poisson = True

buckets_noise, _, _ = cct.testing.add_noise(buckets_clean, num_photons=ph_per_realization, add_poisson=add_poisson)
buckets_noise /= ph_per_realization

print(f"SNR: {np.std(buckets_noise - buckets_clean)} ({np.std(buckets_noise - buckets_clean)/ buckets_clean.std():%})")

## Reconstructions

In [None]:
from typing import Callable

# Algorithms parameters
iterations = 5000
lower_limit = 0.0

reg_type_2d = cct.regularizers.Regularizer_TV2D
reg_type_3d = cct.regularizers.Regularizer_TV3D

fit_reg_weight = True


def find_reg_weight(
    A: cct.operators.ProjectorOperator,
    data: NDArray,
    iterations: int,
    reg: Callable[[float], cct.regularizers.BaseRegularizer],
    lambda_range: tuple[float, float] | NDArray,
    data_term: cct.data_terms.DataFidelityBase | str = "l2",
    parallel_eval: bool = False,
) -> float:
    # Instantiates the solver object, that is later used for computing the reconstruction
    def solver_spawn(lam_reg):
        # Using the PDHG solver from Chambolle and Pock
        return cct.solvers.PDHG(verbose=True, data_term=data_term, regularizer=reg(lam_reg), data_term_test=data_term)

    # Computes the reconstruction for a given solver and a given cross-validation data mask
    def solver_call(solver, b_test_mask=None):
        return solver(A, data, iterations, lower_limit=lower_limit, precondition=True, b_test_mask=b_test_mask)

    # Create the regularization weight finding helper object (using cross-validation)
    reg_help_cv = cct.param_tuning.CrossValidation(
        data.shape, num_averages=1, verbose=True, plot_result=True, parallel_eval=parallel_eval
    )
    reg_help_cv.solver_spawning_function = solver_spawn
    reg_help_cv.solver_calling_function = solver_call

    # Define the regularization weight range
    lams_reg = reg_help_cv.get_lambda_range(lambda_range[0], lambda_range[1], num_per_order=4)

    # Compute the loss function values for all the regularization weights
    f_avgs, _, _ = reg_help_cv.compute_loss_values(lams_reg, return_recs=False)

    # parabolic fit of minimum over the computer curve
    lam_min, _ = reg_help_cv.fit_loss_min(lams_reg, f_avgs)

    return lam_min

### 2-step approach

Let's first reconstruct projections

In [None]:
prj_gi = cct.struct_illum.ProjectorGhostImaging(mc)
if fit_reg_weight:
    reg_weight_gi = find_reg_weight(
        prj_gi, buckets_noise[0], iterations, reg=reg_type_2d, lambda_range=(1e-1, 1e4), parallel_eval=False
    )
else:
    reg_weight_gi = 7.370e00

solver_pdhg = cct.solvers.PDHG(verbose=True, regularizer=reg_type_2d(reg_weight_gi))
rec_projs_wvu, _ = solver_pdhg(prj_gi, buckets_noise, iterations=iterations, lower_limit=0.0)

Let's now reconstruct the volume

In [None]:
rec_projs_vwu = rec_projs_wvu.swapaxes(0, 1)

circ_mask = cct.processing.circular_mask(vol_geom.shape_zxy[-2:])
sirt = cct.solvers.SIRT(verbose=True)
with cct.projectors.ProjectorUncorrected(vol_geom, angles_rot_rad=rot_angles_rad, prj_geom=prj_geom) as prj_tomo:
    rec_vol_2s, _ = sirt(prj_tomo, rec_projs_vwu, iterations=100, lower_limit=0.0)

In [None]:
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=[7, 3.75])
axs[0].imshow(phantom[slice_ind])
axs[0].set_title("Phantom")
axs[1].imshow(rec_vol_2s[slice_ind] * circ_mask)
axs[1].set_title(f"2-step: {masks_gen.info()} - 1/{compression_ratio} buckets")
fig.tight_layout()
plt.show(block=False)

### 1-step approach

In [None]:
with cct.projectors.ProjectorUncorrected(vol_geom=vol_geom, angles_rot_rad=rot_angles_rad, prj_geom=prj_geom) as prj_tomo:
    prj_gt = cct.struct_illum.ProjectorGhostTomography(mc, prj_tomo)

    if fit_reg_weight:
        reg_weight_gt = find_reg_weight(
            prj_gt, buckets_noise, iterations, reg=reg_type_3d, lambda_range=(1e-1, 1e3), parallel_eval=False
        )
    else:
        reg_weight_gt = 1.363e01

    solver_pdhg = cct.solvers.PDHG(verbose=True, regularizer=reg_type_3d(reg_weight_gt))
    rec_vol_1s, _ = solver_pdhg(prj_gt, buckets_noise, iterations=iterations, lower_limit=0.0)

In [None]:
fontsize = 14

fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize=[10, 3.85])
fig.suptitle(f"{masks_gen.info()} - 1/{compression_ratio} buckets", fontsize=fontsize + 1)
axs[0].imshow(phantom[slice_ind])
axs[0].set_title("Phantom", fontsize=fontsize)
axs[1].imshow(rec_vol_2s[slice_ind] * circ_mask)
axs[1].set_title("2-step", fontsize=fontsize)
axs[2].imshow(rec_vol_1s[slice_ind] * circ_mask)
axs[2].set_title("1-step", fontsize=fontsize)
for ax in axs:
    ax.tick_params(labelsize=fontsize)
fig.tight_layout()
plt.show(block=False)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

cct.processing.post.plot_frcs([(phantom, rec_vol_2s), (phantom, rec_vol_1s)], labels=["2-step", "1-step"], snrt=0.4142)

print("PSNRs:")
print(f" - 2-step: {psnr(phantom, rec_vol_2s, data_range=1.0)}")
print(f" - 1-step: {psnr(phantom, rec_vol_1s, data_range=1.0)}")
print("SSIMs:")
print(f" - 2-step: {ssim(phantom, rec_vol_2s, data_range=1.0)}")
print(f" - 1-step: {ssim(phantom, rec_vol_1s, data_range=1.0)}")