To do:  
- allow `--field` to accept `File` instead of `Str` and reformat command line argument
- add input options for correcting slice-to-volume movement
    - `--fwhm`, `--mporder`, `--slspec`, `--json`, `--s2v_niter`, `--s2v_lambda`, `--s2v_interp`
- add multi-band factor options
    - `--mb`, `--mb_offs`
- add additional enumerator options for modeling eddy current (movement)
- update default options so they are captured in `command.txt`
- add additional options for classifying outliers
    - `--ol_nstd`, `--ol_nvox`, `--old_type`, `--ol_pos`, `--ol_sqr`
- add additional options for correcting susceptibility-by-movement interactions
    - `--estimate_move_by_susceptibility`, `--mbs_niter`, `--mbs_lambda`, `--mbs_ksp`, `--dont_sep_offs_move`
- add new outputs from FSL 6.0.0 and 6.0.1

In [None]:
import os

from bids import BIDSLayout
from nipype.pipeline import engine as pe
from nipype.interfaces import fsl, utility as niu

In [None]:
def init_eddy_wf():
    wf = pe.Workflow(name="eddy_wf")
    
    inputnode = pe.Node(niu.IdentityInterface(fields=["dwi_file", 
                                                      "dwi_meta", 
                                                      "bvec_file", 
                                                      "bval_file", 
                                                      "ap_file", 
                                                      "pa_file"]), 
                        name="inputnode")
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    def gen_index(in_file):
        import os
        import numpy as np
        import nibabel as nib
        from nipype.utils import NUMPY_MMAP
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            in_file,
            suffix="_index.txt",
            newpath=os.path.abspath("."),
            use_ext=False)
        
        vols = nib.load(in_file, mmap=NUMPY_MMAP).get_data().shape[-1]
        index_lines = np.ones((vols,))
        index_lines_reshape = index_lines.reshape(1, index_lines.shape[0])
        np.savetxt(out_file, index_lines_reshape, fmt="%i")
        return out_file

    gen_idx = pe.Node(
        niu.Function(
            input_names=["in_file"],
            output_names=["out_file"],
            function=gen_index),
        name="gen_index")

    def gen_acqparams(in_file, metadata):
        import os
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            in_file,
            suffix="_acqparams.txt",
            newpath=os.path.abspath("."),
            use_ext=False)

        acq_param_dict = {
            "j": "0 1 0 %.7f",
            "j-": "0 -1 0 %.7f",
            "i": "1 0 0 %.7f",
            "i-": "-1 0 0 %.7f",
            "k": "0 0 1 %.7f",
            "k-": "0 0 -1 %.7f"}

        pe_dir = metadata.get("PhaseEncodingDirection")

        total_readout = metadata.get("TotalReadoutTime")
        
        acq_param_lines = acq_param_dict[pe_dir] % total_readout
        
        if pe_dir[0] == pe_dir:
            opposite_pe_dir = "%s-" % pe_dir
        else:
            opposite_pe_dir = pe_dir[0]

        acq_param_lines = '\n'.join((acq_param_lines, acq_param_dict[opposite_pe_dir] % total_readout))

        with open(out_file, "w") as f:
            f.write(acq_param_lines)

        return out_file

    acqp = pe.Node(
        niu.Function(
            input_names=["in_file", "metadata"],
            output_names=["out_file"],
            function=gen_acqparams,
        ),
        name="acqp",
    )

    def b0_average(in_dwi, in_bval, b0_thresh=0, out_file=None):
        """
        A function that averages the *b0* volumes from a DWI dataset.
        As current dMRI data are being acquired with all b-values > 0.0,
        the *lowb* volumes are selected by specifying the parameter b0_thresh.
        .. warning:: *b0* should be already registered (head motion artifact
        should be corrected).
        """
        import os
        import numpy as np
        import nibabel as nib
        from nipype.utils import NUMPY_MMAP
        from nipype.utils.filemanip import fname_presuffix

        if out_file is None:
            out_file = fname_presuffix(
                in_dwi, suffix="_avg_b0", newpath=os.path.abspath(".")
            )

        imgs = np.array(nib.four_to_three(nib.load(in_dwi, mmap=NUMPY_MMAP)))
        bval = np.loadtxt(in_bval)
        index = np.argwhere(bval <= b0_thresh).flatten().tolist()

        b0s = [im.get_data().astype(np.float32) for im in imgs[index]]
        b0 = np.average(np.array(b0s), axis=0)

        hdr = imgs[0].header.copy()
        hdr.set_data_shape(b0.shape)
        hdr.set_xyzt_units("mm")
        hdr.set_data_dtype(np.float32)
        nib.Nifti1Image(b0, imgs[0].affine, hdr).to_filename(out_file)
        return out_file

    avg_b0_0 = pe.Node(
        niu.Function(
            input_names=["in_dwi", "in_bval", "b0_thresh"],
            output_names=["out_file"],
            function=b0_average,
        ),
        name="b0_avg_pre",
    )

    bet_dwi0 = pe.Node(
        fsl.BET(frac=0.6, mask=True, robust=True),
        name="bet_dwi_pre",
    )

    list_merge = pe.Node(niu.Merge(numinputs=2), name="list_merge")

    merge = pe.Node(fsl.Merge(dimension="t"), name="mergeAPPA")

    topup = pe.Node(fsl.TOPUP(), name="topup")

    ecc = pe.Node(
        fsl.Eddy(repol=True, cnr_maps=True, residuals=True),
        name="fsl_eddy",
    )

    # if nthreads not specified, do this
    import multiprocessing

    ecc.inputs.num_threads = multiprocessing.cpu_count()

    try:
        if cuda.gpus:
            ecc.inputs.use_cuda = True
    except:
        ecc.inputs.use_cuda = False

    get_path = lambda x: x.split(".nii.gz")[0].split("_fix")[0]

    wf.connect(
        [(inputnode, gen_idx, [("dwi_file", "in_file")]),
         (inputnode, acqp, [("dwi_file", "in_file"),
                            ("dwi_meta", "metadata")]),
         (inputnode, avg_b0_0, [("dwi_file", "in_dwi"),
                                ("bval_file", "in_bval")]),
         (avg_b0_0, bet_dwi0, [("out_file", "in_file")]),
         (inputnode, list_merge, [("ap_file", "in1"),
                                  ("pa_file", "in2")]),
         (list_merge, merge, [("out", "in_files")]),
         (merge, topup, [("merged_file", "in_file")]),
         (acqp, topup, [("out_file", "encoding_file")]),
         (topup, ecc, [("out_fieldcoef", "in_topup_fieldcoef"),
                       ("out_movpar", "in_topup_movpar")]),
         (acqp, ecc, [("out_file", "in_acqp")]),
         (inputnode, ecc, [("dwi_file", "in_file"),
                           ("bval_file", "in_bval"),
                           ("bvec_file", "in_bvec")]),
         (bet_dwi0, ecc, [("mask_file", "in_mask")]),
         (gen_idx, ecc, [("out_file", "in_index")])
        ]
    )
    
    return wf

