Skip to content

Commit

Permalink
Merge pull request #666 from anibalsolon/parallelize_conformation
Browse files Browse the repository at this point in the history
[ENH] Parallelize conformation
  • Loading branch information
effigies authored Aug 21, 2017
2 parents 3022d47 + 095fd30 commit 9542c4f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 105 deletions.
2 changes: 1 addition & 1 deletion docs/anat/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
T1-weighted Conformation
------------------------

.. autoclass:: fmriprep.interfaces.images.ConformSeries
.. autoclass:: fmriprep.interfaces.images.Conform
:members:
:undoc-members:
:show-inheritance:
Expand Down
4 changes: 3 additions & 1 deletion fmriprep/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from .bids import (
ReadSidecarJSON, DerivativesDataSink, BIDSDataGrabber, BIDSFreeSurferDir, BIDSInfo
)
from .images import IntraModalMerge, InvertT1w, ValidateImage, ConformSeries
from .images import (
IntraModalMerge, InvertT1w, ValidateImage, TemplateDimensions, Conform, Reorient
)
from .freesurfer import (
StructuralReference, MakeMidthickness, FSInjectBrainExtracted, FSDetectInputs
)
Expand Down
237 changes: 147 additions & 90 deletions fmriprep/interfaces/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import nilearn.image as nli

from niworkflows.nipype import logging
from niworkflows.nipype.utils.filemanip import fname_presuffix, copyfile
from niworkflows.nipype.utils.filemanip import fname_presuffix
from niworkflows.nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec,
File, InputMultiPath, OutputMultiPath)
Expand Down Expand Up @@ -124,7 +124,7 @@ def _run_interface(self, runtime):
return runtime


CONFORMSERIES_TEMPLATE = """\t\t<h3 class="elem-title">Anatomical Conformation</h3>
CONFORMATION_TEMPLATE = """\t\t<h3 class="elem-title">Anatomical Conformation</h3>
\t\t<ul class="elem-desc">
\t\t\t<li>Input T1w images: {n_t1w}</li>
\t\t\t<li>Output orientation: RAS</li>
Expand All @@ -134,28 +134,32 @@ def _run_interface(self, runtime):
{discard_list}
\t\t</ul>
"""

DISCARD_TEMPLATE = """\t\t\t\t<li><abbr title="{path}">{basename}</abbr></li>"""


class ConformSeriesInputSpec(BaseInterfaceInputSpec):
t1w_list = InputMultiPath(File(exists=True), mandatory=True,
desc='input T1w images')
class TemplateDimensionsInputSpec(BaseInterfaceInputSpec):
t1w_list = InputMultiPath(File(exists=True), mandatory=True, desc='input T1w images')
max_scale = traits.Float(3.0, usedefault=True,
desc='Maximum scaling factor in images to accept')


class ConformSeriesOutputSpec(TraitedSpec):
t1w_list = OutputMultiPath(exists=True, desc='conformed T1w images')
class TemplateDimensionsOutputSpec(TraitedSpec):
t1w_valid_list = OutputMultiPath(exists=True, desc='valid T1w images')
target_zooms = traits.Tuple(traits.Float, traits.Float, traits.Float,
desc='Target zoom information')
target_shape = traits.Tuple(traits.Int, traits.Int, traits.Int,
desc='Target shape information')
out_report = File(exists=True, desc='conformation report')


class ConformSeries(SimpleInterface):
"""Conform a series of T1w images to enable merging.
Performs two basic functions:
class TemplateDimensions(SimpleInterface):
"""
Finds template target dimensions for a series of T1w images, filtering low-resolution images,
if necessary.
#. Orient to RAS (left-right, posterior-anterior, inferior-superior)
#. Along each dimension, resample to minimum voxel size, maximum number of voxels
Along each axis, the minimum voxel size (zoom) and the maximum number of voxels (shape) are
found across images.
The ``max_scale`` parameter sets a bound on the degree of up-sampling performed.
By default, an image with a voxel size greater than 3x the smallest voxel size
Expand All @@ -164,37 +168,19 @@ class ConformSeries(SimpleInterface):
To select images that require no scaling (i.e. all have smallest voxel sizes),
set ``max_scale=1``.
"""
input_spec = ConformSeriesInputSpec
output_spec = ConformSeriesOutputSpec

