Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Parallelize conformation #666

Merged
merged 20 commits into from
Aug 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/anat/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
T1-weighted Conformation
------------------------

.. autoclass:: fmriprep.interfaces.images.ConformSeries
.. autoclass:: fmriprep.interfaces.images.Conform
:members:
:undoc-members:
:show-inheritance:
Expand Down
4 changes: 3 additions & 1 deletion fmriprep/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from .bids import (
ReadSidecarJSON, DerivativesDataSink, BIDSDataGrabber, BIDSFreeSurferDir, BIDSInfo
)
from .images import IntraModalMerge, InvertT1w, ValidateImage, ConformSeries
from .images import (
IntraModalMerge, InvertT1w, ValidateImage, TemplateDimensions, Conform, Reorient
)
from .freesurfer import (
StructuralReference, MakeMidthickness, FSInjectBrainExtracted, FSDetectInputs
)
Expand Down
237 changes: 147 additions & 90 deletions fmriprep/interfaces/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import nilearn.image as nli

from niworkflows.nipype import logging
from niworkflows.nipype.utils.filemanip import fname_presuffix, copyfile
from niworkflows.nipype.utils.filemanip import fname_presuffix
from niworkflows.nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec,
File, InputMultiPath, OutputMultiPath)
Expand Down Expand Up @@ -124,7 +124,7 @@ def _run_interface(self, runtime):
return runtime


CONFORMSERIES_TEMPLATE = """\t\t<h3 class="elem-title">Anatomical Conformation</h3>
CONFORMATION_TEMPLATE = """\t\t<h3 class="elem-title">Anatomical Conformation</h3>
\t\t<ul class="elem-desc">
\t\t\t<li>Input T1w images: {n_t1w}</li>
\t\t\t<li>Output orientation: RAS</li>
Expand All @@ -134,28 +134,32 @@ def _run_interface(self, runtime):
{discard_list}
\t\t</ul>
"""

DISCARD_TEMPLATE = """\t\t\t\t<li><abbr title="{path}">{basename}</abbr></li>"""


class ConformSeriesInputSpec(BaseInterfaceInputSpec):
t1w_list = InputMultiPath(File(exists=True), mandatory=True,
desc='input T1w images')
class TemplateDimensionsInputSpec(BaseInterfaceInputSpec):
t1w_list = InputMultiPath(File(exists=True), mandatory=True, desc='input T1w images')
max_scale = traits.Float(3.0, usedefault=True,
desc='Maximum scaling factor in images to accept')


class ConformSeriesOutputSpec(TraitedSpec):
t1w_list = OutputMultiPath(exists=True, desc='conformed T1w images')
class TemplateDimensionsOutputSpec(TraitedSpec):
t1w_valid_list = OutputMultiPath(exists=True, desc='valid T1w images')
target_zooms = traits.Tuple(traits.Float, traits.Float, traits.Float,
desc='Target zoom information')
target_shape = traits.Tuple(traits.Int, traits.Int, traits.Int,
desc='Target shape information')
out_report = File(exists=True, desc='conformation report')


class ConformSeries(SimpleInterface):
"""Conform a series of T1w images to enable merging.

Performs two basic functions:
class TemplateDimensions(SimpleInterface):
"""
Finds template target dimensions for a series of T1w images, filtering low-resolution images,
if necessary.

#. Orient to RAS (left-right, posterior-anterior, inferior-superior)
#. Along each dimension, resample to minimum voxel size, maximum number of voxels
Along each axis, the minimum voxel size (zoom) and the maximum number of voxels (shape) are
found across images.

The ``max_scale`` parameter sets a bound on the degree of up-sampling performed.
By default, an image with a voxel size greater than 3x the smallest voxel size
Expand All @@ -164,37 +168,19 @@ class ConformSeries(SimpleInterface):
To select images that require no scaling (i.e. all have smallest voxel sizes),
set ``max_scale=1``.
"""
input_spec = ConformSeriesInputSpec
output_spec = ConformSeriesOutputSpec

def _prune_zooms(self, all_zooms, max_scale):
"""Iteratively prune zooms until all scaling factors will be within
``max_scale``.

Removes the largest zooms, and recalculates scaling factors with
remaining zooms.
"""
valid = np.ones(all_zooms.shape[0], dtype=bool)
while valid.any():
target_zooms = all_zooms[valid].min(axis=0)
scales = all_zooms[valid] / target_zooms
if np.all(scales < max_scale):
break

