Skip to content

Commit

Permalink
Merge pull request #1224 from nipreps/fix/spikes-percentage
Browse files Browse the repository at this point in the history
ENH: Add computation of spiking voxels mask and percent IQMs
  • Loading branch information
oesteban committed Apr 2, 2024
2 parents d62756f + b4cab9f commit bb4deb0
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 63 deletions.
128 changes: 85 additions & 43 deletions mriqc/interfaces/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
'FilterShells',
'NumberOfShells',
'ReadDWIMetadata',
'SpikingVoxelsMask',
'SplitShells',
'WeightedStat',
)
Expand Down Expand Up @@ -100,7 +101,8 @@ class _DiffusionQCInputSpec(_BaseInterfaceInputSpec):
in_md = File(exists=True, mandatory=True, desc='input MD map')
brain_mask = File(exists=True, mandatory=True, desc='input probabilistic brain mask')
wm_mask = File(exists=True, mandatory=True, desc='input probabilistic white-matter mask')
cc_mask = File(exists=True, mandatory=True, desc='input probabilistic white-matter mask')
cc_mask = File(exists=True, mandatory=True, desc='input binary mask of the corpus callosum')
spikes_mask = File(exists=True, mandatory=True, desc='input binary mask of spiking voxels')
direction = traits.Enum(
'all',
'x',
Expand Down Expand Up @@ -128,6 +130,7 @@ class _DiffusionQCOutputSpec(TraitedSpec):
efc = traits.Dict
fber = traits.Dict
fd = traits.Dict
spikes_ppm = traits.Dict
# snr = traits.Float
# gsr = traits.Dict
# tsnr = traits.Float
Expand Down Expand Up @@ -217,6 +220,10 @@ def _run_interface(self, runtime):
b_vectors=self.inputs.in_bvec,
)

# Get cc mask data
spmask = np.asanyarray(nb.load(self.inputs.spikes_mask).dataobj) > 0.0
self._results['spikes_ppm'] = dqc.spike_ppm(spmask)

