In [1]:
import functools
import re
import typing
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation


In [2]:
#adapted from https://stackoverflow.com/questions/74519927/best-way-to-rotate-and-translate-a-set-of-points-in-python

def rotate_and_translate(points: jnp.ndarray
                         , center_point: jnp.ndarray, rotation_vector: jnp.ndarray,
                         translation_vector: jnp.ndarray) -> jnp.ndarray:
    rotation_matrix = Rotation.from_rotvec(rotation_vector).as_matrix()
    return (points - center_point) @ rotation_matrix.T + center_point + translation_vector



# rotateMatrix = Rotation.from_rotvec(jnp.array([ jnp.pi,0.0, 0.0])).as_matrix()
points=jnp.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]])
center_point=jnp.array([0.0, 0.0, 0.0])
# newxy = (points-center_point) @ rotateMatrix.T + center_point

newxy=rotate_and_translate(points, center_point, jnp.array([ jnp.pi,0.0, 0.0]), jnp.array([2.0, 2.0, 2.0]))
newxy.round(2)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([[2., 2., 2.],
       [2., 1., 2.],
       [2., 3., 2.]], dtype=float32)

In [3]:
weights=jnp.array([1.0, 1.0, 1.0,2.0, 2.0, 2.0,3.0, 3.0, 3.0])
weights[0:3]
weights[3:6]

Array([2., 2., 2.], dtype=float32)

In [4]:
import functools
import re
import typing
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation


def rotate_and_translate(points: jnp.ndarray
                         , center_point: jnp.ndarray, rotation_vector: jnp.ndarray,
                         translation_vector: jnp.ndarray) -> jnp.ndarray:
    rotation_matrix = Rotation.from_rotvec(rotation_vector).as_matrix()
    return (points - center_point) @ rotation_matrix.T + center_point + translation_vector


def get_fiducial_loos(weights,from_landmarsk,to_landmarks,image_shape):
    """
    first entries in in weights are :
        0-3 first rotation vector   
        3-6 translation vector   
    so we interpret the weights as input for transormations rotation;translation
    we apply this transformation to the fiducial points of moving image and 
    calculate the square distance between transformed fiducial points and fiducial points on fixed image           
    """
    center_point=(jnp.asarray(image_shape) - 1.) / 2
    res=rotate_and_translate(to_landmarks, center_point, weights[0:3],weights[3:6])
    #calculate the square distance between transformed fiducial points and fiducial points on fixed image 
    return jnp.sum(((from_landmarsk-res)**2).flatten())
    
    
def transform_image(image,weights):
    """
    first entries in in weights are :
    0-3 first rotation vector   
    3-6 translation vector   
    so we interpret the weights as input for transormations rotation;translation;rotation
    
    """    
    r = Rotation.from_rotvec(weights[0:3])
    image=r.apply(image)
    image=jax.image.scale_and_translate(image, image.shape,jnp.array([0,1,2]), jnp.array([1.0,1.0,1.0]), weights[3:6], "bicubic")
    return image





In [5]:
import SimpleITK as sitk
import jax.numpy as jnp


def resample_ct_to_suv(ct: sitk.Image, suv: sitk.Image) -> sitk.Image:
    """
    Resample a CT image to the same size as a SUV image
    """
    resampler = sitk.ResampleImageFilter()
    resampler.SetInterpolator(sitk.sitkBSpline)
    resampler.SetOutputSpacing(suv.GetSpacing())
    resampler.SetSize(suv.GetSize())
    resampler.SetOutputDirection(suv.GetDirection())
    resampler.SetOutputOrigin(suv.GetOrigin())
    ct= resampler.Execute(ct)
    
    ct_arr=sitk.GetArrayFromImage(ct)
    suv_arr=sitk.GetArrayFromImage(suv)
    
    res=jnp.stack([jnp.array(suv_arr),jnp.array(ct_arr)],axis=0)
    return res

