# Extra transform exercise

In [1]:
#: standard imports
import numpy as np
import numpy.linalg as npl
# print arrays to 4 decimal places
np.set_printoptions(precision=4, suppress=True)
import matplotlib.pyplot as plt
#: gray colormap by default
plt.rcParams['image.cmap'] = 'gray'
from scipy.optimize import fmin_powell
import scipy.ndimage as snd

import nibabel as nib

This exercise follows on from [optimizing
space](https://textbook.nipraxis.org).

A malicious entity (us) has taken a volume from an FMRI time-series, and
translated it by a few voxels some or all of the three axes (X, Y, and Z).

Here is one plane of the not-shifted volume of the FMRI time series:

In [2]:
unshifted_vol = nib.load('unshifted_vol.nii').get_fdata()
plt.imshow(unshifted_vol[:, :, 15])

Here is one plane of the *shifted* volume:

In [3]:
shifted_vol = nib.load('shifted_vol.nii').get_fdata()
plt.imshow(shifted_vol[:, :, 15])

As you can see, they are slightly different, because of the shift we applied.

Your job is to take the machinery from the page above, and estimate the shift
we applied.

We have made your job just slightly more difficult by using different names for
the functions here, than the ones in the Optimizing Space page, but you will
find implementations for these functions in that page, if you know what to look
for.  Or you can implement the functions yourself, using different code or
algorithms.

In [4]:
def mismatch_func(arr0, arr1):
    """ Return value that is lower when `arr0` and `arr1` are better matched.

    You can chose your mismatch calculation.
    """
    mismatch_val = -np.corrcoef(arr0.ravel(), arr1.ravel())[0, 1]
    return mismatch_val

In [5]:
# Test that an array has more mismatch when displaced
rng = np.random.default_rng()
img0 = rng.normal(size=(10, 11, 12))
img1 = np.zeros(img0.shape)
img1[:-1, :-2, :-3] = img0[1:, 2:, 3:]
assert mismatch_func(img0, img0) < mismatch_func(img0, img1)

In [6]:
def apply_params(vol, x_y_z_trans):
    """ Apply translations `x_y_z_trans` to 3D volume `vol`

    x_y_z_trans is a sequence or array length 3, containing
    the (x, y, z) translations in voxels.

    Values in `x_y_z_trans` can be positive or negative,
    and can be floats.
    """
    x_y_z_trans = np.array(x_y_z_trans)
    trans_vol = snd.affine_transform(vol, [1, 1, 1], -x_y_z_trans, order=1)
    return trans_vol

In [7]:
# Undoing shift above gives, for valid voxels, values similar to original.
undone = apply_params(img1, [1, 2, 3])
assert np.allclose(img0[1:, 2:, 3:], undone[1:, 2:, 3:])

In [8]:
def cost_func(x_y_z_trans, target, moving):
    """ Cost xyz translation `x_y_z_trans`, given `target` and `moving` images

    `target` is the array we are trying to match to.  `moving` is the array we
    are trying to match, by using the `x_trans` transform.

    `x_y_z_trans` are the x, y, z translations mapping from the `moving` to the
    `target` volume.
    """
    unshifted = apply_params(moving, x_y_z_trans)
    cost = mismatch_func(unshifted, target)
    return cost

In [9]:
assert cost_func([0, 0, 0], img0, img0) == mismatch_func(img0, img0)
assert np.isclose(
    cost_func([1, 2, 3], img0, img1),
    mismatch_func(img0, undone))

Do the optimization of `cost_func` using `fmin_powell`:

In [10]:
best_params = fmin_powell(cost_func,
                          [0, 0, 0],
                          args=(unshifted_vol, shifted_vol))
# Show the result
best_params

In [11]:
# It just so happens that the optimal parameters nearly sum to 
# zero.
assert np.abs(np.sum(best_params)) < 0.01

Make a new copy of `shifted_vol` that applies the estimated parameters, to make
an image similar to `unshifted_vol`.

In [12]:
re_unshifted_vol = snd.affine_transform(shifted_vol,
                                        [1, 1, 1],
                                        -best_params)
# Show a middle slice of the fixed volume.
plt.imshow(re_unshifted_vol[:, :, 15])