In [None]:
data_dir = os.path.abspath("../data")
layout = BIDSLayout(data_dir)

In [None]:
dwi_files = layout.get(datatype="dwi", extension=["nii.gz", "nii"], return_type="file")

In [None]:
dwi_file = dwi_files[-3]
bvec_file = layout.get_bvec(dwi_file)
bval_file = layout.get_bval(dwi_file)
dwi_meta = layout.get_metadata(dwi_file)

fmaps = []
fmaps = layout.get_fieldmap(dwi_file, return_list=True)
for fmap in fmaps:
    fmap["metadata"] = layout.get_metadata(fmap[fmap["suffix"]])

ap_file = fmaps[0]['epi']
pa_file = fmaps[1]['epi']

In [None]:
test_wf = init_eddy_wf()
test_wf.base_dir = os.getcwd()

inputspec = test_wf.get_node("inputnode")
inputspec.inputs.dwi_file = dwi_file
inputspec.inputs.dwi_meta = dwi_meta
inputspec.inputs.bvec_file = bvec_file
inputspec.inputs.bval_file = bval_file
inputspec.inputs.ap_file = ap_file
inputspec.inputs.pa_file = pa_file

test_wf.write_graph(graph2use="colored")
test_wf.config["execution"]["remove_unnecessary_outputs"] = False
test_wf.config["execution"]["keep_inputs"] = True

