In this notebook I am exploring different ways of working with image deformations as displacement fields, the way they are represented in MONAI's registration components.

In [None]:
import sys

In [None]:
import monai
import torch
from utils import plot_2D_vector_field, plot_2D_deformation
import math
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def get_example_ddf(s_x, s_y=None, variant=0):
    """Get an example DDF (direct displacement field).
    Arguments:
        s_x, s_y: The x and y scale. Provide s_x only to have them be the same scale.
            "Scale" here really means "resolution." Think of it as the same underlying displacement,
            but meant to be applied to images at different resolutions.
        variant: integer selector for which variant of example to return.
    """
    if s_y is None:
        s_y=s_x
    if variant==0:
        ddf = torch.tensor(
            [[
                [(s_y/32)*math.sin(2*math.pi*(y/s_y) * 3),(s_x/32)*2*math.cos(2*math.pi* (x/s_x) * 2)]
                for x in range(s_x)]
                for y in range(s_y)
            ]
        ).permute((2,0,1))
    elif variant==1:
        ddf = torch.tensor(
            [[
                [(s_y/32)*math.sin(2*math.pi*(x/s_x) * 3),(s_x/32)*2*math.cos(2*math.pi* (y/s_y) * 2)]
                for x in range(s_x)]
                for y in range(s_y)
            ]
        ).permute((2,0,1))
    else:
        raise ValueError(f"There is no variant {variant}")
    return ddf

In [None]:
scale = 128

load_image = monai.transforms.Compose([
    monai.transforms.LoadImage(image_only=True, reader='pilreader', dtype=np.float32),
    monai.transforms.ScaleIntensityRange(0,255,0,1),
    monai.transforms.Transpose((2,1,0)),
    monai.transforms.Resize((scale,scale)),
    monai.transforms.ToTensor()
])

img_path = '/home/ebrahim/Pictures/pumpkin_face_autumn_october.jpg' # can put any image here

img = load_image(img_path)

# function to show image assuming it was in the (channels, height, width) format.
def plt_image(img, title=None):
    plt.imshow(np.transpose(img,axes=(1,2,0)),)
    if title is not None:
        plt.title(title)
    plt.show()

In [None]:
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="border")

# function that is just like warp except it doesn't work with batch dimension
apply_warp = lambda img,ddf : warp(img.unsqueeze(0), ddf.unsqueeze(0))[0]

Let $\Omega$ denote the image domain, so that for example a grayscale image is a map $\Omega\to\mathbb{R}$.

Deformations $\phi:\Omega\to\Omega$ act upon images like this:
$$
I.\phi = I\circ \Phi
$$
Note that this is a _right_ action. This is important to keep in mind in order to get composition in the correct order later on.

# How not to compute the inverse of a displacement field

Suppose $\phi$ is the deformation $x\mapsto x+u(x)$. Here $u(x)$ can be though of as a displacement field, so one might think that $-u(x)$ leads to a reasonable approximate inverse. But it's actually quite bad where the deformation has a sizable derivative. Let's think about how the displacement field for the inverse of $\phi$ is actually related to the dispalcement field for $\phi$.

Let $\bar{u}(y)$ be the displacement field for $\phi^{-1}$; i.e. $\bar{u}(y)$ is such that $\phi^{-1}$ maps $y$  to $y+\bar{u}(y)$. Then

$$x = \phi^{-1}(\phi(x)) = \phi^{-1}(x+u(x)) = x+u(x)+\bar{u}(x+u(x)) $$

$$y = \phi(\phi^{-1}(y)) = \phi^{-1}(y+\bar{u}(y)) = y+\bar{u}(y)+u(y+\bar{u}(y)) $$

so

$$ \bar{u}(x+u(x)) = -u(x) $$
$$ \bar{u}(y) = -u(y+\bar{u}(y)) $$

or in other words

$$ \bar{u}.\phi=-u$$
$$ \bar{u}=-u.\phi^{-1}$$

You can't get $\bar{u}$ without already having a way to apply $\phi^{-1}$, which is nontrivial for arbitrary invertible $\phi$.

## Demo of negative displacement failing to work

In [None]:
strength = 1

u = strength*get_example_ddf(scale, variant=0) 
ub = -u # a bad attempt at approximating the inverse of phi

