diff --git a/petprep/data/reports-spec.yml b/petprep/data/reports-spec.yml index ba31269b..4e4ad7b0 100644 --- a/petprep/data/reports-spec.yml +++ b/petprep/data/reports-spec.yml @@ -128,7 +128,7 @@ sections: subtitle: PET Summary and Carpet Plot - bids: {datatype: figures, desc: hmc, suffix: pet} - caption: Animated frames before and after PET head motion correction (keep cursor over image to restart). + caption: Animated frames before and after PET head motion correction with synchronized framewise displacement trace (keep cursor over image to restart). Below is a lineplot of the Framewise Displacement (FD) in mm per frame. static: false subtitle: Motion correction diff --git a/petprep/interfaces/motion.py b/petprep/interfaces/motion.py index 137d4071..ec9d50af 100644 --- a/petprep/interfaces/motion.py +++ b/petprep/interfaces/motion.py @@ -11,6 +11,7 @@ import nibabel as nib import numpy as np +import pandas as pd from imageio import v2 as imageio from nilearn import image from nilearn.plotting import plot_epi @@ -18,6 +19,7 @@ from nipype.interfaces.base import ( BaseInterfaceInputSpec, File, + isdefined, SimpleInterface, TraitedSpec, traits, @@ -38,6 +40,7 @@ class MotionPlotInputSpec(BaseInterfaceInputSpec): 'transforms to the original data in native PET space' ), ) + fd_file = File(exists=True, desc='Confounds file containing framewise displacement') duration = traits.Float(0.2, usedefault=True, desc='Frame duration for the GIF (seconds)') @@ -69,6 +72,10 @@ def _run_interface(self, runtime): ) _, _, vmin_corr, vmax_corr = self._compute_display_params(self.inputs.corrected_pet) + fd_values = None + if isdefined(self.inputs.fd_file): + fd_values = self._load_framewise_displacement(self.inputs.fd_file) + svg_file = self._build_animation( output_path=svg_file, cut_coords_orig=cut_coords_orig, @@ -77,6 +84,7 @@ def _run_interface(self, runtime): vmax_orig=vmax_orig, vmin_corr=vmin_corr, vmax_corr=vmax_corr, + fd_values=fd_values, ) self._results['svg_file'] = str(svg_file) @@ -97,6 +105,21 @@ def _compute_display_params(self, in_file: str): return mid_img, cut_coords, vmin, vmax + def _load_framewise_displacement(self, fd_file: str) -> np.ndarray: + framewise_disp = pd.read_csv(fd_file, sep='\t') + if 'framewise_displacement' in framewise_disp: + fd_values = framewise_disp['framewise_displacement'] + elif 'FD' in framewise_disp: + fd_values = framewise_disp['FD'] + else: + available = ', '.join(framewise_disp.columns) + raise ValueError( + 'Could not find framewise displacement column in confounds file ' + f'(available columns: {available})' + ) + + return np.asarray(fd_values.fillna(0.0), dtype=float) + def _build_animation( self, *, @@ -107,6 +130,7 @@ def _build_animation( vmax_orig: float, vmin_corr: float, vmax_corr: float, + fd_values: np.ndarray | None, ) -> Path: orig_img = nib.load(self.inputs.original_pet) corr_img = nib.load(self.inputs.corrected_pet) @@ -116,6 +140,10 @@ def _build_animation( corr_img.shape[-1] if corr_img.ndim > 3 else 1, ) + if fd_values is not None: + fd_values = np.asarray(fd_values[:n_frames], dtype=float) + n_frames = min(n_frames, len(fd_values)) + with TemporaryDirectory() as tmpdir: frames = [] for idx in range(n_frames): @@ -162,7 +190,9 @@ def _build_animation( frames.append(combined.astype(orig_arr.dtype, copy=False)) width = int(frames[0].shape[1]) - height = int(frames[0].shape[0]) + frame_height = int(frames[0].shape[0]) + fd_height = 220 if fd_values is not None else 0 + height = frame_height + fd_height total_duration = self.inputs.duration * n_frames svg_parts = [ @@ -176,6 +206,11 @@ def _build_animation( '}' ), '.playing .frame {animation-play-state: running;}', + '.fd-line {fill: none; stroke: #2c7be5; stroke-width: 2;}', + '.fd-axis {stroke: #333; stroke-width: 1;}', + '.fd-point {fill: #2c7be5; stroke: white; stroke-width: 1;}', + '#fd-marker {fill: #d7263d; stroke: white; stroke-width: 2;}', + '#fd-value {font: 14px sans-serif; fill: #1a1a1a;}', '@keyframes framefade {0%, 80% {opacity: 1;} 100% {opacity: 0;}}', ] @@ -191,18 +226,87 @@ def _build_animation( data_uri = b64encode(buffer.getvalue()).decode('ascii') svg_parts.append( f'' ) + if fd_values is not None: + fd_padding = 45 + fd_chart_height = fd_height + fd_x_start = fd_padding + fd_x_end = width - fd_padding + fd_axis_y = frame_height + fd_chart_height - fd_padding + fd_axis_y_top = frame_height + fd_padding + fd_y_range = fd_axis_y - fd_axis_y_top + fd_max = float(np.nanmax(fd_values)) if np.any(fd_values) else 0.0 + if fd_max <= 0: + fd_max = 1.0 + + x_scale = (fd_x_end - fd_x_start) / max(n_frames - 1, 1) + points = [] + point_elems = [] + for idx, value in enumerate(fd_values): + x_coord = fd_x_start + x_scale * idx + y_coord = fd_axis_y - (value / fd_max) * fd_y_range + points.append(f'{x_coord:.2f},{y_coord:.2f}') + point_elems.append( + f'' + ) + + fd_label_y = fd_axis_y_top + (fd_y_range / 2) + + svg_parts.extend( + [ + f'', + f'', + f'', + f'', + *point_elems, + f'', + f'', + f'' + 'FD (mm)', + f'' + 'Frames', + '', + ] + ) + svg_parts.extend( [ '