In [16]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from scipy import interpolate

In [27]:
class SpatialTransformer(nn.Module):
    # 2D or 3d spatial transformer network to calculate the warped moving image

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.grid_dict = {}
        self.norm_coeff_dict = {}

    def forward(self, input_image, flow):
        '''
        input_image: (n, 1, h, w) or (n, 1, d, h, w)
        flow: (n, 2, h, w) or (n, 3, d, h, w)

        return:
            warped moving image, (n, 1, h, w) or (n, 1, d, h, w)
        '''
        img_shape = input_image.shape[2:]
        if img_shape in self.grid_dict:
            grid = self.grid_dict[img_shape]
            norm_coeff = self.norm_coeff_dict[img_shape]
        else:
            grids = torch.meshgrid([torch.arange(0, s) for s in img_shape])
            grid = torch.stack(grids[::-1],
                               dim=0)  # 2 x h x w or 3 x d x h x w, the data in second dimension is in the order of [w, h, d]
            grid = torch.unsqueeze(grid, 0)
            grid = grid.to(dtype=flow.dtype, device=flow.device)
            norm_coeff = 2. / (torch.tensor(img_shape[::-1], dtype=flow.dtype,
                                            device=flow.device) - 1.)  # the coefficients to map image coordinates to [-1, 1]
            self.grid_dict[img_shape] = grid
            self.norm_coeff_dict[img_shape] = norm_coeff
            # logging.info(f'\nAdd grid shape {tuple(img_shape)}')
        new_grid = grid + flow

        if self.dim == 2:
            new_grid = new_grid.permute(0, 2, 3, 1)  # n x h x w x 2
        elif self.dim == 3:
            new_grid = new_grid.permute(0, 2, 3, 4, 1)  # n x d x h x w x 3

        if len(input_image) != len(new_grid):
            # make the image shape compatable by broadcasting
            input_image += torch.zeros_like(new_grid)
            new_grid += torch.zeros_like(input_image)

        warped_input_img = F.grid_sample(input_image, new_grid * norm_coeff - 1., mode='bilinear', align_corners=True,
                                         padding_mode='border')
        return warped_input_img

def pixel_to_itk(pixel, origin, spacing):
    x = (pixel[0]-1)*spacing[0]+origin[0]
    y = (pixel[1]-1)*spacing[1]+origin[1]
    z = (pixel[2]-1)*spacing[2]+origin[2]
    return [x,y,z]

def itk_to_pixel(itk, origin, spacing):
    x = (itk[0] -origin[0])/spacing[0]+1
    y = (itk[1] -origin[1])/spacing[1]+1
    z = (itk[2] -origin[2])/spacing[2]+1
    return [x,y,z]

def pixel_to_crop_pixel(pixel, crop_range):
    x = pixel[0]-crop_range[2]
    y = pixel[1]-crop_range[1]
    z = pixel[2]-crop_range[0]
    return [x,y,z]

def inverse_disp(disp, threshold=0.01, max_iteration=20):
    '''
    compute the inverse field. implementation of "A simple fixed‐point approach to invert a deformation field"

    disp : (n, 2, h, w) or (n, 3, d, h, w) or (2, h, w) or (3, d, h, w)
        displacement field
    '''
    dim=3
    spatial_transformer=SpatialTransformer(dim)
    forward_disp = disp.detach().cuda()
    if disp.ndim < dim+2:
        forward_disp = torch.unsqueeze(forward_disp, 0)
    backward_disp = torch.zeros_like(forward_disp)
    backward_disp_old = backward_disp.clone()
    for i in range(max_iteration):
        backward_disp = -spatial_transformer(forward_disp, backward_disp)
        diff = torch.max(torch.abs(backward_disp - backward_disp_old)).item()
        if diff < threshold:
            break
        backward_disp_old = backward_disp.clone()
    if disp.ndim < dim + 2:
        backward_disp = torch.squeeze(backward_disp, 0)

    return backward_disp

def get_dvf(coordinate_itk, origin, spacing,crop_range, disp_m2f):
    coordinate_pixel = itk_to_pixel(coordinate_itk,origin, spacing)
    coordinate_crop_pixel = pixel_to_crop_pixel(coordinate_pixel, crop_range)

    image_shape = disp_m2f.size()[1:]
    grid_tuple = [np.arange(grid_length, dtype=np.float32) for grid_length in image_shape]

    inter = interpolate.RegularGridInterpolator(grid_tuple,
                                                np.moveaxis(disp_m2f.detach().cpu().numpy(), 0, -1))
    pred = inter(np.flip(coordinate_crop_pixel,0))
    x = pred[0][0]*spacing[0]
    y = pred[0][1]*spacing[1]
    z = pred[0][2]*spacing[2]
    return [round(coordinate_itk[0]+x,2),
            round(coordinate_itk[1]+y,2),
            round(coordinate_itk[2]+z,2)]