test_wf.run()

In [None]:
from nipype.interfaces.base import traits, TraitedSpec, File
from nipype.interfaces.fsl.base import FSLCommand, FSLCommandInputSpec

class EddyInputSpec(FSLCommandInputSpec):
    in_file = File(
        exists=True,
        mandatory=True,
        argstr='--imain=%s',
        desc=('File containing all the images to estimate '
              'distortions for'))
    in_mask = File(
        exists=True,
        mandatory=True,
        argstr='--mask=%s',
        desc='Mask to indicate brain')
    in_index = File(
        exists=True,
        mandatory=True,
        argstr='--index=%s',
        desc=('File containing indices for all volumes in --imain '
              'into --acqp and --topup'))
    in_acqp = File(
        exists=True,
        mandatory=True,
        argstr='--acqp=%s',
        desc='File containing acquisition parameters')
    in_bvec = File(
        exists=True,
        mandatory=True,
        argstr='--bvecs=%s',
        desc=('File containing the b-vectors for all volumes in '
              '--imain'))
    in_bval = File(
        exists=True,
        mandatory=True,
        argstr='--bvals=%s',
        desc=('File containing the b-values for all volumes in '
              '--imain'))
    out_base = traits.Str(
        'eddy_corrected',
        argstr='--out=%s',
        usedefault=True,
        desc=('basename for output (warped) image'))
    session = File(
        exists=True,
        argstr='--session=%s',
        desc=('File containing session indices for all volumes in '
              '--imain'))
    in_topup_fieldcoef = File(
        exists=True, # pass file or not?
        argstr='--topup=%s',
        requires=['in_topup_movpar'],
        desc=('topup file containing the field '
              'coefficients'))
    in_topup_movpar = File(
        exists=True,
        requires=['in_topup_fieldcoef'],
        desc='topup movpar.txt file')
    field = traits.Str(
        argstr='--field=%s',
        desc='NonTOPUP fieldmap scaled in Hz - filename has '
        'to be provided without an extension. TOPUP is '
        'strongly recommended')
    field_mat = File(
        exists=True,
        argstr='--field_mat=%s',
        desc='Matrix that specifies the relative locations of '
        'the field specified by --field and first volume '
        'in file --imain')
    mb = traits.Int(
        argstr='--mb=%s', desc='Multi-band factor', min_ver='5.0.10')
    mb_offs = traits.Enum(
        -1,
        1,
        argstr='--mb_offs=%s',
        desc='Multi-band offset (-1 if bottom slice removed, 1 if top slice removed',
        requires=['mb'],
        min_ver='5.0.10')

    slspec = traits.File(
        argstr='--slspec=%s',
        desc='Name of text file completely specifying slice/group acquisition',
        xor=['json'],
        min_ver='5.0.11')
    json = traits.File(
        argstr='--json=%s',
        desc='Name of .json text file with information about slice timing',
        xor=['slspec'],
        min_ver='6.0.1')
    mporder = traits.Int(
        argstr='--mporder=%s',
        desc='Order of slice-to-vol movement model',
        requires=['slspec'],
        min_ver='5.0.11')
    s2v_lambda = traits.Int(
        1,
        usedefault=True,
        agstr='--s2v_lambda',
        desc='Regularisation weight for slice-to-vol movement (reasonable range 1-10)',
        requires=['slspec'],
        min_ver='5.0.11')

    flm = traits.Enum(
        'linear',
        'quadratic',
        'cubic',  # added movement option in 5.0.11
        argstr='--flm=%s',
        desc='First level EC model')

    slm = traits.Enum(
        'none',
        'linear',
        'quadratic',
        argstr='--slm=%s',
        desc='Second level EC model')

    fwhm = traits.Float(
        desc=('FWHM for conditioning filter when estimating '
              'the parameters'),
        argstr='--fwhm=%s')

    niter = traits.Int(5, usedefault=True,
                       argstr='--niter=%s', desc='Number of iterations')
    s2v_niter = traits.Int(
        5,
        usedefault=True,
        argstr='--s2v_niter=%s',
        desc='Number of iterations for slice-to-vol',
        requires=['slspec'],
        min_ver='5.0.11')

    fep = traits.Bool(
        False, argstr='--fep', desc='Fill empty planes in x- or y-directions')

    interp = traits.Enum(
        'spline',
        'trilinear',
        argstr='--interp=%s',
        desc='Interpolation model for estimation step')
    s2v_interp = traits.Enum(
        'trilinear',
        'spline',
        use_default=True,
        argstr='--s2v_interp=%s',
        desc='Slice-to-vol interpolation model for estimation step',
        requires=['slspec'],
        min_ver='5.0.11')

    method = traits.Enum(
        'jac',
        'lsr',
        argstr='--resamp=%s',
        desc=('Final resampling method (jacobian/least '
              'squares)'))
    
    nvoxhp = traits.Int(
        1000, usedefault=True,
        argstr='--nvoxhp=%s',
        desc=('# of voxels used to estimate the '
              'hyperparameters'))

    initrand = traits.Bool(
        False,
        argstr='--initrand',
        desc='Resets rand for when selecting voxels',
        min_ver='5.0.10')    
    
    fudge_factor = traits.Float(
        10.0, usedefault=True,
        argstr='--ff=%s',
        desc=('Fudge factor for hyperparameter '
              'error variance'))

    repol = traits.Bool(
        False, argstr='--repol', desc='Detect and replace outlier slices')

    outlier_nstd = traits.Int(
        4,
        usedefault=True,
        argstr='--ol_nstd',
        desc='Number of std off to qualify as outlier',
        requires=['repol'],
        min_ver='5.0.10',
    )
    outlier_nvox = traits.Int(
        250,
        usedefault=True,
        argstr='--ol_nvox',
        desc='Min # of voxels in a slice for inclusion in outlier detection',
        requires=['repol'],
        min_ver='5.0.10',
    )
    outlier_type = traits.Enum(
        'sw',
        'gw',
        'both',
        argstr='--ol_type',
        desc='Type of outliers, slicewise (sw), groupwise (gw) or both (both)',
        requires=['repol'],
        min_ver='5.0.10',
    )
    outlier_pos = traits.Bool(
        False,
        argstr='--ol_pos',
        desc='Consider both positive and negative outliers if set',
        requires=['repol'],
        min_ver='5.0.10',
    )
    outlier_sqr = traits.Bool(
        False,
        argstr='--ol_sqr',
        desc='Consider outliers among sums-of-squared differences if set',
        requires=['repol'],
        min_ver='5.0.10',
    )
    estimate_move_by_susceptibility = traits.Bool(
        False,
        argstr='--estimate_move_by_susceptibility',
        desc='Estimate how susceptibility field changes with subject movement',
        min_ver='6.0.1',
    )

    mbs_niter = traits.Int(
        10,
        use_default=True,
        argstr='--mbs_niter=%s',
        desc='Number of iterations for MBS estimation',
        requires=['estimate_move_by_susceptibility'],
        min_ver='6.0.1',
    )
    mbs_lambda = traits.int(
        10,
        use_default=True,
        argstr='--mbs_lambda=%s',
        desc='Weighting of regularisation for MBS estimation',
        requires=['estimate_move_by_susceptibility'],
        min_ver='6.0.1',
    )
    mbs_ksp = traits.Int(
        10,
        use_default=True,
        argstr='--mbs_ksp=%smm',
        desc='Knot-spacing for MBS field estimation',
        requires=['estimate_move_by_susceptibility'],
        min_ver='6.0.1',
    )

    dont_sep_offs_move = traits.Bool(
        False,
        argstr='--dont_sep_offs_move',
        desc=('Do NOT attempt to separate '
              'field offset from subject '
              'movement'))

    dont_peas = traits.Bool(
        False,
        argstr='--dont_peas',
        desc='Do NOT perform a post-eddy alignment of '
        'shells')

    is_shelled = traits.Bool(
        False,
        argstr='--data_is_shelled',
        desc='Override internal check to ensure that '
        'date are acquired on a set of b-value '
        'shells')

    num_threads = traits.Int(
        1,
        usedefault=True,
        nohash=True,
        desc='Number of openmp threads to use')

    use_cuda = traits.Bool(False, desc='Run eddy using cuda gpu')
    cnr_maps = traits.Bool(
        False, desc='Output CNR-Maps', argstr='--cnr_maps', min_ver='5.0.10')
    residuals = traits.Bool(
        False, desc='Output Residuals', argstr='--residuals', min_ver='5.0.10')


