In [3]:
import foam_ct_phantom
import numpy as np
from pathlib import Path
import pylab as pl
from scipy.optimize import minimize
import tomopy
import msdnet
import glob
pl.gray()

# configuration parameters
random_seed = 3141519
phantom_path = Path('phantom.h5') # stores phantom object path
phantom_par_path = Path('projs_par.h5') # stores phantom parameters path
phantom_vol_path = Path('phantom_volume.h5') # generated phantom volume path
phantom_projs_path = Path('phantom_projs.h5')  # simulated projections through phantom path
size = 256  # image size for projections
max_shift = 16  # max amount of shifting (we'll need this extra area for cropping)
sim_size = size + 2*max_shift
pixel_size = 3/size
#thetas = np.linspace(0, np.pi, 128, False)
thetas = np.linspace(np.deg2rad(30), np.deg2rad(150), 128, False)
# shift variables
random_shift_size = 10 # shift up to this many pixels in x and y
center_offset_shift = 6 # center offset amount
random_walk_size = 0.5 # shift up to this many pixels in x and y, cumulative per projection
# msdnet parameters
msdnet_root = Path('msdnet')
msdnet_fn = 'regr_params.h5'
dilation_size = 5 # dilations in network go from [1, dilation_size]
num_layers = 50 # number of NN layers
n_above_and_below = 2 # how many sinograms to load above and below the active one
num_train_datasets = 1#5 # how many tomography datasets to create for training
num_validate_datasets = 1 # how many tomography datasets to create for validation
gpu=True

def generate_shifts(num_projs, random_shift_size, center_offset_shift, random_walk_size, max_shift=None):
    # apply random shifts
    shifts = np.random.random_sample((num_projs, 2)) # (y,x) order
    shifts *= random_shift_size * 2
    shifts -= random_shift_size

    # apply center offset shifts
    shifts[:, 1] += center_offset_shift

    # apply random walk shifts
    random_walk_shifts = np.random.random_sample((num_projs, 2)) # (y,x) order
    random_walk_shifts *= random_walk_size * 2
    random_walk_shifts -= random_walk_size
    random_walk_shifts = np.cumsum(random_walk_shifts, axis=0)
    shifts += random_walk_shifts

    # limit overall shift size to max_shift
    if max_shift is not None:
        np.clip(shifts, -max_shift, max_shift, shifts)
    return shifts