def _prune_zooms(self, all_zooms, max_scale):
"""Iteratively prune zooms until all scaling factors will be within
``max_scale``.
Removes the largest zooms, and recalculates scaling factors with
remaining zooms.
"""
valid = np.ones(all_zooms.shape[0], dtype=bool)
while valid.any():
target_zooms = all_zooms[valid].min(axis=0)
scales = all_zooms[valid] / target_zooms
if np.all(scales < max_scale):
break

valid[valid] ^= np.any(scales == scales.max(), axis=1)

return valid
input_spec = TemplateDimensionsInputSpec
output_spec = TemplateDimensionsOutputSpec

def _generate_segment(self, discards, dims, zooms):
items = [DISCARD_TEMPLATE.format(path=path, basename=os.path.basename(path))
for path in discards]
discard_list = '\n'.join(["\t\t\t<ul>"] + items + ['\t\t\t</ul>']) if items else ''
zoom_fmt = '{:.02g}mm x {:.02g}mm x {:.02g}mm'.format(*zooms)
return CONFORMSERIES_TEMPLATE.format(n_t1w=len(self.inputs.t1w_list),
dims='x'.join(map(str, dims)),
zooms=zoom_fmt,
n_discards=len(discards),
discard_list=discard_list)
return CONFORMATION_TEMPLATE.format(n_t1w=len(self.inputs.t1w_list),
dims='x'.join(map(str, dims)),
zooms=zoom_fmt,
n_discards=len(discards),
discard_list=discard_list)

def _run_interface(self, runtime):
# Load images, orient as RAS, collect shape and zoom data
Expand All @@ -205,76 +191,147 @@ def _run_interface(self, runtime):
all_shapes = np.array([img.shape for img in reoriented])

# Identify images that would require excessive up-sampling
valid = self._prune_zooms(all_zooms, self.inputs.max_scale)
dropped_images = in_names[~valid]
valid = np.ones(all_zooms.shape[0], dtype=bool)
while valid.any():
target_zooms = all_zooms[valid].min(axis=0)
scales = all_zooms[valid] / target_zooms
if np.all(scales < self.inputs.max_scale):
break
valid[valid] ^= np.any(scales == scales.max(), axis=1)

# Ignore dropped images
valid_fnames = in_names[valid]
valid_imgs = orig_imgs[valid]
reoriented = reoriented[valid]
self._results['t1w_valid_list'] = valid_fnames.tolist()

# Set target shape information
target_zooms = all_zooms[valid].min(axis=0)
target_shape = all_shapes[valid].max(axis=0)

self._results['target_zooms'] = tuple(target_zooms.tolist())
self._results['target_shape'] = tuple(target_shape.tolist())

# Create report
dropped_images = in_names[~valid]
segment = self._generate_segment(dropped_images, target_shape, target_zooms)
out_report = os.path.join(runtime.cwd, 'report.html')
with open(out_report, 'w') as fobj:
fobj.write(segment)

self._results['out_report'] = out_report

return runtime


class ConformInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='Input T1w image')
target_zooms = traits.Tuple(traits.Float, traits.Float, traits.Float,
desc='Target zoom information')
target_shape = traits.Tuple(traits.Int, traits.Int, traits.Int,
desc='Target shape information')


class ConformOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Conformed T1w image')


class Conform(SimpleInterface):
"""Conform a series of T1w images to enable merging.
Performs two basic functions:
#. Orient to RAS (left-right, posterior-anterior, inferior-superior)
#. Resample to target zooms (voxel sizes) and shape (number of voxels)
"""
input_spec = ConformInputSpec
output_spec = ConformOutputSpec

def _run_interface(self, runtime):
# Load image, orient as RAS
fname = self.inputs.in_file
orig_img = nb.load(fname)
reoriented = nb.as_closest_canonical(orig_img)

# Set target shape information
target_zooms = np.array(self.inputs.target_zooms)
target_shape = np.array(self.inputs.target_shape)
target_span = target_shape * target_zooms

out_names = []
for img, orig, fname in zip(reoriented, valid_imgs, valid_fnames):
zooms = np.array(img.header.get_zooms()[:3])
shape = np.array(img.shape)

xyz_unit = img.header.get_xyzt_units()[0]
if xyz_unit == 'unknown':
# Common assumption; if we're wrong, unlikely to be the only thing that breaks
xyz_unit = 'mm'
# Set a 0.05mm threshold to performing rescaling
atol = {'meter': 5e-5, 'mm': 0.05, 'micron': 50}[xyz_unit]

