In [None]:
import cv2
import numpy as np
import cycpd
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge, RANSACRegressor
from sklearn.pipeline import make_pipeline
from skimage import transform
from scipy import ndimage

from matplotlib import pyplot as plt

plt.rcParams["figure.figsize"] = [20, 10]

from ai_ct_scans import (
    data_loading,
    non_rigid_alignment,
    phase_correlation_image_processing,
    image_processing_utils,
)
from ai_ct_scans.phase_correlation import shift_nd

In [None]:
# Load full scan data
patient_dir = data_loading.data_root_directory() / "1"
patient_loader = data_loading.PatientLoader(patient_dir)

patient_loader.abdo.scan_1.load_scan()
patient_loader.abdo.scan_2.load_scan()

# Load alignment
alignment_file_path = ""
trans = non_rigid_alignment.read_transform(alignment_file_path)

In [None]:
# Apply alignment to full scan and display a 2D slice
reference = patient_loader.abdo.scan_1.full_scan
to_align = patient_loader.abdo.scan_2.full_scan
aligned = non_rigid_alignment.transform_3d_volume_in_chunks(
    patient_loader.abdo.scan_2.full_scan, trans.predict, 10
)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 255, :]),
            image_processing_utils.normalise(to_align[:, 255, :]),
        ],
        False,
    )
)
axarr[0].title.set_text("Before alignment")
axarr[1].imshow(
    phase_correlation_image_processing.generate_overlay_2d(
        [
            image_processing_utils.normalise(reference[:, 255, :]),
            image_processing_utils.normalise(aligned[:, 255, :]),
        ],
        False,
    )
)
axarr[1].title.set_text("After non-rigid alignment")

In [None]:
# Calculate the shifts across the full 3D volume
import math

chunk_thickness = 10
warp = np.empty_like(aligned)
shifts = np.empty_like(aligned)
shifts = np.zeros((aligned.shape[0], aligned.shape[1], aligned.shape[2], 3))
_, y_len, z_len = aligned.shape
for i in range(math.ceil(aligned.shape[0] / chunk_thickness)):
    chunk_start = i * chunk_thickness
    chunk_end = min((i + 1) * chunk_thickness, aligned.shape[0])
    x_len = chunk_end - chunk_start
    grid_points = non_rigid_alignment._get_grid_points(
        (x_len, y_len, z_len), offset=chunk_start
    )
    coords_in_input = trans.predict(grid_points)
    shift = coords_in_input - grid_points
    warp[chunk_start:chunk_end, :, :] = np.sum(np.absolute(shift), axis=1).reshape(
        x_len, y_len, z_len
    )
    shifts[chunk_start:chunk_end, :, :, :] = shift.reshape(x_len, y_len, z_len, 3)

plt.imshow(warp[:, 255, :])

In [None]:
x_shift = shifts[:, :, :, 0]
y_shift = shifts[:, :, :, 1]
z_shift = shifts[:, :, :, 2]

In [None]:
def divergence(f):
    num_dims = len(f)
    return np.ufunc.reduce(np.add, [np.gradient(f[i], axis=i) for i in range(num_dims)])

In [None]:
# Calculate warp using the divergence
div = divergence([x_shift, y_shift, z_shift])
overlay_slice = image_processing_utils.overlay_warp_on_slice(
    aligned[:, 255, :], div[:, 255, :]
)
plt.imshow(overlay_slice)

In [None]:
# Generate a stand-alone colour bar for the GUI
import matplotlib.cm as cm
import matplotlib.colors as mcolors

normalize = mcolors.Normalize(vmin=np.min(div[:, 255, :]), vmax=np.max(div[:, 255, :]))
fig, ax = plt.subplots(figsize=(1, 10))
fig.colorbar(cm.ScalarMappable(norm=normalize, cmap=cm.jet), cax=ax)