In [13]:
import functools
import re
import typing
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation
from jax.scipy.spatial.transform import Rotation, Slerp
import jax.numpy as jnp


In [28]:
#adapted from https://stackoverflow.com/questions/74519927/best-way-to-rotate-and-translate-a-set-of-points-in-python

def rotate_and_translate(points: jnp.ndarray
                         , center_point: jnp.ndarray, rotation_vector: jnp.ndarray,
                         translation_vector: jnp.ndarray) -> jnp.ndarray:
    # rotation_matrix = Rotation.from_rotvec(rotation_vector).as_matrix()
    rotation_matrix =Rotation.from_euler('xyz', rotation_vector, degrees=False).as_matrix()
    return (points - center_point) @ rotation_matrix.T + center_point + translation_vector



# rotateMatrix = Rotation.from_rotvec(jnp.array([ jnp.pi,0.0, 0.0])).as_matrix()
points=jnp.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]])
center_point=jnp.array([0.0, 0.0, 0.0])
# newxy = (points-center_point) @ rotateMatrix.T + center_point

newxy=rotate_and_translate(points, center_point, jnp.array([ jnp.pi,0.0, 0.0]), jnp.array([2.0, 2.0, 2.0]))
newxy.round(2)

Array([[2., 2., 2.],
       [2., 1., 2.],
       [2., 3., 2.]], dtype=float32)

In [3]:
weights=jnp.array([1.0, 1.0, 1.0,2.0, 2.0, 2.0,3.0, 3.0, 3.0])
weights[0:3]
weights[3:6]

Array([2., 2., 2.], dtype=float32)

In [4]:
import functools
import re
import typing
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation


def rotate_and_translate(points: jnp.ndarray
                         , center_point: jnp.ndarray, rotation_vector: jnp.ndarray,
                         translation_vector: jnp.ndarray) -> jnp.ndarray:
    # rotation_matrix = Rotation.from_rotvec(rotation_vector).as_matrix()

    rotation_matrix =Rotation.from_euler('xyz', rotation_vector, degrees=False).as_matrix()
    return (points - center_point) @ rotation_matrix.T + center_point + translation_vector


def get_fiducial_loos(weights,from_landmarsk,to_landmarks,image_shape):
    """
    first entries in in weights are :
        0-3 first rotation vector   
        3-6 translation vector   
    so we interpret the weights as input for transormations rotation;translation
    we apply this transformation to the fiducial points of moving image and 
    calculate the square distance between transformed fiducial points and fiducial points on fixed image           
    """
    center_point=(jnp.asarray(image_shape) - 1.) / 2
    res=rotate_and_translate(to_landmarks, center_point, weights[0:3],weights[3:6])
    #calculate the square distance between transformed fiducial points and fiducial points on fixed image 
    return jnp.sum(((from_landmarsk-res)**2).flatten())
    
    
def transform_image(image,weights):
    """
    first entries in in weights are :
    0-3 first rotation vector   
    3-6 translation vector   
    so we interpret the weights as input for transormations rotation;translation;rotation
    """    
    r = Rotation.from_rotvec(weights[0:3])
    image=r.apply(image)
    image=jax.image.scale_and_translate(image, image.shape,jnp.array([0,1,2]), jnp.array([1.0,1.0,1.0]), weights[3:6], "bicubic")
    return image





In [5]:
import SimpleITK as sitk
import jax.numpy as jnp


def resample_ct_to_suv(ct: sitk.Image, suv: sitk.Image) -> sitk.Image:
    """
    Resample a CT image to the same size as a SUV image
    """
    resampler = sitk.ResampleImageFilter()
    resampler.SetInterpolator(sitk.sitkBSpline)
    resampler.SetOutputSpacing(suv.GetSpacing())
    resampler.SetSize(suv.GetSize())
    resampler.SetOutputDirection(suv.GetDirection())
    resampler.SetOutputOrigin(suv.GetOrigin())
    ct= resampler.Execute(ct)
    
    ct_arr=sitk.GetArrayFromImage(ct)
    suv_arr=sitk.GetArrayFromImage(suv)
    
    res=jnp.stack([jnp.array(suv_arr),jnp.array(ct_arr)],axis=0)
    return res