# FBER
self._results['fber'] = {
f'b{int(bval):d}': aqc.fber(bdata, mskdata.astype(np.uint8))
Expand Down Expand Up @@ -925,6 +932,66 @@ def _run_interface(self, runtime):
return runtime


class _SpikingVoxelsMaskInputSpec(_BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='a DWI 4D file')
brain_mask = File(exists=True, mandatory=True, desc='input probabilistic brain 3D mask')
z_threshold = traits.Float(3.0, usedefault=True, desc='z-score threshold')
b_masks = traits.List(
traits.List(traits.Int, minlen=1),
minlen=1,
mandatory=True,
desc='list of ``n_shells`` b-value-wise indices lists'
)


class _SpikingVoxelsMaskOutputSpec(_TraitedSpec):
out_mask = File(exists=True, desc='a 4D binary mask of spiking voxels')


class SpikingVoxelsMask(SimpleInterface):
"""Computes :abbr:`QC (Quality Control)` measures on the input DWI EPI scan."""

input_spec = _SpikingVoxelsMaskInputSpec
output_spec = _SpikingVoxelsMaskOutputSpec

def _run_interface(self, runtime):
self._results['out_mask'] = fname_presuffix(
self.inputs.in_file,
suffix='spikesmask',
newpath=runtime.cwd,
)

in_nii = nb.load(self.inputs.in_file)
data = np.round(in_nii.get_fdata(), 4).astype('float32')

bmask_nii = nb.load(self.inputs.brain_mask)
brainmask = np.round(bmask_nii.get_fdata(), 2).astype('float32')

spikes_mask = get_spike_mask(
data,
shell_masks=self.inputs.b_masks,
brainmask=brainmask,
z_threshold=self.inputs.z_threshold,
)

header = bmask_nii.header.copy()
header.set_data_dtype(np.uint8)
header.set_xyzt_units('mm')
header.set_intent('estimate', name='spiking voxels mask')
header['cal_max'] = 1
header['cal_min'] = 0

# Write out binary WM mask after binary opening
spikes_mask_nii = nb.Nifti1Image(
spikes_mask.astype(np.uint8),
bmask_nii.affine,
header,
)
spikes_mask_nii.to_filename(self._results['out_mask'])

return runtime


def _rms(estimator, X):
"""
Callable to pass to GridSearchCV that will calculate a distance score.
Expand Down Expand Up @@ -1029,9 +1096,9 @@ def segment_corpus_callosum(

def get_spike_mask(
data: np.ndarray,
shell_masks: list,
brainmask: np.ndarray,
z_threshold: float = 3.0,
grouping_vals: np.ndarray | None = None,
bmag: int | None = None,
) -> np.ndarray:
"""
Creates a binary mask classifying voxels in the data array as spike or non-spike.
Expand All @@ -1048,18 +1115,10 @@ def get_spike_mask(
z_threshold : :obj:`float`, optional (default=3.0)
The number of standard deviations to use above the mean as the threshold
multiplier.
grouping_vals : :obj:`~numpy.ndarray`, optional
If provided, this array is used to group voxels for thresholding. Voxels
with the same value in ``grouping_vals`` are considered to belong to the same
group. The threshold will be calculated independently for each group.
- If ``grouping_vals`` has the same shape as ``data`` (4D), it is assumed to be
a mask where each voxel value indicates the group it belongs to.
- If ``grouping_vals`` has a 3D shape, it is assumed to represent b-values
corresponding to each voxel in the 4D ``data`` array. In this case, voxels
with the same b-value are grouped together.
bmag : int, optional
The order of magnitude for b-value rounding (used only if
``grouping_vals`` is provided as b-values). Default: None (derived from max b-value).
brainmask : :obj:`~numpy.ndarray`
The brain mask.
shell_masks : :obj:`list`
A list of :obj:`~numpy.ndarray` objects
Returns:
-------
Expand All @@ -1069,33 +1128,16 @@ def get_spike_mask(
data array.
"""
from dipy.core.gradients import round_bvals, unique_bvals_magnitude

if grouping_vals is None:
threshold = np.round((z_threshold * np.std(data)) + np.mean(data), 3)
spike_mask = np.round(data, 3) > threshold
return spike_mask

threshold_mask = np.zeros(data.shape)

rounded_grouping_vals = round_bvals(grouping_vals, bmag)
gvals = unique_bvals_magnitude(grouping_vals, bmag)

if grouping_vals.shape == data.shape:
for gval in gvals:
gval_data = data[rounded_grouping_vals == gval]
gval_threshold = ((z_threshold * np.std(gval_data))
+ np.mean(gval_data))
threshold_mask[rounded_grouping_vals == gval] = (
gval_threshold * np.ones(gval_data.shape))
else:
for gval in gvals:
gval_data = data[..., rounded_grouping_vals == gval]
gval_threshold = ((z_threshold * np.std(gval_data))
+ np.mean(gval_data))
threshold_mask[..., rounded_grouping_vals == gval] = (
gval_threshold * np.ones(gval_data.shape))

spike_mask = data > threshold_mask

spike_mask = np.zeros_like(data, dtype=bool)

brainmask = brainmask >= 0.5

for b_mask in shell_masks:
shelldata = data[..., b_mask]

a_thres = z_threshold * shelldata[brainmask].std() + shelldata[brainmask].mean()

spike_mask[..., b_mask] = shelldata > a_thres

return spike_mask
28 changes: 15 additions & 13 deletions mriqc/qc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,47 +259,49 @@ def cc_snr(
return cc_snr_estimates


def spike_percentage(
data: np.ndarray,
def spike_ppm(
spike_mask: np.ndarray,
slice_threshold: float = 0.05,
decimals: int = 8,
) -> dict[str, float | np.ndarray]:
"""
Calculates the percentage of voxels classified as spikes (global and slice-wise).
Calculates fractions (global and slice-wise) of voxels classified as spikes in ppm.
This function computes two metrics:
* Global spike percentage: The average fraction of voxels exceeding the spike
* Global spike parts-per-million [ppm]: Fraction of voxels exceeding the spike
threshold across the entire data array.
* Slice-wise spiking percentage: The fraction of slices along each dimension of
* Slice-wise spiking [ppm]: The fraction of slices along each dimension of
the data array where the average fraction of spiking voxels within the slice
exceeds a user-defined threshold (``slice_threshold``).
Parameters
----------
data : :obj:`~numpy.ndarray` (float, 4D)
The data array used to generate the spike mask.
spike_mask : :obj:`~numpy.ndarray` (bool, same shape as data)
The binary mask indicating spike voxels (True) and non-spike voxels (False).
slice_threshold : :obj:`float`, optional (default=0.05)
The minimum fraction of voxels in a slice that must be classified as spikes
for the slice to be considered spiking.
decimals : :obj:`int`
The number of decimals to round the fractions.
Returns
-------
:obj:`dict`
A dictionary containing the calculated spike percentages:
* 'spike_perc_global': :obj:`float` - Global percentage of spiking voxels.
* 'spike_perc_slice': :obj:`list` of :obj:`float` - List of slice-wise
spiking percentages for each dimension of the data array.
* 'global': :obj:`float` - global spiking voxels ppm.
* 'slice': :obj:`list` of :obj:`float` - List of slice-wise spiking voxel
fractions in ppm for each dimension of the data array.
"""

spike_perc_global = float(np.mean(np.ravel(spike_mask)))
spike_perc_global = round(float(np.mean(np.ravel(spike_mask))), decimals) * 1e6
spike_perc_slice = [
float(np.mean(np.mean(spike_mask, axis=axis) > slice_threshold))
for axis in range(data.ndim)
round(
float(np.mean(np.mean(spike_mask, axis=axis) > slice_threshold)), decimals
) * 1e6
for axis in range(spike_mask.ndim)
]

return {'global': spike_perc_global, 'slice': spike_perc_slice}
13 changes: 6 additions & 7 deletions mriqc/qc/tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,15 @@

import numpy as np

from mriqc.qc.diffusion import spike_percentage
from mriqc.qc.diffusion import spike_ppm


def test_spike_percentage():
img = np.random.normal(loc=10, scale=1.0, size=(76, 76, 64, 124))
def test_spike_ppm():
msk = np.random.randint(0, high=2, size=(76, 76, 64, 124), dtype=bool)
val = spike_percentage(img, msk, .5)
val = spike_ppm(msk, .5)

assert np.isclose(val['global'], 0.5, rtol=1, atol=1)
assert np.isclose(val['global'], 0.5e6, rtol=1, atol=1)

assert np.min(val['slice']) >= 0
assert np.max(val['slice']) <= 1
assert len(val['slice']) == img.ndim
assert np.max(val['slice']) <= 1e6
assert len(val['slice']) == msk.ndim
9 changes: 9 additions & 0 deletions mriqc/workflows/diffusion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def dmri_qc_workflow(name='dwiMRIQC'):
FilterShells,
NumberOfShells,
ReadDWIMetadata,
SpikingVoxelsMask,
WeightedStat,
)
from mriqc.messages import BUILDING_WORKFLOW
Expand Down Expand Up @@ -189,6 +190,8 @@ def dmri_qc_workflow(name='dwiMRIQC'):
name='dti',
)

sp_mask = pe.Node(SpikingVoxelsMask(), name='sp_mask')

# Calculate CC mask
cc_mask = pe.Node(CCSegmentation(), name='cc_mask')

Expand All @@ -211,7 +214,9 @@ def dmri_qc_workflow(name='dwiMRIQC'):
(datalad_get, iqms_wf, [('in_file', 'inputnode.in_file')]),
(datalad_get, sanitize, [('in_file', 'in_file')]),
(sanitize, dwi_ref, [('out_file', 'in_file')]),
(sanitize, sp_mask, [('out_file', 'in_file')]),
(shells, dwi_ref, [(('b_masks', _first), 't_mask')]),
(shells, sp_mask, [('b_masks', 'b_masks')]),
(meta, shells, [('out_bval_file', 'in_bvals')]),
(sanitize, drift, [('out_file', 'full_epi')]),
(shells, get_lowb, [(('b_indices', _first), 'indices')]),
Expand All @@ -222,6 +227,7 @@ def dmri_qc_workflow(name='dwiMRIQC'):
(hmc_b0, drift, [('out_file', 'in_file')]),
(shells, drift, [(('b_indices', _first), 'b0_ixs')]),
(dwi_ref, dmri_bmsk, [('out_file', 'inputnode.in_files')]),
(dmri_bmsk, sp_mask, [('outputnode.out_mask', 'brain_mask')]),
(dmri_bmsk, drift, [('outputnode.out_mask', 'brainmask_file')]),
(drift, hmcwf, [('out_full_file', 'inputnode.in_file')]),
(drift, averages, [('out_full_file', 'in_file')]),
Expand All @@ -244,6 +250,7 @@ def dmri_qc_workflow(name='dwiMRIQC'):
(dti, cc_mask, [('out_fa', 'in_fa'),
('out_cfa', 'in_cfa')]),
(averages, iqms_wf, [(('out_file', _first), 'inputnode.in_b0')]),
(sp_mask, iqms_wf, [('out_mask', 'inputnode.spikes_mask')]),
(hmcwf, iqms_wf, [('outputnode.out_fd', 'inputnode.framewise_displacement')]),
(dti, iqms_wf, [('out_fa', 'inputnode.in_fa'),
('out_cfa', 'inputnode.in_cfa'),
Expand Down Expand Up @@ -313,6 +320,7 @@ def compute_iqms(name='ComputeIQMs'):
'brain_mask',
'wm_mask',
'cc_mask',
'spikes_mask',
'framewise_displacement',
]
),
Expand Down Expand Up @@ -363,6 +371,7 @@ def compute_iqms(name='ComputeIQMs'):
('brain_mask', 'brain_mask'),
('wm_mask', 'wm_mask'),
('cc_mask', 'cc_mask'),
('spikes_mask', 'spikes_mask'),
('in_fa', 'in_fa'),
('in_md', 'in_md'),
('in_cfa', 'in_cfa'),
Expand Down

0 comments on commit bb4deb0

Please sign in to comment.