diff --git a/petprep/data/reports-spec.yml b/petprep/data/reports-spec.yml
index ba31269b..4e4ad7b0 100644
--- a/petprep/data/reports-spec.yml
+++ b/petprep/data/reports-spec.yml
@@ -128,7 +128,7 @@ sections:
subtitle: PET Summary and Carpet Plot
- bids: {datatype: figures, desc: hmc, suffix: pet}
- caption: Animated frames before and after PET head motion correction (keep cursor over image to restart).
+ caption: Animated frames before and after PET head motion correction with synchronized framewise displacement trace (keep cursor over image to restart). Below is a lineplot of the Framewise Displacement (FD) in mm per frame.
static: false
subtitle: Motion correction
diff --git a/petprep/interfaces/motion.py b/petprep/interfaces/motion.py
index 137d4071..ec9d50af 100644
--- a/petprep/interfaces/motion.py
+++ b/petprep/interfaces/motion.py
@@ -11,6 +11,7 @@
import nibabel as nib
import numpy as np
+import pandas as pd
from imageio import v2 as imageio
from nilearn import image
from nilearn.plotting import plot_epi
@@ -18,6 +19,7 @@
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
+ isdefined,
SimpleInterface,
TraitedSpec,
traits,
@@ -38,6 +40,7 @@ class MotionPlotInputSpec(BaseInterfaceInputSpec):
'transforms to the original data in native PET space'
),
)
+ fd_file = File(exists=True, desc='Confounds file containing framewise displacement')
duration = traits.Float(0.2, usedefault=True, desc='Frame duration for the GIF (seconds)')
@@ -69,6 +72,10 @@ def _run_interface(self, runtime):
)
_, _, vmin_corr, vmax_corr = self._compute_display_params(self.inputs.corrected_pet)
+ fd_values = None
+ if isdefined(self.inputs.fd_file):
+ fd_values = self._load_framewise_displacement(self.inputs.fd_file)
+
svg_file = self._build_animation(
output_path=svg_file,
cut_coords_orig=cut_coords_orig,
@@ -77,6 +84,7 @@ def _run_interface(self, runtime):
vmax_orig=vmax_orig,
vmin_corr=vmin_corr,
vmax_corr=vmax_corr,
+ fd_values=fd_values,
)
self._results['svg_file'] = str(svg_file)
@@ -97,6 +105,21 @@ def _compute_display_params(self, in_file: str):
return mid_img, cut_coords, vmin, vmax
+ def _load_framewise_displacement(self, fd_file: str) -> np.ndarray:
+ framewise_disp = pd.read_csv(fd_file, sep='\t')
+ if 'framewise_displacement' in framewise_disp:
+ fd_values = framewise_disp['framewise_displacement']
+ elif 'FD' in framewise_disp:
+ fd_values = framewise_disp['FD']
+ else:
+ available = ', '.join(framewise_disp.columns)
+ raise ValueError(
+ 'Could not find framewise displacement column in confounds file '
+ f'(available columns: {available})'
+ )
+
+ return np.asarray(fd_values.fillna(0.0), dtype=float)
+
def _build_animation(
self,
*,
@@ -107,6 +130,7 @@ def _build_animation(
vmax_orig: float,
vmin_corr: float,
vmax_corr: float,
+ fd_values: np.ndarray | None,
) -> Path:
orig_img = nib.load(self.inputs.original_pet)
corr_img = nib.load(self.inputs.corrected_pet)
@@ -116,6 +140,10 @@ def _build_animation(
corr_img.shape[-1] if corr_img.ndim > 3 else 1,
)
+ if fd_values is not None:
+ fd_values = np.asarray(fd_values[:n_frames], dtype=float)
+ n_frames = min(n_frames, len(fd_values))
+
with TemporaryDirectory() as tmpdir:
frames = []
for idx in range(n_frames):
@@ -162,7 +190,9 @@ def _build_animation(
frames.append(combined.astype(orig_arr.dtype, copy=False))
width = int(frames[0].shape[1])
- height = int(frames[0].shape[0])
+ frame_height = int(frames[0].shape[0])
+ fd_height = 220 if fd_values is not None else 0
+ height = frame_height + fd_height
total_duration = self.inputs.duration * n_frames
svg_parts = [
@@ -176,6 +206,11 @@ def _build_animation(
'}'
),
'.playing .frame {animation-play-state: running;}',
+ '.fd-line {fill: none; stroke: #2c7be5; stroke-width: 2;}',
+ '.fd-axis {stroke: #333; stroke-width: 1;}',
+ '.fd-point {fill: #2c7be5; stroke: white; stroke-width: 1;}',
+ '#fd-marker {fill: #d7263d; stroke: white; stroke-width: 2;}',
+ '#fd-value {font: 14px sans-serif; fill: #1a1a1a;}',
'@keyframes framefade {0%, 80% {opacity: 1;} 100% {opacity: 0;}}',
]
@@ -191,18 +226,87 @@ def _build_animation(
data_uri = b64encode(buffer.getvalue()).decode('ascii')
svg_parts.append(
f''
)
+ if fd_values is not None:
+ fd_padding = 45
+ fd_chart_height = fd_height
+ fd_x_start = fd_padding
+ fd_x_end = width - fd_padding
+ fd_axis_y = frame_height + fd_chart_height - fd_padding
+ fd_axis_y_top = frame_height + fd_padding
+ fd_y_range = fd_axis_y - fd_axis_y_top
+ fd_max = float(np.nanmax(fd_values)) if np.any(fd_values) else 0.0
+ if fd_max <= 0:
+ fd_max = 1.0
+
+ x_scale = (fd_x_end - fd_x_start) / max(n_frames - 1, 1)
+ points = []
+ point_elems = []
+ for idx, value in enumerate(fd_values):
+ x_coord = fd_x_start + x_scale * idx
+ y_coord = fd_axis_y - (value / fd_max) * fd_y_range
+ points.append(f'{x_coord:.2f},{y_coord:.2f}')
+ point_elems.append(
+ f''
+ )
+
+ fd_label_y = fd_axis_y_top + (fd_y_range / 2)
+
+ svg_parts.extend(
+ [
+ f'',
+ f'',
+ f'',
+ f'',
+ *point_elems,
+ f'',
+ f'',
+ f''
+ 'FD (mm)',
+ f''
+ 'Frames',
+ '',
+ ]
+ )
+
svg_parts.extend(
[
'