diff --git a/petprep/cli/parser.py b/petprep/cli/parser.py index 82b9046a..d0cb6d31 100644 --- a/petprep/cli/parser.py +++ b/petprep/cli/parser.py @@ -358,10 +358,14 @@ def _bids_filter(value, parser): '6 degrees (rotation and translation) are used by default.', ) g_conf.add_argument( - '--pet2anat-robust', - action='store_true', - help='Use FreeSurfer mri_robust_register with an NMI cost function for' - 'PET-to-T1w co-registration. This option is limited to 6 dof.', + '--pet2anat-method', + action='store', + default='mri_coreg', + choices=['mri_coreg', 'robust', 'ants'], + help='Method for PET-to-anatomical registration. ' + '"mri_coreg" (default) uses FreeSurfer mri_coreg. ' + '"robust" uses FreeSurfer mri_robust_register (6 DoF only). ' + '"ants" uses ANTs rigid registration (6 DoF only).', ) g_conf.add_argument( '--force-bbr', @@ -769,8 +773,11 @@ def parse_args(args=None, namespace=None): parser = _build_parser() opts = parser.parse_args(args, namespace) - if getattr(opts, 'pet2anat_robust', False) and opts.pet2anat_dof != 6: - parser.error('--pet2anat-robust requires --pet2anat-dof=6.') + # Validate DoF constraints for registration methods + if opts.pet2anat_method in ('robust', 'ants') and opts.pet2anat_dof != 6: + parser.error( + f'--pet2anat-method {opts.pet2anat_method} requires --pet2anat-dof=6.' + ) if opts.config_file: skip = {} if opts.reports_only else {'execution': ('run_uuid',)} diff --git a/petprep/config.py b/petprep/config.py index b246720d..1a805389 100644 --- a/petprep/config.py +++ b/petprep/config.py @@ -556,8 +556,8 @@ class workflow(_Config): """Degrees of freedom of the PET-to-anatomical registration steps.""" pet2anat_init = 'auto' """Initial transform for PET-to-anatomical registration.""" - pet2anat_robust = False - """Use ``mri_robust_register`` for PET-to-anatomical alignment.""" + pet2anat_method: str = 'mri_coreg' + """PET-to-anatomical registration method (mri_coreg, robust, or ants).""" petref: str = 'template' """Strategy for building the PET reference (``'template'``, ``'twa'`` or ``'sum'``).""" cifti_output = None diff --git a/petprep/interfaces/reports.py b/petprep/interfaces/reports.py index cda4e8dc..060ccf94 100644 --- a/petprep/interfaces/reports.py +++ b/petprep/interfaces/reports.py @@ -229,6 +229,7 @@ class FunctionalSummaryInputSpec(TraitedSpec): registration = traits.Enum( 'mri_coreg', 'mri_robust_register', + 'ants_registration', 'Precomputed', mandatory=True, desc='PET/anatomical registration method', @@ -257,8 +258,12 @@ def _generate_segment(self): reg = 'Precomputed affine transformation' elif self.inputs.registration == 'mri_coreg': reg = f'FreeSurfer mri_coreg - {dof} dof' - else: + elif self.inputs.registration == 'ants_registration': + reg = 'ANTs rigid registration (6 DoF)' + elif self.inputs.registration == 'mri_robust_register': reg = 'FreeSurfer mri_robust_register (NMI cost)' + else: + reg = f'Unknown registration method: {self.inputs.registration}' reference_map = { 'template': 'Motion correction template', diff --git a/petprep/interfaces/tests/test_reports.py b/petprep/interfaces/tests/test_reports.py index 46d14c48..90b16389 100644 --- a/petprep/interfaces/tests/test_reports.py +++ b/petprep/interfaces/tests/test_reports.py @@ -76,7 +76,7 @@ def test_subject_summary_handles_missing_task(tmp_path): @pytest.mark.parametrize( 'registration', - ['mri_coreg', 'mri_robust_register'], + ['mri_coreg', 'mri_robust_register', 'ants_registration'], ) def test_functional_summary_with_metadata(registration): from ..reports import FunctionalSummary diff --git a/petprep/workflows/pet/fit.py b/petprep/workflows/pet/fit.py index 6994874d..5611c519 100644 --- a/petprep/workflows/pet/fit.py +++ b/petprep/workflows/pet/fit.py @@ -372,9 +372,11 @@ def init_pet_fit_wf( registration_method = 'Precomputed' if not petref2anat_xform: - registration_method = ( - 'mri_robust_register' if config.workflow.pet2anat_robust else 'mri_coreg' - ) + registration_method = { + 'mri_coreg': 'mri_coreg', + 'robust': 'mri_robust_register', + 'ants': 'ants_registration', + }[config.workflow.pet2anat_method] if hmc_disabled: config.execution.work_dir.mkdir(parents=True, exist_ok=True) petref = petref or reference_function(pet_file, **reference_kwargs) @@ -528,7 +530,7 @@ def init_pet_fit_wf( pet2anat_dof=config.workflow.pet2anat_dof, omp_nthreads=omp_nthreads, mem_gb=mem_gb['resampled'], - use_robust_register=config.workflow.pet2anat_robust, + pet2anat_method=config.workflow.pet2anat_method, sloppy=config.execution.sloppy, ) diff --git a/petprep/workflows/pet/registration.py b/petprep/workflows/pet/registration.py index 370da756..6dd49935 100644 --- a/petprep/workflows/pet/registration.py +++ b/petprep/workflows/pet/registration.py @@ -36,12 +36,17 @@ AffineDOF = ty.Literal[6, 9, 12] +def _get_first(in_list): + """Extract first element from a list (for ANTs transform output).""" + return in_list[0] + + def init_pet_reg_wf( *, pet2anat_dof: AffineDOF, mem_gb: float, omp_nthreads: int, - use_robust_register: bool = False, + pet2anat_method: str = 'flirt', name: str = 'pet_reg_wf', sloppy: bool = False, ): @@ -70,10 +75,10 @@ def init_pet_reg_wf( Size of PET file in GB omp_nthreads : :obj:`int` Maximum number of threads an individual process may use - use_robust_register : :obj:`bool` - Run FreeSurfer ``mri_robust_register`` with an NMI cost function for - PET-to-anatomical alignment. Only rigid-body (6 dof) alignment is - supported in this mode. + pet2anat_method : :obj:`str` + Method for PET-to-anatomical registration. Options are 'flirt' (default, + uses FSL FLIRT), 'robust' (uses FreeSurfer mri_robust_register with NMI, + 6 DoF only), or 'ants' (uses ANTs rigid registration, 6 DoF only). name : :obj:`str` Name of workflow (default: ``pet_reg_wf``) @@ -95,6 +100,7 @@ def init_pet_reg_wf( Affine transform from anatomical space to PET space (ITK format) """ + from nipype.interfaces.ants import Registration from nipype.interfaces.freesurfer import MRICoreg, RobustRegister from niworkflows.engine.workflows import LiterateWorkflow as Workflow from niworkflows.interfaces.nibabel import ApplyMask @@ -112,7 +118,43 @@ def init_pet_reg_wf( ) mask_brain = pe.Node(ApplyMask(), name='mask_brain') - if use_robust_register: + + if pet2anat_method == 'ants': + coreg = pe.Node( + Registration( + dimension=3, + float=True, + output_transform_prefix='pet2anat_', + output_warped_image='pet2anat_Warped.nii.gz', + transforms=['Rigid'], + transform_parameters=[(0.1,)], + metric=['MI'], + metric_weight=[1], + radius_or_number_of_bins=[32], + sampling_strategy=['Regular'], + sampling_percentage=[0.25], + number_of_iterations=[[1000, 500, 250]], + convergence_threshold=[1e-6], + convergence_window_size=[10], + shrink_factors=[[4, 2, 1]], + smoothing_sigmas=[[2, 1, 0]], + sigma_units=['vox'], + use_histogram_matching=False, + initial_moving_transform_com=1, + winsorize_lower_quantile=0.005, + winsorize_upper_quantile=0.995, + ), + name='ants_registration', + n_procs=omp_nthreads, + mem_gb=mem_gb * 2, + ) + coreg_target = 'fixed_image' + coreg_mask = 'fixed_image_masks' + coreg_moving = 'moving_image' + coreg_output = 'forward_transforms' + coreg_output_is_list = True + + elif pet2anat_method == 'robust': coreg = pe.Node( RobustRegister( auto_sens=False, @@ -128,8 +170,11 @@ def init_pet_reg_wf( mem_gb=5, ) coreg_target = 'target_file' + coreg_moving = 'source_file' coreg_output = 'out_reg_file' - else: + coreg_output_is_list = False + + else: # mri_coreg (default) coreg = pe.Node( MRICoreg(dof=pet2anat_dof, sep=[4], ftol=0.0001, linmintol=0.01), name='mri_coreg', @@ -137,30 +182,63 @@ def init_pet_reg_wf( mem_gb=5, ) coreg_target = 'reference_file' + coreg_moving = 'source_file' coreg_output = 'out_lta_file' + coreg_output_is_list = False + convert_xfm = pe.Node(ConcatenateXFMs(inverse=True), name='convert_xfm') - connections = [ - ( - inputnode, - mask_brain, - [ - ('anat_preproc', 'in_file'), - ('anat_mask', 'in_mask'), - ], - ), - (inputnode, coreg, [('ref_pet_brain', 'source_file')]), - (mask_brain, coreg, [('out_file', coreg_target)]), - (coreg, convert_xfm, [(coreg_output, 'in_xfms')]), - ( - convert_xfm, - outputnode, - [ - ('out_xfm', 'itk_pet_to_t1'), - ('out_inv', 'itk_t1_to_pet'), - ], - ), - ] + # Build connections dynamically based on output type + if coreg_output_is_list: + # ANTs outputs a list of transforms; take the first (and only) one + # ANTs gets unmasked T1W + separate mask (not pre-masked image) + connections = [ + ( + inputnode, + mask_brain, + [ + ('anat_preproc', 'in_file'), + ('anat_mask', 'in_mask'), + ], + ), + (inputnode, coreg, [ + ('ref_pet_brain', coreg_moving), + ('anat_preproc', coreg_target), + ('anat_mask', coreg_mask), + ]), + (coreg, convert_xfm, [((coreg_output, _get_first), 'in_xfms')]), + ( + convert_xfm, + outputnode, + [ + ('out_xfm', 'itk_pet_to_t1'), + ('out_inv', 'itk_t1_to_pet'), + ], + ), + ] + else: + # FLIRT and Robust output single transform file + connections = [ + ( + inputnode, + mask_brain, + [ + ('anat_preproc', 'in_file'), + ('anat_mask', 'in_mask'), + ], + ), + (inputnode, coreg, [('ref_pet_brain', coreg_moving)]), + (mask_brain, coreg, [('out_file', coreg_target)]), + (coreg, convert_xfm, [(coreg_output, 'in_xfms')]), + ( + convert_xfm, + outputnode, + [ + ('out_xfm', 'itk_pet_to_t1'), + ('out_inv', 'itk_t1_to_pet'), + ], + ), + ] workflow.connect(connections) # fmt:skip diff --git a/petprep/workflows/pet/tests/test_fit.py b/petprep/workflows/pet/tests/test_fit.py index 81f2f9c0..af736e57 100644 --- a/petprep/workflows/pet/tests/test_fit.py +++ b/petprep/workflows/pet/tests/test_fit.py @@ -355,13 +355,35 @@ def test_pet_fit_robust_registration(bids_root: Path, tmp_path: Path): ) with mock_config(bids_dir=bids_root): - config.workflow.pet2anat_robust = True + config.workflow.pet2anat_method = 'robust' config.workflow.pet2anat_dof = 6 wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1) node_names = wf.list_node_names() assert 'pet_reg_wf.mri_robust_register' in node_names assert 'pet_reg_wf.mri_coreg' not in node_names + assert 'pet_reg_wf.ants_registration' not in node_names + + +def test_init_pet_fit_wf_ants_registration(bids_root: Path, tmp_path: Path): + """Test PET fit workflow with ANTs registration.""" + pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')] + img = nb.Nifti1Image(np.zeros((2, 2, 2, 1)), np.eye(4)) + for path in pet_series: + img.to_filename(path) + Path(path).with_suffix('').with_suffix('.json').write_text( + '{"FrameTimesStart": [0], "FrameDuration": [1]}' + ) + + with mock_config(bids_dir=bids_root): + config.workflow.pet2anat_method = 'ants' + config.workflow.pet2anat_dof = 6 + wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1) + + node_names = wf.list_node_names() + assert 'pet_reg_wf.ants_registration' in node_names + assert 'pet_reg_wf.mri_coreg' not in node_names + assert 'pet_reg_wf.mri_robust_register' not in node_names def test_pet_fit_requires_both_derivatives(bids_root: Path, tmp_path: Path):