Skip to content

Commit

Permalink
enh: replace FSL FAST with ANTs Atropos for brain tissue segmentation
Browse files Browse the repository at this point in the history
One more step toward #1032.

Partially-Supersedes: #803, #832.
  • Loading branch information
oesteban committed Apr 1, 2023
1 parent f0eeb68 commit 8589e33
Showing 1 changed file with 121 additions and 82 deletions.
203 changes: 121 additions & 82 deletions mriqc/workflows/anatomical/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,48 +132,7 @@ def anat_qc_workflow(name="anatMRIQC"):
# 5. Air mask (with and without artifacts)
amw = airmsk_wf()
# 6. Brain tissue segmentation
if config.workflow.species.lower() == "human":
from nipype.interfaces.fsl import FAST

segment = pe.Node(
FAST(segments=True, out_basename="segment"),
name="segmentation",
mem_gb=5,
)
seg_in_file = "in_files"
dseg_out = "tissue_class_map"
pve_out = "partial_volume_files"
else:
from nipype.interfaces.ants import Atropos

format_tpm_names = pe.Node(
niu.Function(
input_names=["in_files"],
output_names=["file_format"],
function=_format_tpm_names,
execution={"keep_inputs": True, "remove_unnecessary_outputs": False},
),
name="format_tpm_names",
)

segment = pe.Node(
Atropos(
initialization="PriorProbabilityImages",
number_of_tissue_classes=3,
prior_weighting=0.1,
mrf_radius=[1, 1, 1],
mrf_smoothing_factor=0.01,
save_posteriors=True,
out_classified_image_name="segment.nii.gz",
output_posteriors_name_template="segment_%02d.nii.gz",
),
name="segmentation",
mem_gb=5,
)
seg_in_file = "intensity_images"
dseg_out = "classified_image"
pve_out = "posteriors"

bts = init_brain_tissue_segmentation()
# 7. Compute IQMs
iqmswf = compute_iqms()
# Reports
Expand All @@ -190,18 +149,20 @@ def anat_qc_workflow(name="anatMRIQC"):
(datalad_get, iqmswf, [("in_file", "inputnode.in_file")]),
(datalad_get, norm, [(("in_file", _get_mod), "inputnode.modality")]),
(to_ras, skull_stripping, [("out_file", "inputnode.in_files")]),
(skull_stripping, segment, [("outputnode.out_brain", seg_in_file)]),
(skull_stripping, bts, [("outputnode.out_corrected", "inputnode.in_file"),
("outputnode.out_mask", "inputnode.brainmask")]),
(skull_stripping, hmsk, [("outputnode.out_corrected", "inputnode.in_file")]),
(segment, hmsk, [(dseg_out, "inputnode.in_segm")]),
(skull_stripping, norm, [
("outputnode.out_corrected", "inputnode.moving_image"),
("outputnode.out_mask", "inputnode.moving_mask")]),
(norm, bts, [("outputnode.out_tpms", "inputnode.std_tpms")]),
(norm, amw, [
("outputnode.ind2std_xfm", "inputnode.ind2std_xfm")]),
(norm, iqmswf, [
("outputnode.out_tpms", "inputnode.std_tpms")]),
(norm, anat_report_wf, ([
("outputnode.out_report", "inputnode.mni_report")])),
(bts, hmsk, [("outputnode.out_segm", "inputnode.in_segm")]),
(to_ras, amw, [("out_file", "inputnode.in_file")]),
(skull_stripping, amw, [("outputnode.out_mask", "inputnode.in_mask")]),
(hmsk, amw, [("outputnode.out_file", "inputnode.head_mask")]),
Expand All @@ -213,8 +174,8 @@ def anat_qc_workflow(name="anatMRIQC"):
("outputnode.hat_mask", "inputnode.hatmask"),
("outputnode.art_mask", "inputnode.artmask"),
("outputnode.rot_mask", "inputnode.rotmask")]),
(segment, iqmswf, [(dseg_out, "inputnode.segmentation"),
(pve_out, "inputnode.pvms")]),
(bts, iqmswf, [("outputnode.out_segm", "inputnode.segmentation"),
("outputnode.out_pvms", "inputnode.pvms")]),
(hmsk, iqmswf, [("outputnode.out_file", "inputnode.headmask")]),
(to_ras, anat_report_wf, [("out_file", "inputnode.in_ras")]),
(skull_stripping, anat_report_wf, [
Expand All @@ -226,22 +187,11 @@ def anat_qc_workflow(name="anatMRIQC"):
("outputnode.art_mask", "inputnode.artmask"),
("outputnode.rot_mask", "inputnode.rotmask"),
]),
(segment, anat_report_wf, [(dseg_out, "inputnode.segmentation")]),
(bts, anat_report_wf, [("outputnode.out_segm", "inputnode.segmentation")]),
(iqmswf, anat_report_wf, [("outputnode.noisefit", "inputnode.noisefit")]),
(iqmswf, anat_report_wf, [("outputnode.out_file", "inputnode.in_iqms")]),
(iqmswf, outputnode, [("outputnode.out_file", "out_json")]),
])

