## Function Definition

In [None]:
def dwi_to_b0(dwi_path,dwi_name,b_prefix,output_name,b_output_name):
    '''
    dwi_to_b0(brain_path=brain_path,dwi_name='dwi_topuped',b_prefix='dwi',output_name='dwi_to_b0',b_output_name)
    '''
    import os
    import numpy as np
    from tqdm import tqdm
    import nibabel as nib
    from dipy.io.image import load_nifti,save_nifti
    from dipy.io import read_bvals_bvecs
    from dipy.core.gradients import gradient_table
    from dipy.core.gradients import reorient_bvecs
    from dipy.align import register_series, center_of_mass,translation,rigid
    from dipy.align._public import affine,register_dwi_series,syn_registration,register_dwi_to_template
    data_prefix = dwi_name
    b_prefix = b_prefix
    dwi_file = os.path.join(dwi_path,data_prefix+'.nii') # 4D diffusion data file name
    bval = os.path.join(dwi_path,b_prefix+'.bval') # bval file name
    bvec = os.path.join(dwi_path,b_prefix+'.bvec') # bvec file name
    #------------
    bvals, bvecs = read_bvals_bvecs(bval,bvec)
    gtab = gradient_table(bvals, bvecs,atol=0.15,b0_threshold=5)
    img_data,img_affine,img = load_nifti(dwi_file,return_img=True)
    xformed,affines = register_dwi_series(img_data,gtab,affine=img_affine,b0_ref=0,
                                            pipeline= [center_of_mass, translation,rigid,affine])#,syn_registration
    affines1 = affines.swapaxes(1,2).swapaxes(0,1)[~gtab.b0s_mask]
    gtab_new = reorient_bvecs(gtab, affines1, atol=0.15)
    np.savetxt(os.path.join(dwi_path,b_output_name+'.bval'),gtab_new.bvals[np.newaxis,:],fmt='%1.5f')
    np.savetxt(os.path.join(dwi_path,b_output_name+'.bvec'),gtab_new.bvecs.T,fmt='%1.5f')
    # save_nifti_origin(input=xformed.get_data(),output_name=output_name,type_output=np.int16,
    #                   zooms=img.header.get_zooms()[:4])
    save_nifti(os.path.join(dwi_path,output_name), xformed.get_data(), img_affine,hdr=img.header)

In [None]:
def data_padding(static,moving_data):
    import numpy as np
    #padding to make registration working well
    pad_by1 = (moving_data.shape[0]-static.shape[0])//2
    pad_by2 = (moving_data.shape[1]-static.shape[1])//2
    if (pad_by1>=0) and (pad_by2>=0):
            moving_data = moving_data[pad_by1:-pad_by1,pad_by2:-pad_by2,:]
    elif (pad_by1>=0) and (pad_by2<0):
            moving_data = moving_data[pad_by1:-pad_by1,:,:]
            moving_data = np.pad(moving_data, [(0, 0), (-pad_by2, -pad_by2), (0, 0)],
                            mode='constant', constant_values=0)
    elif (pad_by1<0) and (pad_by2>=0):
            moving_data = moving_data[:,pad_by2:-pad_by2,:]
            moving_data = np.pad(moving_data, [(-pad_by1, -pad_by1),(0, 0), (0, 0)],
                            mode='constant', constant_values=0)
    elif (pad_by1<0) and (pad_by2<0):
            moving_data = np.pad(moving_data, [(-pad_by1, -pad_by1),(-pad_by2, -pad_by2), (0, 0)],
                            mode='constant', constant_values=0)
    return moving_data


