In [1]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from scipy import interpolate

In [2]:
disp_f2m_ = torch.load('./disp_5.pth')
print(disp_f2m_.size())
# torch.save(disp_f2m_,'disp_5.pth')
origin_itk_ = [-250,-201,-158.5]
cor_pixel_=[257,207,33]
pixel_spacing_ = [0.9766, 0.9766, 5]
crop_range_ = [0,170,50]

torch.Size([3, 66, 260, 400])


In [3]:
class SpatialTransformer(nn.Module):
    # 2D or 3d spatial transformer network to calculate the warped moving image

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.grid_dict = {}
        self.norm_coeff_dict = {}

    def forward(self, input_image, flow):
        '''
        input_image: (n, 1, h, w) or (n, 1, d, h, w)
        flow: (n, 2, h, w) or (n, 3, d, h, w)

        return:
            warped moving image, (n, 1, h, w) or (n, 1, d, h, w)
        '''
        img_shape = input_image.shape[2:]
        if img_shape in self.grid_dict:
            grid = self.grid_dict[img_shape]
            norm_coeff = self.norm_coeff_dict[img_shape]
        else:
            grids = torch.meshgrid([torch.arange(0, s) for s in img_shape])
            grid = torch.stack(grids[::-1],
                               dim=0)  # 2 x h x w or 3 x d x h x w, the data in second dimension is in the order of [w, h, d]
            grid = torch.unsqueeze(grid, 0)
            grid = grid.to(dtype=flow.dtype, device=flow.device)
            norm_coeff = 2. / (torch.tensor(img_shape[::-1], dtype=flow.dtype,
                                            device=flow.device) - 1.)  # the coefficients to map image coordinates to [-1, 1]
            self.grid_dict[img_shape] = grid
            self.norm_coeff_dict[img_shape] = norm_coeff
            # logging.info(f'\nAdd grid shape {tuple(img_shape)}')
        new_grid = grid + flow

        if self.dim == 2:
            new_grid = new_grid.permute(0, 2, 3, 1)  # n x h x w x 2
        elif self.dim == 3:
            new_grid = new_grid.permute(0, 2, 3, 4, 1)  # n x d x h x w x 3

        if len(input_image) != len(new_grid):
            # make the image shape compatable by broadcasting
            input_image += torch.zeros_like(new_grid)
            new_grid += torch.zeros_like(input_image)

        warped_input_img = F.grid_sample(input_image, new_grid * norm_coeff - 1., mode='bilinear', align_corners=True,
                                         padding_mode='border')
        return warped_input_img

def pixel_to_itk(pixel, origin, spacing):
    x = (pixel[0]-1)*spacing[0]+origin[0]
    y = (pixel[1]-1)*spacing[1]+origin[1]
    z = (pixel[2]-1)*spacing[2]+origin[2]
    return [x,y,z]

def itk_to_pixel(itk, origin, spacing):
    x = (itk[0] -origin[0])/spacing[0]+1
    y = (itk[1] -origin[1])/spacing[1]+1
    z = (itk[2] -origin[2])/spacing[2]+1
    return [x,y,z]

def pixel_to_crop_pixel(pixel, crop_range):
    x = pixel[0]-crop_range[2]
    y = pixel[1]-crop_range[1]
    z = pixel[2]-crop_range[0]
    return [x,y,z]

def inverse_disp(disp, threshold=0.01, max_iteration=20):
    '''
    compute the inverse field. implementation of "A simple fixed‐point approach to invert a deformation field"

    disp : (n, 2, h, w) or (n, 3, d, h, w) or (2, h, w) or (3, d, h, w)
        displacement field
    '''
    dim=3
    spatial_transformer=SpatialTransformer(dim)
    forward_disp = disp.detach().cuda()
    if disp.ndim < dim+2:
        forward_disp = torch.unsqueeze(forward_disp, 0)
    backward_disp = torch.zeros_like(forward_disp)
    backward_disp_old = backward_disp.clone()
    for i in range(max_iteration):
        backward_disp = -spatial_transformer(forward_disp, backward_disp)
        diff = torch.max(torch.abs(backward_disp - backward_disp_old)).item()
        if diff < threshold:
            break
        backward_disp_old = backward_disp.clone()
    if disp.ndim < dim + 2:
        backward_disp = torch.squeeze(backward_disp, 0)

    return backward_disp

def get_dvf(coordinate_itk, origin_itk, pixel_spacing,crop_range, disp_f2m):
    coordinate_pixel = itk_to_pixel(coordinate_itk,origin_itk, pixel_spacing)
    coordinate_crop_pixel = pixel_to_crop_pixel(coordinate_pixel, crop_range)

    disp_m2f = inverse_disp(disp_f2m, threshold=0.01, max_iteration=20)
    image_shape = disp_m2f.size()[1:]
    grid_tuple = [np.arange(grid_length, dtype=np.float32) for grid_length in image_shape]

    inter = interpolate.RegularGridInterpolator(grid_tuple,
                                                np.moveaxis(disp_m2f.detach().cpu().numpy(), 0, -1))
    pred = inter(np.flip(coordinate_crop_pixel,0))
    x = pred[0][0]*pixel_spacing[0]
    y = pred[0][1]*pixel_spacing[1]
    z = pred[0][2]*pixel_spacing[2]
    return [x,y,z]
    # return pred

In [4]:
get_dvf(coordinate_itk=[74.23,98.82,6.48], origin_itk=origin_itk_,
        pixel_spacing=pixel_spacing_,crop_range = crop_range_,
        disp_f2m=disp_f2m_)


[-0.4760943556981456, -1.8668366722960572, 5.167355481570456]