plot_2D_vector_field(u, 4)
plt.title('phi')
plt.show()
plot_2D_vector_field(ub, 4)
plt.title('phi with negated displacements; let\'s call it psi')
plt.show()

In [None]:
img1 = apply_warp(img,u)
img2 = apply_warp(img,ub)
img3 = apply_warp(img1,ub)
img4 = apply_warp(img2,u)
plt_image(img,"img")
plt_image(img1,"img.phi")
plt_image(img2,"img.psi")
plt_image(img3,"img.phi.psi")
plt_image(img4,"img.psi.phi")

The last two images should look like the first image, perhaps with some loss. But they look warped, because, as discussed above, negating the displacement field doesn't work to produce an inverse.

# How to compose two displacement fields

Say $\phi(x)=x+u(x)$ and $\psi(y)=y+v(y)$. Then 
$$\psi(\phi(x)) = x+u(x)+v(x+u(x))$$
so the displacement field for $\psi\circ \phi$ is
$$u(x)+v(x+u(x))$$
or in other words
$$u+v.\phi$$
The displacement field $v$ is just a kind of image, where the channels are vector components. So we can compute $\phi.v$ using the same method that we use to warp images.

Potential source of confusion: Acting on an image by $\psi\circ\phi$ first applies the deformation $\psi$ and then applies the deformation $\phi$. This is because deforming images is a _right_ action. This is confusing, because thinking of $\psi\circ\phi$ as a function $\Omega\to\Omega$ usually means you are thinking about how it acts on points, which is a left action and should be interpreted in the "$\phi$ then $\psi$" order.

In [None]:
strength = 1

u = strength*get_example_ddf(scale, variant=1) 
v = strength*get_example_ddf(scale, variant=0) 

plot_2D_vector_field(u, 4)
plt.title('phi')
plt.show()
plot_2D_vector_field(v, 4)
plt.title('psi')
plt.show()

In [None]:
plt_image(img,"img")
plt_image(apply_warp(img,v),"img.psi")
plt_image(apply_warp(apply_warp(img,v),u),"(img.psi).phi")
w = u+apply_warp(v,u)
plt_image(apply_warp(img,w),"img.(psi.phi)")

The fact that the last two images look the same is a good verification of the suggested technique for displacement field composition:
```
w = u + warp(v,u)
```
This is the way to create a displacement field for the deformation that first warps via displacement `v` and then warps via displacement `u`.

For fun, here's a view of the composite displacement field:

In [None]:
plot_2D_vector_field(w, 4)
plt.title("psi.phi")
plt.show()

# TODO: resolve a couple of confusions here

This following picture is currently wrong. The function `plot_2D_deformation` plots the wrong direction of deformation action. This needs to be fixed in the deep atlas tutorial.

In [None]:
plot_2D_deformation(v,4)

The following picture confuses me. I would expect that the inverse of the displacement described by the vector field shown below should push axis-aligned grid lines around in such a way that they become non-straight. Yet when I apply monai's warp to an image of a grid, I do get straight lines. What am I missing here?

In [None]:
plot_2D_vector_field(v,4)
plt.show()
plt_image(apply_warp(load_image('/home/ebrahim/Pictures/grid.jpg'),v))

## Make a grid and right-act upon it by displacement field

In [None]:
vector_field = v
grid_spacing = 4 # must be int

spatial_dims, height, width = vector_field.shape
assert(spatial_dims==2)


In [None]:
img = np.zeros((height, width))
img[np.arange(0, height, grid_spacing),:]=1
img[:,np.arange(0, width, grid_spacing)]=1
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
plt_image(img)
warped_img = apply_warp(img, vector_field)
plt_image(warped_img)

## Kind of reimplement MONAI warp to make sure it's doing what I think

In [None]:
grid_sample = torch.nn.functional.grid_sample

In [None]:
input = img.unsqueeze(0)

identity_phi = torch.stack(torch.meshgrid(torch.linspace(0,height-1,height), torch.linspace(0,width-1,width)))
phi = identity_phi + vector_field
phi_normalized = ((phi/torch.tensor([height,width]).reshape((2,1,1)))*2-1)
grid = phi_normalized.permute((1,2,0)).unsqueeze(0)

In [None]:
out = grid_sample(input, grid, align_corners=False, mode='bilinear')

plt.imshow(out[0,0])
plt.show()

Hmm yes, it is definitely doing what I think. I am still puzzled by the output.