In [None]:
def dwi_to_t2(dwi_path,t2w_path,transform_file_path,dwi_name,t2w_name,b_input_name,reg_aff_name,sym_name,output_name_base):
        '''
        dwi_to_t2(brain_path,input_dwi,input_t2,b_input_name,b_output_name,registered_out)
        dwi_to_t2_transform(dwi_path,t2w_path,transform_file_path,dwi_name,t2w_name,reg_aff_name,sym_name,output_name_base,transform_dir):
        dwi_to_t2(brain_path,input_dwi='dwi_b0_resample.nii',input_t2='T2_resample.nii',b_input_name='',b_output_name='',registered_out='')
        '''
        import os
        import numpy as np
        from tqdm import tqdm
        import nibabel as nib
        from dipy.segment.mask import applymask
        from dipy.align.imaffine import (transform_centers_of_mass,
                                        AffineMap,
                                        MutualInformationMetric)
        from dipy.io.image import load_nifti,save_nifti
        from dipy.align import affine_registration, register_dwi_to_template,read_mapping,write_mapping,resample
        from dipy.io.gradients import read_bvals_bvecs
        from dipy.core.gradients import gradient_table,reorient_bvecs
        from dipy.align.metrics import CCMetric
        from dipy.align.imwarp import SymmetricDiffeomorphicRegistration
        from dipy.viz import regtools
        dwi_file = os.path.join(dwi_path,dwi_name)
        moving_data,moving_affine,moving_img = load_nifti(dwi_file,return_img=True)
        
        moving_affine = np.eye(4)
        moving_grid2world = moving_affine
        
        t2_file = os.path.join(t2w_path,t2w_name)
        static_data,static_affine_original,static_img = load_nifti(t2_file,return_img=True)
        static = static_data
        static_affine = np.eye(4)
        static_grid2world = static_affine
        
        #padding to make registration working well
        moving_data = data_padding(static,moving_data)
        moving = moving_data
        #
        c_of_mass = transform_centers_of_mass(static, static_grid2world,
                                                moving_data, moving_grid2world)
        transformed = c_of_mass.transform(moving_data)
        print('Image after center of mass transform')
        regtools.overlay_slices(static, transformed, None, 2,
                                        "Static", "Transformed", "0_b0_to_t2w_center_of_mass_registration_test.png",dpi=500)
        #affine parameter
        nbins = 32
        sampling_prop = None
        metric = MutualInformationMetric(nbins, sampling_prop)
        level_iters = [10000, 1000, 500]
        sigmas = [3.0, 1.0, 0.0]
        factors = [4, 2, 1]
        pipeline = ["center_of_mass","translation", "rigid", "affine"]#["affine"]#
        # affine reg
        xformed_img, reg_affine = affine_registration(
                                        moving_data,
                                        static,
                                        moving_affine=moving_affine,
                                        static_affine=static_affine,
                                        nbins=32,
                                        metric='MI',
                                        pipeline=pipeline,
                                        level_iters=level_iters,
                                        sigmas=sigmas,
                                        factors=factors,
                                        # static_mask=static_mask,
                                        # moving_mask=moving_mask,
                                        )
        regtools.overlay_slices(static, xformed_img, None, 2, 'Static',
                                'Warped moving', '1_b0_to_t2w_aff_registration_test.png')
        # --------------------------------
        # reoriente b table
        bval = os.path.join(dwi_path,b_input_name+'.bval') # bval file name
        bvec = os.path.join(dwi_path,b_input_name+'.bvec') # bvec file name
        bvals, bvecs = read_bvals_bvecs(bval,bvec)
        gtab = gradient_table(bvals, bvecs,atol=0.15,b0_threshold=5)
        affines0 = np.zeros([len(bvals),4,4])#affines.swapaxes(1,2).swapaxes(0,1)[~gtab.b0s_mask]
        for i_affine in np.arange(len(bvals)):
            affines0[i_affine,:,:] = reg_affine+1-1
        affines1 = affines0[~gtab.b0s_mask]
        gtab_new = reorient_bvecs(gtab, affines1, atol=0.15)
        np.savetxt(output_name_base+'_reorient.bval',gtab_new.bvals[:,np.newaxis].T,fmt='%1.5f')
        np.savetxt(output_name_base+'_reorient.bvec',gtab_new.bvecs.T,fmt='%1.5f')
        #----------------------------------

        prealign = reg_affine
        np.savetxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'.txt'),prealign)#for b0 to t2w

        metric = CCMetric(3)
        level_iters = [50, 30, 10]#[10,10,5]#[80, 50, 25]#[3,2,1]#
        sdr = SymmetricDiffeomorphicRegistration(metric, level_iters)
        mapping = sdr.optimize(static, moving, static_affine, moving_affine, prealign)
        warped_moving = mapping.transform(moving)
        # save and display results
        np.savetxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'_mapping.txt'),mapping.prealign)#for b0 to t2w
        np.savetxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'_mapping_inv.txt'),mapping.prealign_inv)#for b0 to t2w
        write_mapping(mapping,os.path.join(transform_file_path,output_name_base+sym_name))
        save_nifti(os.path.join(dwi_path,output_name_base+'_aff_sym_forward_test.nii'),warped_moving,static_affine_original,static_img.header)
        regtools.overlay_slices(static, warped_moving, None, 2, 'Static',
                                'Warped moving', '2_b0_to_t2w_aff_sym_registration_test.png')

