# Register Skullstripped T1 and T2 MRIs to Atlas

In [None]:
#Register to atlas image (MNI305) using FLIRT tool from FSL
import nibabel as nib
from subprocess import PIPE, run
import traceback
import os
import glob
import multiprocessing


In [None]:
#Variables to change

#root directory
rt_dir='../infantBrainAge/'

#directory in root where skullstripped nifti image files to process are located
nifti_dir='nifti_all_nih'

#Atlas image location
ref_im='../infantBrainAge/resources/mni305/average305_t1_tal_lin_skullstripped.nii' 

#T1 match pattern (using glob)
T1_pattern = 'T1__*_skullstripped.nii.gz'

#T2 match pattern
T2_pattern = 'T2__*_skullstripped.nii.gz'

#Output file name and/or full path. This file will be used for training
sv_file_name = 'T1_T2_registered.csv'

In [None]:
def register_im(in_file):
    
    
    output=in_file
    if os.path.exists(in_file):
        try:

            # reference Atlas image file (must be skullstripped, in 'nifti form')

            command = ['flirt', '-in', in_file, '-ref', ref_im, '-out', in_file.replace('.nii.gz','_registered.nii.gz'),
                      '-omat', in_file.replace('.nii.gz','_flirt.mat'), '-dof', '12']
            result = run(command, stdout=PIPE, stderr=PIPE, universal_newlines=True)
            return output, str(result.stdout), in_file.replace('.nii.gz','_registered.nii.gz')

        except:
            return output, traceback.format_exc(), None

    else:
        output+='\nno files found\n\n'
      

In [None]:
# Processing T1 files
all_files_to_register = glob.glob(os.path.join(rt_dir, nifti_dir, T1_pattern))
print(len(all_files_to_register))
pool = multiprocessing.Pool(28)
outputs_t1, fxn_outputs_t1, all_registered_ims_t1 = zip(*pool.map(register_im, all_files_to_register))

# Processing T2 files
all_files_to_register = glob.glob(os.path.join(rt_dir, nifti_dir, T2_pattern))
print(len(all_files_to_register))
pool = multiprocessing.Pool(28)
outputs_t2, fxn_outputs_t2, all_registered_ims_t2 = zip(*pool.map(register_im, all_files_to_register))

In [None]:
# Saving outputs for use during training
registered_df = pd.DataFrame({'acc_nums':[x.split('/')[-2] for x in all_registered_ims_t1], 
                              'T1_registered_images':all_registered_ims_t1, 
                              'T2_registered_images':all_registered_ims_t2})
registered_df.to_csv(sv_file_name, index=False)