In this notebook we devise and test out a technique for compuing the discrete spatial derivative matrix of a 3D transformation that is represented by a direct displacement vector field.

In [None]:
import monai
import torch
import math
import matplotlib.pyplot as plt
import numpy as np
from util import plot_2D_deformation, plot_2D_vector_field, preview_3D_deformation, preview_3D_vector_field, preview_image

In [None]:
# Borrowed from https://github.com/ebrahimebrahim/deep-atlas/blob/main/warp_action_exploration.ipynb

def get_example_ddf_2D(s_x, s_y=None, variant=0):
    """Get an example DDF (direct displacement field).
    Arguments:
        s_x, s_y: The x and y scale. Provide s_x only to have them be the same scale.
            "Scale" here really means "resolution." Think of it as the same underlying displacement,
            but meant to be applied to images at different resolutions.
        variant: integer selector for which variant of example to return.
    """
    if s_y is None:
        s_y=s_x
    if variant==0:
        ddf = torch.tensor(
            [[
                [(s_y/32)*math.sin(2*math.pi*(y/s_y) * 3),(s_x/32)*2*math.cos(2*math.pi* (x/s_x) * 2)]
                for x in range(s_x)]
                for y in range(s_y)
            ]
        ).permute((2,0,1))
    elif variant==1:
        ddf = torch.tensor(
            [[
                [(s_y/32)*math.sin(2*math.pi*(x/s_x) * 3),(s_x/32)*2*math.cos(2*math.pi* (y/s_y) * 2)]
                for x in range(s_x)]
                for y in range(s_y)
            ]
        ).permute((2,0,1))
    else:
        raise ValueError(f"There is no variant {variant}")
    return ddf

In [None]:
# Some 3D examples to test things out on

def get_example_ddf_3d(s_x, s_y=None, s_z=None, variant=0):
    """Get an example DDF (direct displacement field).
    Arguments:
        s_x, s_y. s_z: The x,y,z scales. Provide s_x only to have them be the same scale.
            "Scale" here really means "resolution." Think of it as the same underlying displacement,
            but meant to be applied to images at different resolutions.
        variant: integer selector for which variant of example to return.
    """
    if s_y is None:
        s_y=s_x
    if s_z is None:
        s_z = s_x
    if variant==0:
        ddf = torch.tensor(
            [[[
                [
                    (s_z/32)*math.sin(2*math.pi*(y/s_y + x/s_x) * 2.5),
                    (s_y/32)*math.sin(2*math.pi*(y/s_y) * 3),
                    (s_x/32)*2*math.cos(2*math.pi* (x/s_x) * 2),
                ]
                for x in range(s_x)]
                for y in range(s_y)]
                for z in range(s_z)
            ]
        ).permute((3,0,1,2))
    elif variant==1:
        ddf = torch.tensor(
            [[[
                [
                    (s_z/32)*math.sin(2*math.pi*(y/s_y + x/s_x) * 2.5),
                    (s_y/32)*math.sin(2*math.pi*(x/s_x) * 3),
                    (s_x/32)*2*math.cos(2*math.pi* (y/s_y) * 2),
                ]
                for x in range(s_x)]
                for y in range(s_y)]
                for z in range(s_z)
            ]
        ).permute((3,0,1,2))
    elif variant==2:
        ddf = torch.tensor(
            [[[
                [
                    (s_z/32)*( 1*z/s_z + 2*y/s_y + 3*x/s_x ),
                    (s_y/32)*( 4*z/s_z + 5*y/s_y + 6*x/s_x ),
                    (s_x/32)*( 7*z/s_z + 8*y/s_y + 9*x/s_x ),
                ]
                for x in range(s_x)]
                for y in range(s_y)]
                for z in range(s_z)
            ]
        ).permute((3,0,1,2))
    else:
        raise ValueError(f"There is no variant {variant}")
    return ddf

In [None]:
ddf = get_example_ddf_3d(60,variant=2)
preview_3D_vector_field(ddf, downsampling=2)