In [None]:
def dwi_to_t2_transform(dwi_path,t2w_path,transform_file_path,dwi_name,t2w_name,reg_aff_name,sym_name,output_name_base,transform_dir):
    '''
    transform_dir: eighter 'forward' or 'backward';
    '''
    import os
    import numpy as np
    from dipy.io.image import load_nifti,save_nifti
    from dipy.align.imaffine import (AffineMap,
                                    AffineRegistration)
    from dipy.align import read_mapping,resample
    import matplotlib.pylab as plt
    static_data,static_affine_original,static_img = load_nifti(os.path.join(t2w_path,t2w_name),
                                            return_img=True)
    static_affine = np.eye(4)
    static = np.squeeze(static_data[:,:,:])#*mask
    static_grid2world = static_affine

    save_nifti(os.path.join(t2w_path,'codomain_test.nii'),static,static_affine,static_img.header)
    codomain_img = os.path.join(t2w_path,'codomain_test.nii')

    # load moving data
    moving_data_original, moving_affine_original, moving_img = load_nifti(os.path.join(dwi_path,dwi_name), return_img=True)
    moving_affine = np.eye(4)
    moving_data = np.squeeze(moving_data_original)+1-1

    # padding
    moving_data = data_padding(static,moving_data)
 
    save_nifti(os.path.join(dwi_path,'domain_test.nii'),moving_data,moving_affine,moving_img.header)
    domain_img = os.path.join(dwi_path,'domain_test.nii')
    #mapping transform
    prealign = np.loadtxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'.txt'))
    affinemap = AffineMap(prealign, 
                        domain_grid_shape=moving_data.shape, 
                        domain_grid2world=moving_affine,
                        codomain_grid_shape=static,
                        codomain_grid2world=static_affine)
    affine_forward = affinemap.transform(moving_data, image_grid2world=moving_affine, sampling_grid_shape=static_data.shape,
                        sampling_grid2world=static_affine, resample_only=False)
    # affine_img = resample(moving_data, static, moving_affine=moving_affine, static_affine=static_affine,between_affine=prealign)
    save_nifti(os.path.join(dwi_path,output_name_base+'_aff_forward.nii'),affine_forward,static_affine_original,static_img.header)
    # save_nifti(os.path.join(dwi_path,output_name_base+'_aff_forward_resample.nii'),affine_img.get_fdata(),static_affine_original,static_img.header)
    # affine reverse transform
    affine_inv = affinemap.transform_inverse(static, image_grid2world=static_affine, sampling_grid_shape=moving_data.shape,
                        sampling_grid2world=None, resample_only=False)
    data_aff_back = data_padding(moving_data_original,affine_inv)
    save_nifti(os.path.join(dwi_path,output_name_base+'_aff_backward.nii'),data_aff_back,moving_affine_original,moving_img.header)
    #
    mapping_file = os.path.join(transform_file_path,output_name_base+sym_name)
    prealign_mapping = np.loadtxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'_mapping.txt'))
    mapping = read_mapping(mapping_file,domain_img,codomain_img,prealign_mapping)#prealign
    x_transformed = mapping.transform(moving_data)#mapping.transform(affine_img.get_fdata())
    save_nifti(os.path.join(dwi_path,output_name_base+'_aff_sym_forward.nii'),x_transformed,static_affine_original,static_img.header)
    #
    x_transformed_inv = mapping.transform_inverse(static,image_world2grid=static_affine)
    x_affine_inv = affinemap.transform_inverse(x_transformed_inv, image_grid2world=mapping.prealign, sampling_grid_shape=moving_data.shape,
                        sampling_grid2world=None, resample_only=False)
    data_aff_sym_back = data_padding(moving_data_original,x_affine_inv)
    save_nifti(os.path.join(dwi_path,output_name_base+'_aff_sym_backward.nii'),data_aff_sym_back,moving_affine_original,moving_img.header)

