Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions petprep/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,14 @@ def _bids_filter(value, parser):
'6 degrees (rotation and translation) are used by default.',
)
g_conf.add_argument(
'--pet2anat-robust',
action='store_true',
help='Use FreeSurfer mri_robust_register with an NMI cost function for'
'PET-to-T1w co-registration. This option is limited to 6 dof.',
'--pet2anat-method',
action='store',
default='mri_coreg',
choices=['mri_coreg', 'robust', 'ants'],
help='Method for PET-to-anatomical registration. '
'"mri_coreg" (default) uses FreeSurfer mri_coreg. '
'"robust" uses FreeSurfer mri_robust_register (6 DoF only). '
'"ants" uses ANTs rigid registration (6 DoF only).',
)
g_conf.add_argument(
'--force-bbr',
Expand Down Expand Up @@ -769,8 +773,11 @@ def parse_args(args=None, namespace=None):
parser = _build_parser()
opts = parser.parse_args(args, namespace)

if getattr(opts, 'pet2anat_robust', False) and opts.pet2anat_dof != 6:
parser.error('--pet2anat-robust requires --pet2anat-dof=6.')
# Validate DoF constraints for registration methods
if opts.pet2anat_method in ('robust', 'ants') and opts.pet2anat_dof != 6:
parser.error(
f'--pet2anat-method {opts.pet2anat_method} requires --pet2anat-dof=6.'
)

if opts.config_file:
skip = {} if opts.reports_only else {'execution': ('run_uuid',)}
Expand Down
4 changes: 2 additions & 2 deletions petprep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,8 @@ class workflow(_Config):
"""Degrees of freedom of the PET-to-anatomical registration steps."""
pet2anat_init = 'auto'
"""Initial transform for PET-to-anatomical registration."""
pet2anat_robust = False
"""Use ``mri_robust_register`` for PET-to-anatomical alignment."""
pet2anat_method: str = 'mri_coreg'
"""PET-to-anatomical registration method (mri_coreg, robust, or ants)."""
petref: str = 'template'
"""Strategy for building the PET reference (``'template'``, ``'twa'`` or ``'sum'``)."""
cifti_output = None
Expand Down
7 changes: 6 additions & 1 deletion petprep/interfaces/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class FunctionalSummaryInputSpec(TraitedSpec):
registration = traits.Enum(
'mri_coreg',
'mri_robust_register',
'ants_registration',
'Precomputed',
mandatory=True,
desc='PET/anatomical registration method',
Expand Down Expand Up @@ -257,8 +258,12 @@ def _generate_segment(self):
reg = 'Precomputed affine transformation'
elif self.inputs.registration == 'mri_coreg':
reg = f'FreeSurfer <code>mri_coreg</code> - {dof} dof'
else:
elif self.inputs.registration == 'ants_registration':
reg = 'ANTs rigid registration (6 DoF)'
elif self.inputs.registration == 'mri_robust_register':
reg = 'FreeSurfer <code>mri_robust_register</code> (NMI cost)'
else:
reg = f'Unknown registration method: {self.inputs.registration}'

reference_map = {
'template': 'Motion correction template',
Expand Down
2 changes: 1 addition & 1 deletion petprep/interfaces/tests/test_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_subject_summary_handles_missing_task(tmp_path):

@pytest.mark.parametrize(
'registration',
['mri_coreg', 'mri_robust_register'],
['mri_coreg', 'mri_robust_register', 'ants_registration'],
)
def test_functional_summary_with_metadata(registration):
from ..reports import FunctionalSummary
Expand Down
10 changes: 6 additions & 4 deletions petprep/workflows/pet/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,11 @@ def init_pet_fit_wf(

registration_method = 'Precomputed'
if not petref2anat_xform:
registration_method = (
'mri_robust_register' if config.workflow.pet2anat_robust else 'mri_coreg'
)
registration_method = {
'mri_coreg': 'mri_coreg',
'robust': 'mri_robust_register',
'ants': 'ants_registration',
}[config.workflow.pet2anat_method]
if hmc_disabled:
config.execution.work_dir.mkdir(parents=True, exist_ok=True)
petref = petref or reference_function(pet_file, **reference_kwargs)
Expand Down Expand Up @@ -528,7 +530,7 @@ def init_pet_fit_wf(
pet2anat_dof=config.workflow.pet2anat_dof,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb['resampled'],
use_robust_register=config.workflow.pet2anat_robust,
pet2anat_method=config.workflow.pet2anat_method,
sloppy=config.execution.sloppy,
)

Expand Down
134 changes: 106 additions & 28 deletions petprep/workflows/pet/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,17 @@
AffineDOF = ty.Literal[6, 9, 12]


def _get_first(in_list):
"""Extract first element from a list (for ANTs transform output)."""
return in_list[0]


def init_pet_reg_wf(
*,
pet2anat_dof: AffineDOF,
mem_gb: float,
omp_nthreads: int,
use_robust_register: bool = False,
pet2anat_method: str = 'flirt',
name: str = 'pet_reg_wf',
sloppy: bool = False,
):
Expand Down Expand Up @@ -70,10 +75,10 @@ def init_pet_reg_wf(
Size of PET file in GB
omp_nthreads : :obj:`int`
Maximum number of threads an individual process may use
use_robust_register : :obj:`bool`
Run FreeSurfer ``mri_robust_register`` with an NMI cost function for
PET-to-anatomical alignment. Only rigid-body (6 dof) alignment is
supported in this mode.
pet2anat_method : :obj:`str`
Method for PET-to-anatomical registration. Options are 'flirt' (default,
uses FSL FLIRT), 'robust' (uses FreeSurfer mri_robust_register with NMI,
6 DoF only), or 'ants' (uses ANTs rigid registration, 6 DoF only).
name : :obj:`str`
Name of workflow (default: ``pet_reg_wf``)

Expand All @@ -95,6 +100,7 @@ def init_pet_reg_wf(
Affine transform from anatomical space to PET space (ITK format)

"""
from nipype.interfaces.ants import Registration
from nipype.interfaces.freesurfer import MRICoreg, RobustRegister
from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.nibabel import ApplyMask
Expand All @@ -112,7 +118,43 @@ def init_pet_reg_wf(
)

mask_brain = pe.Node(ApplyMask(), name='mask_brain')
if use_robust_register:

if pet2anat_method == 'ants':
coreg = pe.Node(
Registration(
dimension=3,
float=True,
output_transform_prefix='pet2anat_',
output_warped_image='pet2anat_Warped.nii.gz',
transforms=['Rigid'],
transform_parameters=[(0.1,)],
metric=['MI'],
metric_weight=[1],
radius_or_number_of_bins=[32],
sampling_strategy=['Regular'],
sampling_percentage=[0.25],
number_of_iterations=[[1000, 500, 250]],
convergence_threshold=[1e-6],
convergence_window_size=[10],
shrink_factors=[[4, 2, 1]],
smoothing_sigmas=[[2, 1, 0]],
sigma_units=['vox'],
use_histogram_matching=False,
initial_moving_transform_com=1,
winsorize_lower_quantile=0.005,
winsorize_upper_quantile=0.995,
),
name='ants_registration',
n_procs=omp_nthreads,
mem_gb=mem_gb * 2,
)
coreg_target = 'fixed_image'
coreg_mask = 'fixed_image_masks'
coreg_moving = 'moving_image'
coreg_output = 'forward_transforms'
coreg_output_is_list = True

elif pet2anat_method == 'robust':
coreg = pe.Node(
RobustRegister(
auto_sens=False,
Expand All @@ -128,39 +170,75 @@ def init_pet_reg_wf(
mem_gb=5,
)
coreg_target = 'target_file'
coreg_moving = 'source_file'
coreg_output = 'out_reg_file'
else:
coreg_output_is_list = False

else: # mri_coreg (default)
coreg = pe.Node(
MRICoreg(dof=pet2anat_dof, sep=[4], ftol=0.0001, linmintol=0.01),
name='mri_coreg',
n_procs=omp_nthreads,
mem_gb=5,
)
coreg_target = 'reference_file'
coreg_moving = 'source_file'
coreg_output = 'out_lta_file'
coreg_output_is_list = False

convert_xfm = pe.Node(ConcatenateXFMs(inverse=True), name='convert_xfm')

connections = [
(
inputnode,
mask_brain,
[
('anat_preproc', 'in_file'),
('anat_mask', 'in_mask'),
],
),
(inputnode, coreg, [('ref_pet_brain', 'source_file')]),
(mask_brain, coreg, [('out_file', coreg_target)]),
(coreg, convert_xfm, [(coreg_output, 'in_xfms')]),
(
convert_xfm,
outputnode,
[
('out_xfm', 'itk_pet_to_t1'),
('out_inv', 'itk_t1_to_pet'),
],
),
]
# Build connections dynamically based on output type
if coreg_output_is_list:
# ANTs outputs a list of transforms; take the first (and only) one
# ANTs gets unmasked T1W + separate mask (not pre-masked image)
connections = [
(
inputnode,
mask_brain,
[
('anat_preproc', 'in_file'),
('anat_mask', 'in_mask'),
],
),
(inputnode, coreg, [
('ref_pet_brain', coreg_moving),
('anat_preproc', coreg_target),
('anat_mask', coreg_mask),
]),
(coreg, convert_xfm, [((coreg_output, _get_first), 'in_xfms')]),
(
convert_xfm,
outputnode,
[
('out_xfm', 'itk_pet_to_t1'),
('out_inv', 'itk_t1_to_pet'),
],
),
]
else:
# FLIRT and Robust output single transform file
connections = [
(
inputnode,
mask_brain,
[
('anat_preproc', 'in_file'),
('anat_mask', 'in_mask'),
],
),
(inputnode, coreg, [('ref_pet_brain', coreg_moving)]),
(mask_brain, coreg, [('out_file', coreg_target)]),
(coreg, convert_xfm, [(coreg_output, 'in_xfms')]),
(
convert_xfm,
outputnode,
[
('out_xfm', 'itk_pet_to_t1'),
('out_inv', 'itk_t1_to_pet'),
],
),
]

workflow.connect(connections) # fmt:skip

Expand Down
24 changes: 23 additions & 1 deletion petprep/workflows/pet/tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,35 @@ def test_pet_fit_robust_registration(bids_root: Path, tmp_path: Path):
)

with mock_config(bids_dir=bids_root):
config.workflow.pet2anat_robust = True
config.workflow.pet2anat_method = 'robust'
config.workflow.pet2anat_dof = 6
wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)

node_names = wf.list_node_names()
assert 'pet_reg_wf.mri_robust_register' in node_names
assert 'pet_reg_wf.mri_coreg' not in node_names
assert 'pet_reg_wf.ants_registration' not in node_names


def test_init_pet_fit_wf_ants_registration(bids_root: Path, tmp_path: Path):
"""Test PET fit workflow with ANTs registration."""
pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
img = nb.Nifti1Image(np.zeros((2, 2, 2, 1)), np.eye(4))
for path in pet_series:
img.to_filename(path)
Path(path).with_suffix('').with_suffix('.json').write_text(
'{"FrameTimesStart": [0], "FrameDuration": [1]}'
)

with mock_config(bids_dir=bids_root):
config.workflow.pet2anat_method = 'ants'
config.workflow.pet2anat_dof = 6
wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)

node_names = wf.list_node_names()
assert 'pet_reg_wf.ants_registration' in node_names
assert 'pet_reg_wf.mri_coreg' not in node_names
assert 'pet_reg_wf.mri_robust_register' not in node_names


def test_pet_fit_requires_both_derivatives(bids_root: Path, tmp_path: Path):
Expand Down