Plan:
- Copy the MedianBlur filter design I made earlier; we will similarly make a 3D conv layer with a fixed kernel
- Define the kernel. It will take us from 3 channels to 9 channels. For each of the 3 input channels i in (x,y,z), there are 3 tensors (indexed by j in (x,y,z)) of shape 3x3x3 that help us compute central differences. The i,j tensor is a 3x3x3 tensor that helps us compute the derivative of the i^th component with respect to the j^th variable.
  - Done below; the best appraoch tuend out to be to move channel dimension into batch dimension, compute gradient, and then return channel dimension back to where it was.
- Check out MONAI approach to sobel kernel?
  - It does not seem to be intended for what we need (it's a transform and not a network layer).
- Remember to multiply by 1/2 for central difference. Remember to pad reasonably.
  - We replicate-pad after computing the derivative. It's good enough; see the docstring below.
- The convolution with this kernel helps us take a derivative of a function, not a DDF. Correct this by adding the appropriate fixed map.
  - It is now designed to work with a DDF, by adding a bias after convolution.
- Next we must address the problem of applying the inverse of the jacobian to the diffusion tensors, "contracting" it with both indices of the diffusion tensors. Frame this as a problem that can be solved by torch.linalg.solve, changing the view of the tensor fields such that the spatial dimensions and the columns/rows are merged into the batch dimension.
  - Going back to work n

In [None]:
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

def get_kernel_bias() -> torch.Tensor:
    """Create sobel style kernel and bias for computing spatial derivatives (via central difference).
    This is specifically for use with the spatial_derivative function defined below."""
    kernel: torch.Tensor = torch.zeros(3,1,3,3,3)
    
    # sobel x
    kernel[0,0,0,1,1] -= 0.5
    kernel[0,0,2,1,1] += 0.5
    
    # sobel y
    kernel[1,0,1,0,1] -= 0.5
    kernel[1,0,1,2,1] += 0.5
    
    # sobel z
    kernel[2,0,1,1,0] -= 0.5
    kernel[2,0,1,1,2] += 0.5
    
    # The kernel constructed above can be used to get the spatial gradient of a scalar function, shape (b,1,H,W,D)
    
    # The bias is what we need to add to each channel in order to go from looking at the derivative of the
    # displacement field to looking at the derivative of the transformation
    bias = torch.eye(3,3, dtype=kernel.dtype).reshape(-1)
    
    return kernel, bias

def spatial_derivative(input: torch.Tensor, kernel_bias: Tuple[torch.Tensor, torch.Tensor] = None) -> torch.Tensor:
    r"""Compute the spatial derivative of a 3D transformation represented as a displacement vector field.

    Args:
        input: the input image representing a spatial transformation as a displacement vector field;
            should have shape :math:`(B,3,H,W,D)`
        kernel: optionally a prebuilt kernel and bias to use
             kernel shape: :math:`(9,1,3,3,3)`
             bias shape:   :math:`(9,)`

    Returns:
        the jacobian matrix field with shape :math:`(B,9,H,W,D)`,
        where the entry (b,c,x,y,z) has the following interpretation for image at batch index b
        located at (x,y,z) at each value of i:
            i=0: x-derivative of of the x-component of the transformation
            i=1: y-derivative of of the x-component of the transformation
            i=2: z-derivative of of the x-component of the transformation
            i=3: x-derivative of of the y-component of the transformation
            i=4: y-derivative of of the y-component of the transformation
            i=5: z-derivative of of the y-component of the transformation
            i=6: x-derivative of of the z-component of the transformation
            i=7: y-derivative of of the z-component of the transformation
            i=8: z-derivative of of the z-component of the transformation
    
    If you reshape the return value to (B,3,3,H,W,D) then you will have 3x3 jacobian matrices
    living in the dimensions 1 and 2. That is, [b,:,:,x,y,z] will be the 3x3 jacobian matrix
    for the transformation with batch index b at location (x,y,z).
    """
    if not isinstance(input, torch.Tensor):
        raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")

    if not len(input.shape) == 5 or input.shape[1]!=3:
        raise ValueError(f"Invalid input shape, we expect Bx3xHxWxD. Got: {input.shape}")


    if kernel_bias is None:
        kernel, bias = get_kernel().to(input)
    else:
        kernel, bias = kernel_bias
    
    b,c,h,w,d = input.shape
    assert(c==3)
    deriv: torch.Tensor = F.conv3d(input.view(b*c,1,h,w,d), kernel, stride=1, groups=1)
    deriv = F.pad(deriv, (1,1,1,1,1,1), mode='replicate')
    deriv = deriv.view(b,3*c,h,w,d)
    deriv += bias.view(1,9,1,1,1)
    

    return deriv



class DerivativeOfDDF(nn.Module):
    r"""Compute the spatial derivative of a 3D transformation represented as a displacement vector field.
    
    We use central difference. In order to keep spatial dimensions the same,
    the boundary is padded in "replicate" mode, so the derivatives on the boundary
    are off by one voxel. That is, the derivatives are firt computed to produce a smaller image,
    and then replication padding fixes the image size.
    This is not ideal (a better approach would be to use forward-difference
    or backward difference at the boundaries), but this is much simpler, and it shouldn't
    matter very much (it matters only when the displacement field has large second derivatives
    at the boundary *and* when what happens at the boundary is actually important-- this situation
    is just not very likely for my use case.)

    Args:
        input: the input image representing a spatial trans

    Returns:
        the jacobian matrix field with shape :math:`(B,9,H,W,D)`,
        where the entry (b,c,x,y,z) has the following interpretation for image at batch index b
        located at (x,y,z) at each value of i:
            i=0: x-derivative of of the x-component of the transformation
            i=1: y-derivative of of the x-component of the transformation
            i=2: z-derivative of of the x-component of the transformation
            i=3: x-derivative of of the y-component of the transformation
            i=4: y-derivative of of the y-component of the transformation
            i=5: z-derivative of of the y-component of the transformation
            i=6: x-derivative of of the z-component of the transformation
            i=7: y-derivative of of the z-component of the transformation
            i=8: z-derivative of of the z-component of the transformation

    Shape:
        - Input: :math:`(B, 3, H, W, D)`
        - Output: :math:`(B, 9, H, W, D)`
    
    If you reshape the ouput to (B,3,3,H,W,D) then you will have 3x3 jacobian matrices
    living in the dimensions 1 and 2. That is, [b,:,:,x,y,z] will be the 3x3 jacobian matrix
    for the transformation with batch index b at location (x,y,z).
    """

    def __init__(self, device='cpu') -> None:
        super().__init__()
        self.kernel, self.bias = get_kernel_bias()
        self.kernel = self.kernel.to(device);
        self.bias = self.bias.to(device);

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return spatial_derivative(input, (self.kernel, self.bias))

In [None]:
dev = 'cuda'
deriv_ddf = DerivativeOfDDF(device=dev)
ddf0 = get_example_ddf_3d(60,variant=0).to(dev)
ddf1 = get_example_ddf_3d(60,variant=1).to(dev)
ddf2 = get_example_ddf_3d(60,variant=2).to(dev)
ddfs = torch.stack([ddf0, ddf1, ddf2])
d = deriv_ddf(ddfs)
d_mat = d.view((3,3,3,60,60,60))
d_mat.shape

In [None]:
# Test it out at a bunch of voxels

B,C,H,W,D = ddfs.shape
assert(C==3)
for _ in range(10000):
    b=np.random.randint(B)
    x=np.random.randint(1,H-1) # It does not work exactly on the boundary-- see docstring
    y=np.random.randint(1,W-1)
    z=np.random.randint(1,D-1)
    i,j = np.random.randint(3,size=(2,))
    deriv = d_mat[b,i,j,x,y,z]
    if j==0:
        true_deriv = (ddfs[b,i,x+1,y,z] - ddfs[b,i,x-1,y,z])/2
    elif j==1:
        true_deriv = (ddfs[b,i,x,y+1,z] - ddfs[b,i,x,y-1,z])/2
    else:
        true_deriv = (ddfs[b,i,x,y,z+1] - ddfs[b,i,x,y,z-1])/2
    if i==j:
        true_deriv += 1
    error = (true_deriv - deriv).abs()
    if not error<1e-5: # There is some numerical error
        print(error, true_deriv, deriv, 'params:', b,i,j,x,y,z)

Seems to work!