diff --git a/petprep/cli/parser.py b/petprep/cli/parser.py
index 82b9046a..d0cb6d31 100644
--- a/petprep/cli/parser.py
+++ b/petprep/cli/parser.py
@@ -358,10 +358,14 @@ def _bids_filter(value, parser):
'6 degrees (rotation and translation) are used by default.',
)
g_conf.add_argument(
- '--pet2anat-robust',
- action='store_true',
- help='Use FreeSurfer mri_robust_register with an NMI cost function for'
- 'PET-to-T1w co-registration. This option is limited to 6 dof.',
+ '--pet2anat-method',
+ action='store',
+ default='mri_coreg',
+ choices=['mri_coreg', 'robust', 'ants'],
+ help='Method for PET-to-anatomical registration. '
+ '"mri_coreg" (default) uses FreeSurfer mri_coreg. '
+ '"robust" uses FreeSurfer mri_robust_register (6 DoF only). '
+ '"ants" uses ANTs rigid registration (6 DoF only).',
)
g_conf.add_argument(
'--force-bbr',
@@ -769,8 +773,11 @@ def parse_args(args=None, namespace=None):
parser = _build_parser()
opts = parser.parse_args(args, namespace)
- if getattr(opts, 'pet2anat_robust', False) and opts.pet2anat_dof != 6:
- parser.error('--pet2anat-robust requires --pet2anat-dof=6.')
+ # Validate DoF constraints for registration methods
+ if opts.pet2anat_method in ('robust', 'ants') and opts.pet2anat_dof != 6:
+ parser.error(
+ f'--pet2anat-method {opts.pet2anat_method} requires --pet2anat-dof=6.'
+ )
if opts.config_file:
skip = {} if opts.reports_only else {'execution': ('run_uuid',)}
diff --git a/petprep/config.py b/petprep/config.py
index b246720d..1a805389 100644
--- a/petprep/config.py
+++ b/petprep/config.py
@@ -556,8 +556,8 @@ class workflow(_Config):
"""Degrees of freedom of the PET-to-anatomical registration steps."""
pet2anat_init = 'auto'
"""Initial transform for PET-to-anatomical registration."""
- pet2anat_robust = False
- """Use ``mri_robust_register`` for PET-to-anatomical alignment."""
+ pet2anat_method: str = 'mri_coreg'
+ """PET-to-anatomical registration method (mri_coreg, robust, or ants)."""
petref: str = 'template'
"""Strategy for building the PET reference (``'template'``, ``'twa'`` or ``'sum'``)."""
cifti_output = None
diff --git a/petprep/interfaces/reports.py b/petprep/interfaces/reports.py
index cda4e8dc..060ccf94 100644
--- a/petprep/interfaces/reports.py
+++ b/petprep/interfaces/reports.py
@@ -229,6 +229,7 @@ class FunctionalSummaryInputSpec(TraitedSpec):
registration = traits.Enum(
'mri_coreg',
'mri_robust_register',
+ 'ants_registration',
'Precomputed',
mandatory=True,
desc='PET/anatomical registration method',
@@ -257,8 +258,12 @@ def _generate_segment(self):
reg = 'Precomputed affine transformation'
elif self.inputs.registration == 'mri_coreg':
reg = f'FreeSurfer mri_coreg - {dof} dof'
- else:
+ elif self.inputs.registration == 'ants_registration':
+ reg = 'ANTs rigid registration (6 DoF)'
+ elif self.inputs.registration == 'mri_robust_register':
reg = 'FreeSurfer mri_robust_register (NMI cost)'
+ else:
+ reg = f'Unknown registration method: {self.inputs.registration}'
reference_map = {
'template': 'Motion correction template',
diff --git a/petprep/interfaces/tests/test_reports.py b/petprep/interfaces/tests/test_reports.py
index 46d14c48..90b16389 100644
--- a/petprep/interfaces/tests/test_reports.py
+++ b/petprep/interfaces/tests/test_reports.py
@@ -76,7 +76,7 @@ def test_subject_summary_handles_missing_task(tmp_path):
@pytest.mark.parametrize(
'registration',
- ['mri_coreg', 'mri_robust_register'],
+ ['mri_coreg', 'mri_robust_register', 'ants_registration'],
)
def test_functional_summary_with_metadata(registration):
from ..reports import FunctionalSummary
diff --git a/petprep/workflows/pet/fit.py b/petprep/workflows/pet/fit.py
index 6994874d..5611c519 100644
--- a/petprep/workflows/pet/fit.py
+++ b/petprep/workflows/pet/fit.py
@@ -372,9 +372,11 @@ def init_pet_fit_wf(
registration_method = 'Precomputed'
if not petref2anat_xform:
- registration_method = (
- 'mri_robust_register' if config.workflow.pet2anat_robust else 'mri_coreg'
- )
+ registration_method = {
+ 'mri_coreg': 'mri_coreg',
+ 'robust': 'mri_robust_register',
+ 'ants': 'ants_registration',
+ }[config.workflow.pet2anat_method]
if hmc_disabled:
config.execution.work_dir.mkdir(parents=True, exist_ok=True)
petref = petref or reference_function(pet_file, **reference_kwargs)
@@ -528,7 +530,7 @@ def init_pet_fit_wf(
pet2anat_dof=config.workflow.pet2anat_dof,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb['resampled'],
- use_robust_register=config.workflow.pet2anat_robust,
+ pet2anat_method=config.workflow.pet2anat_method,
sloppy=config.execution.sloppy,
)
diff --git a/petprep/workflows/pet/registration.py b/petprep/workflows/pet/registration.py
index 370da756..6dd49935 100644
--- a/petprep/workflows/pet/registration.py
+++ b/petprep/workflows/pet/registration.py
@@ -36,12 +36,17 @@
AffineDOF = ty.Literal[6, 9, 12]
+def _get_first(in_list):
+ """Extract first element from a list (for ANTs transform output)."""
+ return in_list[0]
+
+
def init_pet_reg_wf(
*,
pet2anat_dof: AffineDOF,
mem_gb: float,
omp_nthreads: int,
- use_robust_register: bool = False,
+ pet2anat_method: str = 'flirt',
name: str = 'pet_reg_wf',
sloppy: bool = False,
):
@@ -70,10 +75,10 @@ def init_pet_reg_wf(
Size of PET file in GB
omp_nthreads : :obj:`int`
Maximum number of threads an individual process may use
- use_robust_register : :obj:`bool`
- Run FreeSurfer ``mri_robust_register`` with an NMI cost function for
- PET-to-anatomical alignment. Only rigid-body (6 dof) alignment is
- supported in this mode.
+ pet2anat_method : :obj:`str`
+ Method for PET-to-anatomical registration. Options are 'flirt' (default,
+ uses FSL FLIRT), 'robust' (uses FreeSurfer mri_robust_register with NMI,
+ 6 DoF only), or 'ants' (uses ANTs rigid registration, 6 DoF only).
name : :obj:`str`
Name of workflow (default: ``pet_reg_wf``)
@@ -95,6 +100,7 @@ def init_pet_reg_wf(
Affine transform from anatomical space to PET space (ITK format)
"""
+ from nipype.interfaces.ants import Registration
from nipype.interfaces.freesurfer import MRICoreg, RobustRegister
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.nibabel import ApplyMask
@@ -112,7 +118,43 @@ def init_pet_reg_wf(
)
mask_brain = pe.Node(ApplyMask(), name='mask_brain')
- if use_robust_register:
+
+ if pet2anat_method == 'ants':
+ coreg = pe.Node(
+ Registration(
+ dimension=3,
+ float=True,
+ output_transform_prefix='pet2anat_',
+ output_warped_image='pet2anat_Warped.nii.gz',
+ transforms=['Rigid'],
+ transform_parameters=[(0.1,)],
+ metric=['MI'],
+ metric_weight=[1],
+ radius_or_number_of_bins=[32],
+ sampling_strategy=['Regular'],
+ sampling_percentage=[0.25],
+ number_of_iterations=[[1000, 500, 250]],
+ convergence_threshold=[1e-6],
+ convergence_window_size=[10],
+ shrink_factors=[[4, 2, 1]],
+ smoothing_sigmas=[[2, 1, 0]],
+ sigma_units=['vox'],
+ use_histogram_matching=False,
+ initial_moving_transform_com=1,
+ winsorize_lower_quantile=0.005,
+ winsorize_upper_quantile=0.995,
+ ),
+ name='ants_registration',
+ n_procs=omp_nthreads,
+ mem_gb=mem_gb * 2,
+ )
+ coreg_target = 'fixed_image'
+ coreg_mask = 'fixed_image_masks'
+ coreg_moving = 'moving_image'
+ coreg_output = 'forward_transforms'
+ coreg_output_is_list = True
+
+ elif pet2anat_method == 'robust':
coreg = pe.Node(
RobustRegister(
auto_sens=False,
@@ -128,8 +170,11 @@ def init_pet_reg_wf(
mem_gb=5,
)
coreg_target = 'target_file'
+ coreg_moving = 'source_file'
coreg_output = 'out_reg_file'
- else:
+ coreg_output_is_list = False
+
+ else: # mri_coreg (default)
coreg = pe.Node(
MRICoreg(dof=pet2anat_dof, sep=[4], ftol=0.0001, linmintol=0.01),
name='mri_coreg',
@@ -137,30 +182,63 @@ def init_pet_reg_wf(
mem_gb=5,
)
coreg_target = 'reference_file'
+ coreg_moving = 'source_file'
coreg_output = 'out_lta_file'
+ coreg_output_is_list = False
+
convert_xfm = pe.Node(ConcatenateXFMs(inverse=True), name='convert_xfm')
- connections = [
- (
- inputnode,
- mask_brain,
- [
- ('anat_preproc', 'in_file'),
- ('anat_mask', 'in_mask'),
- ],
- ),
- (inputnode, coreg, [('ref_pet_brain', 'source_file')]),
- (mask_brain, coreg, [('out_file', coreg_target)]),
- (coreg, convert_xfm, [(coreg_output, 'in_xfms')]),
- (
- convert_xfm,
- outputnode,
- [
- ('out_xfm', 'itk_pet_to_t1'),
- ('out_inv', 'itk_t1_to_pet'),
- ],
- ),
- ]
+ # Build connections dynamically based on output type
+ if coreg_output_is_list:
+ # ANTs outputs a list of transforms; take the first (and only) one
+ # ANTs gets unmasked T1W + separate mask (not pre-masked image)
+ connections = [
+ (
+ inputnode,
+ mask_brain,
+ [
+ ('anat_preproc', 'in_file'),
+ ('anat_mask', 'in_mask'),
+ ],
+ ),
+ (inputnode, coreg, [
+ ('ref_pet_brain', coreg_moving),
+ ('anat_preproc', coreg_target),
+ ('anat_mask', coreg_mask),
+ ]),
+ (coreg, convert_xfm, [((coreg_output, _get_first), 'in_xfms')]),
+ (
+ convert_xfm,
+ outputnode,
+ [
+ ('out_xfm', 'itk_pet_to_t1'),
+ ('out_inv', 'itk_t1_to_pet'),
+ ],
+ ),
+ ]
+ else:
+ # FLIRT and Robust output single transform file
+ connections = [
+ (
+ inputnode,
+ mask_brain,
+ [
+ ('anat_preproc', 'in_file'),
+ ('anat_mask', 'in_mask'),
+ ],
+ ),
+ (inputnode, coreg, [('ref_pet_brain', coreg_moving)]),
+ (mask_brain, coreg, [('out_file', coreg_target)]),
+ (coreg, convert_xfm, [(coreg_output, 'in_xfms')]),
+ (
+ convert_xfm,
+ outputnode,
+ [
+ ('out_xfm', 'itk_pet_to_t1'),
+ ('out_inv', 'itk_t1_to_pet'),
+ ],
+ ),
+ ]
workflow.connect(connections) # fmt:skip
diff --git a/petprep/workflows/pet/tests/test_fit.py b/petprep/workflows/pet/tests/test_fit.py
index 81f2f9c0..af736e57 100644
--- a/petprep/workflows/pet/tests/test_fit.py
+++ b/petprep/workflows/pet/tests/test_fit.py
@@ -355,13 +355,35 @@ def test_pet_fit_robust_registration(bids_root: Path, tmp_path: Path):
)
with mock_config(bids_dir=bids_root):
- config.workflow.pet2anat_robust = True
+ config.workflow.pet2anat_method = 'robust'
config.workflow.pet2anat_dof = 6
wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)
node_names = wf.list_node_names()
assert 'pet_reg_wf.mri_robust_register' in node_names
assert 'pet_reg_wf.mri_coreg' not in node_names
+ assert 'pet_reg_wf.ants_registration' not in node_names
+
+
+def test_init_pet_fit_wf_ants_registration(bids_root: Path, tmp_path: Path):
+ """Test PET fit workflow with ANTs registration."""
+ pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
+ img = nb.Nifti1Image(np.zeros((2, 2, 2, 1)), np.eye(4))
+ for path in pet_series:
+ img.to_filename(path)
+ Path(path).with_suffix('').with_suffix('.json').write_text(
+ '{"FrameTimesStart": [0], "FrameDuration": [1]}'
+ )
+
+ with mock_config(bids_dir=bids_root):
+ config.workflow.pet2anat_method = 'ants'
+ config.workflow.pet2anat_dof = 6
+ wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)
+
+ node_names = wf.list_node_names()
+ assert 'pet_reg_wf.ants_registration' in node_names
+ assert 'pet_reg_wf.mri_coreg' not in node_names
+ assert 'pet_reg_wf.mri_robust_register' not in node_names
def test_pet_fit_requires_both_derivatives(bids_root: Path, tmp_path: Path):