Skip to content

Commit

Permalink
Merge pull request #1242 from nipreps/fix/only-one-bzero
Browse files Browse the repository at this point in the history
FIX: Drift should not be estimated when less than three low-b volumes present
  • Loading branch information
oesteban committed Apr 8, 2024
2 parents bf516ea + 951ff1f commit 781da0a
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions mriqc/interfaces/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ class CorrectSignalDrift(SimpleInterface):
output_spec = _CorrectSignalDriftOutputSpec

def _run_interface(self, runtime):
from mriqc import config

bvals = np.loadtxt(self.inputs.bval_file)
len_dmri = bvals.size

img = nb.load(self.inputs.in_file)
data = img.get_fdata()
bmask = np.ones_like(data[..., 0], dtype=bool)
Expand All @@ -578,28 +583,54 @@ def _run_interface(self, runtime):
data *= nb.load(self.inputs.bias_file).get_fdata()[..., np.newaxis]

if isdefined(self.inputs.brainmask_file):
bmask = np.asanyarray(nb.load(self.inputs.brainmask_file).dataobj) > 1e-3
bmask = np.round(nb.load(self.inputs.brainmask_file).get_fdata(), 2) > 0.5

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_nodrift', newpath=runtime.cwd
)

if (b0len := int(data.ndim < 4)) or (b0len := data.shape[3]) < 3:
config.loggers.interface.warn(
f'Insufficient number of low-b orientations ({b0len}) '
'to safely calculate signal drift.'
)

img.__class__(
np.round(data.astype('float32'), 4), img.affine, img.header,
).to_filename(self._results['out_file'])

if isdefined(self.inputs.full_epi):
self._results['out_full_file'] = self.inputs.full_epi

self._results['b0_drift'] = [1.0] * b0len
self._results['signal_drift'] = [1.0] * len_dmri

return runtime

global_signal = np.array([
np.median(data[..., n_b0][bmask]) for n_b0 in range(img.shape[-1])
]).astype('float32')

# Normalize and correct
global_signal /= global_signal[0]
self._results['b0_drift'] = [float(gs) for gs in global_signal]
self._results['b0_drift'] = [
round(float(gs), 4) for gs in global_signal
]

config.loggers.interface.info(
f'Correcting drift with {len(global_signal)} b=0 volumes, with '
f'global signal estimated at {",".join(self._results["b0_drift"])}.'
)

data *= 1.0 / global_signal[np.newaxis, np.newaxis, np.newaxis, :]

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_nodrift', newpath=runtime.cwd
)
img.__class__(
data.astype(img.header.get_data_dtype()), img.affine, img.header,
).to_filename(self._results['out_file'])

# Fit line to log-transformed drifts
K, A_log = np.polyfit(self.inputs.b0_ixs, np.log(global_signal), 1)

len_dmri = np.loadtxt(self.inputs.bval_file).size
t_points = np.arange(len_dmri, dtype=int)
fitted = np.squeeze(_exp_func(t_points, np.exp(A_log), K, 0))
self._results['signal_drift'] = fitted.astype(float).tolist()
Expand Down

0 comments on commit 781da0a

Please sign in to comment.