Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion petprep/data/reports-spec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ sections:
subtitle: PET Summary and Carpet Plot

- bids: {datatype: figures, desc: hmc, suffix: pet}
caption: Animated frames before and after PET head motion correction (keep cursor over image to restart).
caption: Animated frames before and after PET head motion correction with synchronized framewise displacement trace (keep cursor over image to restart). Below is a lineplot of the Framewise Displacement (FD) in mm per frame.
static: false
subtitle: Motion correction

Expand Down
120 changes: 118 additions & 2 deletions petprep/interfaces/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

import nibabel as nib
import numpy as np
import pandas as pd
from imageio import v2 as imageio
from nilearn import image
from nilearn.plotting import plot_epi
from nilearn.plotting.find_cuts import find_xyz_cut_coords
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
isdefined,
SimpleInterface,
TraitedSpec,
traits,
Expand All @@ -38,6 +40,7 @@ class MotionPlotInputSpec(BaseInterfaceInputSpec):
'transforms to the original data in native PET space'
),
)
fd_file = File(exists=True, desc='Confounds file containing framewise displacement')
duration = traits.Float(0.2, usedefault=True, desc='Frame duration for the GIF (seconds)')


Expand Down Expand Up @@ -69,6 +72,10 @@ def _run_interface(self, runtime):
)
_, _, vmin_corr, vmax_corr = self._compute_display_params(self.inputs.corrected_pet)

fd_values = None
if isdefined(self.inputs.fd_file):
fd_values = self._load_framewise_displacement(self.inputs.fd_file)

svg_file = self._build_animation(
output_path=svg_file,
cut_coords_orig=cut_coords_orig,
Expand All @@ -77,6 +84,7 @@ def _run_interface(self, runtime):
vmax_orig=vmax_orig,
vmin_corr=vmin_corr,
vmax_corr=vmax_corr,
fd_values=fd_values,
)

self._results['svg_file'] = str(svg_file)
Expand All @@ -97,6 +105,21 @@ def _compute_display_params(self, in_file: str):

return mid_img, cut_coords, vmin, vmax

def _load_framewise_displacement(self, fd_file: str) -> np.ndarray:
framewise_disp = pd.read_csv(fd_file, sep='\t')
if 'framewise_displacement' in framewise_disp:
fd_values = framewise_disp['framewise_displacement']
elif 'FD' in framewise_disp:
fd_values = framewise_disp['FD']
else:
available = ', '.join(framewise_disp.columns)
raise ValueError(
'Could not find framewise displacement column in confounds file '
f'(available columns: {available})'
)

return np.asarray(fd_values.fillna(0.0), dtype=float)

