Skip to content

Commit

Permalink
RF: Move anat_ribbon into fit workflow for efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Sep 29, 2023
1 parent fe7332b commit 53cb30c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 48 deletions.
56 changes: 40 additions & 16 deletions smriprep/workflows/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def init_anat_preproc_wf(
("outputnode.fsnative2t1w_xfm", "fsnative2t1w_xfm"),
("outputnode.sphere_reg", "sphere_reg"),
(f"outputnode.sphere_reg_{'msm' if msm_sulc else 'fsLR'}", "sphere_reg_fsLR"),
("outputnode.anat_ribbon", "anat_ribbon"),
]),
(anat_fit_wf, anat_second_derivatives_wf, [
('outputnode.template', 'inputnode.template'),
Expand All @@ -308,7 +309,6 @@ def init_anat_preproc_wf(
surface_derivatives_wf = init_surface_derivatives_wf(
cifti_output=cifti_output,
)
anat_ribbon_wf = init_anat_ribbon_wf()
ds_surfaces_wf = init_ds_surfaces_wf(
bids_root=bids_root, output_dir=output_dir, surfaces=["inflated"]
)
Expand All @@ -324,11 +324,6 @@ def init_anat_preproc_wf(
('outputnode.thickness', 'inputnode.thickness'),
('outputnode.sulc', 'inputnode.sulc'),
]),
(anat_fit_wf, anat_ribbon_wf, [
('outputnode.t1w_mask', 'inputnode.ref_file'),
('outputnode.white', 'inputnode.white'),
('outputnode.pial', 'inputnode.pial'),
]),
(anat_fit_wf, ds_surfaces_wf, [
('outputnode.t1w_valid_list', 'inputnode.source_files'),
]),
Expand All @@ -346,12 +341,6 @@ def init_anat_preproc_wf(
('outputnode.out_aseg', 't1w_aseg'),
('outputnode.out_aparc', 't1w_aparc'),
]),
(anat_ribbon_wf, outputnode, [
("outputnode.anat_ribbon", "anat_ribbon"),
]),
(anat_ribbon_wf, anat_second_derivatives_wf, [
("outputnode.anat_ribbon", "inputnode.anat_ribbon"),
]),
])
# fmt:on

Expand Down Expand Up @@ -558,6 +547,7 @@ def init_anat_fit_wf(
"sphere_reg",
"sphere_reg_fsLR",
"sphere_reg_msm",
"anat_ribbon",
# Reverse transform; not computable from forward transform
"std2anat_xfm",
# Metadata
Expand Down Expand Up @@ -789,12 +779,17 @@ def init_anat_fit_wf(
])
# fmt:on

ds_mask_wf = init_ds_mask_wf(bids_root=bids_root, output_dir=output_dir)
ds_t1w_mask_wf = init_ds_mask_wf(
bids_root=bids_root,
output_dir=output_dir,
mask_type="brain",
name="ds_t1w_mask_wf",
)
# fmt:off
workflow.connect([
(sourcefile_buffer, ds_mask_wf, [("source_files", "inputnode.source_files")]),
(refined_buffer, ds_mask_wf, [("t1w_mask", "inputnode.t1w_mask")]),
(ds_mask_wf, outputnode, [("outputnode.t1w_mask", "t1w_mask")]),
(sourcefile_buffer, ds_t1w_mask_wf, [("source_files", "inputnode.source_files")]),
(refined_buffer, ds_t1w_mask_wf, [("t1w_mask", "inputnode.mask_file")]),
(ds_t1w_mask_wf, outputnode, [("outputnode.mask_file", "t1w_mask")]),
])
# fmt:on
else:
Expand Down Expand Up @@ -1182,6 +1177,35 @@ def init_anat_fit_wf(
])
# fmt:on

if "anat_ribbon" not in precomputed:
LOGGER.info("ANAT Stage 8a: Creating cortical ribbon mask")
anat_ribbon_wf = init_anat_ribbon_wf()
ds_ribbon_mask_wf = init_ds_mask_wf(
bids_root=bids_root,
output_dir=output_dir,
mask_type="ribbon",
name="ds_ribbon_mask_wf",
)
# fmt:off
workflow.connect([
(t1w_buffer, anat_ribbon_wf, [
("t1w_preproc", "inputnode.ref_file"),
]),
(surfaces_buffer, anat_ribbon_wf, [
("white", "inputnode.white"),
("pial", "inputnode.pial"),
]),
(sourcefile_buffer, ds_ribbon_mask_wf, [("source_files", "inputnode.source_files")]),
(anat_ribbon_wf, ds_ribbon_mask_wf, [
("outputnode.anat_ribbon", "inputnode.mask_file"),
]),
(ds_ribbon_mask_wf, outputnode, [("outputnode.mask_file", "anat_ribbon")]),
])
# fmt:on
else:
LOGGER.info("ANAT Stage 8a: Found pre-computed cortical ribbon mask")
outputnode.inputs.anat_ribbon = precomputed["anat_ribbon"]