valid[valid] ^= np.any(scales == scales.max(), axis=1)

return valid
input_spec = TemplateDimensionsInputSpec
output_spec = TemplateDimensionsOutputSpec

def _generate_segment(self, discards, dims, zooms):
items = [DISCARD_TEMPLATE.format(path=path, basename=os.path.basename(path))
for path in discards]
discard_list = '\n'.join(["\t\t\t<ul>"] + items + ['\t\t\t</ul>']) if items else ''
zoom_fmt = '{:.02g}mm x {:.02g}mm x {:.02g}mm'.format(*zooms)
return CONFORMSERIES_TEMPLATE.format(n_t1w=len(self.inputs.t1w_list),
dims='x'.join(map(str, dims)),
zooms=zoom_fmt,
n_discards=len(discards),
discard_list=discard_list)
return CONFORMATION_TEMPLATE.format(n_t1w=len(self.inputs.t1w_list),
dims='x'.join(map(str, dims)),
zooms=zoom_fmt,
n_discards=len(discards),
discard_list=discard_list)

def _run_interface(self, runtime):
# Load images, orient as RAS, collect shape and zoom data
Expand All @@ -205,76 +191,147 @@ def _run_interface(self, runtime):
all_shapes = np.array([img.shape for img in reoriented])

# Identify images that would require excessive up-sampling
valid = self._prune_zooms(all_zooms, self.inputs.max_scale)
dropped_images = in_names[~valid]
valid = np.ones(all_zooms.shape[0], dtype=bool)
while valid.any():
target_zooms = all_zooms[valid].min(axis=0)
scales = all_zooms[valid] / target_zooms
if np.all(scales < self.inputs.max_scale):
break
valid[valid] ^= np.any(scales == scales.max(), axis=1)

# Ignore dropped images
valid_fnames = in_names[valid]
valid_imgs = orig_imgs[valid]
reoriented = reoriented[valid]
self._results['t1w_valid_list'] = valid_fnames.tolist()

# Set target shape information
target_zooms = all_zooms[valid].min(axis=0)
target_shape = all_shapes[valid].max(axis=0)

self._results['target_zooms'] = tuple(target_zooms.tolist())
self._results['target_shape'] = tuple(target_shape.tolist())

# Create report
dropped_images = in_names[~valid]
segment = self._generate_segment(dropped_images, target_shape, target_zooms)
out_report = os.path.join(runtime.cwd, 'report.html')
with open(out_report, 'w') as fobj:
fobj.write(segment)

self._results['out_report'] = out_report

return runtime


class ConformInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='Input T1w image')
target_zooms = traits.Tuple(traits.Float, traits.Float, traits.Float,
desc='Target zoom information')
target_shape = traits.Tuple(traits.Int, traits.Int, traits.Int,
desc='Target shape information')


class ConformOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Conformed T1w image')


class Conform(SimpleInterface):
"""Conform a series of T1w images to enable merging.

Performs two basic functions:

#. Orient to RAS (left-right, posterior-anterior, inferior-superior)
#. Resample to target zooms (voxel sizes) and shape (number of voxels)
"""
input_spec = ConformInputSpec
output_spec = ConformOutputSpec

def _run_interface(self, runtime):
# Load image, orient as RAS
fname = self.inputs.in_file
orig_img = nb.load(fname)
reoriented = nb.as_closest_canonical(orig_img)

# Set target shape information
target_zooms = np.array(self.inputs.target_zooms)
target_shape = np.array(self.inputs.target_shape)
target_span = target_shape * target_zooms

out_names = []
for img, orig, fname in zip(reoriented, valid_imgs, valid_fnames):
zooms = np.array(img.header.get_zooms()[:3])
shape = np.array(img.shape)

xyz_unit = img.header.get_xyzt_units()[0]
if xyz_unit == 'unknown':
# Common assumption; if we're wrong, unlikely to be the only thing that breaks
xyz_unit = 'mm'
# Set a 0.05mm threshold to performing rescaling
atol = {'meter': 5e-5, 'mm': 0.05, 'micron': 50}[xyz_unit]

