In [None]:
from pathlib import Path
import random
import dipy.io.image
import dipy.reconst.dti
import matplotlib.pyplot as plt
import numpy as np
import monai

In [None]:
def preview(img):
    fig,axs = plt.subplots(1,3,figsize=(20,10))
    axs[0].imshow(img[62,:,:].T, origin='lower', cmap='gray')
    axs[1].imshow(img[:,:,80].T, origin='lower', cmap='gray')
    axs[2].imshow(img[:,75,:].T, origin='lower', cmap='gray')
    plt.show()
    
dti_image_paths = list(Path('dti_fit_images_nontest/dti/').glob('*'))

In [None]:
dti_image_path = random.choice(dti_image_paths)
print(dti_image_path)

In [None]:
img_data, affine = dipy.io.image.load_nifti(dti_image_path)
img_data.shape

Now `img_data` is a numpy array of shape (140,140,140,6), representing the lower triangular entries of a diffusion tensor on a space of shape (140,140,140). I believe they are in the order Dxx, Dxy, Dyy, Dxz, Dyz, Dzz; see [here](https://dipy.org/documentation/1.4.0./reference/dipy.reconst).

In [None]:
dti = dipy.reconst.dti.from_lower_triangular(img_data)

assert((dti[:,:,:,0,1]==dti[:,:,:,1,0]).all())
assert((dti[:,:,:,1,2]==dti[:,:,:,2,1]).all())
assert((dti[:,:,:,0,2]==dti[:,:,:,2,0]).all())
dti.shape

Above we have produced the 3x3 symmetric matrices from the lower triangular part

In [None]:
import torch

In [None]:
dti_tensor = torch.tensor(dti).permute((3,4,0,1,2)).unsqueeze(0)

# Need to think about dipy axis order more carefully, but for now let's see if solving will work at all.
dti_tensor.shape

In [None]:
# Let's compute a warp from FA images, just so we have a warp to play with

from fa_deformable_registration_models.reg_model1 import RegModel

reg_model = RegModel(device='cpu')

In [None]:
# Load the FA of this DTI image and then also load some other random FA image

fa_image_path = dti_image_path.parent.parent/'fa'/dti_image_path.name
fa_image_path2 = random.choice(list((dti_image_path.parent.parent/'fa').glob('*')))
print(fa_image_path, fa_image_path2, sep='\n')

In [None]:
# Turn the FA images into tensors and compute a deformation that aligns our original FA image to the random one

fa_img, affine = dipy.io.image.load_nifti(fa_image_path)
fa_img2, affine = dipy.io.image.load_nifti(fa_image_path2)

fa_tensor1 = torch.tensor(fa_img, dtype=torch.float32).unsqueeze(0)
fa_tensor2 = torch.tensor(fa_img2, dtype=torch.float32).unsqueeze(0)

ddf, fa_tensor1_warped = reg_model.forward(fa_tensor2, fa_tensor1, include_warped_image=True)

from util import preview_3D_vector_field
preview_3D_vector_field(ddf)

In [None]:
# Now we have a warp to work with, so let's start from the DTI img_data again and show how we warp it

In [None]:
from spatial_derivatives import DerivativeOfDDF

deriv_ddf = DerivativeOfDDF(device=reg_model.device)

In [None]:
# Compute the derivative matrix field of the warp

c,h,w,d = ddf.shape
b=1
assert(c==3)
J = deriv_ddf(ddf.unsqueeze(0)).reshape(b,3,3,h,w,d)
J.shape

In [None]:
# Name some operations to make it easier to interpret the steps below
from util import batchify
dipy2torch_lotri_batch = lambda t : t.permute(0,4,1,2,3)
torch2dipy_lotri_batch = lambda t : t.permute(0,2,3,4,1)
dipy2torch_mat_batch = lambda t : t.permute(0,4,5,1,2,3)
torch2dipy_mat_batch = lambda t : t.permute(0,3,4,5,1,2)
dipy_lotri2mat = dipy.reconst.dti.from_lower_triangular
dipy_lotri2mat_batch = batchify(dipy_lotri2mat)
dipy_mat2lotri = dipy.reconst.dti.lower_triangular
dipy_mat2lotri_batch = batchify(dipy.reconst.dti.lower_triangular)
torch_lotri2mat_batch = lambda t : dipy2torch_mat_batch(dipy_lotri2mat_batch(torch2dipy_lotri_batch(t)))
torch_mat2lotri_batch = lambda t : dipy2torch_lotri_batch(dipy_mat2lotri_batch(torch2dipy_mat_batch(t)))
torch_mat_batch_absorbspatial = lambda t : t.permute((0,3,4,5,1,2)).reshape((-1,3,3))
torch_mat_batch_expandspatial = lambda t,h,w,d : t.reshape(b,h,w,d,3,3).permute((0,4,5,1,2,3))

In [None]:
# Take our original DTI and make it a tensor
# and use a naming convention that clearly explains the shapes of things
F_dipy_lotri_batch = torch.tensor(img_data).unsqueeze(0).float().to(J)
J_torch_mat_batch = J

In [None]:
# Warp the DTI, spatially moving tensors but not transforming the tensors yet
F_torch_lotri_batch = dipy2torch_lotri_batch(F_dipy_lotri_batch)
F_warped_torch_lotri_batch = reg_model.model.warp(F_torch_lotri_batch, ddf.unsqueeze(0))
F_warped_torch_mat_batch = torch_lotri2mat_batch(F_warped_torch_lotri_batch)

In [None]:
# Move the spatial dimensions into the batch dimension

F_warped_torch_mat_batch_nospatial = torch_mat_batch_absorbspatial(F_warped_torch_mat_batch)
J_torch_mat_batch_nospatial = torch_mat_batch_absorbspatial(J_torch_mat_batch)

In [None]:
# Get SVD of jacobian
U, S, Vh = torch.linalg.svd(J_torch_mat_batch_nospatial)

In [None]:
# Deduce the orthogonal component of the jacobian, in the sense of its polar decomposition
Jrot_torch_mat_batch_nospatial = torch.matmul(U, Vh)

In [None]:
# Sanity check that Jrot is an orthogonal matrix
(torch.matmul(Jrot_torch_mat_batch_nospatial, Jrot_torch_mat_batch_nospatial.permute((0,2,1))) - torch.repeat_interleave(torch.eye(3).unsqueeze(0), Jrot_torch_mat_batch_nospatial.shape[0], dim=0)).max().item()

In [None]:
# Transform tensors using the tensor transformation law, but using only the rotational component Jrot of J
F_warped_transformed_torch_mat_batch_nospatial = torch.matmul(
    Jrot_torch_mat_batch_nospatial.permute(0,2,1),
    torch.matmul(
        F_warped_torch_mat_batch_nospatial,
        Jrot_torch_mat_batch_nospatial,
    )
)

In [None]:
# Move the spatial dimensions back out of the batch dimension

F_warped_transformed_torch_mat_batch =\
    torch_mat_batch_expandspatial(F_warped_transformed_torch_mat_batch_nospatial, h, w, d)

In [None]:
# Switch to dipy indexing  so we can compute a new FA image out of our fully transformed DTI
F_warped_transformed_dipy_lotri = dipy_mat2lotri(torch2dipy_mat_batch(F_warped_transformed_torch_mat_batch)[0])

# Compute the new FA image
eig = dipy.reconst.dti.eig_from_lo_tri(F_warped_transformed_dipy_lotri) # has eigenvals and eigenvecs
eigvals = eig[:,:,:,:3] # take only the eigenvals
fa_after_transform = dipy.reconst.dti.fractional_anisotropy(eigvals)
print(fa_after_transform.shape)

In [None]:
print("FA image 1:")
preview(fa_img)
print("FA image 2:")
preview(fa_img2)
print("The result of inferring a deformation from FA image 1 to FA image 2, using that deformation to transform DTI image 1, and then computing the FA of the resulting transformed DTI:")
preview(fa_after_transform)
print("The result of applying that same deformation directly to FA image 1:")
preview(fa_tensor1_warped[0])

In [None]:
# Inspect the difference between transforming the FA image directly and transforming the DTI
# Theoretically there should be no difference because using the orthogonal matrix Jrot
# to transform the diffusion tensors should preserve eigenvalues.

absolute_difference = np.abs(fa_tensor1_warped[0] - fa_after_transform)
print("Mean absolute difference:", np.mean(absolute_difference))
print("99.9th percentile:", np.percentile(absolute_difference, 99.9))
print("Max:", np.max(absolute_difference))
preview(absolute_difference)

Moving voxels around for the sake of spatial correspondence should not affect our description of white matter microstructure. Actual water diffusion at the molecular scale in that fiber bundle is a microstructure property, and therefore it shouldn't change just because some voxels were moved around. It makes sense to transform DTs with orientation changes alone, and to never scale the eigenvalues while doing so. In the DT description, it is only the rotational aspect that cares about spatial arrangement of other voxels. Any other aspect is going to be a microsctructure descriptor that should essentially be treated like a scalar, i.e. invariant of the coordinate system.

If we have successfully preserved eigenvalues in our chain of transformations, the image above should be zero.
However, we see that there is some error, especially at the brain mask edge.

In [None]:
# Peel off a few layers from the brain mask boundary and check the absolute difference again, to see
# the extent to which the errors occur at the mask boundary

import ants
brainmask_path_2 = fa_image_path2.parent.parent/'brainmask'/fa_image_path2.name
brainmask2 = ants.image_read(str(brainmask_path_2))

mask = brainmask2.morphology('erode',3)

absolute_difference_masked = absolute_difference * mask.numpy()
print("Mean absolute difference:", np.mean(absolute_difference_masked))
print("99.9th percentile:", np.percentile(absolute_difference_masked, 99.9))
print("Max:", np.max(absolute_difference_masked))
preview(absolute_difference_masked)

If we go above and leave out the tensor transformation step, e.g. by setting `F_warped_transformed_torch_mat_batch_nospatial = F_warped_torch_mat_batch_nospatial`, then there's not much change in the error. There's still a significant error of about the same magnitude. Therefore I believe the error could come from the method of _interpolation_ of diffusion tensors, rather than coming from an error in the transformations. If I go above and change the interpolation to nearest-neighbor (on both of the warps, the warp directly being applied to the FA and the warp being applied to the DTI), then the error goes away almost entirely. So that's strong evidence that the incorrect linear interpolation of diffusion tensors is messing with the eigenvalues.

Or at least this is what we're observing: linear interpolation does not commute with FA value computation. Linear interpolation of DTs is not very natural (as pointed out [here](https://onlinelibrary.wiley.com/doi/pdf/10.1002/mrm.20334)), but I don't think linear interpolation of FA values is totally natural either. Perhaps it's alright for the purpose of learning. Let's stick with it for now, see how it does with deep learning, and leave the problem of improving interpolation for a future direciton.

Here is an encapsulation of the above code for DTI transformation into a convenient module:

In [None]:
import dti_warp
warp_dti = dti_warp.WarpDTI(device = reg_model.device)

In [None]:
dti = F_torch_lotri_batch
dti_warped = warp_dti(dti, ddf.unsqueeze(0))

As another sanity check, let's look at images of the principal direction of diffusion before and after warp.

In [None]:
eig = dipy.reconst.dti.eig_from_lo_tri(torch2dipy_lotri_batch(dti)[0])
princ_diffusion_direction = eig[:,:,:,3:6] # take first eigenvector (principal direction of diffusion)
princ_eigenvalue = eig[:,:,:,0]

eig2 = dipy.reconst.dti.eig_from_lo_tri(torch2dipy_lotri_batch(dti_warped)[0])
princ_diffusion_direction_after_warp = eig2[:,:,:,3:6] # take first eigenvector (principal direction of diffusion)
princ_eigenvalue_after_warp = eig2[:,:,:,0]

In [None]:
# Similar to above but let's view both at the same time

import vtkmodules.vtkInteractionStyle
import vtkmodules.vtkRenderingOpenGL2
from vtkmodules.vtkCommonColor import vtkNamedColors
from vtkmodules.vtkCommonCore import VTK_DOUBLE
from vtkmodules.vtkCommonDataModel import vtkImageData
from vtkmodules.vtkFiltersGeometry import vtkImageDataGeometryFilter
from vtkmodules.vtkInteractionStyle import vtkInteractorStyleTrackballCamera
from vtkmodules.vtkFiltersCore import vtkGlyph3D
from vtkmodules.vtkFiltersCore import vtkTensorGlyph
from vtkmodules.vtkFiltersSources import vtkLineSource
from vtkmodules.vtkCommonCore import (vtkPoints, vtkDoubleArray)
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkPolyDataMapper,
    vtkRenderWindow,
    vtkRenderWindowInteractor,
    vtkRenderer
)
from vtkmodules.vtkCommonDataModel import vtkPolyData