In [None]:
def dwi_to_t2_transform_atlas(dwi_path,transform_file_path,atlas_path,dwi_name,reg_aff_name,sym_name,atlas_name,output_name_base):
    '''
    transform_dir: eighter 'forward' or 'backward';
    '''
    import os
    import numpy as np
    from dipy.io.image import load_nifti,save_nifti
    from dipy.align.imaffine import (AffineMap,
                                    AffineRegistration)
    from dipy.align import read_mapping,resample
    import matplotlib.pylab as plt
    atlas_data,atlas_affine = load_nifti(os.path.join(atlas_path,atlas_name),return_img=False)

    static_data,static_affine_original,static_img = load_nifti(os.path.join(atlas_path,atlas_name),
                                            return_img=True)
    static_affine = np.eye(4)
    static = np.squeeze(static_data[:,:,:])#*mask
    static_grid2world = static_affine

    save_nifti(os.path.join(atlas_path,'codomain_test.nii'),static,static_affine,static_img.header)
    codomain_img = os.path.join(atlas_path,'codomain_test.nii')

    # load moving data
    moving_data_original, moving_affine_original, moving_img = load_nifti(os.path.join(dwi_path,dwi_name), return_img=True)
    moving_affine = np.eye(4)
    moving_data = np.squeeze(moving_data_original)+1-1

    # padding
    moving_data = data_padding(static,moving_data)

    save_nifti(os.path.join(dwi_path,'domain_test.nii'),moving_data,moving_affine,moving_img.header)
    domain_img = os.path.join(dwi_path,'domain_test.nii')
    #mapping transform
    roi_max = int(np.max(atlas_data))
    print(roi_max)
    atlas_all = np.zeros(moving_data_original.shape[0:3]+(roi_max+1,))
    atlas_all_roi = np.zeros(moving_data_original.shape[0:3]+(roi_max+1,))
    for i_atlas in np.arange(1,roi_max+1):
        #
        atlas_all_roi[...,i_atlas] = i_atlas
        #
        i_roi_mask = np.where(atlas_data==i_atlas,1,0)
        prealign = np.loadtxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'.txt'))
        affinemap = AffineMap(prealign, domain_grid_shape=moving_data.shape, domain_grid2world=moving_affine,
                    codomain_grid_shape=static, codomain_grid2world=static_affine)
        # affine reverse transform
        affine_inv = affinemap.transform_inverse(i_roi_mask, image_grid2world=static_affine, sampling_grid_shape=moving_data.shape,
                            sampling_grid2world=None, resample_only=False)
        data_aff_back = data_padding(moving_data_original,affine_inv)
        # If needs to save single roi, comment out this
        # save_nifti(os.path.join(dwi_path,output_name_base+'_aff_backward_atlas_'+str(i_atlas)+'.nii'),data_aff_back,moving_affine_original,moving_img.header)
        #
        mapping_file = os.path.join(transform_file_path,output_name_base+sym_name)
        prealign_mapping = np.loadtxt(os.path.join(transform_file_path,output_name_base+reg_aff_name+'_mapping.txt'))
        mapping = read_mapping(mapping_file,domain_img,codomain_img,prealign_mapping)#prealign
        #
        x_transformed_inv = mapping.transform_inverse(i_roi_mask,image_world2grid=static_affine)
        x_affine_inv = affinemap.transform_inverse(x_transformed_inv, image_grid2world=mapping.prealign, sampling_grid_shape=moving_data.shape,
                            sampling_grid2world=None, resample_only=False)
        data_aff_sym_back = data_padding(moving_data_original,x_affine_inv)
        # If needs to save single roi, comment out this
        # save_nifti(os.path.join(dwi_path,output_name_base+'_aff_sym_backward_atlas_'+str(i_atlas)+'.nii'),data_aff_sym_back,moving_affine_original,moving_img.header)
        atlas_all[:,:,:,i_atlas] = data_aff_sym_back
    #--
    atlas_inv_final_ind = np.argmax(atlas_all,axis=-1)
    atlas_inv_final = np.take_along_axis(atlas_all_roi, np.expand_dims(atlas_inv_final_ind, axis=-1), axis=-1)
    save_nifti(os.path.join(dwi_path,output_name_base+'_aff_sym_backward_atlas.nii'),atlas_inv_final,moving_affine_original,moving_img.header)

