Skip to content

Commit

Permalink
Merge pull request #319 from effigies/enh/t2w-alignment
Browse files Browse the repository at this point in the history
ENH: Merge T2w images and coregister to T1w template
  • Loading branch information
effigies committed Jan 26, 2023
2 parents 1a07895 + b41116b commit 99b0cfd
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 54 deletions.
175 changes: 121 additions & 54 deletions smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from nipype.interfaces.ants import N4BiasFieldCorrection

from niworkflows.engine.workflows import LiterateWorkflow as Workflow
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
from niworkflows.interfaces.freesurfer import (
StructuralReference,
PatchedLTAConvert as LTAConvert,
Expand All @@ -62,6 +63,7 @@ def init_anat_preproc_wf(
hires,
longitudinal,
t1w,
t2w,
omp_nthreads,
output_dir,
skull_strip_mode,
Expand Down Expand Up @@ -98,6 +100,7 @@ def init_anat_preproc_wf(
hires=True,
longitudinal=False,
t1w=['t1w.nii.gz'],
t2w=[],
omp_nthreads=1,
output_dir='.',
skull_strip_mode='force',
Expand Down Expand Up @@ -228,7 +231,8 @@ def init_anat_preproc_wf(

outputnode = pe.Node(
niu.IdentityInterface(
fields=["template", "subjects_dir", "subject_id"] + get_outputnode_spec()
fields=["template", "subjects_dir", "subject_id", "t2w_preproc"]
+ get_outputnode_spec()
),
name="outputnode",
)
Expand Down Expand Up @@ -338,7 +342,10 @@ def init_anat_preproc_wf(

# 1. Anatomical reference generation - average input T1w images.
anat_template_wf = init_anat_template_wf(
longitudinal=longitudinal, omp_nthreads=omp_nthreads, num_t1w=num_t1w
longitudinal=longitudinal,
omp_nthreads=omp_nthreads,
num_files=num_t1w,
contrast="T1w",
)

anat_validate = pe.Node(
Expand Down Expand Up @@ -393,15 +400,15 @@ def _check_img(img):
# fmt:off
workflow.connect([
# Step 1.
(inputnode, anat_template_wf, [('t1w', 'inputnode.t1w')]),
(inputnode, anat_template_wf, [('t1w', 'inputnode.anat_files')]),
(anat_template_wf, anat_validate, [
('outputnode.t1w_ref', 'in_file')]),
('outputnode.anat_ref', 'in_file')]),
(anat_validate, brain_extraction_wf, [
('out_file', 'inputnode.in_files')]),
(brain_extraction_wf, outputnode, [
(('outputnode.bias_corrected', _pop), 't1w_preproc')]),
(anat_template_wf, outputnode, [
('outputnode.t1w_realign_xfm', 't1w_ref_xfms')]),
('outputnode.anat_realign_xfm', 't1w_ref_xfms')]),
(buffernode, outputnode, [('t1w_brain', 't1w_brain'),
('t1w_mask', 't1w_mask')]),
# Steps 2, 3 and 4
Expand Down Expand Up @@ -454,6 +461,7 @@ def _check_img(img):
bids_root=bids_root,
freesurfer=freesurfer,
num_t1w=num_t1w,
t2w=t2w,
output_dir=output_dir,
spaces=spaces,
)
Expand All @@ -462,7 +470,7 @@ def _check_img(img):
workflow.connect([
# Connect derivatives
(anat_template_wf, anat_derivatives_wf, [
('outputnode.t1w_valid_list', 'inputnode.source_files')]),
('outputnode.anat_valid_list', 'inputnode.source_files')]),
(anat_norm_wf, anat_derivatives_wf, [
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
Expand All @@ -474,6 +482,7 @@ def _check_img(img):
('t1w_mask', 'inputnode.t1w_mask'),
('t1w_dseg', 'inputnode.t1w_dseg'),
('t1w_tpms', 'inputnode.t1w_tpms'),
('t2w_preproc', 'inputnode.t2w_preproc'),
]),
])
# fmt:on
Expand Down Expand Up @@ -509,6 +518,7 @@ def _check_img(img):
(('outputnode.out_file', _pop), 't1w_brain'),
('outputnode.out_mask', 't1w_mask')]),
])
# fmt:on
return workflow