colors = vtkNamedColors()

def view_principal_diffusion_direction_renderer_only(princ_eigenvalue, princ_diffusion_direction, axial_slice_index):
    imageData = vtkImageData()
    h,w,d = princ_eigenvalue.shape
    imageData.SetDimensions(h,w,1)
    imageData.AllocateScalars(VTK_DOUBLE, 1)

    dims = imageData.GetDimensions()

    max_eigenval = np.max(princ_eigenvalue[:,:,axial_slice_index])
    vecs = vtkDoubleArray()
    vecs.SetNumberOfComponents(3)
    vecs.SetName("princ_diffusion_direction")
    for i in range(imageData.GetNumberOfPoints()):
        x,y,z = imageData.GetPoint(i)
        x,y,z = int(x), int(y), int(z)
        vecs.InsertComponent(i,0,princ_diffusion_direction[x,y,axial_slice_index,0])
        vecs.InsertComponent(i,1,princ_diffusion_direction[x,y,axial_slice_index,1]) 
        vecs.InsertComponent(i,2,princ_diffusion_direction[x,y,axial_slice_index,2])
    
    for y in range(dims[1]):
        for x in range(dims[0]):
            imageData.SetScalarComponentFromDouble(x, y, 0, 0, princ_eigenvalue[x,y,70]/max_eigenval)

    imageData.GetPointData().AddArray(vecs)
    imageData.GetPointData().SetActiveVectors('princ_diffusion_direction')
    

    lineSource = vtkLineSource()
    lineSource.SetPoint1(0,0,0)
    lineSource.SetPoint2(1.0,0,0)
    glyph3D = vtkGlyph3D()
    glyph3D.SetSourceConnection(lineSource.GetOutputPort())
    glyph3D.SetInputData(imageData)
    glyph3D.OrientOn()
    glyph3D.ScalingOn()
    glyph3D.SetVectorModeToUseVector()
    glyph3D.SetScaleModeToScaleByScalar()
    

    mapper = vtkPolyDataMapper()
    mapper.SetInputConnection(glyph3D.GetOutputPort())
    mapper.ScalarVisibilityOn()

    actor = vtkActor()
    actor.SetMapper(mapper)
    actor.GetProperty().SetLineWidth(2)
    actor.GetProperty().SetColor(colors.GetColor3d('White'))

    # Setup rendering
    renderer = vtkRenderer()
    renderer.AddActor(actor)
    renderer.SetBackground(colors.GetColor3d('Black'))
    renderer.ResetCamera()
    
    return renderer