def get_start_end_needle(str_ori):
    """
    input: [120.835 -60.0364 17.1076]  [55.7537 129.078 17.1076]
    func: 将上述string转成start和end两个数组
    output: start和end两个数组
    """
    index1 = str_ori.index("[")
    index2 = str_ori.index("]  [")
    str_start = str_ori[index1 + 1:index2]
    str_end = str_ori[index2 + 4:-1]
    x_start = float(str_start.split(' ')[0])
    y_start = float(str_start.split(' ')[1])
    z_start = float(str_start.split(' ')[2])

    x_end = float(str_end.split(' ')[0])
    y_end = float(str_end.split(' ')[1])
    z_end = float(str_end.split(' ')[2])
    return [x_start,y_start,z_start],[x_end,y_end,z_end]

def get_seed(str_ori):
    """
    input: [63.8889 105.439 17.1076]
    func: 将上述string转成数组
    output: 数组[x,y,z]
    """
    str_ori = str_ori[1:-1]
    x = float(str_ori.split(' ')[0])
    y = float(str_ori.split(' ')[1])
    z = float(str_ori.split(' ')[2])
    return [x,y,z]

def get_start_end_seed_list():
    start_list=[]
    end_list=[]
    seed_list=[]
    with open("plan.needle", "r") as fo:  # 打开文件
        for line in fo.readlines():                          #依次读取每行
            line = line.strip()                             #去掉每行头尾空白
            start, end = get_start_end_needle(line)
            start_list.append(start)
            end_list.append(end)
        # 关闭文件
        fo.close()

    with open("plan.seed", "r") as fo:  # 打开文件
        for line in fo.readlines():                          #依次读取每行
            line = line.strip()                             #去掉每行头尾空白
            seed = get_seed(line)
            seed_list.append(seed)
        # 关闭文件
        fo.close()
    return start_list,end_list,seed_list

In [24]:
origin_itk = [-250,-201,-158.5]
pixel_spacing = [0.9766, 0.9766, 5]
crop_range = [0,170,50]
disp_f2m = torch.load('./disp.pth')[5]
disp_m2f = inverse_disp(disp_f2m, threshold=0.01, max_iteration=20)
start_list, end_list, seed_list = get_start_end_seed_list()

In [28]:
for end in end_list:
    print('原坐标：',end)
    dvf = get_dvf(coordinate_itk=end, origin=origin_itk,
            spacing=pixel_spacing,crop_range = crop_range,
            disp_m2f=disp_m2f)
    print('新坐标：',dvf)


原坐标： [55.7537, 129.078, 17.1076]
新坐标： [55.4, 126.08, 23.03]
原坐标： [32.0171, 121.311, 7.75172]
新坐标： [32.11, 119.17, 14.61]
原坐标： [41.3269, 124.507, 22.3141]
新坐标： [41.57, 121.58, 28.42]
原坐标： [60.4165, 125.073, 21.1303]
新坐标： [60.1, 121.88, 27.49]
原坐标： [37.187, 123.36, -2.52513]
新坐标： [37.14, 121.26, 5.46]
原坐标： [45.9483, 110.915, 36.7143]
新坐标： [46.15, 109.35, 42.39]
原坐标： [77.3006, 122.53, 3.4654]
新坐标： [76.86, 119.27, 9.79]
原坐标： [30.8012, 119.034, 0.626806]
新坐标： [30.83, 116.92, 8.47]
原坐标： [52.3672, 118.42, 36.0254]
新坐标： [52.4, 115.88, 41.82]
原坐标： [30.4098, 111.339, 37.4032]
新坐标： [30.78, 109.91, 41.76]
原坐标： [65.3505, 125.655, 26.8033]
新坐标： [65.11, 121.95, 33.74]
原坐标： [41.225, 124.673, -14.6858]
新坐标： [41.17, 123.15, -6.64]
原坐标： [48.773, 129.703, -22.9932]
新坐标： [48.94, 128.11, -14.35]
原坐标： [29.244, 117.653, 16.4379]
新坐标： [29.24, 115.57, 21.71]
原坐标： [44.5351, 106.11, 54.5189]
新坐标： [44.74, 105.07, 59.51]
原坐标： [77.21, 117.818, -11.1537]
新坐标： [76.68, 113.93, -3.8]
