In [1]:
import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import json
from scipy import interpolate

In [2]:
class SpatialTransformer(nn.Module):
    # 2D or 3d spatial transformer network to calculate the warped moving image

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.grid_dict = {}
        self.norm_coeff_dict = {}

    def forward(self, input_image, flow):
        '''
        input_image: (n, 1, h, w) or (n, 1, d, h, w)
        flow: (n, 2, h, w) or (n, 3, d, h, w)

        return:
            warped moving image, (n, 1, h, w) or (n, 1, d, h, w)
        '''
        img_shape = input_image.shape[2:]
        if img_shape in self.grid_dict:
            grid = self.grid_dict[img_shape]
            norm_coeff = self.norm_coeff_dict[img_shape]
        else:
            grids = torch.meshgrid([torch.arange(0, s) for s in img_shape])
            grid = torch.stack(grids[::-1],
                               dim=0)  # 2 x h x w or 3 x d x h x w, the data in second dimension is in the order of [w, h, d]
            grid = torch.unsqueeze(grid, 0)
            grid = grid.to(dtype=flow.dtype, device=flow.device)
            norm_coeff = 2. / (torch.tensor(img_shape[::-1], dtype=flow.dtype,
                                            device=flow.device) - 1.)  # the coefficients to map image coordinates to [-1, 1]
            self.grid_dict[img_shape] = grid
            self.norm_coeff_dict[img_shape] = norm_coeff
            # logging.info(f'\nAdd grid shape {tuple(img_shape)}')
        new_grid = grid + flow

        if self.dim == 2:
            new_grid = new_grid.permute(0, 2, 3, 1)  # n x h x w x 2
        elif self.dim == 3:
            new_grid = new_grid.permute(0, 2, 3, 4, 1)  # n x d x h x w x 3

        if len(input_image) != len(new_grid):
            # make the image shape compatable by broadcasting
            input_image += torch.zeros_like(new_grid)
            new_grid += torch.zeros_like(input_image)

        warped_input_img = F.grid_sample(input_image, new_grid * norm_coeff - 1., mode='bilinear', align_corners=True,
                                         padding_mode='border')
        return warped_input_img

def pixel_to_itk(pixel, origin, spacing):
    x = (pixel[0]-1)*spacing[0]+origin[0]
    y = (pixel[1]-1)*spacing[1]+origin[1]
    z = (pixel[2]-1)*spacing[2]+origin[2]
    return [x,y,z]

def itk_to_pixel(itk, origin, spacing):
    x = (itk[0] -origin[0])/spacing[0]+1
    y = (itk[1] -origin[1])/spacing[1]+1
    z = (itk[2] -origin[2])/spacing[2]+1
    return [x,y,z]

def pixel_to_crop_pixel(pixel, crop_range):
    x = pixel[0]-crop_range[2]
    y = pixel[1]-crop_range[1]
    z = pixel[2]-crop_range[0]
    return [x,y,z]

def inverse_disp(disp, threshold=0.01, max_iteration=20):
    '''
    compute the inverse field. implementation of "A simple fixed‐point approach to invert a deformation field"

    disp : (n, 2, h, w) or (n, 3, d, h, w) or (2, h, w) or (3, d, h, w)
        displacement field
    '''
    dim=3
    spatial_transformer=SpatialTransformer(dim)
    forward_disp = disp.detach().cuda()
    if disp.ndim < dim+2:
        forward_disp = torch.unsqueeze(forward_disp, 0)
    backward_disp = torch.zeros_like(forward_disp)
    backward_disp_old = backward_disp.clone()
    for i in range(max_iteration):
        backward_disp = -spatial_transformer(forward_disp, backward_disp)
        diff = torch.max(torch.abs(backward_disp - backward_disp_old)).item()
        if diff < threshold:
            break
        backward_disp_old = backward_disp.clone()
    if disp.ndim < dim + 2:
        backward_disp = torch.squeeze(backward_disp, 0)

    return backward_disp

# 从指定的形变场中提取指定位置的形变值
def get_dvf(coordinate_itk, origin, spacing, crop_range, disp_m2f):
    coordinate_pixel = itk_to_pixel(coordinate_itk,origin, spacing)
    coordinate_crop_pixel = pixel_to_crop_pixel(coordinate_pixel, crop_range)

    image_shape = disp_m2f.size()[1:]
    grid_tuple = [np.arange(grid_length, dtype=np.float32) for grid_length in image_shape]

    inter = interpolate.RegularGridInterpolator(grid_tuple,
                                                np.moveaxis(disp_m2f.detach().cpu().numpy(), 0, -1))
    pred = inter(np.flip(coordinate_crop_pixel,0))
    x = pred[0][0]*spacing[0]
    y = pred[0][1]*spacing[1]
    z = pred[0][2]*spacing[2]
    return [round(coordinate_itk[0]+x,2),
            round(coordinate_itk[1]+y,2),
            round(coordinate_itk[2]+z,2)]