def view_both(axial_slice, princ_eigenvalue, princ_diffusion_direction, princ_eigenvalue_after_warp, princ_diffusion_direction_after_warp):
    renderer1 = view_principal_diffusion_direction_renderer_only(princ_eigenvalue, princ_diffusion_direction,axial_slice)
    renderer2 = view_principal_diffusion_direction_renderer_only(princ_eigenvalue_after_warp, princ_diffusion_direction_after_warp,axial_slice)
    renderWindow = vtkRenderWindow()
    renderWindow.AddRenderer(renderer1)
    renderWindow.AddRenderer(renderer2)
    renderer1.SetViewport([0, 0, 0.5, 1])
    renderer2.SetViewport([0.5, 0, 1, 1])
    renderer2.SetActiveCamera(renderer1.GetActiveCamera())
    renderWindow.SetSize(1900,800)

    renderWindowInteractor = vtkRenderWindowInteractor()
    style = vtkInteractorStyleTrackballCamera()
    renderWindowInteractor.SetInteractorStyle(style)

    renderWindowInteractor.SetRenderWindow(renderWindow)
    renderWindowInteractor.Initialize()
    renderWindowInteractor.Start()

In [None]:
view_both(75, princ_eigenvalue, princ_diffusion_direction, princ_eigenvalue_after_warp, princ_diffusion_direction_after_warp)