# check for older IsRunning files and remove accordingly
Expand All @@ -522,6 +532,56 @@ def _check_img(img):
name="surface_recon_wf", omp_nthreads=omp_nthreads, hires=hires
)
applyrefined = pe.Node(fsl.ApplyMask(), name="applyrefined")

if t2w:
t2w_template_wf = init_anat_template_wf(
longitudinal=longitudinal,
omp_nthreads=omp_nthreads,
num_files=len(t2w),
contrast="T2w",
name="t2w_template_wf",
)
bbreg = pe.Node(
fs.BBRegister(
contrast_type="t2",
init="coreg",
dof=6,
out_lta_file=True,
args="--gm-proj-abs 2 --wm-proj-abs 1",
),
name="bbreg",
)
coreg_xfms = pe.Node(niu.Merge(2), name="merge_xfms", run_without_submitting=True)
t2wtot1w_xfm = pe.Node(ConcatenateXFMs(), name="t2wtot1w_xfm", run_without_submitting=True)
t2w_resample = pe.Node(
ApplyTransforms(
dimension=3,
default_value=0,
float=True,
interpolation="LanczosWindowedSinc",
),
name="t2w_resample",
)
# fmt:off
workflow.connect([
(inputnode, t2w_template_wf, [('t2w', 'inputnode.anat_files')]),
(t2w_template_wf, bbreg, [('outputnode.anat_ref', 'source_file')]),
(surface_recon_wf, bbreg, [
('outputnode.subject_id', 'subject_id'),
('outputnode.subjects_dir', 'subjects_dir'),
]),
(bbreg, coreg_xfms, [('out_lta_file', 'in1')]),
(surface_recon_wf, coreg_xfms, [('outputnode.fsnative2t1w_xfm', 'in2')]),
(coreg_xfms, t2wtot1w_xfm, [('out', 'in_xfms')]),
(t2w_template_wf, t2w_resample, [('outputnode.anat_ref', 'input_image')]),
(brain_extraction_wf, t2w_resample, [
(('outputnode.bias_corrected', _pop), 'reference_image'),
]),
(t2wtot1w_xfm, t2w_resample, [('out_xfm', 'transforms')]),
(t2w_resample, outputnode, [('output_image', 't2w_preproc')]),
])
# fmt:on

