# TO CHANGE

In [None]:
# file names
D0 = '/n/boslfs02/LABS/lichtman_lab/donglai/ExM/Chi/'
f_vol_fix = D0 + '20211130_larva1_barcode3.nd2'
f_vol_move = D0 + '20211130_larva1_barcode5.nd2'
f_vol_out = D0 + '20211130_larva1_barcode5_warped.h5'

# ['rigid']: rigid transformation (rotation + translation)
# ['affine']: affine transformation (rigid + shear)
# ['affine', 'bspline']: non-rigid transformation (affine + b-spline)
m_transform_type = ['affine'] 
m_channel_name = '405' # okay to be partial name
m_resolution = [1.625,1.625,4] # um: xyz. the image volume is in zyx-order

# TO RUN

In [None]:
# library and utility functions
import numpy as np
from nd2reader import ND2Reader
import h5py

def nd2ToVol(filename, channel_name='405 SD', ratio=1):
    # volume in zyx order
    vol = ND2Reader(filename)
    channel_names = vol.metadata['channels']
    print('Availabel channels:', channel_names)
    channel_id = [x for x in range(len(channel_names)) if channel_name in channel_names[x]]
    assert len(channel_id) == 1
    channel_id = channel_id[0]

    out = np.zeros([len(vol), vol[0].shape[0] // ratio, vol[0].shape[1] // ratio], np.uint16)
    for z in range(len(vol)):
        out[z] = vol.get_frame_2D(c=channel_id, t=0, z=z, x=0, y=0, v=0)[::ratio, ::ratio]
    return out

In [None]:
import SimpleITK as sitk

# main function to run
elastixImageFilter = sitk.ElastixImageFilter()

# 1. set transformation parameters
if len(m_transform_type) == 1:
    param_map = sitk.GetDefaultParameterMap(m_transform_type[0])
    param_map['NumberOfSamplesForExactGradient'] = ['100000']
    param_map['MaximumNumberOfIterations'] = ['10000']
    param_map['MaximumNumberOfSamplingAttempts'] = ['15']
    param_map['FinalBSplineInterpolationOrder'] = ['1']
    elastixImageFilter.SetParameterMap(param_map)
else:
    parameterMapVector = sitk.VectorOfParameterMap()
    for trans in m_transform_type:
        parameterMapVector.append(sitk.GetDefaultParameterMap(trans))
    elastixImageFilter.SetParameterMap(parameterMapVector)

# 2. load volume
img_np = nd2ToVol(f_vol_fix, m_channel_name)
print('vol-fix shape:', img_np.shape)
img = sitk.GetImageFromArray(img_np)
img.SetSpacing(m_resolution)
elastixImageFilter.SetFixedImage(img)

img_np = nd2ToVol(f_vol_move, m_channel_name)
print('vol-move shape:', img_np.shape)
img = sitk.GetImageFromArray(img_np)
img.SetSpacing(m_resolution)
elastixImageFilter.SetMovingImage(img)

# 3. compute transformation
elastixImageFilter.Execute()

# 4. save output
# save transformation param
param_map = elastixImageFilter.GetTransformParameterMap()[0]
sitk.WriteParameterFile(param_map, f_vol_out[:f_vol_out.rfind('.')] + '.txt')

# save warped channels
channel_names = ND2Reader(f_vol_move).metadata['channels']

if len(channel_names) == 1:
    # directly save
    sitk.WriteImage(sitk.Cast(elastixImageFilter.GetResultImage(), sitk.sitkUInt16), f_vol_out)
else:
    fid = h5py.File(f_vol_out, 'w')
    ds = fid.create_dataset('spacing', [3], compression="gzip", dtype=int)
    ds[:] = np.array(m_resolution).astype(int)
    # image type: float -> np.uint16
    img_out = sitk.GetArrayFromImage(elastixImageFilter.GetResultImage()).astype(np.uint16)
    ds = fid.create_dataset([x for x in channel_names if m_channel_name in x][0], img_out.shape, compression="gzip", dtype=img_out.dtype)
    ds[:] = img_out

    # warp other channels
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetTransformParameterMap(param_map)
    for channel_name in channel_names:
        if m_channel_name not in channel_name:
            img_np = nd2ToVol(f_vol_move, channel_name)
            print('vol 2:', channel_name, img_np.shape)
            img = sitk.GetImageFromArray(img_np)
            img.SetSpacing(m_resolution)
            transformixImageFilter.SetMovingImage(img)
            transformixImageFilter.Execute()
            img_out = sitk.GetArrayFromImage(transformixImageFilter.GetResultImage()).astype(np.uint16)
            ds = fid.create_dataset(channel_name, img_out.shape, compression="gzip", dtype=img_out.dtype)
            ds[:] = img_out
    fid.close()

# TO VISUALIZE

In [None]:
# may need to load a different kernel
# if use a new kernel, need to re-run previous code cells except the alignment

# library and utility functions
import napari
import numpy as np
from nd2reader import ND2Reader
import h5py

viewer = napari.Viewer()

def display_vol(f_vol_fix, f_vol_out, channel_name, ratio = [1,1,1]):
    # ratio: display downsampled volume
    img_fix = nd2ToVol(f_vol_fix, channel_name)
    viewer.add_image(img_fix[::ratio[0], ::ratio[1], ::ratio[2]], \
                      name = 'fixed-'+channel_name, \
                      scale = m_resolution[::-1])

    img_warp = np.array(h5py.File(f_vol_out, 'r')[channel_name])
    viewer.add_image(img_warp[::ratio[0], ::ratio[1], ::ratio[2]], \
                      name = 'warped-'+channel_name, \
                      scale = m_resolution[::-1])



# napari visualization tips:
# top-left: for the two volumes, set colormaps to "red" and "green" respectively
# A good alignment result will lead to a yellow-ish image where green and red matches
# bottom-left: select "Toggle ndisplay" for 3D viewer
channel_names = ND2Reader(f_vol_fix).metadata['channels']
display_vol(f_vol_fix, f_vol_out, channel_names[0])