def load_landmark_data(folder_path:str):
    """
    given path to folder with landmarks files and images after general registaration we load the data
    we want to first load the suv and ct images resample them to the same size and then load the landmarks
    we need to load separately study 0 and 1 
    the output should be in form of a dictionary with keys 'study_0','study_1','From`,`To`' where `From` and `To` are the landmarks
    all the data should be in form of jnp.arrays
    """
    ct_0=sitk.ReadImage(folder_path+'/study_0_ct_soft.nii.gz')
    suv_0=sitk.ReadImage(folder_path+'/study_0_SUVS.nii.gz')
    # Resample ct_0 to match ct_1
    arr_0 = resample_ct_to_suv(ct_0, suv_0)
            
    ct_1=sitk.ReadImage(folder_path+'/study_1_ct_soft.nii.gz')
    suv_1=sitk.ReadImage(folder_path+'/study_1_SUVS.nii.gz')    
    arr_1 = resample_ct_to_suv(ct_1, suv_1)

    return {'study_0':arr_0,'study_1':arr_1, 'From':jnp.load(folder_path+'/From.npy'),'To':jnp.load(folder_path+'/To.npy')}




folder_path='/root/data/pat_2/general_transform'
load_landmark_data(folder_path)

# cp /root/data/pat_2/To.npy /root/data/pat_2/general_transform/To.npy