Let's now try with a handmade rotation+scaling and observe that the colors remain the same (eigenvalues are preserved) and the vectors rotate properly.

In [None]:
def get_example_ddf_3d(s_x, s_y=None, s_z=None, th=2*np.pi/8, oy=0.5, oz=0.5, scaling = 1.0):
    """Get an example DDF (direct displacement field).
    Arguments:
        s_x, s_y. s_z: The x,y,z scales. Provide s_x only to have them be the same scale.
            "Scale" here really means "resolution." Think of it as the same underlying displacement,
            but meant to be applied to images at different resolutions.
        th: rotation angle in radians
        oy, oz: the rotation center in [0,1]\times [0,1] coordinates
        scaling: any scaling to also perform
    """
    if s_y is None:
        s_y=s_x
    if s_z is None:
        s_z = s_x
    m = np.array([[np.cos(th), -np.sin(th)],[np.sin(th), np.cos(th)]]) * scaling
    ddf = torch.tensor(
        [[[
            [
                (y-oy*s_y) * m[1,0] + (z-oz*s_z) * m[1,1] - (z-oz*s_z), # z component
                (y-oy*s_y) * m[0,0] + (z-oz*s_z) * m[0,1] - (y-oy*s_y), # y component
                0, # x component
            ]
            for x in range(s_x)]
            for y in range(s_y)]
            for z in range(s_z)
        ]
    ).permute((3,0,1,2)).float()
    return ddf

In [None]:
ddf = get_example_ddf_3d(140, th=2*np.pi/8, scaling=0.75)

In [None]:
preview_3D_vector_field(ddf)

In [None]:
warp_dti = dti_warp.WarpDTI(
    device = reg_model.device,
    tensor_transform_type=dti_warp.TensorTransformType.FINITE_STRAIN
)
dti = F_torch_lotri_batch
dti_warped = warp_dti(dti, ddf.unsqueeze(0))

eig = dipy.reconst.dti.eig_from_lo_tri(torch2dipy_lotri_batch(dti)[0])
princ_diffusion_direction = eig[:,:,:,3:6] # take first eigenvector (principal direction of diffusion)
princ_eigenvalue = eig[:,:,:,0]

eig2 = dipy.reconst.dti.eig_from_lo_tri(torch2dipy_lotri_batch(dti_warped)[0])
princ_diffusion_direction_after_warp = eig2[:,:,:,3:6] # take first eigenvector (principal direction of diffusion)
princ_eigenvalue_after_warp = eig2[:,:,:,0]

In [None]:
view_both(75, princ_eigenvalue, princ_diffusion_direction, princ_eigenvalue_after_warp, princ_diffusion_direction_after_warp)