def get_start_end_needle(str_ori):
    """
    input: [120.835 -60.0364 17.1076]  [55.7537 129.078 17.1076]
    func: 将上述string转成start和end两个数组
    output: start和end两个数组
    """
    index1 = str_ori.index("[")
    index2 = str_ori.index("]  [")
    str_start = str_ori[index1 + 1:index2]
    str_end = str_ori[index2 + 4:-1]
    x_start = float(str_start.split(' ')[0])
    y_start = float(str_start.split(' ')[1])
    z_start = float(str_start.split(' ')[2])

    x_end = float(str_end.split(' ')[0])
    y_end = float(str_end.split(' ')[1])
    z_end = float(str_end.split(' ')[2])
    return [x_start,y_start,z_start],[x_end,y_end,z_end]

def get_seed(str_ori):
    """
    input: [63.8889 105.439 17.1076]
    func: 将上述string转成数组
    output: 数组[x,y,z]
    """
    str_ori = str_ori[1:-1]
    x = float(str_ori.split(' ')[0])
    y = float(str_ori.split(' ')[1])
    z = float(str_ori.split(' ')[2])
    return [x,y,z]

def get_start_end_seed_list():
    start_list=[]
    end_list=[]
    seed_list=[]
    with open("plan.needle", "r") as fo:  # 打开文件
        for line in fo.readlines():                          #依次读取每行
            line = line.strip()                             #去掉每行头尾空白
            start, end = get_start_end_needle(line)
            start_list.append(start)
            end_list.append(end)
        # 关闭文件
        fo.close()

    with open("plan.seed", "r") as fo:  # 打开文件
        for line in fo.readlines():                          #依次读取每行
            line = line.strip()                             #去掉每行头尾空白
            seed = get_seed(line)
            seed_list.append(seed)
        # 关闭文件
        fo.close()
    return start_list,end_list,seed_list

In [3]:
origin_itk = [-250,-201,-158.5]
pixel_spacing = [0.9766, 0.9766, 5]
crop_range = [0,170,50]
disp_list=[]
for index in range(10):
    disp = inverse_disp(torch.load('./disp.pth')[index], threshold=0.01, max_iteration=20)
    disp_list.append(disp)
start_list, end_list, seed_list = get_start_end_seed_list()

In [4]:
'''
    从BUAA提供的规划数据和预测的形变场数据中提取在不同时刻的起点、终点、粒子靶点，写入json
'''
for time_index in range(10):
    disp = disp_list[time_index]
    article_info = {}
    data = json.loads(json.dumps(article_info))
    data['id'] = 3
    data['name'] = 'updatePath'
    pathArray=[]
    for index in range(len(start_list)):
        start =','.join('%s' %i for i in start_list[index])
        end_pred = get_dvf(coordinate_itk=end_list[index], origin=origin_itk,
                spacing=pixel_spacing,crop_range = crop_range,
                disp_m2f=disp)
        end = ','.join('%s' %i for i in end_pred)
        seeds=[]
        z = start_list[index][2]
        for seed in seed_list:
            if seed[2] == z:
                seed_pred = get_dvf(coordinate_itk=seed, origin=origin_itk,
                                    spacing=pixel_spacing,crop_range = crop_range,
                                    disp_m2f=disp)
                seeds.append(','.join('%s' %i for i in seed_pred))

        path={"end": end,"id": index+1,'seed':seeds ,'start': start}
        pathArray.append(path)

    data['pathArray']=pathArray
    data['type']=0

    article = json.dumps(data, ensure_ascii=False)

    print(article)
    with open(f'data{time_index}.json','w',encoding='utf-8') as f:
      f.write(article)

{"id": 3, "name": "updatePath", "pathArray": [{"end": "55.75,129.08,17.11", "id": 1, "seed": ["63.88,105.46,17.14", "67.14,96.0,17.15", "62.25,110.19,17.15", "70.39,86.53,17.2", "72.02,81.79,17.2", "73.65,77.05,17.15", "55.75,129.08,17.11", "68.76,91.27,17.17", "65.51,100.73,17.15", "60.62,114.92,17.16", "58.98,119.65,17.15", "57.37,124.37,17.11"], "start": "120.835,-60.0364,17.1076"}, {"end": "32.0,121.36,7.77", "id": 2, "seed": ["72.45,101.58,7.71", "54.48,110.37,7.73", "81.44,97.19,7.72", "45.5,114.76,7.73", "90.42,92.81,7.75", "85.93,94.99,7.73", "76.95,99.38,7.71", "67.96,103.78,7.71", "63.47,105.98,7.71", "58.98,108.17,7.71", "49.99,112.56,7.73", "41.01,116.95,7.77", "36.5,119.15,7.77", "32.0,121.36,7.77"], "start": "211.761,33.6063,7.75172"}, {"end": "41.31,124.53,22.35", "id": 3, "seed": ["49.46,100.89,22.37", "46.2,110.34,22.35", "52.71,91.42,22.38", "42.94,119.79,22.37", "54.33,86.68,22.39", "55.96,81.95,22.35", "51.09,96.16,22.34", "47.84,105.62,22.3", "44.57,115.07,22.38", 