def create_test_dataset(root_path, thetas, random_shift_size, center_offset_shift, random_walk_size, max_shift, size,
                        sim_size, pixel_size, display=False):
    root_path.mkdir(parents=True, exist_ok=True)
    # create phantom
    phantom_path = root_path / "phantom.h5"
    if not phantom_path.exists():
        # Note that nspheres_per_unit is set to a low value to reduce the computation time here.
        # The default value is 100000.
        foam_ct_phantom.FoamPhantom.generate(str(phantom_path), np.random.random_sample(), nspheres_per_unit=1000)
    phantom = foam_ct_phantom.FoamPhantom(str(phantom_path))

    # create phantom volume
    phantom_path = root_path / "phantom_volume.h5"
    geom = foam_ct_phantom.VolumeGeometry(sim_size, sim_size, sim_size, pixel_size)
    if not phantom_vol_path.exists():
        phantom.generate_volume(str(phantom_vol_path), geom)
    vol = foam_ct_phantom.load_volume(str(phantom_vol_path))
    if display:
        pl.imshow(vol[vol.shape[0]//2])
        pl.show()

    # create projections
    proj_geom = foam_ct_phantom.ParallelGeometry(sim_size, sim_size, thetas, pixel_size)
    if not phantom_projs_path.exists():
        phantom.generate_projections(str(phantom_projs_path), proj_geom)
    projs = foam_ct_phantom.load_projections(str(phantom_projs_path))
    if display:
        pl.imshow(projs[projs.shape[0]//2])
        pl.show()

    # Reconstruct with no shifts applied
    if display:
        rec = tomopy.recon(projs, thetas, algorithm="sirt", num_iter=30)
        pl.imshow(rec[rec.shape[0]//2])
        pl.show()

    # create shifts for projections
    shifts = generate_shifts(projs.shape[0], random_shift_size, center_offset_shift, random_walk_size, max_shift)
    if display:
        pl.plot(np.rad2deg(thetas), shifts[:, 1], label="x-shifts")
        pl.plot(np.rad2deg(thetas), shifts[:, 0], label="y-shifts")
        pl.xlabel("deg")
        pl.ylabel("shifts")
        pl.legend()
        pl.show()

    # shift projections and then crop using max_shift (shift in projection space)
    tomopy.exp_minus(projs, out=projs)
    projs = tomopy.shift_images(projs, shifts[:, 1], shifts[:, 0])
    projs = projs[:, max_shift:max_shift+size, max_shift:max_shift+size]
    np.clip(projs, 1e-6, 1-1e-6, projs)
    tomopy.minus_log(projs, out=projs)

    # crop phantom vol to match reconstruction vol
    vol = vol[max_shift:max_shift+size, max_shift:max_shift+size, max_shift:max_shift+size]

    # Reconstruct with shifts applied
    rec = tomopy.recon(projs, thetas, algorithm="sirt", num_iter=1)
    if display:
        pl.imshow(rec[rec.shape[0]//2])
        pl.show()

    return vol, rec, projs, shifts

<Figure size 432x288 with 0 Axes>

In [4]:
msdnet_root.mkdir(parents=True, exist_ok=True)
msdnet_params_path = msdnet_root / msdnet_fn
if not msdnet_params_path.exists():
    # create and train network
    dilations = msdnet.dilations.IncrementDilations(dilation_size)
    input_channels = 2 * n_above_and_below + 1
    n = msdnet.network.MSDNet(num_layers, dilations, input_channels, 1, gpu=gpu)
    # Initialize network parameters
    n.initialize()
    # create test data
    train_dats = []
    for i in range(num_train_datasets):
        train_path = msdnet_root / f"train_{i:05}"
        vol, rec, projs, shifts = create_test_dataset(train_path, thetas, random_shift_size, center_offset_shift, random_walk_size, max_shift, size, sim_size, pixel_size, display=False)
        # load rec and ground truth (vol) into datasets
        dats = []
        for j in range(vol.shape[0]):
            dats.append(msdnet.data.ArrayDataPoint(rec[j], vol[j]))
        # Convert input slices to input slabs (i.e. multiple slices as input)
        dats = msdnet.data.convert_to_slabs(dats, n_above_and_below, flip=False)
        train_dats.extend(dats)
    # Normalize input and output of network to zero mean and unit variance using
    # training data images
    n.normalizeinout(train_dats)
    # Use image batches of a single image
    bprov = msdnet.data.BatchProvider(train_dats, 1)
    validate_dats = []
    for i in range(num_validate_datasets):
        validate_path = msdnet_root / f"validate_{i:05}"
        vol, rec, projs, shifts = create_test_dataset(validate_path, thetas, random_shift_size, center_offset_shift, random_walk_size, max_shift, size, sim_size, pixel_size, display=False)
        # load rec and ground truth (vol) into datasets
        dats = []
        for j in range(vol.shape[0]):
            dats.append(msdnet.data.ArrayDataPoint(rec[j], vol[j]))
        # Convert input slices to input slabs (i.e. multiple slices as input)
        dats = msdnet.data.convert_to_slabs(dats, n_above_and_below, flip=False)
        validate_dats.extend(dats)

    # Validate with Mean-Squared Error
    val = msdnet.validate.MSEValidation(validate_dats)

    # Use ADAM training algorithms
    t = msdnet.train.AdamAlgorithm(n)

    # Log error metrics to console
    consolelog = msdnet.loggers.ConsoleLogger()
    # Log error metrics to file
    filelog = msdnet.loggers.FileLogger(str(msdnet_root/'log_tomo_regr.txt'))
    # Log typical, worst, and best images to image files
    imagelog = msdnet.loggers.ImageLogger(str(msdnet_root/'log_tomo_regr'), onlyifbetter=True, chan_in=2)

    # Train network until program is stopped manually (or stops improving)
    stopcrit = msdnet.stoppingcriterion.NonImprovingValidationSteps(5)
    msdnet.train.train(n, t, val, bprov, str(msdnet_params_path),loggers=[consolelog,filelog,imagelog], val_every=len(validate_dats), stopcrit=stopcrit, progress=True)

100%|██████████| 1500/1500 [00:51<00:00, 29.34it/s] 
100%|██████████| 1500/1500 [00:51<00:00, 29.05it/s] 
  0%|          | 0/256 [00:00<?, ?it/s]

ValueError: Number of input channels (1280) does not match expected number (5).

In [None]:
# Load network from file
n = msdnet.network.MSDNet.from_file(str(msdnet_params_path), gpu=gpu)

# create test dataset
test_path = msdnet_root / f"test"
vol, rec, projs, shifts = create_test_dataset(test_path, thetas, random_shift_size, center_offset_shift, random_walk_size, max_shift, size, sim_size, pixel_size, display=True)
# load rec and ground truth (vol) into datasets
dats = []
for j in range(vol.shape[0]):
    dats.append(msdnet.data.ArrayDataPoint(rec[j], vol[j]))
# Convert input slices to input slabs (i.e. multiple slices as input)
dats = msdnet.data.convert_to_slabs(dats, n_above_and_below, flip=False)
# run network on reconstruction
nn_out = np.zeros_like(rec)
for i in range(rec.shape[0]):
    nn_out = n.forward(dats[i].input)
pl.imshow(nn_out[nn_out.shape[0]//2])
pl.show()


In [None]:
# standard shift correction
align_shifts = np.zeros_like(shifts)
aligned_projs, align_shifts[:,1], align_shifts[:,0], error = tomopy.align_seq(projs, thetas, iters=10, upsample_factor=100, blur=False, save=True)

# plot difference
diff_shifts = shifts + align_shifts
pl.plot(np.rad2deg(thetas), diff_shifts[:, 1], label="x-shifts")
pl.plot(np.rad2deg(thetas), diff_shifts[:, 0], label="y-shifts")
pl.xlabel("deg")
pl.ylabel("Diff in Shifts")
pl.legend()
pl.show()

# recon after shifts corrected
rec = tomopy.recon(aligned_projs, thetas, algorithm="sirt", num_iter=30)

pl.imshow(rec[rec.shape[0]//2])
pl.show()

In [None]:
# calculate shift error with x,y,z motion applied
# TODO: should we include center shift?
def calc_shift_error(thetas, shifts, true_shifts):
    # calculate Y-shift that reduces error
    options = {'ftol':0.000001}

    def _y_cost(args):
        y_shift = args
        return np.sum(np.square(true_shifts[:,0] - (shifts[:,0] - y_shift)))

    res = minimize(_y_cost, (0.0,), method='Powell', options=options)
    y_shift = float(res.x)

    # now calculate x_shift and z_shift that reduces error in shifts
    # z_shift is the sin(theta) and x_shift is the cos(theta)
    def _xz_cost(args):
        x_shift, z_shift, center_shift = args
        new_shifts = shifts[:,1] - (z_shift * np.sin(thetas) + x_shift * np.cos(thetas) + center_shift)
        #x_shift, z_shift = args
        #new_shifts = shifts[:,1] - (z_shift * np.sin(thetas) + x_shift * np.cos(thetas))
        return np.sum(np.square(true_shifts[:,1] - new_shifts))

    res = minimize(_xz_cost, (0.0, 0.0, 0.0), method='Powell', options=options)
    x_shift, z_shift, center_shift = res.x
    #res = minimize(_xz_cost, (0.0, 0.0), method='Powell', options=options)
    #center_shift = 0
    #x_shift, z_shift, = res.x
    new_shifts = shifts.copy()
    new_shifts[:,0] -= y_shift
    new_shifts[:,1] -= z_shift * np.sin(thetas) + x_shift * np.cos(thetas) + center_shift
    error = np.sum(np.square(true_shifts - new_shifts))
    return x_shift, y_shift, z_shift, center_shift, error

# minimize error in shift output
x_shift, y_shift, z_shift, center_shift, error = calc_shift_error(thetas, align_shifts, -shifts)
# calc shift error before minimization
orig_error = np.sum(np.square(shifts + align_shifts))
print(f"min err: {error}, orig err: {orig_error}, x_shift: {x_shift}, y_shift: {y_shift}, z_shift: {z_shift}, center_shift: {center_shift}")

# calculate aligned shifts with minimized error for display
align_shifts_min_err = np.copy(align_shifts)
align_shifts_min_err[:, 0] -= y_shift
align_shifts_min_err[:, 1] -= z_shift * np.sin(thetas) + x_shift * np.cos(thetas) + center_shift

diff_shifts = shifts + align_shifts_min_err
pl.plot(np.rad2deg(thetas), diff_shifts[:, 1], label="x-shifts")
pl.plot(np.rad2deg(thetas), diff_shifts[:, 0], label="y-shifts")
pl.xlabel("deg")
pl.ylabel("Diff in shifts after minimize")
pl.legend()
pl.show()
