diff --git a/petprep/cli/parser.py b/petprep/cli/parser.py index 50114f31..7c470800 100644 --- a/petprep/cli/parser.py +++ b/petprep/cli/parser.py @@ -188,8 +188,13 @@ def _bids_filter(value, parser): 'identifier (the sub- prefix can be removed)', ) # Re-enable when option is actually implemented - # g_bids.add_argument('-s', '--session-id', action='store', default='single_session', - # help='Select a specific session to be processed') + g_bids.add_argument( + '--session-label', + nargs='+', + type=lambda label: label.removeprefix('ses-'), + help='A space delimited list of session identifiers or a single ' + 'identifier (the ses- prefix can be removed)', + ) # Re-enable when option is actually implemented # g_bids.add_argument('-r', '--run-id', action='store', default='single_run', # help='Select a specific run to be processed') @@ -749,6 +754,13 @@ def parse_args(args=None, namespace=None): config.execution.log_level = int(max(25 - 5 * opts.verbose_count, logging.DEBUG)) config.from_dict(vars(opts), init=['nipype']) + if config.execution.session_label: + config.execution.bids_filters = config.execution.bids_filters or {} + config.execution.bids_filters['pet'] = { + **config.execution.bids_filters.get('pet', {}), + 'session': config.execution.session_label, + } + pvc_vals = (opts.pvc_tool, opts.pvc_method, opts.pvc_psf) if any(val is not None for val in pvc_vals) and not all(val is not None for val in pvc_vals): parser.error('Options --pvc-tool, --pvc-method and --pvc-psf must be used together.') @@ -908,5 +920,16 @@ def parse_args(args=None, namespace=None): f'One or more participant labels were not found in the BIDS directory: {", ".join(missing_subjects)}.' ) + if config.execution.session_label: + available_sessions = set( + config.execution.layout.get_sessions(subject=list(participant_label) or None) + ) + missing_sessions = set(config.execution.session_label) - available_sessions + if missing_sessions: + parser.error( + 'One or more session labels were not found in the BIDS directory: ' + f'{", ".join(sorted(missing_sessions))}.' + ) + config.execution.participant_label = sorted(participant_label) config.workflow.skull_strip_template = config.workflow.skull_strip_template[0] diff --git a/petprep/cli/tests/test_parser.py b/petprep/cli/tests/test_parser.py index 78d47ebb..d259f6f8 100644 --- a/petprep/cli/tests/test_parser.py +++ b/petprep/cli/tests/test_parser.py @@ -24,6 +24,8 @@ from argparse import ArgumentError +import nibabel as nb +import numpy as np import pytest from packaging.version import Version @@ -225,6 +227,45 @@ def test_derivatives(tmp_path): _reset_config() +def test_session_label_only_filters_pet(tmp_path): + bids = tmp_path / 'bids' + out_dir = tmp_path / 'out' + work_dir = tmp_path / 'work' + bids.mkdir() + (bids / 'dataset_description.json').write_text('{"Name": "Test", "BIDSVersion": "1.8.0"}') + + anat_path = bids / 'sub-01' / 'anat' / 'sub-01_T1w.nii.gz' + anat_path.parent.mkdir(parents=True, exist_ok=True) + nb.Nifti1Image(np.zeros((5, 5, 5)), np.eye(4)).to_filename(anat_path) + + pet_path = bids / 'sub-01' / 'ses-blocked' / 'pet' / 'sub-01_ses-blocked_pet.nii.gz' + pet_path.parent.mkdir(parents=True, exist_ok=True) + nb.Nifti1Image(np.zeros((5, 5, 5, 1)), np.eye(4)).to_filename(pet_path) + (pet_path.with_suffix('').with_suffix('.json')).write_text( + '{"FrameTimesStart": [0], "FrameDuration": [1]}' + ) + + try: + parse_args( + args=[ + str(bids), + str(out_dir), + 'participant', + '--session-label', + 'blocked', + '--skip-bids-validation', + '-w', + str(work_dir), + ] + ) + + filters = config.execution.bids_filters + assert filters.get('pet', {}).get('session') == ['blocked'] + assert 'session' not in filters.get('anat', {}) + finally: + _reset_config() + + def test_pvc_argument_handling(tmp_path, minimal_bids): out_dir = tmp_path / 'out' work_dir = tmp_path / 'work' diff --git a/petprep/config.py b/petprep/config.py index c61319b1..2cfc100b 100644 --- a/petprep/config.py +++ b/petprep/config.py @@ -432,6 +432,8 @@ class execution(_Config): """Unique identifier of this particular run.""" participant_label = None """List of participant identifiers that are to be preprocessed.""" + session_label = None + """List of session identifiers that are to be preprocessed.""" task_id = None """Select a particular task from all available in the dataset.""" templateflow_home = _templateflow_home diff --git a/petprep/data/reports-spec-pet.yml b/petprep/data/reports-spec-pet.yml index 1ec41495..580e9f40 100644 --- a/petprep/data/reports-spec-pet.yml +++ b/petprep/data/reports-spec-pet.yml @@ -6,6 +6,7 @@ sections: reportlets: - bids: {datatype: figures, desc: summary, suffix: pet} - bids: {datatype: figures, desc: validation, suffix: pet} + - bids: {datatype: figures, desc: hmc, suffix: pet} - bids: {datatype: figures, desc: carpetplot, suffix: pet} - bids: {datatype: figures, desc: confoundcorr, suffix: pet} - bids: {datatype: figures, desc: coreg, suffix: pet} diff --git a/petprep/data/reports-spec.yml b/petprep/data/reports-spec.yml index 8d57713d..66853ce9 100644 --- a/petprep/data/reports-spec.yml +++ b/petprep/data/reports-spec.yml @@ -127,6 +127,11 @@ sections: static: false subtitle: PET Summary and Carpet Plot + - bids: {datatype: figures, desc: hmc, suffix: pet} + caption: Animated frames before and after PET head motion correction with synchronized framewise displacement trace (keep cursor over image to restart). Below is a lineplot of the Framewise Displacement (FD) in mm per frame. Red line segments between points indicate frames where FD exceeds 3 mm, suggesting motion that may impact data quality. + static: false + subtitle: Motion correction + - bids: {datatype: figures, desc: confoundcorr, suffix: pet} caption: | Left: Correlation heatmap illustrating relationships among PET-derived confound variables (e.g., motion parameters, global signal). diff --git a/petprep/interfaces/__init__.py b/petprep/interfaces/__init__.py index af99d939..56c9b9ea 100644 --- a/petprep/interfaces/__init__.py +++ b/petprep/interfaces/__init__.py @@ -3,6 +3,7 @@ from niworkflows.interfaces.bids import DerivativesDataSink as _DDSink from .cifti import GeneratePetCifti +from .motion import MotionPlot from .tacs import ExtractRefTAC, ExtractTACs @@ -15,4 +16,5 @@ class DerivativesDataSink(_DDSink): 'GeneratePetCifti', 'ExtractTACs', 'ExtractRefTAC', + 'MotionPlot', ) diff --git a/petprep/interfaces/motion.py b/petprep/interfaces/motion.py new file mode 100644 index 00000000..35fc02c0 --- /dev/null +++ b/petprep/interfaces/motion.py @@ -0,0 +1,418 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Reportlets illustrating motion correction.""" + +from __future__ import annotations + +from base64 import b64encode +from io import BytesIO +from pathlib import Path +from tempfile import TemporaryDirectory + +import nibabel as nib +import numpy as np +import pandas as pd +from imageio import v2 as imageio +from nilearn import image +from nilearn.plotting import plot_epi +from nilearn.plotting.find_cuts import find_xyz_cut_coords +from nipype.interfaces.base import ( + BaseInterfaceInputSpec, + File, + SimpleInterface, + TraitedSpec, + isdefined, + traits, +) + + +class MotionPlotInputSpec(BaseInterfaceInputSpec): + original_pet = File( + exists=True, + mandatory=True, + desc='Original (uncorrected) PET series in native PET space', + ) + corrected_pet = File( + exists=True, + mandatory=True, + desc=( + 'Motion-corrected PET series derived by applying the estimated motion ' + 'transforms to the original data in native PET space' + ), + ) + fd_file = File(exists=True, desc='Confounds file containing framewise displacement') + duration = traits.Float(0.2, usedefault=True, desc='Frame duration for the GIF (seconds)') + + +class MotionPlotOutputSpec(TraitedSpec): + svg_file = File(exists=True, desc='Animated before/after motion correction SVG') + + +class MotionPlot(SimpleInterface): + """Generate animated visualizations before and after motion correction. + + A single GIF is created using ortho views with consistent cut coordinates + and color scaling derived from the midpoint frame of each series. The + per-frame views of the original and motion-corrected series are concatenated + horizontally, allowing the main PET report to display the animation + directly. + """ + + input_spec = MotionPlotInputSpec + output_spec = MotionPlotOutputSpec + + def _run_interface(self, runtime): + runtime.cwd = Path(runtime.cwd) + + svg_file = runtime.cwd / 'pet_motion_hmc.svg' + svg_file.parent.mkdir(parents=True, exist_ok=True) + + mid_orig, cut_coords_orig, vmin_orig, vmax_orig = self._compute_display_params( + self.inputs.original_pet + ) + _, _, vmin_corr, vmax_corr = self._compute_display_params(self.inputs.corrected_pet) + + fd_values = None + if isdefined(self.inputs.fd_file): + fd_values = self._load_framewise_displacement(self.inputs.fd_file) + + svg_file = self._build_animation( + output_path=svg_file, + cut_coords_orig=cut_coords_orig, + cut_coords_corr=cut_coords_orig, + vmin_orig=vmin_orig, + vmax_orig=vmax_orig, + vmin_corr=vmin_corr, + vmax_corr=vmax_corr, + fd_values=fd_values, + ) + + self._results['svg_file'] = str(svg_file) + + return runtime + + def _compute_display_params(self, in_file: str): + img = nib.load(in_file) + if img.ndim == 3: + mid_img = img + else: + mid_img = image.index_img(in_file, img.shape[-1] // 2) + + data = mid_img.get_fdata().astype(float) + vmax = float(np.percentile(data.flatten(), 99.9)) + vmin = float(np.percentile(data.flatten(), 80)) + cut_coords = find_xyz_cut_coords(mid_img) + + return mid_img, cut_coords, vmin, vmax + + def _load_framewise_displacement(self, fd_file: str) -> np.ndarray: + framewise_disp = pd.read_csv(fd_file, sep='\t') + if 'framewise_displacement' in framewise_disp: + fd_values = framewise_disp['framewise_displacement'] + elif 'FD' in framewise_disp: + fd_values = framewise_disp['FD'] + else: + available = ', '.join(framewise_disp.columns) + raise ValueError( + 'Could not find framewise displacement column in confounds file ' + f'(available columns: {available})' + ) + + return np.asarray(fd_values.fillna(0.0), dtype=float) + + def _build_animation( + self, + *, + output_path: Path, + cut_coords_orig: tuple[float, float, float], + cut_coords_corr: tuple[float, float, float], + vmin_orig: float, + vmax_orig: float, + vmin_corr: float, + vmax_corr: float, + fd_values: np.ndarray | None, + ) -> Path: + orig_img = nib.load(self.inputs.original_pet) + corr_img = nib.load(self.inputs.corrected_pet) + + n_frames = min( + orig_img.shape[-1] if orig_img.ndim > 3 else 1, + corr_img.shape[-1] if corr_img.ndim > 3 else 1, + ) + + if fd_values is not None: + fd_values = np.asarray(fd_values[:n_frames], dtype=float) + n_frames = min(n_frames, len(fd_values)) + + with TemporaryDirectory() as tmpdir: + frames = [] + for idx in range(n_frames): + orig_png = Path(tmpdir) / f'orig_{idx:04d}.png' + corr_png = Path(tmpdir) / f'corr_{idx:04d}.png' + + plot_epi( + image.index_img(self.inputs.original_pet, idx), + colorbar=True, + display_mode='ortho', + title=f'Before motion correction | Frame {idx + 1}', + cut_coords=cut_coords_orig, + vmin=vmin_orig, + vmax=vmax_orig, + output_file=str(orig_png), + ) + plot_epi( + image.index_img(self.inputs.corrected_pet, idx), + colorbar=True, + display_mode='ortho', + title=f'After motion correction | Frame {idx + 1}', + cut_coords=cut_coords_corr, + vmin=vmin_corr, + vmax=vmax_corr, + output_file=str(corr_png), + ) + + orig_arr = np.asarray(imageio.imread(orig_png)) + corr_arr = np.asarray(imageio.imread(corr_png)) + + max_height = max(orig_arr.shape[0], corr_arr.shape[0]) + if orig_arr.shape[0] < max_height: + pad = max_height - orig_arr.shape[0] + orig_arr = np.pad( + orig_arr, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255 + ) + if corr_arr.shape[0] < max_height: + pad = max_height - corr_arr.shape[0] + corr_arr = np.pad( + corr_arr, ((0, pad), (0, 0), (0, 0)), mode='constant', constant_values=255 + ) + + combined = np.concatenate([orig_arr, corr_arr], axis=1) + frames.append(combined.astype(orig_arr.dtype, copy=False)) + + width = int(frames[0].shape[1]) + frame_height = int(frames[0].shape[0]) + fd_height = 220 if fd_values is not None else 0 + height = frame_height + fd_height + total_duration = self.inputs.duration * n_frames + + svg_parts = [ + '', + '') + + for idx, frame in enumerate(frames): + buffer = BytesIO() + imageio.imwrite(buffer, frame, format='PNG') + data_uri = b64encode(buffer.getvalue()).decode('ascii') + svg_parts.append( + f'' + ) + + if fd_values is not None: + fd_padding = 45 + fd_chart_height = fd_height + fd_x_start = fd_padding + fd_x_end = width - fd_padding + fd_axis_y = frame_height + fd_chart_height - fd_padding + fd_axis_y_top = frame_height + fd_padding + fd_y_range = fd_axis_y - fd_axis_y_top + fd_max = float(np.nanmax(fd_values)) if np.any(fd_values) else 0.0 + if fd_max <= 0: + fd_max = 1.0 + + x_scale = (fd_x_end - fd_x_start) / max(n_frames - 1, 1) + points = [] + point_elems = [] + line_elems = [] + fd_threshold = 3.0 + for idx, value in enumerate(fd_values): + x_coord = fd_x_start + x_scale * idx + y_coord = fd_axis_y - (value / fd_max) * fd_y_range + points.append(f'{x_coord:.2f},{y_coord:.2f}') + point_elems.append( + f'' + ) + if idx > 0: + prev_x, prev_y = map(float, points[idx - 1].split(',')) + line_class = ( + 'fd-line-alert' if value >= fd_threshold else 'fd-line-primary' + ) + line_elems.append( + f'' + ) + + fd_label_y = fd_axis_y_top + (fd_y_range / 2) + fd_label_offset = 35 + + tick_values = np.linspace(0, fd_max, num=3) + tick_length = 6 + tick_elems = [] + label_elems = [] + for tick_value in tick_values: + y_coord = fd_axis_y - (tick_value / fd_max) * fd_y_range + tick_elems.append( + f'' + ) + label_elems.append( + f'' + f'{tick_value:.1f}' + ) + + # X-axis ticks show every other frame (plus the last) to avoid clutter + if n_frames <= 1: + x_tick_indices = np.array([0]) + else: + tick_stride = 2 + x_tick_indices = np.arange(0, n_frames, tick_stride) + if x_tick_indices[-1] != n_frames - 1: + x_tick_indices = np.append(x_tick_indices, n_frames - 1) + + x_tick_length = 6 + x_tick_elems = [] + x_label_elems = [] + for tick_idx in x_tick_indices: + x_coord = fd_x_start + x_scale * tick_idx + x_tick_elems.append( + f'' + ) + x_label_elems.append( + f'' + f'{tick_idx + 1}' + ) + + svg_parts.extend( + [ + '', + f'', + f'', + *tick_elems, + *label_elems, + *line_elems, + *point_elems, + *x_tick_elems, + *x_label_elems, + f'', + f'', + f'' + 'FD (mm)', + f'' + 'Frames', + '', + ] + ) + + svg_parts.extend( + [ + '', + '', + ] + ) + + output_path.write_text('\n'.join(svg_parts), encoding='utf-8') + + return output_path + + +__all__ = ['MotionPlot'] diff --git a/petprep/interfaces/tests/test_motion.py b/petprep/interfaces/tests/test_motion.py new file mode 100644 index 00000000..3acdb95d --- /dev/null +++ b/petprep/interfaces/tests/test_motion.py @@ -0,0 +1,101 @@ +from pathlib import Path + +import nibabel as nb +import numpy as np +import pytest + +from petprep.interfaces.motion import MotionPlot + + +def _write_image(path: Path, shape): + data = np.linspace(0, 1, int(np.prod(shape)), dtype=float).reshape(shape) + img = nb.Nifti1Image(data, np.eye(4)) + img.to_filename(path) + return path + + +def test_motion_plot_builds_svg(tmp_path, monkeypatch): + orig_path = _write_image(tmp_path / 'orig.nii.gz', (4, 4, 4, 2)) + corr_path = _write_image(tmp_path / 'corr.nii.gz', (4, 4, 4, 2)) + + call_count = {'count': 0} + + def fake_plot_epi(img, **kwargs): + height = 10 if call_count['count'] % 2 == 0 else 6 + array = np.ones((height, 8, 3), dtype=np.uint8) * 255 + from imageio import v2 as imageio + + imageio.imwrite(kwargs['output_file'], array) + call_count['count'] += 1 + + monkeypatch.setattr('petprep.interfaces.motion.plot_epi', fake_plot_epi) + + motion = MotionPlot() + motion.inputs.original_pet = str(orig_path) + motion.inputs.corrected_pet = str(corr_path) + motion.inputs.duration = 0.05 + + result = motion.run(cwd=tmp_path) + svg_file = Path(result.outputs.svg_file) + + content = svg_file.read_text() + assert 'frame-0' in content + assert 'animation-delay: 0.05s' in content + assert call_count['count'] == 4 + + +def test_compute_display_params_handles_single_frame(tmp_path): + img_path = _write_image(tmp_path / 'single.nii.gz', (5, 5, 5)) + + motion = MotionPlot() + mid_img, cut_coords, vmin, vmax = motion._compute_display_params(str(img_path)) + + assert mid_img.ndim == 3 + assert len(cut_coords) == 3 + assert vmin <= vmax + + +def test_load_framewise_displacement_variants(tmp_path): + fd_path = tmp_path / 'fd.tsv' + fd_path.write_text('framewise_displacement\n0.1\n0.2\n') + + motion = MotionPlot() + values = motion._load_framewise_displacement(str(fd_path)) + assert np.allclose(values, [0.1, 0.2]) + + fd_path.write_text('FD\n0.0\n') + values = motion._load_framewise_displacement(str(fd_path)) + assert np.allclose(values, [0.0]) + + fd_path.write_text('other\n1.0\n') + with pytest.raises(ValueError): + motion._load_framewise_displacement(str(fd_path)) + + +def test_build_animation_includes_fd_plot(tmp_path, monkeypatch): + orig_path = _write_image(tmp_path / 'orig.nii.gz', (4, 4, 4, 3)) + corr_path = _write_image(tmp_path / 'corr.nii.gz', (4, 4, 4, 3)) + fd_path = tmp_path / 'fd.tsv' + fd_path.write_text('FD\n0\n0\n') + + def fake_plot_epi(img, **kwargs): + array = np.ones((8, 8, 3), dtype=np.uint8) * 255 + from imageio import v2 as imageio + + imageio.imwrite(kwargs['output_file'], array) + + monkeypatch.setattr('petprep.interfaces.motion.plot_epi', fake_plot_epi) + + motion = MotionPlot() + motion.inputs.original_pet = str(orig_path) + motion.inputs.corrected_pet = str(corr_path) + motion.inputs.fd_file = str(fd_path) + motion.inputs.duration = 0.01 + + result = motion.run(cwd=tmp_path) + svg_file = Path(result.outputs.svg_file) + content = svg_file.read_text() + + assert 'fd-plot' in content + assert 'FD (mm)' in content + assert 'frame-2' not in content # limited to FD length diff --git a/petprep/workflows/pet/base.py b/petprep/workflows/pet/base.py index 7d58a9eb..934b1bcc 100644 --- a/petprep/workflows/pet/base.py +++ b/petprep/workflows/pet/base.py @@ -39,7 +39,7 @@ from niworkflows.utils.connections import listify from ... import config -from ...interfaces import DerivativesDataSink +from ...interfaces import DerivativesDataSink, MotionPlot from ...utils.misc import estimate_pet_mem_usage # PET workflows @@ -275,6 +275,39 @@ def init_pet_wf( ]), ]) # fmt:skip + if nvols > 1: + motion_report = pe.Node(MotionPlot(), name='motion_report', mem_gb=0.1) + ds_motion_report = pe.Node( + DerivativesDataSink( + base_directory=petprep_dir, + desc='hmc', + datatype='figures', + suffix='pet', + ), + name='ds_report_motion', + run_without_submitting=True, + mem_gb=config.DEFAULT_MEMORY_MIN_GB, + ) + ds_motion_report.inputs.source_file = pet_file + + workflow.connect( + [ + ( + pet_native_wf, + motion_report, + [ + ('outputnode.pet_minimal', 'original_pet'), + ('outputnode.pet_native', 'corrected_pet'), + ], + ), + (motion_report, ds_motion_report, [('svg_file', 'in_file')]), + ] + ) + else: + config.loggers.workflow.warning( + f'Motion report will be skipped - series has only {nvols} frame(s)' + ) + petref_out = bool(nonstd_spaces.intersection(('pet', 'run', 'petref'))) petref_out &= config.workflow.level == 'full' @@ -791,6 +824,14 @@ def init_pet_wf( ]), ]) # fmt:skip + if nvols > 1: + workflow.connect( + pet_confounds_wf, + 'outputnode.confounds_file', + motion_report, + 'fd_file', + ) + if spaces.get_spaces(nonstandard=False, dim=(3,)): carpetplot_wf = init_carpetplot_wf( mem_gb=mem_gb['resampled'],