{'study_0': Array([[[[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
  

In [6]:
import os
import jax
import jax.numpy as jnp

folder_path = "/root/data"
folder_names = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))]
folder_names=list(filter(lambda x: re.match(r'pat_\d+',x),folder_names))
folder_names=list(map(lambda x: f"{folder_path}/{x}/general_transform",folder_names))

curr_folder='/root/data/pat_2/general_transform'
data_dict=load_landmark_data(curr_folder)

# v_transform_image=jax.vmap(transform_image , in_axes = (0,None) )

def get_data_for_pretained_model(image_dat,landmark_data, rng):
    """
    given the image data and landmark data we return the data in the format that can be used for the pretrained model
    so we make a random transformation of the image and the landmarks and return the transformed image and landmarks 
    """
    #random weights for the transformation
    weights = jax.random.uniform(rng, shape=(6,), minval=0, maxval=300)
    print(f"image_dat {image_dat.shape} landmark_data {landmark_data.shape} weights {weights.shape}")
    tansf_image=jnp.stack([transform_image(image_dat[0,:,:,:],weights),transform_image(image_dat[1,:,:,:],weights)   ])
    center_point=(jnp.asarray(image_dat) - 1.) / 2
    center_point=center_point[0:3]
    transf_landmarsk=rotate_and_translate(landmark_data, center_point,weights[0:3],weights[3:6])


    return tansf_image,transf_landmarsk

# krowa transform the image save it and check if it tranforms correctly
 

random_seed = 42
rng = jax.random.PRNGKey(random_seed)
get_data_for_pretained_model(data_dict['study_0'],data_dict['From'], rng)

image_dat (2, 425, 200, 200) landmark_data (5, 3) weights (6,)


ValueError: inconsistent size for core dimension 'm': 200 vs 3 on vectorized function with excluded=frozenset() and signature='(m,m),(m),()->(m)'

In [34]:
import chex
from typing import Callable, Sequence, Tuple, Union
from dm_pix._src import augment

def affine_transform(
    image: chex.Array,
    matrix: chex.Array,
    *,
    offset: Union[chex.Array, chex.Numeric] = 0.,
    order: int = 1,
    mode: str = "constant",
    cval: float = 0.0,
) -> chex.Array:
  """Applies an affine transformation given by matrix.

  """

  meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in image.shape],
                          indexing="ij")
  print(f"meshgrid {meshgrid} image.shape {image.shape}")
  indices = jnp.concatenate(
      [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1)

  zz, yy, xx = meshgrid
  z_center, y_center,x_center= (jnp.asarray(image.shape) - 1.) / 2.
  indices = jnp.array([xx - x_center, yy - y_center, zz - z_center])

  # offset = matrix[:image.ndim, image.ndim]
  # matrix = matrix[:image.ndim, :image.ndim]

  coordinates = jnp.tensordot(matrix, indices, axes=((1), (0)))
  # coordinates = indices @ jnp.linalg.inv(matrix).T
  # coordinates = jnp.moveaxis(coordinates, source=-1, destination=0)

  # Alter coordinates to account for offset.
  # offset = jnp.full((3,), fill_value=offset)
  # coordinates += jnp.reshape(a=offset, newshape=(*offset.shape, 1, 1, 1))

  interpolate_function = augment._get_interpolate_function(
      mode=mode,
      order=order,
      cval=cval,
  )
  return interpolate_function(image, coordinates)

angles=jnp.array([1.1, 1.8, 2.1])
affine_transform(jnp.zeros((38,33,23)),Rotation.from_euler('xyz', angles, degrees=False).inv().as_matrix())
# b=rotate_3d(angles[0],angles[1],angles[2])
# print(f"a {a.round(2)} \n b {b.round(2)}")


meshgrid [Array([[[ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        ...,
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0],
        [ 0,  0,  0, ...,  0,  0,  0]],

       [[ 1,  1,  1, ...,  1,  1,  1],
        [ 1,  1,  1, ...,  1,  1,  1],
        [ 1,  1,  1, ...,  1,  1,  1],
        ...,
        [ 1,  1,  1, ...,  1,  1,  1],
        [ 1,  1,  1, ...,  1,  1,  1],
        [ 1,  1,  1, ...,  1,  1,  1]],

       [[ 2,  2,  2, ...,  2,  2,  2],
        [ 2,  2,  2, ...,  2,  2,  2],
        [ 2,  2,  2, ...,  2,  2,  2],
        ...,
        [ 2,  2,  2, ...,  2,  2,  2],
        [ 2,  2,  2, ...,  2,  2,  2],
        [ 2,  2,  2, ...,  2,  2,  2]],

       ...,

       [[35, 35, 35, ..., 35, 35, 35],
        [35, 35, 35, ..., 35, 35, 35],
        [35, 35, 35, ..., 35, 35, 35],
        ...,
        [35, 35, 35, ..., 35, 35, 35],
        [35, 35, 35, ..., 35, 35, 35],
        [35, 35, 35, .

Array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [None]:
# function ``` def transform_image(image,weights):
#     """
#     first entries in in weights are :
#     0-3 first rotation vector   
#     3-6 translation vector   
#     so we interpret the weights as input for transormations rotation;translation;rotation
#     """    
#     r = Rotation.from_rotvec(weights[0:3])
#     image=r.apply(image)
#     image=jax.image.scale_and_translate(image, image.shape,jnp.array([0,1,2]), jnp.array([1.0,1.0,1.0]), weights[3:6], "bicubic")
#     return image v_transform_image=jax.vmap(transform_image , in_axes = (0,None) ) ```` give error ``` ValueError: inconsistent size for core dimension 'm': 200 vs 3 on vectorized function with excluded=frozenset() and signature='(m,m),(m),()->(m)'``` correct function

In [12]:
r = Rotation.from_rotvec(jnp.array([2*jnp.pi, 0.0, 0.0])) ####### it is in radians
# imagee= jnp.zeros((3,3,3))

# Create consecutive array of integers
imagee = jnp.arange(27)
imagee = imagee.reshape((4, 4, 4))

print(imagee)
imagee=r.apply(imagee)
# imagee=jax.image.scale_and_translate(imagee, imagee.shape,jnp.array([0,1,2]), jnp.array([1.0,1.0,1.0]), jnp.array([0.1,0.0,0.0]), "bicubic")
imagee

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (27,) and (4, 4, 4)