# Rescale => change zooms
# Resize => update image dimensions
rescale = not np.allclose(zooms, target_zooms, atol=atol)
resize = not np.all(shape == target_shape)
if rescale or resize:
target_affine = np.eye(4, dtype=img.affine.dtype)
if rescale:
scale_factor = target_zooms / zooms
target_affine[:3, :3] = img.affine[:3, :3].dot(np.diag(scale_factor))
else:
target_affine[:3, :3] = img.affine[:3, :3]

if resize:
# The shift is applied after scaling.
# Use a proportional shift to maintain relative position in dataset
size_factor = target_span / (zooms * shape)
# Use integer shifts to avoid unnecessary interpolation
offset = (img.affine[:3, 3] * size_factor - img.affine[:3, 3]).astype(int)
target_affine[:3, 3] = img.affine[:3, 3] + offset
else:
target_affine[:3, 3] = img.affine[:3, 3]
zooms = np.array(reoriented.header.get_zooms()[:3])
shape = np.array(reoriented.shape)

xyz_unit = reoriented.header.get_xyzt_units()[0]
if xyz_unit == 'unknown':
# Common assumption; if we're wrong, unlikely to be the only thing that breaks
xyz_unit = 'mm'

# Set a 0.05mm threshold to performing rescaling
atol = {'meter': 5e-5, 'mm': 0.05, 'micron': 50}[xyz_unit]

# Rescale => change zooms
# Resize => update image dimensions
rescale = not np.allclose(zooms, target_zooms, atol=atol)
resize = not np.all(shape == target_shape)
if rescale or resize:
target_affine = np.eye(4, dtype=reoriented.affine.dtype)
if rescale:
scale_factor = target_zooms / zooms
target_affine[:3, :3] = reoriented.affine[:3, :3].dot(np.diag(scale_factor))
else:
target_affine[:3, :3] = reoriented.affine[:3, :3]

if resize:
# The shift is applied after scaling.
# Use a proportional shift to maintain relative position in dataset
size_factor = target_span / (zooms * shape)
# Use integer shifts to avoid unnecessary interpolation
offset = (reoriented.affine[:3, 3] * size_factor - reoriented.affine[:3, 3])
target_affine[:3, 3] = reoriented.affine[:3, 3] + offset.astype(int)
else:
target_affine[:3, 3] = reoriented.affine[:3, 3]

data = nli.resample_img(img, target_affine, target_shape).get_data()
img = img.__class__(data, target_affine, img.header)
data = nli.resample_img(reoriented, target_affine, target_shape).get_data()
reoriented = reoriented.__class__(data, target_affine, reoriented.header)

# Image may be reoriented, rescaled, and/or resized
if reoriented is not orig_img:
out_name = fname_presuffix(fname, suffix='_ras', newpath=runtime.cwd)
reoriented.to_filename(out_name)
else:
out_name = fname

# Image may be reoriented, rescaled, and/or resized
if img is not orig:
img.to_filename(out_name)
else:
copyfile(fname, out_name, copy=True, use_hardlink=True)
self._results['out_file'] = out_name

out_names.append(out_name)
return runtime

self._results['t1w_list'] = out_names

# Create report
segment = self._generate_segment(dropped_images, target_shape, target_zooms)
class ReorientInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True,
desc='Input T1w image')

out_report = os.path.join(runtime.cwd, 'report.html')
with open(out_report, 'w') as fobj:
fobj.write(segment)

self._results['out_report'] = out_report
class ReorientOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Reoriented T1w image')


class Reorient(SimpleInterface):
"""Reorient a T1w image to RAS (left-right, posterior-anterior, inferior-superior)"""
input_spec = ReorientInputSpec
output_spec = ReorientOutputSpec

def _run_interface(self, runtime):
# Load image, orient as RAS
fname = self.inputs.in_file
orig_img = nb.load(fname)
reoriented = nb.as_closest_canonical(orig_img)

# Image may be reoriented
if reoriented is not orig_img:
out_name = fname_presuffix(fname, suffix='_ras', newpath=runtime.cwd)
reoriented.to_filename(out_name)
else:
out_name = fname

self._results['out_file'] = out_name

return runtime

Expand Down
Loading