# TO CHANGE

In [None]:
# file names
D0 = '/n/boslfs02/LABS/lichtman_lab/donglai/ExM/Chi/'
f_vol_fix = D0 + '20211130_larva1_barcode3.nd2'
f_vol_move = D0 + '20211130_larva1_barcode5.nd2'
f_vol_out = D0 + '20211130_larva1_barcode5_warped.h5'

# ['rigid']: rigid transformation (rotation + translation)
# ['affine']: affine transformation (rigid + shear)
# ['affine', 'bspline']: non-rigid transformation (affine + b-spline)
m_transform_type = ['affine'] 
m_channel_name = '405' # okay to be partial name
m_resolution = [4,1.625,1.625] # um

# TO RUN

In [None]:
# library and utility functions
import numpy as np
import SimpleITK as sitk
from nd2reader import ND2Reader

def nd2ToVol(filename, channel_name='405', ratio=1):
    vol = ND2Reader(filename)
    channel_names = vol.metadata['channels']
    print('Availabel channels:', channel_names)
    channel_id = [x for x in range(len(channel_names)) if channel_name in channel_names[x]]
    assert len(channel_id) == 1
    channel_id = channel_id[0]

    out = np.zeros([len(vol), vol[0].shape[0] // ratio, vol[0].shape[1] // ratio], np.uint16)
    for z in range(len(vol)):
        out[z] = vol.get_frame_2D(c=channel_id, t=0, z=z, x=0, y=0, v=0)[::ratio, ::ratio]
    return out

In [None]:
# main function to run
elastixImageFilter = sitk.ElastixImageFilter()

# 1. set transformation parameters
if len(m_transform_type) == 1:
    elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap(m_transform_type[0]))
else:
    parameterMapVector = sitk.VectorOfParameterMap()
    for trans in m_transform_type:
        parameterMapVector.append(sitk.GetDefaultParameterMap(trans))
    elastixImageFilter.SetParameterMap(parameterMapVector)

# 2. load volume
img_np = nd2ToVol(f_vol_fix, channel_name)
print('vol-fix shape:', img_np.shape)
img = sitk.GetImageFromArray(img_np)
img.SetSpacing(m_resolution)
elastixImageFilter.SetFixedImage(img)

img_np = nd2ToVol(f_vol_move, channel_name)
print('vol-move shape:', img_np.shape)
img = sitk.GetImageFromArray(img_np)
img.SetSpacing(m_resolution)
elastixImageFilter.SetMovingImage(img)

# 3. compute transformation
elastixImageFilter.Execute()

# 4. save output
# image type: float -> np.uint16
sitk.WriteImage(sitk.Cast(elastixImageFilter.GetResultImage(), sitk.sitkUInt16), f_vol_out)
sitk.WriteParameterFile(elastixImageFilter.GetTransformParameterMap()[0], f_vol_out[:f_vol_out.rfind('.')] + '.tfm')