In [None]:
from pathlib import Path
import random
import dipy.io.image
import dipy.reconst.dti
import matplotlib.pyplot as plt

In [None]:
def preview(img):
    fig,axs = plt.subplots(1,3,figsize=(20,10))
    axs[0].imshow(img[62,:,:].T, origin='lower', cmap='gray')
    axs[1].imshow(img[:,:,80].T, origin='lower', cmap='gray')
    axs[2].imshow(img[:,75,:].T, origin='lower', cmap='gray')
    plt.show()
    
dti_image_paths = list(Path('dti_fit_images_nontest/dti/').glob('*'))

In [None]:
dti_image_path = random.choice(dti_image_paths)
print(dti_image_path)

In [None]:
img_data, affine = dipy.io.image.load_nifti(dti_image_path)
img_data.shape

Now `img_data` is a numpy array of shape (140,140,140,6), representing the lower triangular entries of a diffusion tensor on a space of shape (140,140,140). I believe they are in the order Dxx, Dxy, Dyy, Dxz, Dyz, Dzz; see [here](https://dipy.org/documentation/1.4.0./reference/dipy.reconst).

In [None]:
dti = dipy.reconst.dti.from_lower_triangular(img_data)

assert((dti[:,:,:,0,1]==dti[:,:,:,1,0]).all())
assert((dti[:,:,:,1,2]==dti[:,:,:,2,1]).all())
assert((dti[:,:,:,0,2]==dti[:,:,:,2,0]).all())
dti.shape

Above we have produced the 3x3 symmetric matrices from the lower triangular part

In [None]:
import torch

In [None]:
dti_tensor = torch.tensor(dti).permute((3,4,0,1,2)).unsqueeze(0)

# Nedd to think about dipy axis order more carefully, but for now let's see if solving will work at all.
dti_tensor.shape

In [None]:
# Let's compute a warp from FA images, just so we have a warp to play with

from fa_deformable_registration_models.reg_model1 import RegModel

reg_model = RegModel()

In [None]:
# Load the FA of this DTI image and then also load some other random FA image

fa_image_path = dti_image_path.parent.parent/'fa'/dti_image_path.name
fa_image_path2 = random.choice(list((dti_image_path.parent.parent/'fa').glob('*')))
print(fa_image_path, fa_image_path2, sep='\n')

In [None]:
# Turn the FA images into tensors and compute a deformation that aligns our original FA image to the random one

fa_img, affine = dipy.io.image.load_nifti(fa_image_path)
fa_img2, affine = dipy.io.image.load_nifti(fa_image_path2)

fa_tensor1 = torch.tensor(fa_img, dtype=torch.float32).unsqueeze(0)
fa_tensor2 = torch.tensor(fa_img2, dtype=torch.float32).unsqueeze(0)

ddf = reg_model.forward(fa_tensor2, fa_tensor1)

from util import preview_3D_vector_field
preview_3D_vector_field(ddf)

In [None]:
from spatial_derivatives import DerivativeOfDDF

deriv_ddf = DerivativeOfDDF(device=reg_model.device)

In [None]:
# Compute the derivative matrix field of the warp

c,h,w,d = ddf.shape
b=1
assert(c==3)
J = deriv_ddf(ddf.unsqueeze(0)).reshape(b,3,3,h,w,d)
J.shape

In [None]:
# Take our original DTI and make it a tensor

dti_tensor = dti_tensor.to(J)
dti_tensor.shape

In [None]:
# Next we will show how we transform the DTI tensors based on the warping

# First, move the spatial dimensions into the batch dimension

F_batched = dti_tensor.permute((0,3,4,5,1,2)).reshape((-1,3,3))
J_batched = J.permute((0,3,4,5,1,2)).reshape(-1,3,3)

In [None]:
# Use torch.linalg.solve on the batch to compute the transformed diffusion tensors

G = torch.linalg.solve(J_batched.permute((0,2,1)), F_batched)
F_transformed_batched = torch.linalg.solve(J_batched.permute((0,2,1)), G.permute((0,2,1))).permute((0,2,1))

In [None]:
# Move the spatial dimensions back out of the batch dimension

F_transformed = F_transformed_batched.reshape(b,h,w,d,3,3).permute((0,4,5,1,2,3))
F_transformed.shape

In [None]:
# Convert to lower triangular form (convert to dipy indexing,
# use dipy to convert to lower triangular form, then convert back to usual indexing)
F_transformed_lo_tri = dipy.reconst.dti.lower_triangular(F_transformed[0].permute((2,3,4,0,1))).permute((3,0,1,2)).unsqueeze(0)
print(F_transformed_lo_tri.shape)

# Apply the spatial transformation to actually moe the tensors in space
# (Above we transformed the tensors in place, but we didn't move them yet)
F_transformed_lo_tri_warped = reg_model.model.warp(F_transformed_lo_tri, ddf.unsqueeze(0))
print(F_transformed_lo_tri_warped.shape)

In [None]:
# Switch to dipy indexing  so we can compute a new FA image out of our fully transformed DTI
F_transformed_lo_tri_warped_dipy = F_transformed_lo_tri_warped[0].permute((1,2,3,0)).cpu().numpy()
print(F_transformed_lo_tri_warped_dipy.shape)

# Compute the new FA image
eig = dipy.reconst.dti.eig_from_lo_tri(F_transformed_lo_tri_warped_dipy) # has eigenvals and eigenvecs
eigvals = eig[:,:,:,:3] # take only the eigenvals
fa_after_transform = dipy.reconst.dti.fractional_anisotropy(eigvals) # take only the eigenvals
print(fa_after_transform.shape)

In [None]:
print("FA image 1:")
preview(fa_img)
print("FA image 2:")
preview(fa_img2)
print("The result of inferring a deformation from FA image 1 to FA image 2, using that deformation to transform DTI image 1, and then computing the FA of the resulting transformed DTI:")
preview(fa_after_transform)

Something is horribly wrong. I wonder if my assumptions were wrong about DIPY's ordering of axes.
Or maybe my tensor transformation law implementation has an error.