def load_landmark_data(folder_path:str):
    """
    given path to folder with landmarks files and images after general registaration we load the data
    we want to first load the suv and ct images resample them to the same size and then load the landmarks
    we need to load separately study 0 and 1 
    the output should be in form of a dictionary with keys 'study_0','study_1','From`,`To`' where `From` and `To` are the landmarks
    all the data should be in form of jnp.arrays
    """
    ct_0=sitk.ReadImage(folder_path+'/study_0_ct_soft.nii.gz')
    suv_0=sitk.ReadImage(folder_path+'/study_0_SUVS.nii.gz')
    # Resample ct_0 to match ct_1
    arr_0 = resample_ct_to_suv(ct_0, suv_0)
            
    ct_1=sitk.ReadImage(folder_path+'/study_1_ct_soft.nii.gz')
    suv_1=sitk.ReadImage(folder_path+'/study_1_SUVS.nii.gz')    
    arr_1 = resample_ct_to_suv(ct_1, suv_1)

    return {'study_0':arr_0,'study_1':arr_1, 'From':jnp.load(folder_path+'/From.npy'),'To':jnp.load(folder_path+'/To.npy')}




folder_path='/root/data/pat_2/general_transform'
load_landmark_data(folder_path)

# cp /root/data/pat_2/To.npy /root/data/pat_2/general_transform/To.npy

{'study_0': Array([[[[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
  

In [10]:
import os
import jax
import jax.numpy as jnp

folder_path = "/root/data"
folder_names = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))]
folder_names=list(filter(lambda x: re.match(r'pat_\d+',x),folder_names))
folder_names=list(map(lambda x: f"{folder_path}/{x}/general_transform",folder_names))

curr_folder='/root/data/pat_2/general_transform'
data_dict=load_landmark_data(curr_folder)

def get_data_for_pretained_model(image_dat,landmark_data, rng):
    """
    given the image data and landmark data we return the data in the format that can be used for the pretrained model
    so we make a random transformation of the image and the landmarks and return the transformed image and landmarks 
    """

    #random weights for the transformation
    weights = jax.random.uniform(rng, shape=(6,), minval=0, maxval=300)

    
    transform_image(image_dat,weights)

    rotate_and_translate(points: jnp.ndarray
                            , center_point: jnp.ndarray, rotation_vector: jnp.ndarray,
                            translation_vector: jnp.ndarray)


    return data_dict['study_0'],data_dict['study_1'],data_dict['From'],data_dict['To']


random_seed = 42
rng = jax.random.PRNGKey(random_seed)

['/root/data/pat_9/general_transform', '/root/data/pat_20/general_transform', '/root/data/pat_18/general_transform', '/root/data/pat_12/general_transform', '/root/data/pat_31/general_transform', '/root/data/pat_24/general_transform', '/root/data/pat_4/general_transform', '/root/data/pat_11/general_transform', '/root/data/pat_28/general_transform', '/root/data/pat_25/general_transform', '/root/data/pat_29/general_transform', '/root/data/pat_15/general_transform', '/root/data/pat_14/general_transform', '/root/data/pat_5/general_transform', '/root/data/pat_27/general_transform', '/root/data/pat_22/general_transform', '/root/data/pat_21/general_transform', '/root/data/pat_3/general_transform', '/root/data/pat_10/general_transform', '/root/data/pat_23/general_transform', '/root/data/pat_8/general_transform', '/root/data/pat_6/general_transform', '/root/data/pat_13/general_transform', '/root/data/pat_7/general_transform', '/root/data/pat_19/general_transform', '/root/data/pat_16/general_tran

In [6]:
r = Rotation.from_rotvec(jnp.array([2*jnp.pi, 0.0, 0.0])) ####### it is in radians
# imagee= jnp.zeros((3,3,3))

# Create consecutive array of integers
imagee = jnp.arange(27)
imagee = imagee.reshape((3, 3, 3))

print(imagee)
imagee=r.apply(imagee)
# imagee=jax.image.scale_and_translate(imagee, imagee.shape,jnp.array([0,1,2]), jnp.array([1.0,1.0,1.0]), jnp.array([0.1,0.0,0.0]), "bicubic")
imagee

[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]


Array([[[ 0.        ,  0.99999964,  2.0000002 ],
        [ 3.        ,  3.999999  ,  5.0000005 ],
        [ 6.        ,  6.9999986 ,  8.000001  ]],

       [[ 9.        ,  9.999998  , 11.000002  ],
        [12.        , 12.999997  , 14.000002  ],
        [15.        , 15.999997  , 17.000002  ]],

       [[18.        , 18.999996  , 20.000004  ],
        [21.        , 21.999996  , 23.000004  ],
        [24.        , 24.999996  , 26.000004  ]]], dtype=float32)