# Rescale => change zooms
# Resize => update image dimensions
rescale = not np.allclose(zooms, target_zooms, atol=atol)
resize = not np.all(shape == target_shape)
if rescale or resize:
target_affine = np.eye(4, dtype=img.affine.dtype)
if rescale:
scale_factor = target_zooms / zooms
target_affine[:3, :3] = img.affine[:3, :3].dot(np.diag(scale_factor))
else:
target_affine[:3, :3] = img.affine[:3, :3]

if resize:
# The shift is applied after scaling.
# Use a proportional shift to maintain relative position in dataset
size_factor = target_span / (zooms * shape)
# Use integer shifts to avoid unnecessary interpolation
offset = (img.affine[:3, 3] * size_factor - img.affine[:3, 3]).astype(int)
target_affine[:3, 3] = img.affine[:3, 3] + offset
else:
target_affine[:3, 3] = img.affine[:3, 3]
zooms = np.array(reoriented.header.get_zooms()[:3])
shape = np.array(reoriented.shape)

xyz_unit = reoriented.header.get_xyzt_units()[0]
if xyz_unit == 'unknown':
# Common assumption; if we're wrong, unlikely to be the only thing that breaks
xyz_unit = 'mm'

# Set a 0.05mm threshold to performing rescaling
atol = {'meter': 5e-5, 'mm': 0.05, 'micron': 50}[xyz_unit]

# Rescale => change zooms
# Resize => update image dimensions
rescale = not np.allclose(zooms, target_zooms, atol=atol)
resize = not np.all(shape == target_shape)
if rescale or resize:
target_affine = np.eye(4, dtype=reoriented.affine.dtype)
if rescale:
scale_factor = target_zooms / zooms
target_affine[:3, :3] = reoriented.affine[:3, :3].dot(np.diag(scale_factor))
else:
target_affine[:3, :3] = reoriented.affine[:3, :3]

if resize:
# The shift is applied after scaling.
# Use a proportional shift to maintain relative position in dataset
size_factor = target_span / (zooms * shape)
# Use integer shifts to avoid unnecessary interpolation
offset = (reoriented.affine[:3, 3] * size_factor - reoriented.affine[:3, 3])
target_affine[:3, 3] = reoriented.affine[:3, 3] + offset.astype(int)
else:
target_affine[:3, 3] = reoriented.affine[:3, 3]

data = nli.resample_img(img, target_affine, target_shape).get_data()
img = img.__class__(data, target_affine, img.header)
data = nli.resample_img(reoriented, target_affine, target_shape).get_data()
reoriented = reoriented.__class__(data, target_affine, reoriented.header)

# Image may be reoriented, rescaled, and/or resized
if reoriented is not orig_img:
out_name = fname_presuffix(fname, suffix='_ras', newpath=runtime.cwd)
reoriented.to_filename(out_name)
else:
out_name = fname

# Image may be reoriented, rescaled, and/or resized
if img is not orig:
img.to_filename(out_name)
else:
copyfile(fname, out_name, copy=True, use_hardlink=True)
self._results['out_file'] = out_name

out_names.append(out_name)
return runtime

self._results['t1w_list'] = out_names

# Create report
segment = self._generate_segment(dropped_images, target_shape, target_zooms)
class ReorientInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True,
desc='Input T1w image')

out_report = os.path.join(runtime.cwd, 'report.html')
with open(out_report, 'w') as fobj:
fobj.write(segment)

self._results['out_report'] = out_report
class ReorientOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='Reoriented T1w image')


class Reorient(SimpleInterface):
"""Reorient a T1w image to RAS (left-right, posterior-anterior, inferior-superior)"""
input_spec = ReorientInputSpec
output_spec = ReorientOutputSpec

def _run_interface(self, runtime):
# Load image, orient as RAS
fname = self.inputs.in_file
orig_img = nb.load(fname)
reoriented = nb.as_closest_canonical(orig_img)

# Image may be reoriented
if reoriented is not orig_img:
out_name = fname_presuffix(fname, suffix='_ras', newpath=runtime.cwd)
reoriented.to_filename(out_name)
else:
out_name = fname

self._results['out_file'] = out_name

return runtime

Expand Down
Loading

0 comments on commit 9542c4f

Please sign in to comment.