In [None]:
def dwi_resample(brain_path,data_in,data_out,zooms_new):
    '''
    '''
    import os
    import numpy as np
    from dipy.io.image import load_nifti, save_nifti
    from dipy.segment.mask import applymask
    from dipy.io import read_bvals_bvecs
    from dipy.core.gradients import gradient_table
    from dipy.reconst.dti import TensorModel
    from dipy.reconst.dti import TensorFit
    from dipy.reconst.dti import color_fa
    from tqdm.notebook import tqdm
    from dipy.viz import regtools
    import dipy.denoise.noise_estimate as ne
    import nibabel as nib
    from dipy.align.reslice import reslice
    os.chdir(brain_path)        
    data_prefix = data_in
    dwi_file = os.path.join(brain_path,data_prefix) # 4D diffusion data file name
    img_data,img_affine,img = load_nifti(dwi_file,return_img=True)
    zooms_old = img.header.get_zooms()[:3]
    new_zooms = zooms_new[0:3]
    if len(img_data.shape) == 4:
        test,affine_test = reslice(img_data[:,:,:,0], img_affine, zooms_old, new_zooms)
        data_new = np.zeros(test.shape+(img_data.shape[-1],))
        for i_slice in np.arange(img_data.shape[-1]):
            data_new[:,:,:,i_slice], affine_new = reslice(img_data[:,:,:,i_slice], img_affine, zooms_old, new_zooms)
    elif len(img_data.shape) == 3:
        data_new,affine_test = reslice(img_data, img_affine, zooms_old, new_zooms)
    else:
        print('input file error!')
    # img_resample = nib.Nifti1Image(data_new.astype(np.int16), np.eye(4))
    # save_nifti_origin(input=img_resample.get_data(),output_name=data_out,type_output=np.int16,
    #                   zooms=zooms_new)
    img.header.set_zooms(zooms_new)
    save_nifti(data_out, data_new, img_affine,hdr=img.header)

In [None]:
def dwi_choose_vol(brain_path,data_in,vol_arr,data_out):
    '''
    '''
    import os
    import numpy as np
    from dipy.io.image import load_nifti, save_nifti
    from dipy.segment.mask import applymask
    from dipy.io import read_bvals_bvecs
    from dipy.core.gradients import gradient_table
    from dipy.reconst.dti import TensorModel
    from dipy.reconst.dti import TensorFit
    from dipy.reconst.dti import color_fa
    from tqdm.notebook import tqdm
    from dipy.viz import regtools
    import dipy.denoise.noise_estimate as ne
    import nibabel as nib
    from dipy.align.reslice import reslice
    os.chdir(brain_path)        
    data_prefix = data_in
    dwi_file = os.path.join(brain_path,data_prefix) # 4D diffusion data file name
    img_data,img_affine,img = load_nifti(dwi_file,return_img=True)
    save_nifti(data_out, img_data[:,:,:,vol_arr], img_affine,hdr=img.header)