class EddyOutputSpec(TraitedSpec):
    out_corrected = File(
        exists=True, desc='4D image file containing all the corrected volumes')
    out_parameter = File(
        exists=True,
        desc=('text file with parameters definining the field and'
              'movement for each scan'))
    out_rotated_bvecs = File(
        exists=True, desc='File containing rotated b-values for all volumes')
    out_movement_rms = File(
        exists=True, desc='Summary of the 'total movement' in each volume')
    out_restricted_movement_rms = File(
        exists=True,
        desc=('Summary of the 'total movement' in each volume '
              'disregarding translation in the PE direction'))
    out_shell_alignment_parameters = File(
        exists=True,
        desc=('File containing rigid body movement parameters '
              'between the different shells as estimated by a '
              'post-hoc mutual information based registration'))
    out_outlier_report = File(
        exists=True,
        desc=('Text-file with a plain language report on what '
              'outlier slices eddy has found'))
    out_cnr_maps = File(
        exists=True, desc='path/name of file with the cnr_maps')
    out_residuals = File(
        exists=True, desc='path/name of file with the residuals')


class Eddy(FSLCommand):
    """
    Interface for FSL eddy, a tool for estimating and correcting eddy
    currents induced distortions. `User guide
    <http://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Eddy/UsersGuide>`_ and
    `more info regarding acqp file
    <http://fsl.fmrib.ox.ac.uk/fsl/fslwiki/eddy/Faq#How_do_I_know_what_to_put_into_my_--acqp_file>`_.
    Examples
    --------
    >>> from nipype.interfaces.fsl import Eddy
    >>> eddy = Eddy()
    >>> eddy.inputs.in_file = 'epi.nii'
    >>> eddy.inputs.in_mask  = 'epi_mask.nii'
    >>> eddy.inputs.in_index = 'epi_index.txt'
    >>> eddy.inputs.in_acqp  = 'epi_acqp.txt'
    >>> eddy.inputs.in_bvec  = 'bvecs.scheme'
    >>> eddy.inputs.in_bval  = 'bvals.scheme'
    >>> eddy.inputs.use_cuda = True
    >>> eddy.cmdline # doctest: +ELLIPSIS
    'eddy_cuda --ff=10.0 --acqp=epi_acqp.txt --bvals=bvals.scheme \
--bvecs=bvecs.scheme --imain=epi.nii --index=epi_index.txt \
--mask=epi_mask.nii --niter=5 --nvoxhp=1000 --out=.../eddy_corrected'
    >>> eddy.inputs.use_cuda = False
    >>> eddy.cmdline # doctest: +ELLIPSIS
    'eddy_openmp --ff=10.0 --acqp=epi_acqp.txt --bvals=bvals.scheme \
--bvecs=bvecs.scheme --imain=epi.nii --index=epi_index.txt \
--mask=epi_mask.nii --niter=5 --nvoxhp=1000 --out=.../eddy_corrected'
    >>> res = eddy.run() # doctest: +SKIP
    """
    _cmd = 'eddy_openmp'
    input_spec = EddyInputSpec
    output_spec = EddyOutputSpec

    _num_threads = 1

    def __init__(self, **inputs):
        super(Eddy, self).__init__(**inputs)
        self.inputs.on_trait_change(self._num_threads_update, 'num_threads')
        if not isdefined(self.inputs.num_threads):
            self.inputs.num_threads = self._num_threads
        else:
            self._num_threads_update()
        self.inputs.on_trait_change(self._use_cuda, 'use_cuda')
        if isdefined(self.inputs.use_cuda):
            self._use_cuda()

    def _num_threads_update(self):
        self._num_threads = self.inputs.num_threads
        if not isdefined(self.inputs.num_threads):
            if 'OMP_NUM_THREADS' in self.inputs.environ:
                del self.inputs.environ['OMP_NUM_THREADS']
        else:
            self.inputs.environ['OMP_NUM_THREADS'] = str(
                self.inputs.num_threads)

    def _use_cuda(self):
        self._cmd = 'eddy_cuda' if self.inputs.use_cuda else 'eddy_openmp'

    def _run_interface(self, runtime):
        # If 'eddy_openmp' is missing, use 'eddy'
        FSLDIR = os.getenv('FSLDIR', '')
        cmd = self._cmd
        if all((FSLDIR != '', cmd == 'eddy_openmp',
                not os.path.exists(os.path.join(FSLDIR, 'bin', cmd)))):
            self._cmd = 'eddy'
        runtime = super(Eddy, self)._run_interface(runtime)

        # Restore command to avoid side-effects
        self._cmd = cmd
        return runtime

    def _format_arg(self, name, spec, value):
        if name == 'in_topup_fieldcoef':
            return spec.argstr % value.split('_fieldcoef')[0]
        if name == 'out_base':
            return spec.argstr % os.path.abspath(value)
        return super(Eddy, self)._format_arg(name, spec, value)

    def _list_outputs(self):
        outputs = self.output_spec().get()
        outputs['out_corrected'] = os.path.abspath(
            '%s.nii.gz' % self.inputs.out_base)
        outputs['out_parameter'] = os.path.abspath(
            '%s.eddy_parameters' % self.inputs.out_base)

        # File generation might depend on the version of EDDY
        out_rotated_bvecs = os.path.abspath(
            '%s.eddy_rotated_bvecs' % self.inputs.out_base)
        out_movement_rms = os.path.abspath(
            '%s.eddy_movement_rms' % self.inputs.out_base)
        out_restricted_movement_rms = os.path.abspath(
            '%s.eddy_restricted_movement_rms' % self.inputs.out_base)
        out_shell_alignment_parameters = os.path.abspath(
            '%s.eddy_post_eddy_shell_alignment_parameters' %
            self.inputs.out_base)
        out_outlier_report = os.path.abspath(
            '%s.eddy_outlier_report' % self.inputs.out_base)
        if isdefined(self.inputs.cnr_maps) and self.inputs.cnr_maps:
            out_cnr_maps = os.path.abspath(
                '%s.eddy_cnr_maps.nii.gz' % self.inputs.out_base)
            if os.path.exists(out_cnr_maps):
                outputs['out_cnr_maps'] = out_cnr_maps
        if isdefined(self.inputs.residuals) and self.inputs.residuals:
            out_residuals = os.path.abspath(
                '%s.eddy_residuals.nii.gz' % self.inputs.out_base)
            if os.path.exists(out_residuals):
                outputs['out_residuals'] = out_residuals

        if os.path.exists(out_rotated_bvecs):
            outputs['out_rotated_bvecs'] = out_rotated_bvecs
        if os.path.exists(out_movement_rms):
            outputs['out_movement_rms'] = out_movement_rms
        if os.path.exists(out_restricted_movement_rms):
            outputs['out_restricted_movement_rms'] = \
                out_restricted_movement_rms
        if os.path.exists(out_shell_alignment_parameters):
            outputs['out_shell_alignment_parameters'] = \
                out_shell_alignment_parameters
        if os.path.exists(out_outlier_report):
            outputs['out_outlier_report'] = out_outlier_report

        return outputs