Skip to content

Commit

Permalink
FIX: Tag memory estimates in resamplers (#3150)
Browse files Browse the repository at this point in the history
We'll start with the baseline assumption that our resampler uses about
4*the original BOLD series size, which was generally a good estimate for
antsApplyTransform.

Also tagging a few things with `run_without_submitting` and some tasks
that are surely using more than the default amount with 1GB, which is at
least better. STC I assume uses 2x the total amount.

All of this could stand profiling, and I'm curious to try out
[memray](https://bloomberg.github.io/memray/) on the Python stuff at
least.
  • Loading branch information
effigies committed Nov 21, 2023
1 parent d9e92a9 commit e2c0fc0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 11 deletions.
22 changes: 16 additions & 6 deletions fmriprep/workflows/bold/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def init_bold_volumetric_resample_wf(
*,
metadata: dict,
mem_gb: dict[str, float],
fieldmap_id: str | None = None,
omp_nthreads: int = 1,
name: str = 'bold_volumetric_resample_wf',
Expand Down Expand Up @@ -119,9 +120,14 @@ def init_bold_volumetric_resample_wf(

gen_ref = pe.Node(GenerateSamplingReference(), name='gen_ref', mem_gb=0.3)

boldref2target = pe.Node(niu.Merge(2), name='boldref2target')
bold2target = pe.Node(niu.Merge(2), name='bold2target')
resample = pe.Node(ResampleSeries(), name="resample", n_procs=omp_nthreads)
boldref2target = pe.Node(niu.Merge(2), name='boldref2target', run_without_submitting=True)
bold2target = pe.Node(niu.Merge(2), name='bold2target', run_without_submitting=True)
resample = pe.Node(
ResampleSeries(),
name="resample",
n_procs=omp_nthreads,
mem_gb=mem_gb['resampled'],
)

workflow.connect([
(inputnode, gen_ref, [
Expand Down Expand Up @@ -156,10 +162,14 @@ def init_bold_volumetric_resample_wf(
name="distortion_params",
run_without_submitting=True,
)
fmap2target = pe.Node(niu.Merge(2), name='fmap2target')
inverses = pe.Node(niu.Function(function=_gen_inverses), name='inverses')
fmap2target = pe.Node(niu.Merge(2), name='fmap2target', run_without_submitting=True)
inverses = pe.Node(
niu.Function(function=_gen_inverses),
name='inverses',
run_without_submitting=True,
)

fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon")
fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon", mem_gb=1)

workflow.connect([
(inputnode, fmap_select, [
Expand Down
5 changes: 4 additions & 1 deletion fmriprep/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def init_bold_wf(
metadata=all_metadata[0],
fieldmap_id=fieldmap_id if not multiecho else None,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb,
name='bold_anat_wf',
)
bold_anat_wf.inputs.inputnode.resolution = "native"
Expand Down Expand Up @@ -446,6 +447,7 @@ def init_bold_wf(
metadata=all_metadata[0],
fieldmap_id=fieldmap_id if not multiecho else None,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb,
name='bold_std_wf',
)
ds_bold_std_wf = init_ds_volumes_wf(
Expand Down Expand Up @@ -525,6 +527,7 @@ def init_bold_wf(
metadata=all_metadata[0],
fieldmap_id=fieldmap_id if not multiecho else None,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb,
name='bold_MNI6_wf',
)

Expand All @@ -537,7 +540,7 @@ def init_bold_wf(

bold_grayords_wf = init_bold_grayords_wf(
grayord_density=config.workflow.cifti_output,
mem_gb=mem_gb["resampled"],
mem_gb=1,
repetition_time=all_metadata[0]["RepetitionTime"],
)

Expand Down
12 changes: 9 additions & 3 deletions fmriprep/workflows/bold/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def init_bold_native_wf(

# Slice-timing correction
if run_stc:
bold_stc_wf = init_bold_stc_wf(name="bold_stc_wf", metadata=metadata)
bold_stc_wf = init_bold_stc_wf(metadata=metadata, mem_gb=mem_gb)
workflow.connect([
(inputnode, bold_stc_wf, [("dummy_scans", "inputnode.skip_vols")]),
(validate_bold, bold_stc_wf, [("out_file", "inputnode.bold_file")]),
Expand Down Expand Up @@ -824,7 +824,12 @@ def init_bold_native_wf(
]) # fmt:skip

# Resample to boldref
boldref_bold = pe.Node(ResampleSeries(), name="boldref_bold", n_procs=omp_nthreads)
boldref_bold = pe.Node(
ResampleSeries(),
name="boldref_bold",
n_procs=omp_nthreads,
mem_gb=mem_gb["resampled"],
)

workflow.connect([
(inputnode, boldref_bold, [
Expand All @@ -839,7 +844,7 @@ def init_bold_native_wf(
]) # fmt:skip

if fieldmap_id:
boldref_fmap = pe.Node(ReconstructFieldmap(inverse=[True]), name="boldref_fmap")
boldref_fmap = pe.Node(ReconstructFieldmap(inverse=[True]), name="boldref_fmap", mem_gb=1)
workflow.connect([
(inputnode, boldref_fmap, [
("boldref", "target_ref_file"),
Expand All @@ -858,6 +863,7 @@ def init_bold_native_wf(
joinsource="echo_index",
joinfield=["bold_files"],
name="join_echos",
run_without_submitting=True,
)

# create optimal combination, adaptive T2* map
Expand Down
3 changes: 3 additions & 0 deletions fmriprep/workflows/bold/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,14 @@ def init_bold_fsLR_resampling_wf(
metric_dilate = pe.Node(
MetricDilate(distance=10, nearest=True),
name="metric_dilate",
mem_gb=1,
n_procs=omp_nthreads,
)
mask_native = pe.Node(MetricMask(), name="mask_native")
resample_to_fsLR = pe.Node(
MetricResample(method='ADAP_BARY_AREA', area_surfs=True),
name="resample_to_fsLR",
mem_gb=1,
n_procs=omp_nthreads,
)
# ... line 89
Expand Down Expand Up @@ -812,6 +814,7 @@ def init_bold_grayords_wf(
grayordinates=grayord_density,
),
name="gen_cifti",
mem_gb=mem_gb,
)

workflow.connect([
Expand Down
8 changes: 7 additions & 1 deletion fmriprep/workflows/bold/stc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def _pre_run_hook(self, runtime):
return runtime


def init_bold_stc_wf(metadata: dict, name='bold_stc_wf'):
def init_bold_stc_wf(
*,
mem_gb: dict,
metadata: dict,
name='bold_stc_wf',
):
"""
Create a workflow for :abbr:`STC (slice-timing correction)`.
Expand Down Expand Up @@ -119,6 +124,7 @@ def init_bold_stc_wf(metadata: dict, name='bold_stc_wf'):
slice_encoding_direction=metadata.get('SliceEncodingDirection', 'k'),
tzero=tzero,
),
mem_gb=mem_gb['filesize'] * 2,
name='slice_timing_correction',
)

Expand Down

0 comments on commit e2c0fc0

Please sign in to comment.