Skip to content

Commit

Permalink
ENH: A 3D tensor B-Spline approximator and extrapolator
Browse files Browse the repository at this point in the history
This PR finally adds an implementation for B-Spline smoothing and
extrapolation of fieldmaps.

References: #71, #22.
Resolves: #72.
Resolves: #14.
  • Loading branch information
oesteban committed Nov 26, 2020
1 parent 2e4359d commit 339cd5a
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 14 deletions.
188 changes: 188 additions & 0 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""
B-Spline filtering.
.. testsetup::
>>> tmpdir = getfixture('tmpdir')
>>> tmp = tmpdir.chdir() # changing to a temporary directory
>>> nb.Nifti1Image(np.zeros((90, 90, 60)), None, None).to_filename(
... tmpdir.join('epi.nii.gz').strpath)
"""
from pathlib import Path
import numpy as np
import nibabel as nb

from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
TraitedSpec,
File,
traits,
SimpleInterface,
InputMultiObject,
OutputMultiObject,
)


DEFAULT_ZOOMS_MM = (40.0, 40.0, 20.0) # For human adults (mid-frequency), in mm
DEFAULT_LF_ZOOMS_MM = (100.0, 100.0, 40.0) # For human adults (low-frequency), in mm
DEFAULT_HF_ZOOMS_MM = (16.0, 16.0, 10.0) # For human adults (high-frequency), in mm


class _BSplineApproxInputSpec(BaseInterfaceInputSpec):
in_data = File(exists=True, mandatory=True, desc="path to a fieldmap")
in_mask = File(exists=True, mandatory=True, desc="path to a brain mask")
bs_spacing = InputMultiObject(
[DEFAULT_ZOOMS_MM],
traits.Tuple(traits.Float, traits.Float, traits.Float),
usedefault=True,
desc="spacing between B-Spline control points",
)
ridge_alpha = traits.Float(
1e-4, usedefault=True, desc="controls the regularization"
)


class _BSplineApproxOutputSpec(TraitedSpec):
out_field = File(exists=True)
out_coeff = OutputMultiObject(File(exists=True))


class BSplineApprox(SimpleInterface):
"""
Approximate the field to smooth it removing spikes and extrapolating beyond the brain mask.
Examples
--------
"""

input_spec = _BSplineApproxInputSpec
output_spec = _BSplineApproxOutputSpec

def _run_interface(self, runtime):
from gridbspline.maths import cubic
from sklearn import linear_model as lm

_vbspl = np.vectorize(cubic)

# Load in the fieldmap
fmapnii = nb.load(self.inputs.in_data)
data = fmapnii.get_fdata()
mask = nb.load(self.inputs.in_mask).get_fdata() > 0
bs_spacing = [np.array(sp, dtype="float32") for sp in self.inputs.bs_spacing]

# Calculate B-Splines grid(s)
bs_levels = []
for sp in bs_spacing:
bs_levels.append(bspline_grid(fmapnii, control_zooms_mm=sp))

# Calculate spatial location of voxels, and normalize per B-Spline grid
fmap_points = grid_coords(fmapnii)
sample_points = []
for sp in bs_spacing:
sample_points.append((fmap_points / sp).astype("float32"))

# Calculate the spatial location of control points
bs_x = []
ncoeff = []
for sp, level, points in zip(bs_spacing, bs_levels, sample_points):
ncoeff.append(level.dataobj.size)
control_points = grid_coords(level, control_zooms_mm=sp)
bs_x.append(control_points[:, np.newaxis, :] - points[np.newaxis, ...])

# Calculate the cubic spline weights per dimension and tensor-product
dist = np.vstack(bs_x)
dist_support = (np.abs(dist) < 2).all(axis=-1)
weights = _vbspl(dist[dist_support]).prod(axis=-1)

# Compose the interpolation matrix
interp_mat = np.zeros(dist.shape[:2])
interp_mat[dist_support] = weights

# Fit the model
model = lm.Ridge(alpha=self.inputs.ridge_alpha, fit_intercept=False)
model.fit(
interp_mat[..., mask.reshape(-1)].T, # Regress only within brainmask
data[mask],
)

# Store outputs
out_name = str(
Path(
fname_presuffix(
self.inputs.in_data, suffix="_field", newpath=runtime.cwd
)
).absolute()
)
hdr = fmapnii.header.copy()
hdr.set_data_dtype("float32")
nb.Nifti1Image(
(model.intercept_ + np.array(model.coef_) @ interp_mat)
.astype("float32") # Interpolation
.reshape(data.shape),
fmapnii.affine,
hdr,
).to_filename(out_name)
self._results["out_field"] = out_name

index = 0
self._results["out_coeff"] = []
for i, (n, bsl) in enumerate(zip(ncoeff, bs_levels)):
out_level = out_name.replace("_field.", f"_coeff{i:03}.")
nb.Nifti1Image(
np.array(model.coef_, dtype="float32")[index : index + n].reshape(
bsl.shape
),
bsl.affine,
bsl.header,
).to_filename(out_level)
index += n
self._results["out_coeff"].append(out_level)
return runtime


def bspline_grid(img, control_zooms_mm=DEFAULT_ZOOMS_MM):
"""Calculate a Nifti1Image object encoding the location of control points."""
if isinstance(img, (str, Path)):
img = nb.load(img)

im_zooms = np.array(img.header.get_zooms())
im_shape = np.array(img.shape[:3])

# Calculate the direction cosines of the target image
dir_cos = img.affine[:3, :3] / im_zooms

# Initialize the affine of the B-Spline grid
bs_affine = np.diag(np.hstack((np.array(control_zooms_mm) @ dir_cos, 1)))
bs_zooms = nb.affines.voxel_sizes(bs_affine)

# Calculate the shape of the B-Spline grid
im_extent = im_zooms * (im_shape - 1)
bs_shape = (im_extent // bs_zooms + 3).astype(int)

# Center both images
im_center = img.affine @ np.hstack((0.5 * (im_shape - 1), 1))
bs_center = bs_affine @ np.hstack((0.5 * (bs_shape - 1), 1))
bs_affine[:3, 3] = im_center[:3] - bs_center[:3]

return nb.Nifti1Image(np.zeros(bs_shape, dtype="float32"), bs_affine)


def grid_coords(img, control_zooms_mm=None, dtype="float32"):
"""Create a linear space of physical coordinates."""
if isinstance(img, (str, Path)):
img = nb.load(img)

grid = np.array(
np.meshgrid(*[range(s) for s in img.shape[:3]]), dtype=dtype
).reshape(3, -1)
coords = (img.affine @ np.vstack((grid, np.ones(grid.shape[-1])))).T[..., :3]

if control_zooms_mm is not None:
coords /= np.array(control_zooms_mm)

return coords.astype(dtype)
26 changes: 13 additions & 13 deletions sdcflows/workflows/fit/fieldmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
from niworkflows.engine.workflows import LiterateWorkflow as Workflow


def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"):
def init_fmap_wf(omp_nthreads=1, debug=False, mode="phasediff", name="fmap_wf"):
"""
Estimate the fieldmap based on a field-mapping MRI acquisition.
Expand Down Expand Up @@ -156,6 +156,10 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"):
pair.
"""
from ...interfaces.bspline import (
BSplineApprox, DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM
)

workflow = Workflow(name=name)

inputnode = pe.Node(
Expand All @@ -167,19 +171,19 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"):
)

magnitude_wf = init_magnitude_wf(omp_nthreads=omp_nthreads)
fmap_postproc_wf = init_fmap_postproc_wf(omp_nthreads=omp_nthreads)
bs_filter = pe.Node(BSplineApprox(
bs_spacing=[DEFAULT_LF_ZOOMS_MM] if debug else [DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM],
), n_procs=omp_nthreads, name="bs_filter")

# fmt: off
workflow.connect([
(inputnode, magnitude_wf, [("magnitude", "inputnode.magnitude")]),
(magnitude_wf, fmap_postproc_wf, [
("outputnode.fmap_mask", "inputnode.fmap_mask"),
("outputnode.fmap_ref", "inputnode.fmap_ref")]),
(magnitude_wf, bs_filter, [("outputnode.fmap_mask", "in_mask")]),
(magnitude_wf, outputnode, [
("outputnode.fmap_mask", "fmap_mask"),
("outputnode.fmap_ref", "fmap_ref"),
]),
(fmap_postproc_wf, outputnode, [("outputnode.out_fmap", "fmap")]),
(bs_filter, outputnode, [("out_field", "fmap")]),
])
# fmt: on

Expand All @@ -198,13 +202,12 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"):
("outputnode.fmap_ref", "inputnode.magnitude"),
("outputnode.fmap_mask", "inputnode.mask"),
]),
(phdiff_wf, fmap_postproc_wf, [
("outputnode.fieldmap", "inputnode.fmap"),
(phdiff_wf, bs_filter, [
("outputnode.fieldmap", "in_data"),
]),
])
# fmt: on
else:
from niworkflows.interfaces.nibabel import ApplyMask
from niworkflows.interfaces.images import IntraModalMerge

workflow.__desc__ = """\
Expand All @@ -215,13 +218,10 @@ def init_fmap_wf(omp_nthreads=1, mode="phasediff", name="fmap_wf"):
fmapmrg = pe.Node(
IntraModalMerge(zero_based_avg=False, hmc=False), name="fmapmrg"
)
applymsk = pe.Node(ApplyMask(), name="applymsk")
# fmt: off
workflow.connect([
(inputnode, fmapmrg, [("fieldmap", "in_files")]),
(fmapmrg, applymsk, [("out_avg", "in_file")]),
(magnitude_wf, applymsk, [("outputnode.fmap_mask", "in_mask")]),
(applymsk, fmap_postproc_wf, [("out_file", "inputnode.fmap")]),
(fmapmrg, bs_filter, [("out_avg", "in_data")]),
])
# fmt: on

Expand Down
2 changes: 1 addition & 1 deletion sdcflows/workflows/fit/tests/test_phdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_phdiff(tmpdir, datadir, workdir, outdir, fmap_path):
wf = Workflow(
name=f"phdiff_{fmap_path[0].name.replace('.nii.gz', '').replace('-', '_')}"
)
phdiff_wf = init_fmap_wf(omp_nthreads=2)
phdiff_wf = init_fmap_wf(omp_nthreads=2, debug=True)
phdiff_wf.inputs.inputnode.fieldmap = fieldmaps
phdiff_wf.inputs.inputnode.magnitude = [
f.replace("diff", "1").replace("phase", "magnitude") for f, _ in fieldmaps
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ setup_requires =
setuptools_scm >= 3.4
toml
install_requires =
gridbspline
nibabel >=3.0.1
niflow-nipype1-workflows ~= 0.0.1
nipype >=1.5.1,<2.0
Expand Down

0 comments on commit 339cd5a

Please sign in to comment.