In [1]:
import numpy as np
import napari
import SimpleITK as sitk
import nd2

In [2]:
import h5py

In [4]:
!python -V

# 3.7 and below works with ND2Reader, need to set to 3.7 then re-make the Elastix

Python 3.7.12


In [5]:
from nd2reader import ND2Reader

In [6]:
# About: Round 1 to Round 4 are images of 12 RNAs using HCR-FISH, 
# Round 5 to Round 7 are images of proteins using antibody staining. 
# We want to align rounds of images using the 405 channel (index 0), and register to Round 1.
# We will perform segmentation later. 

In [7]:
from alignment_modules import nd2ToVol

In [9]:
outputs = []

f_vol_fix = './Stress_granule/Round1.nd2'

for i in range(2, 8):
    f_vol_move = './Stress_granule/Round' + str(i) + '.nd2'
    f_vol_out = './Stress_granule/Round' + str(i) + '_warped.h5'
    
    m_transform_type = ['affine'] 
    m_channel_name = '405' # okay to be partial name
    m_resolution = [1.625,1.625, 4] # um: xyz. the image volume is in zyx-order
    
    # main function to run
    elastixImageFilter = sitk.ElastixImageFilter()

    # 1. set transformation parameters
    if len(m_transform_type) == 1:
        param_map = sitk.GetDefaultParameterMap(m_transform_type[0])
        param_map['NumberOfSamplesForExactGradient'] = ['100000']
        param_map['MaximumNumberOfIterations'] = ['10000']
        param_map['MaximumNumberOfSamplingAttempts'] = ['15']
        param_map['FinalBSplineInterpolationOrder'] = ['1']
        elastixImageFilter.SetParameterMap(param_map)
    else:
        parameterMapVector = sitk.VectorOfParameterMap()
        for trans in m_transform_type:
            parameterMapVector.append(sitk.GetDefaultParameterMap(trans))
        elastixImageFilter.SetParameterMap(parameterMapVector)

    # 2. load volume
    img_np = nd2ToVol(f_vol_fix, m_channel_name)
    print('vol-fix shape:', img_np.shape)
    img = sitk.GetImageFromArray(img_np)
    img.SetSpacing(m_resolution)
    elastixImageFilter.SetFixedImage(img)
    
    img_np = nd2ToVol(f_vol_move, m_channel_name)
    print('vol-move shape:', img_np.shape)
    img = sitk.GetImageFromArray(img_np)
    img.SetSpacing(m_resolution)
    elastixImageFilter.SetMovingImage(img)

    # 3. compute transformation
    elastixImageFilter.Execute()

    # 4. save output
    # save transformation param
    param_map = elastixImageFilter.GetTransformParameterMap()[0]
    sitk.WriteParameterFile(param_map, f_vol_out[:f_vol_out.rfind('.')] + '.txt')

    # save warped channels
    channel_names = ND2Reader(f_vol_move).metadata['channels']

    if len(channel_names) == 1:
        # directly save
        sitk.WriteImage(sitk.Cast(elastixImageFilter.GetResultImage(), sitk.sitkUInt16), f_vol_out)
    else:
        fid = h5py.File(f_vol_out, 'w')
        ds = fid.create_dataset('spacing', [3], compression="gzip", dtype=int)
        ds[:] = np.array(m_resolution).astype(int)
        # image type: float -> np.uint16
        img_out = sitk.GetArrayFromImage(elastixImageFilter.GetResultImage()).astype(np.uint16)
        ds = fid.create_dataset([x for x in channel_names if m_channel_name in x][0], img_out.shape, compression="gzip", dtype=img_out.dtype)
        ds[:] = img_out

        # warp other channels
        transformixImageFilter = sitk.TransformixImageFilter()
        transformixImageFilter.SetTransformParameterMap(param_map)
        for channel_name in channel_names:
            if m_channel_name not in channel_name:
                img_np = nd2ToVol(f_vol_move, channel_name)
                print('vol 2:', channel_name, img_np.shape)
                img = sitk.GetImageFromArray(img_np)
                img.SetSpacing(m_resolution)
                transformixImageFilter.SetMovingImage(img)
                transformixImageFilter.Execute()
                img_out = sitk.GetArrayFromImage(transformixImageFilter.GetResultImage()).astype(np.uint16)
                ds = fid.create_dataset(channel_name, img_out.shape, compression="gzip", dtype=img_out.dtype)
                ds[:] = img_out
        fid.close()
        
    outputs.append(f_vol_out)

Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol-fix shape: (60, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol-move shape: (74, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 640 SD (74, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 561 SD (74, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 488 SD (74, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol-fix shape: (60, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol-move shape: (67, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 640 SD (67, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 561 SD (67, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol 2: 488 SD (67, 2048, 2048)
Available channels: ['640 SD', '561 SD', '488 SD', '405 SD']
vol-fix shape

In [None]:
# Visualize results

# from alignment_modules import display_vol

# round_ = 2 
# assert round_ in [2, 3, 4, 5, 6, 7]

# f_vol_out = outputs[round_ - 2]
# viewer = napari.Viewer()
# channel_names = ND2Reader(f_vol_fix).metadata['channels']
# display_vol(f_vol_fix, f_vol_out, channel_names[0])