if config.workflow.species.lower() == 'human':
workflow.connect([
(datalad_get, segment, [(("in_file", _get_imgtype), "img_type")]),
])
else:
workflow.connect([
(norm, format_tpm_names, [('outputnode.out_tpms', 'in_files')]),
(format_tpm_names, segment, [(('file_format', _pop), 'prior_image')]),
(skull_stripping, segment, [("outputnode.out_mask", "mask_image")]),
])
# fmt: on

# Upload metrics
Expand Down Expand Up @@ -330,6 +280,30 @@ def spatial_normalization(name="SpatialNormalization"):
)
]

# Project MNI segmentation to T1 space
tpms_std2t1w = pe.MapNode(
ApplyTransforms(
dimension=3,
default_value=0,
interpolation="Linear",
float=config.execution.ants_float,
),
iterfield=["input_image"],
name="tpms_std2t1w",
)
tpms_std2t1w.inputs.input_image = [
str(p)
for p in get_template(
config.workflow.template_id,
suffix="probseg",
resolution=(
1 if config.workflow.species.lower() == "human"
else None
),
label=["CSF", "GM", "WM"],
)
]

# fmt: off
workflow.connect([
(inputnode, norm, [("moving_image", "moving_image"),
Expand All @@ -350,6 +324,95 @@ def spatial_normalization(name="SpatialNormalization"):
return workflow


def init_brain_tissue_segmentation(name="brain_tissue_segmentation"):
"""
Setup a workflow for brain tissue segmentation.
.. workflow::
from mriqc.workflows.anatomical.base import init_brain_tissue_segmentation
from mriqc.testing import mock_config
with mock_config():
wf = init_brain_tissue_segmentation()
"""
from nipype.interfaces.ants import Atropos

def _format_tpm_names(in_files, fname_string=None):
from pathlib import Path
import nibabel as nb
import glob

out_path = Path.cwd().absolute()

# copy files to cwd and rename iteratively
for count, fname in enumerate(in_files):
img = nb.load(fname)
extension = "".join(Path(fname).suffixes)
out_fname = f"priors_{1 + count:02}{extension}"
nb.save(img, Path(out_path, out_fname))

if fname_string is None:
fname_string = f"priors_%02d{extension}"

out_files = [
str(prior) for prior in glob.glob(str(Path(out_path, f"priors*{extension}")))
]

# return path with c-style format string for Atropos
file_format = str(Path(out_path, fname_string))
return file_format, out_files

workflow = pe.Workflow(name=name)
inputnode = pe.Node(
niu.IdentityInterface(fields=["in_file", "brainmask", "std_tpms"]),
name="inputnode",
)
outputnode = pe.Node(
niu.IdentityInterface(fields=["out_segm", "out_pvms"]),
name="outputnode",
)

format_tpm_names = pe.Node(
niu.Function(
input_names=["in_files"],
output_names=["file_format"],
function=_format_tpm_names,
execution={"keep_inputs": True, "remove_unnecessary_outputs": False},
),
name="format_tpm_names",
)

segment = pe.Node(
Atropos(
initialization="PriorProbabilityImages",
number_of_tissue_classes=3,
prior_weighting=0.1,
mrf_radius=[1, 1, 1],
mrf_smoothing_factor=0.01,
save_posteriors=True,
out_classified_image_name="segment.nii.gz",
output_posteriors_name_template="segment_%02d.nii.gz",
num_threads=config.nipype.omp_nthreads,
),
name="segmentation",
mem_gb=5,
num_threads=config.nipype.omp_nthreads,
)

# fmt: off
workflow.connect([
(inputnode, segment, [("in_file", "intensity_images"),
("brainmask", "mask_image")]),
(inputnode, format_tpm_names, [('std_tpms', 'in_files')]),
(format_tpm_names, segment, [(('file_format', _pop), 'prior_image')]),
(segment, outputnode, [("classified_image", "out_segm"),
("posteriors", "out_pvms")]),
])
# fmt: on
return workflow


def compute_iqms(name="ComputeIQMs"):
"""
Setup the workflow that actually computes the IQMs.
Expand Down Expand Up @@ -896,30 +959,6 @@ def _get_mod(in_file):
return Path(in_file).name.rstrip(".gz").rstrip(".nii").split("_")[-1]


def _format_tpm_names(in_files, fname_string=None):
from pathlib import Path
import nibabel as nb
import glob

out_path = Path.cwd().absolute()

# copy files to cwd and rename iteratively
for count, fname in enumerate(in_files):
img = nb.load(fname)
extension = "".join(Path(fname).suffixes)
out_fname = f"priors_{1 + count:02}{extension}"
nb.save(img, Path(out_path, out_fname))

if fname_string is None:
fname_string = f"priors_%02d{extension}"

out_files = [str(prior) for prior in glob.glob(str(Path(out_path, f"priors*{extension}")))]

# return path with c-style format string for Atropos
file_format = str(Path(out_path, fname_string))
return file_format, out_files


def _pop(inlist):
if isinstance(inlist, (list, tuple)):
return inlist[0]
Expand Down

0 comments on commit 8589e33

Please sign in to comment.