In [None]:
def flip_t22_xy(brain_path,input_name,flip_axis,output_path,output_name):
    '''
    '''
    import nibabel as nib
    import numpy as np
    from dipy.io.image import load_nifti,save_nifti
    import os

    data_original, aff_original, img_original = load_nifti(os.path.join(brain_path,input_name),return_img=True)
    data_new = np.flip(data_original,axis=flip_axis)
    save_nifti(os.path.join(output_path,output_name), data_new, aff_original,hdr=img_original.header)

## Processing

### Flip T2w along x

In [None]:
brain_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
flip_t22_xy(brain_path,input_name='T2w.nii',flip_axis=0,
            output_path=brain_path,
            output_name='T2w_brain_flip_x.nii')

### dwi to b0

In [None]:
brain_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_to_b0(dwi_path=brain_path,dwi_name='dwi_topuped',b_prefix='dwi_mb0',
          output_name='dwi_to_b0',b_output_name='dwi_to_b0')

### Resample dwi to isotropic one

In [None]:
brain_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_resample(brain_path,data_in='dwi_to_b0.nii',data_out='dwi_iso.nii',zooms_new=[1.5,1.5,1.5,8])

### resampling b0 to t2w

In [None]:
import nibabel as nib
import os
brain_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_resample(brain_path,data_in='dwi_iso.nii',data_out='dwi_iso_resampled.nii',zooms_new=[1.0,1.0,1.0,8])
dwi_resample(brain_path,data_in='T2w_brain_flip_x.nii',data_out='T2_resampled.nii',zooms_new=[1.0,1.0,1.0])

In [None]:
brain_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
data_in = 'dwi_iso_resampled.nii'
vol_arr = [0]
data_out = 'dwi_iso_resampled_b0.nii'
dwi_choose_vol(brain_path,data_in,vol_arr,data_out)

### dwi to t2w

In [None]:

dwi_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
t2w_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
transform_file_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_name = 'dwi_iso_resampled_b0_brain.nii.gz'
t2w_name = 'T2w_resampled_brain.nii.gz'
b_input_name = 'dwi_to_b0'
reg_aff_name = '_aff'
sym_name = '_aff_sym.nii'
output_name_base = 'b0_to_t2w'
dwi_to_t2(dwi_path,t2w_path,transform_file_path,dwi_name,t2w_name,b_input_name,reg_aff_name,sym_name,output_name_base)

### transformation

In [None]:
dwi_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
t2w_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
transform_file_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_name = 'dwi_iso_resampled_b0_brain.nii.gz'
t2w_name = 'T2w_resampled_brain.nii.gz'
reg_aff_name = '_aff'
sym_name = '_aff_sym.nii'
output_name_base = 'b0_to_t2w'
transform_dir = 'forward'

dwi_to_t2_transform(dwi_path,t2w_path,transform_file_path,dwi_name,t2w_name,reg_aff_name,sym_name,output_name_base,transform_dir)

### Resampling back to dwi original resolution

### Create ROI atlas

### Transfer ROI back to subjects

In [None]:
dwi_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
atlas_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
transform_file_path = '/media/erjun/One_Touch/Newborns/sub-B027S1/ses-01/Registration_test'
dwi_name = 'dwi_iso_resampled_b0_brain.nii.gz'
atlas_name = 'T2w_brain_atlas.nii.gz'
reg_aff_name = '_aff'
sym_name = '_aff_sym.nii'
output_name_base = 'b0_to_t2w'
dwi_to_t2_transform_atlas(dwi_path,transform_file_path,atlas_path,dwi_name,reg_aff_name,sym_name,atlas_name,output_name_base)