def _build_animation(
self,
*,
Expand All @@ -107,6 +130,7 @@ def _build_animation(
vmax_orig: float,
vmin_corr: float,
vmax_corr: float,
fd_values: np.ndarray | None,
) -> Path:
orig_img = nib.load(self.inputs.original_pet)
corr_img = nib.load(self.inputs.corrected_pet)
Expand All @@ -116,6 +140,10 @@ def _build_animation(
corr_img.shape[-1] if corr_img.ndim > 3 else 1,
)

if fd_values is not None:
fd_values = np.asarray(fd_values[:n_frames], dtype=float)
n_frames = min(n_frames, len(fd_values))

with TemporaryDirectory() as tmpdir:
frames = []
for idx in range(n_frames):
Expand Down Expand Up @@ -162,7 +190,9 @@ def _build_animation(
frames.append(combined.astype(orig_arr.dtype, copy=False))

width = int(frames[0].shape[1])
height = int(frames[0].shape[0])
frame_height = int(frames[0].shape[0])
fd_height = 220 if fd_values is not None else 0
height = frame_height + fd_height
total_duration = self.inputs.duration * n_frames

svg_parts = [
Expand All @@ -176,6 +206,11 @@ def _build_animation(
'}'
),
'.playing .frame {animation-play-state: running;}',
'.fd-line {fill: none; stroke: #2c7be5; stroke-width: 2;}',
'.fd-axis {stroke: #333; stroke-width: 1;}',
'.fd-point {fill: #2c7be5; stroke: white; stroke-width: 1;}',
'#fd-marker {fill: #d7263d; stroke: white; stroke-width: 2;}',
'#fd-value {font: 14px sans-serif; fill: #1a1a1a;}',
'@keyframes framefade {0%, 80% {opacity: 1;} 100% {opacity: 0;}}',
]

Expand All @@ -191,44 +226,125 @@ def _build_animation(
data_uri = b64encode(buffer.getvalue()).decode('ascii')
svg_parts.append(
f'<image class="frame frame-{idx}" '
f'width="{width}" height="{height}" x="0" y="0" '
f'width="{width}" height="{frame_height}" x="0" y="0" '
f'href="data:image/png;base64,{data_uri}" />'
)

if fd_values is not None:
fd_padding = 45
fd_chart_height = fd_height
fd_x_start = fd_padding
fd_x_end = width - fd_padding
fd_axis_y = frame_height + fd_chart_height - fd_padding
fd_axis_y_top = frame_height + fd_padding
fd_y_range = fd_axis_y - fd_axis_y_top
fd_max = float(np.nanmax(fd_values)) if np.any(fd_values) else 0.0
if fd_max <= 0:
fd_max = 1.0

x_scale = (fd_x_end - fd_x_start) / max(n_frames - 1, 1)
points = []
point_elems = []
for idx, value in enumerate(fd_values):
x_coord = fd_x_start + x_scale * idx
y_coord = fd_axis_y - (value / fd_max) * fd_y_range
points.append(f'{x_coord:.2f},{y_coord:.2f}')
point_elems.append(
f'<circle class="fd-point fd-point-{idx}" cx="{x_coord:.2f}" '
f'cy="{y_coord:.2f}" r="3" data-value="{value:.6f}" />'
)

fd_label_y = fd_axis_y_top + (fd_y_range / 2)

svg_parts.extend(
[
f'<g class="fd-plot" aria-label="Framewise displacement">',
f'<line class="fd-axis" x1="{fd_x_start}" x2="{fd_x_end}" '
f'y1="{fd_axis_y}" y2="{fd_axis_y}" />',
f'<line class="fd-axis" x1="{fd_x_start}" x2="{fd_x_start}" '
f'y1="{fd_axis_y_top}" y2="{fd_axis_y}" />',
f'<polyline class="fd-line" points="{' '.join(points)}" />',
*point_elems,
f'<circle id="fd-marker" r="6" cx="{fd_x_start}" cy="{fd_axis_y}" />',
f'<text id="fd-value" x="{fd_x_start}" '
f'y="{fd_axis_y_top - 12}" aria-live="polite"></text>',
f'<text x="{fd_x_start - 25}" y="{fd_label_y:.2f}" '
'font-size="14" text-anchor="middle" transform='
f'"rotate(-90 {fd_x_start - 25},{fd_label_y:.2f})">'
'FD (mm)</text>',
f'<text x="{(fd_x_start + fd_x_end) / 2:.2f}" '
f'y="{fd_axis_y + 35}" font-size="14" text-anchor="middle">'
'Frames</text>',
'</g>',
]
)

svg_parts.extend(
[
'<script>',
'(() => {',
' const svg = document.currentScript.parentNode;',
" const frames = svg.querySelectorAll('.frame');",
" const fdPoints = Array.from(svg.querySelectorAll('.fd-point'));",
" const fdMarker = svg.querySelector('#fd-marker');",
" const fdValueLabel = svg.querySelector('#fd-value');",
f' const cycleMs = {total_duration * 1000:.0f};',
f' const frameDurationMs = {self.inputs.duration * 1000:.0f};',
' let restartTimer = null;',
' let playbackTimer = null;',
' let currentFrame = 0;',
' const setFdMarker = (index) => {',
' if (!fdMarker || !fdPoints.length) return;',
' const point = fdPoints[index % fdPoints.length];',
' fdMarker.setAttribute("cx", point.getAttribute("cx"));',
' fdMarker.setAttribute("cy", point.getAttribute("cy"));',
' if (fdValueLabel) {',
' const value = parseFloat(point.dataset.value || "0");',
' fdValueLabel.textContent = `Frame ${index + 1}: ${value.toFixed(3)} mm`;',
' }',
' };',
' const showFrame = (index) => {',
' currentFrame = index % frames.length;',
' setFdMarker(currentFrame);',
' };',
' const restart = () => {',
' frames.forEach((frame) => {',
" frame.style.animation = 'none';",
' // Force reflow to restart the CSS animation',
' void frame.getBoundingClientRect();',
" frame.style.animation = '';",
' });',
' showFrame(0);',
' };',
' const start = () => {',
' if (restartTimer) {',
' clearInterval(restartTimer);',
' }',
' if (playbackTimer) {',
' clearInterval(playbackTimer);',
' }',
' svg.classList.add("playing");',
' restart();',
' restartTimer = setInterval(restart, cycleMs);',
' playbackTimer = setInterval(() => {',
' showFrame(currentFrame + 1);',
' }, frameDurationMs);',
' };',
' const stop = () => {',
' svg.classList.remove("playing");',
' if (restartTimer) {',
' clearInterval(restartTimer);',
' restartTimer = null;',
' }',
' if (playbackTimer) {',
' clearInterval(playbackTimer);',
' playbackTimer = null;',
' }',
' frames.forEach((frame) => {',
" frame.style.animation = 'none';",
' });',
' };',
' showFrame(0);',
" svg.addEventListener('mouseenter', start);",
" svg.addEventListener('mouseleave', stop);",
'})();',
Expand Down
8 changes: 8 additions & 0 deletions petprep/workflows/pet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,14 @@ def init_pet_wf(
]),
]) # fmt:skip

if nvols > 1:
workflow.connect(
pet_confounds_wf,
'outputnode.confounds_file',
motion_report,
'fd_file',
)

if spaces.get_spaces(nonstandard=False, dim=(3,)):
carpetplot_wf = init_carpetplot_wf(
mem_gb=mem_gb['resampled'],
Expand Down