# Stage 9: Baseline fsLR registration
if len(precomputed.get("sphere_reg_fsLR", [])) < 2:
LOGGER.info("ANAT Stage 9: Creating fsLR registration sphere")
Expand Down
56 changes: 24 additions & 32 deletions smriprep/workflows/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,9 @@ def init_ds_template_wf(

def init_ds_mask_wf(
*,
bids_root,
output_dir,
bids_root: str,
output_dir: str,
mask_type: str,
name="ds_mask_wf",
):
"""
Expand All @@ -336,40 +337,48 @@ def init_ds_mask_wf(
------
source_files
List of input T1w images
t1w_mask
Mask of the ``t1w_preproc``
mask_file
Mask to save
Outputs
-------
t1w_mask
The location in the output directory of the T1w mask
mask_file
The location in the output directory of the mask
"""
workflow = Workflow(name=name)

inputnode = pe.Node(
niu.IdentityInterface(fields=["source_files", "t1w_mask"]),
niu.IdentityInterface(fields=["source_files", "mask_file"]),
name="inputnode",
)
outputnode = pe.Node(niu.IdentityInterface(fields=["t1w_mask"]), name="outputnode")
outputnode = pe.Node(niu.IdentityInterface(fields=["mask_file"]), name="outputnode")

raw_sources = pe.Node(niu.Function(function=_bids_relative), name="raw_sources")
raw_sources.inputs.bids_root = bids_root

ds_t1w_mask = pe.Node(
DerivativesDataSink(base_directory=output_dir, desc="brain", suffix="mask", compress=True),
ds_mask = pe.Node(
DerivativesDataSink(
base_directory=output_dir,
desc=mask_type,
suffix="mask",
compress=True,
),
name="ds_t1w_mask",
run_without_submitting=True,
)
ds_t1w_mask.inputs.Type = "Brain"
if mask_type == "brain":
ds_mask.inputs.Type = "Brain"
else:
ds_mask.inputs.Type = "ROI"

# fmt:off
workflow.connect([
(inputnode, raw_sources, [('source_files', 'in_files')]),
(inputnode, ds_t1w_mask, [('t1w_mask', 'in_file'),
('source_files', 'source_file')]),
(raw_sources, ds_t1w_mask, [('out', 'RawSources')]),
(ds_t1w_mask, outputnode, [('out_file', 't1w_mask')]),
(inputnode, ds_mask, [('mask_file', 'in_file'),
('source_files', 'source_file')]),
(raw_sources, ds_mask, [('out', 'RawSources')]),
(ds_mask, outputnode, [('out_file', 'mask_file')]),
])
# fmt:on

Expand Down Expand Up @@ -865,11 +874,9 @@ def init_anat_second_derivatives_wf(
"t1w_dseg",
"t1w_tpms",
"anat2std_xfm",
"surfaces",
"sphere_reg",
"sphere_reg_fsLR",
"morphometrics",
"anat_ribbon",
"t1w_fs_aseg",
"t1w_fs_aparc",
"cifti_morph",
Expand Down Expand Up @@ -1019,18 +1026,6 @@ def init_anat_second_derivatives_wf(
name="ds_morphs",
run_without_submitting=True,
)
# Ribbon volume
ds_anat_ribbon = pe.Node(
DerivativesDataSink(
base_directory=output_dir,
desc="ribbon",
suffix="mask",
extension=".nii.gz",
compress=True,
),
name="ds_anat_ribbon",
run_without_submitting=True,
)

# Parcellations
ds_t1w_fsaseg = pe.Node(
Expand All @@ -1057,9 +1052,6 @@ def init_anat_second_derivatives_wf(
('source_files', 'source_file')]),
(inputnode, ds_t1w_fsparc, [('t1w_fs_aparc', 'in_file'),
('source_files', 'source_file')]),
(inputnode, ds_anat_ribbon, [('anat_ribbon', 'in_file'),
('source_files', 'source_file')]),

])
# fmt:on

Expand Down

0 comments on commit 53cb30c

Please sign in to comment.