Skip to content

Commit

Permalink
Merge pull request #1272 from nipreps/fix/dwi-group-iqm-names
Browse files Browse the repository at this point in the history
FIX: Finalized naming and connection of DWI IQMs
  • Loading branch information
oesteban committed Apr 15, 2024
2 parents 5cca275 + c1b6e66 commit ef04a62
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 122 deletions.
135 changes: 73 additions & 62 deletions mriqc/interfaces/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ class _DiffusionQCInputSpec(_BaseInterfaceInputSpec):
mandatory=True,
desc='DWI data after HMC and split by shells (indexed by in_bval)'
)
in_bval = traits.List(
in_shells_bval = traits.List(
traits.Float,
minlen=1,
mandatory=True,
desc='list of unique b-values (one per shell), ordered by growing intensity',
)
in_bval_file = File(exists=True, mandatory=True, desc='original b-vals file')
in_bvec = traits.List(
traits.List(
traits.Tuple(traits.Float, traits.Float, traits.Float),
Expand Down Expand Up @@ -149,21 +150,19 @@ class _DiffusionQCInputSpec(_BaseInterfaceInputSpec):

class _DiffusionQCOutputSpec(TraitedSpec):
bdiffs = traits.Dict
cc_snr = traits.Dict
efc = traits.Dict
fa_degenerate = traits.Float
fa_nans = traits.Float
fber = traits.Dict
fd = traits.Dict
ndc = traits.Float
sigma_cc = traits.Float
sigma_pca = traits.Float
sigma_piesno = traits.Float
spikes_ppm = traits.Dict
sigma = traits.Dict
spikes = traits.Dict
# gsr = traits.Dict
# tsnr = traits.Float
# fwhm = traits.Dict(desc='full width half-maximum measure')
# size = traits.Dict
snr_cc = traits.Dict
summary = traits.Dict

out_qc = traits.Dict(desc='output flattened dictionary with all measures')
Expand Down Expand Up @@ -240,45 +239,44 @@ def _run_interface(self, runtime):
self._results['summary'] = stats

# CC mask SNR and std
self._results['cc_snr'], cc_sigma = dqc.cc_snr(
self._results['snr_cc'], cc_sigma = dqc.cc_snr(
in_b0=b0data,
dwi_shells=shelldata,
cc_mask=ccdata,
b_values=self.inputs.in_bval,
b_values=self.inputs.in_shells_bval,
b_vectors=self.inputs.in_bvec,
)
self._results['sigma_cc'] = round(float(cc_sigma), 4)

fa_nans_mask = np.asanyarray(nb.load(self.inputs.in_fa_nans).dataobj) > 0.0
self._results['fa_nans'] = np.round(float(fa_nans_mask[mskdata > 0.5].mean()), 8) * 1e6
self._results['fa_nans'] = round(float(1e6 * fa_nans_mask[mskdata > 0.5].mean()), 2)

fa_degenerate_mask = np.asanyarray(nb.load(self.inputs.in_fa_degenerate).dataobj) > 0.0
self._results['fa_degenerate'] = np.round(
float(fa_degenerate_mask[mskdata > 0.5].mean()),
8,
) * 1e6
self._results['fa_degenerate'] = round(
float(1e6 * fa_degenerate_mask[mskdata > 0.5].mean()),
2,
)

# Get spikes-mask data
spmask = np.asanyarray(nb.load(self.inputs.spikes_mask).dataobj) > 0.0
self._results['spikes_ppm'] = dqc.spike_ppm(spmask)
self._results['spikes'] = dqc.spike_ppm(spmask)

# FBER
self._results['fber'] = {
f'b{int(bval):d}': aqc.fber(bdata, mskdata.astype(np.uint8))
for bval, bdata in zip(self.inputs.in_bval, shelldata)
f'shell{i + 1:02d}': aqc.fber(bdata, mskdata.astype(np.uint8))
for i, bdata in enumerate(shelldata)
}

# EFC
self._results['efc'] = {
f'b{int(bval):d}': aqc.efc(bdata)
for bval, bdata in zip(self.inputs.in_bval, shelldata)
f'shell{i + 1:02d}': aqc.efc(bdata)
for i, bdata in enumerate(shelldata)
}

# FD
fd_data = np.loadtxt(self.inputs.in_fd, skiprows=1)
num_fd = (fd_data > self.inputs.fd_thres).sum()
self._results['fd'] = {
'mean': float(fd_data.mean()),
'mean': round(float(fd_data.mean()), 4),
'num': int(num_fd),
'perc': float(num_fd * 100 / (len(fd_data) + 1)),
}
Expand All @@ -288,19 +286,18 @@ def _run_interface(self, runtime):
np.nan_to_num(nb.load(self.inputs.in_file).get_fdata()),
3,
)
self._results['ndc'] = float(
dqc.neighboring_dwi_correlation(
dwidata,
neighbor_indices=self.inputs.qspace_neighbors,
mask=mskdata > 0.5,
)
self._results['ndc'] = dqc.neighboring_dwi_correlation(
dwidata,
neighbor_indices=self.inputs.qspace_neighbors,
mask=mskdata > 0.5,
)

# PIESNO
self._results['sigma_piesno'] = round(self.inputs.piesno_sigma, 4)

# dwidenoise - Marchenko-Pastur PCA
self._results['sigma_pca'] = round(self.inputs.noise_floor, 4)
# Sigmas
self._results['sigma'] = {
'cc': round(float(cc_sigma), 4),
'piesno': round(self.inputs.piesno_sigma, 4),
'pca': round(self.inputs.noise_floor, 4),
}

# rotated b-vecs deviations
diffs = np.array(self.inputs.in_bvec_diff)
Expand Down Expand Up @@ -396,6 +393,9 @@ def _run_interface(self, runtime):
class _NumberOfShellsInputSpec(_BaseInterfaceInputSpec):
in_bvals = File(mandatory=True, desc='bvals file')
b0_threshold = traits.Float(50, usedefault=True, desc='a threshold for the low-b values')
dsi_threshold = traits.Int(
11, usedefault=True, desc='number of shells to call a dataset DSI'
)


class _NumberOfShellsOutputSpec(_TraitedSpec):
Expand Down Expand Up @@ -447,37 +447,48 @@ class NumberOfShells(SimpleInterface):
def _run_interface(self, runtime):
in_data = np.squeeze(np.loadtxt(self.inputs.in_bvals))
highb_mask = in_data > self.inputs.b0_threshold
grid_search = GridSearchCV(
KMeans(), param_grid={'n_clusters': range(1, 10)}, scoring=_rms
).fit(in_data[highb_mask].reshape(-1, 1))

results = np.array(sorted(zip(
grid_search.cv_results_['mean_test_score'] * -1.0,
grid_search.cv_results_['param_n_clusters'],
)))

self._results['models'] = results[:, 1].astype(int).tolist()
self._results['n_shells'] = int(grid_search.best_params_['n_clusters'])

out_data = np.zeros_like(in_data)
predicted_shell = np.rint(np.squeeze(
grid_search.best_estimator_.cluster_centers_[
grid_search.best_estimator_.predict(in_data[highb_mask].reshape(-1, 1))
],
)).astype(int)
original_bvals = np.unique(np.rint(in_data[highb_mask]).astype(int))

# If estimated shells matches direct count, probably right -- do not change b-vals
if len(original_bvals) == self._results['n_shells']:
# Find closest b-values
indices = np.abs(predicted_shell[:, np.newaxis] - original_bvals).argmin(axis=1)
predicted_shell = original_bvals[indices]

out_data[highb_mask] = predicted_shell
self._results['out_data'] = np.round(out_data.astype(float), 2).tolist()
self._results['b_values'] = sorted(
np.unique(np.round(predicted_shell.astype(float), 2)).tolist()
)

original_bvals = sorted(set(np.rint(in_data[highb_mask]).astype(int)))
round_bvals = np.round(in_data, -2).astype(int)
shell_bvals = sorted(set(round_bvals))

if len(shell_bvals) <= self.inputs.dsi_threshold:
self._results['n_shells'] = len(shell_bvals) - 1
self._results['models'] = [self._results['n_shells']]
self._results['out_data'] = round_bvals.tolist()
self._results['b_values'] = shell_bvals
else:
# For datasets identified as DSI, fit a k-means
grid_search = GridSearchCV(
KMeans(), param_grid={'n_clusters': range(1, 10)}, scoring=_rms
).fit(in_data[highb_mask].reshape(-1, 1))

results = np.array(sorted(zip(
grid_search.cv_results_['mean_test_score'] * -1.0,
grid_search.cv_results_['param_n_clusters'],
)))

self._results['models'] = results[:, 1].astype(int).tolist()
self._results['n_shells'] = int(grid_search.best_params_['n_clusters'])

out_data = np.zeros_like(in_data)
predicted_shell = np.rint(np.squeeze(
grid_search.best_estimator_.cluster_centers_[
grid_search.best_estimator_.predict(in_data[highb_mask].reshape(-1, 1))
],
)).astype(int)

# If estimated shells matches direct count, probably right -- do not change b-vals
if len(original_bvals) == self._results['n_shells']:
# Find closest b-values
indices = np.abs(predicted_shell[:, np.newaxis] - original_bvals).argmin(axis=1)
predicted_shell = original_bvals[indices]

out_data[highb_mask] = predicted_shell
self._results['out_data'] = np.round(out_data.astype(float), 2).tolist()
self._results['b_values'] = sorted(
np.unique(np.round(predicted_shell.astype(float), 2)).tolist()
)

self._results['b_masks'] = [(~highb_mask).tolist()] + [
np.isclose(self._results['out_data'], bvalue).tolist()
Expand Down
35 changes: 21 additions & 14 deletions mriqc/qc/anatomical.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def cjv(mu_wm, mu_gm, sigma_wm, sigma_gm):
return float((sigma_wm + sigma_gm) / abs(mu_wm - mu_gm))


def fber(img, headmask, rotmask=None):
def fber(img, headmask, rotmask=None, decimals=4):
r"""
Calculate the :abbr:`FBER (Foreground-Background Energy Ratio)` [Shehzad2015]_,
defined as the mean energy of image values within the head relative
Expand Down Expand Up @@ -347,10 +347,10 @@ def fber(img, headmask, rotmask=None):
bg_mu = np.median(np.abs(img[airmask == 1]) ** 2)
if bg_mu < 1.0e-3:
return -1.0
return float(fg_mu / bg_mu)
return round(float(fg_mu / bg_mu), decimals)


def efc(img, framemask=None):
def efc(img, framemask=None, decimals=4):
r"""
Calculate the :abbr:`EFC (Entropy Focus Criterion)` [Atkinson1997]_.
Uses the Shannon entropy of voxel intensities as an indication of ghosting
Expand Down Expand Up @@ -390,9 +390,15 @@ def efc(img, framemask=None):
b_max = np.sqrt((img[framemask == 0] ** 2).sum())

# Calculate EFC (add 1e-16 to the image data to keep log happy)
return float(
(1.0 / efc_max)
* np.sum((img[framemask == 0] / b_max) * np.log((img[framemask == 0] + 1e-16) / b_max))
return round(
float(
(1.0 / efc_max)
* np.sum(
(img[framemask == 0] / b_max)
* np.log((img[framemask == 0] + 1e-16) / b_max)
),
),
decimals,
)


Expand Down Expand Up @@ -562,7 +568,8 @@ def summary_stats(
data: np.ndarray,
pvms: dict[str, np.ndarray],
rprec_data: int = 0,
rprec_prob: int = 3
rprec_prob: int = 3,
decimals: int = 4,
) -> dict[str, dict[str, float]]:
"""
Estimates weighted summary statistics for each tissue distribution in the data.
Expand Down Expand Up @@ -619,13 +626,13 @@ def summary_stats(
thresholded = data[probmap > (0.5 * probmap.max())]

output[label] = {
'mean': float(wstats.mean),
'median': float(median),
'p95': float(p95),
'p05': float(p05),
'k': float(kurtosis(thresholded)),
'stdv': float(wstats.std),
'mad': float(mad(thresholded, center=median)),
'mean': round(float(wstats.mean), decimals),
'median': round(float(median), decimals),
'p95': round(float(p95), decimals),
'p05': round(float(p05), decimals),
'k': round(float(kurtosis(thresholded)), decimals),
'stdv': round(float(wstats.std), decimals),
'mad': round(float(mad(thresholded, center=median)), decimals),
'n': float(nvox),
}

Expand Down
48 changes: 29 additions & 19 deletions mriqc/qc/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def cc_snr(
cc_mask: np.ndarray,
b_values: np.ndarray,
b_vectors: np.ndarray,
bval_thres: int = 50,
decimals: int = 2,
) -> dict[int, (float, float)]:
"""
Calculates the worst-case and best-case signal-to-noise ratio (SNR) within the corpus callosum.
Expand Down Expand Up @@ -206,13 +208,16 @@ def cc_snr(
xyz = np.eye(3)

b_values = np.rint(b_values).astype(np.uint16)
n_shells = len(b_values)

# Shell-wise calculation
for bval, bvecs, shell_data in zip(b_values, b_vectors, dwi_shells):
if bval == 0:
cc_snr_estimates[f'b{bval:d}'] = in_b0[cc_mask].mean() / std_signal
continue
cc_snr_estimates['shell0'] = round(
float(in_b0[cc_mask].mean() / std_signal), decimals
)

# Shell-wise calculation
for shell_index, bvecs, shell_data in zip(
range(1, n_shells + 1), b_vectors, dwi_shells
):
shell_data = shell_data[cc_mask]

# Find main directions of diffusion
Expand All @@ -230,9 +235,11 @@ def cc_snr(
mean_signal_worst = np.mean(data_X)
mean_signal_best = 0.5 * (np.mean(data_Y) + np.mean(data_Z))

cc_snr_estimates[f'b{bval:d}'] = (
np.mean(mean_signal_worst / std_signal),
np.mean(mean_signal_best / std_signal),
cc_snr_estimates[f'shell{shell_index:d}_worst'] = round(
float(np.mean(mean_signal_worst / std_signal)), decimals
)
cc_snr_estimates[f'shell{shell_index:d}_best'] = round(
float(np.mean(mean_signal_best / std_signal)), decimals
)

return cc_snr_estimates, std_signal
Expand All @@ -241,7 +248,7 @@ def cc_snr(
def spike_ppm(
spike_mask: np.ndarray,
slice_threshold: float = 0.05,
decimals: int = 8,
decimals: int = 2,
) -> dict[str, float | np.ndarray]:
"""
Calculates fractions (global and slice-wise) of voxels classified as spikes in ppm.
Expand Down Expand Up @@ -270,26 +277,29 @@ def spike_ppm(
A dictionary containing the calculated spike percentages:
* 'global': :obj:`float` - global spiking voxels ppm.
* 'slice': :obj:`list` of :obj:`float` - List of slice-wise spiking voxel
* 'slice_{i,j,k,t}': :obj:`float` - Slice-wise spiking voxel
fractions in ppm for each dimension of the data array.
"""

spike_perc_global = round(float(np.mean(np.ravel(spike_mask))), decimals) * 1e6
spike_perc_slice = [
round(
float(np.mean(np.mean(spike_mask, axis=axis) > slice_threshold)), decimals
) * 1e6
for axis in range(spike_mask.ndim)
]
axisnames = 'ijkt'

spike_global = round(float(1e6 * np.mean(np.ravel(spike_mask))), decimals)
spike_slice = {
f'slice_{axisnames[axis]}': round(
float(1e6 * np.mean(np.mean(spike_mask, axis=axis) > slice_threshold)), decimals
)
for axis in range(min(spike_mask.ndim, 3))
}

return {'global': spike_perc_global, 'slice': spike_perc_slice}
return {'global': spike_global} | spike_slice


def neighboring_dwi_correlation(
dwi_data: np.ndarray,
neighbor_indices: list[tuple[int, int]],
mask: np.ndarray | None = None,
decimals: int = 4,
) -> float:
"""
Calculates the Neighboring DWI Correlation (NDC) from diffusion MRI (dMRI) data.
Expand Down Expand Up @@ -340,4 +350,4 @@ def neighboring_dwi_correlation(
np.corrcoef(flat_from_image, flat_to_image)[0, 1]
)

return np.round(np.mean(neighbor_correlations), 4)
return round(float(np.mean(neighbor_correlations)), decimals)

0 comments on commit ef04a62

Please sign in to comment.