# fmt:off
workflow.connect([
(inputnode, fs_isrunning, [
Expand Down Expand Up @@ -573,10 +633,15 @@ def _check_img(img):


def init_anat_template_wf(
*, longitudinal, omp_nthreads, num_t1w, name="anat_template_wf"
*,
longitudinal: bool,
omp_nthreads: int,
num_files: int,
contrast: str,
name: str = "anat_template_wf",
):
"""
Generate a canonically-oriented, structural average from all input T1w images.
Generate a canonically-oriented, structural average from all input images.
Workflow Graph
.. workflow::
Expand All @@ -585,7 +650,8 @@ def init_anat_template_wf(
from smriprep.workflows.anatomical import init_anat_template_wf
wf = init_anat_template_wf(
longitudinal=False, omp_nthreads=1, num_t1w=1)
longitudinal=False, omp_nthreads=1, num_files=1, contrast="T1w"
)
Parameters
----------
Expand All @@ -594,79 +660,80 @@ def init_anat_template_wf(
(may increase runtime)
omp_nthreads : :obj:`int`
Maximum number of threads an individual process may use
num_t1w : :obj:`int`
Number of T1w images
num_files : :obj:`int`
Number of images
contrast : :obj:`str`
Name of contrast, for reporting purposes, e.g., T1w, T2w, PDw
name : :obj:`str`, optional
Workflow name (default: anat_template_wf)
Inputs
------
t1w
List of T1-weighted structural images
anat_files
List of structural images
Outputs
-------
t1w_ref
Structural reference averaging input T1w images, defining the T1w space.
t1w_realign_xfm
List of affine transforms to realign input T1w images
anat_ref
Structural reference averaging input images
anat_valid_list
List of structural images accepted for combination
anat_realign_xfm
List of affine transforms to realign input images to final reference
out_report
Conformation report
"""
workflow = Workflow(name=name)

if num_t1w > 1:
workflow.__desc__ = """\
A T1w-reference map was computed after registration of
{num_t1w} T1w images (after INU-correction) using
if num_files > 1:
fs_ver = fs.Info().looseversion() or "<ver>"
workflow.__desc__ = f"""\
An anatomical {contrast}-reference map was computed after registration of
{num_files} {contrast} images (after INU-correction) using
`mri_robust_template` [FreeSurfer {fs_ver}, @fs_template].
""".format(
num_t1w=num_t1w, fs_ver=fs.Info().looseversion() or "<ver>"
)
"""

inputnode = pe.Node(niu.IdentityInterface(fields=["t1w"]), name="inputnode")
inputnode = pe.Node(niu.IdentityInterface(fields=["anat_files"]), name="inputnode")
outputnode = pe.Node(
niu.IdentityInterface(
fields=["t1w_ref", "t1w_valid_list", "t1w_realign_xfm", "out_report"]
fields=["anat_ref", "anat_valid_list", "anat_realign_xfm", "out_report"]
),
name="outputnode",
)

# 0. Reorient T1w image(s) to RAS and resample to common voxel space
t1w_ref_dimensions = pe.Node(TemplateDimensions(), name="t1w_ref_dimensions")
t1w_conform = pe.MapNode(Conform(), iterfield="in_file", name="t1w_conform")
anat_ref_dimensions = pe.Node(TemplateDimensions(), name="anat_ref_dimensions")
anat_conform = pe.MapNode(Conform(), iterfield="in_file", name="anat_conform")

# fmt:off
workflow.connect([
(inputnode, t1w_ref_dimensions, [('t1w', 't1w_list')]),
(t1w_ref_dimensions, t1w_conform, [
(inputnode, anat_ref_dimensions, [('anat_files', 't1w_list')]),
(anat_ref_dimensions, anat_conform, [
('t1w_valid_list', 'in_file'),
('target_zooms', 'target_zooms'),
('target_shape', 'target_shape')]),
(t1w_ref_dimensions, outputnode, [('out_report', 'out_report'),
('t1w_valid_list', 't1w_valid_list')]),
(anat_ref_dimensions, outputnode, [('out_report', 'out_report'),
('t1w_valid_list', 'anat_valid_list')]),
])
# fmt:on

if num_t1w == 1:
if num_files == 1:
get1st = pe.Node(niu.Select(index=[0]), name="get1st")
outputnode.inputs.t1w_realign_xfm = [
pkgr("smriprep", "data/itkIdentityTransform.txt")
]
outputnode.inputs.anat_realign_xfm = [pkgr("smriprep", "data/itkIdentityTransform.txt")]

# fmt:off
workflow.connect([
(t1w_conform, get1st, [('out_file', 'inlist')]),
(get1st, outputnode, [('out', 't1w_ref')]),
(anat_conform, get1st, [('out_file', 'inlist')]),
(get1st, outputnode, [('out', 'anat_ref')]),
])
# fmt:on
return workflow

t1w_conform_xfm = pe.MapNode(
anat_conform_xfm = pe.MapNode(
LTAConvert(in_lta="identity.nofile", out_lta=True),
iterfield=["source_file", "target_file"],
name="t1w_conform_xfm",
name="anat_conform_xfm",
)

# 1. Template (only if several T1w images)
Expand All @@ -681,7 +748,7 @@ def init_anat_template_wf(
n_procs=1,
) # n_procs=1 for reproducibility
# StructuralReference is fs.RobustTemplate if > 1 volume, copying otherwise
t1w_merge = pe.Node(
anat_merge = pe.Node(
StructuralReference(
auto_detect_sensitivity=True,
initial_timepoint=1, # For deterministic behavior
Expand All @@ -691,12 +758,12 @@ def init_anat_template_wf(
no_iteration=not longitudinal,
transform_outputs=True,
),
mem_gb=2 * num_t1w - 1,
name="t1w_merge",
mem_gb=2 * num_files - 1,
name="anat_merge",
)

# 2. Reorient template to RAS, if needed (mri_robust_template may set to LIA)
t1w_reorient = pe.Node(image.Reorient(), name="t1w_reorient")
anat_reorient = pe.Node(image.Reorient(), name="anat_reorient")

merge_xfm = pe.MapNode(
niu.Merge(2),
Expand All @@ -716,21 +783,21 @@ def _set_threads(in_list, maximum):

# fmt:off
workflow.connect([
(t1w_ref_dimensions, t1w_conform_xfm, [('t1w_valid_list', 'source_file')]),
(t1w_conform, t1w_conform_xfm, [('out_file', 'target_file')]),
(t1w_conform, n4_correct, [('out_file', 'input_image')]),
(t1w_conform, t1w_merge, [
(anat_ref_dimensions, anat_conform_xfm, [('t1w_valid_list', 'source_file')]),
(anat_conform, anat_conform_xfm, [('out_file', 'target_file')]),
(anat_conform, n4_correct, [('out_file', 'input_image')]),
(anat_conform, anat_merge, [
(('out_file', _set_threads, omp_nthreads), 'num_threads'),
(('out_file', add_suffix, '_template'), 'out_file')]),
(n4_correct, t1w_merge, [('output_image', 'in_files')]),
(t1w_merge, t1w_reorient, [('out_file', 'in_file')]),
(n4_correct, anat_merge, [('output_image', 'in_files')]),
(anat_merge, anat_reorient, [('out_file', 'in_file')]),
# Combine orientation and template transforms
(t1w_conform_xfm, merge_xfm, [('out_lta', 'in1')]),
(t1w_merge, merge_xfm, [('transform_outputs', 'in2')]),
(anat_conform_xfm, merge_xfm, [('out_lta', 'in1')]),
(anat_merge, merge_xfm, [('transform_outputs', 'in2')]),
(merge_xfm, concat_xfms, [('out', 'in_xfms')]),
# Output
(t1w_reorient, outputnode, [('out_file', 't1w_ref')]),
(concat_xfms, outputnode, [('out_xfm', 't1w_realign_xfm')]),
(anat_reorient, outputnode, [('out_file', 'anat_ref')]),
(concat_xfms, outputnode, [('out_xfm', 'anat_realign_xfm')]),
])
# fmt:on
return workflow
Expand Down
2 changes: 2 additions & 0 deletions smriprep/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def init_single_subject_wf(
# for documentation purposes
subject_data = {
"t1w": ["/completely/made/up/path/sub-01_T1w.nii.gz"],
"t2w": [],
}
else:
subject_data = collect_data(layout, subject_id, bids_filters=bids_filters)[0]
Expand Down Expand Up @@ -406,6 +407,7 @@ def init_single_subject_wf(
longitudinal=longitudinal,
name="anat_preproc_wf",
t1w=subject_data["t1w"],
t2w=subject_data["t2w"],
omp_nthreads=omp_nthreads,
output_dir=output_dir,
skull_strip_fixed_seed=skull_strip_fixed_seed,
Expand Down
Loading

0 comments on commit 99b0cfd

Please sign in to comment.