In [None]:
import os

from bids import BIDSLayout

In [None]:
data_dir = os.path.abspath("../data")
layout = BIDSLayout(data_dir)

In [None]:
dwi_files = layout.get(datatype="dwi", extension=["nii.gz", "nii"], return_type="file")

In [None]:
dwi_file = dwi_files[-5]
bvec_file = layout.get_bvec(dwi_file)
bval_file = layout.get_bval(dwi_file)
dwi_meta = layout.get_metadata(dwi_file)

fmaps = []
fmaps = layout.get_fieldmap(dwi_file, return_list=True)
for fmap in fmaps:
    fmap["metadata"] = layout.get_metadata(fmap[fmap["suffix"]])

fmap_files = fmaps

In [None]:
def init_artefact_removal_wf(ignore):
    from nipype.pipeline import engine as pe
    from nipype.interfaces import mrtrix3, utility as niu

    wf = pe.Workflow(name="artefact_remove_wf")
    
    inputnode = pe.Node(niu.IdentityInterface(fields=["dwi_file"]))
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    denoise = pe.Node(mrtrix3.DWIDenoise(), name="denoise")
    
    unring = pe.Node(mrtrix3.MRDeGibbs(), name="unring")
    
    ignore_list = list(ignore)
    
    if ignore_list == ["denoise"]:
        wf.connect([
            (inputnode, unring, [("dwi_file", "in_file")]),
            (unring, outputnode, [("out_file", "out_file")])
        ])
    
    elif ignore_list == ["unring"]:
        wf.connect([
            (inputnode, denoise, [("dwi_file", "in_file")]),
            (denoise, outputnode, [("out_file", "out_file")])
        ])

    elif ["denoise", "unring"] in ignore_list:
        wf.connect([
            (inputnode, outputnode, "dwi_file", "out_file")
        ])
    
    else:
        wf.connect([
            (inputnode, denoise, [("dwi_file", "in_file")]),
            (denoise, unring, [("out_file", "in_file")]),
            (unring, outputnode, [("out_file", "out_file")])
        ])
    
    return wf

In [None]:
def init_sdc_wf(layout):
    wf = pe.Workflow(name="sdc_wf")

    inputnode = pe.Node(niu.IdentityInterface(fields=["dwi_file", "dwi_meta", "fmap_files"]))
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    FMAP_PRIORITY = {"epi": 0, "fieldmap": 1, "phasediff": 2, "phase": 3, "syn": 4}
    
    fmaps.sort(key=lambda fmap: FMAP_PRIORITY[fmap["suffix"]])
    fmap = fmaps[0]
    
    if fmap["suffix"] in ("fieldmap", "phasediff", "phase"):
        fmap_type = "field"
        
        if fmap["suffix"] == "fieldmap":
            
            

    elif fmap["suffix"] == "epi":
        fmap_type = "topup"
        
        epi_fmaps = [fmap_["epi"] for fmap_ in fmaps if fmap_["suffix"] == "epi"]
           
    def gen_acqparams(in_files):
        import os
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            in_file,
            suffix="_acqparams.txt",
            newpath=os.path.abspath("."),
            use_ext=False)

        acq_param_dict = {
            "j": "0 1 0 %.7f",
            "j-": "0 -1 0 %.7f",
            "i": "1 0 0 %.7f",
            "i-": "-1 0 0 %.7f",
            "k": "0 0 1 %.7f",
            "k-": "0 0 -1 %.7f"}
        
        acq_param_lines = []
        
        for f in in_files:
            metadata = layout.get_metadata(f)
            pe_dir = metadata.get("PhaseEncodingDirection")
            total_readout = metadata.get("TotalReadoutTime")
            acq_param_lines.append(acq_param_dict[pe_dir] % total_readout)
        
        acq_params = '\n'.join((acq_param_lines))

        with open(out_file, "w") as f:
            f.write(acq_params)

        return out_file

    acqp = pe.Node(
        niu.Function(
            input_names=["in_file", "metadata"],
            output_names=["out_file"],
            function=gen_acqparams,
        ),
        name="acqp",
    )
    
    return wf

In [None]:
def init_phasediff_wf(layout):
    wf = pe.Workflow(name="phasediff_wf")
    
    

In [None]:
def init_pepolar_wf():
    wf = pe.Workflow(name="pepolar_wf")

    list_merge = pe.Node(niu.Merge(numinputs=2), name="list_merge")

    merge = pe.Node(fsl.Merge(dimension="t"), name="mergeAPPA")

    topup = pe.Node(fsl.TOPUP(), name="topup")
    
    wf.connect(
        [(inputnode, acqp, [("dwi_file", "in_file"),
                            ("dwi_meta", "metadata")]),
         (inputnode, list_merge, [("ap_file", "in1"),
                                  ("pa_file", "in2")]),
         (list_merge, merge, [("out", "in_files")]),
         (merge, topup, [("merged_file", "in_file")]),
         (acqp, topup, [("out_file", "encoding_file")]),
         (topup, outputnode, [("", "")])
        ]
    )
    
    return wf

In [None]:
def init_topup_wf():
    wf = pe.Workflow(name="topup_wf")

    inputnode = pe.Node(niu.IdentityInterface(fields=["ap_file", "pa_file", "acqp"]))
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
       
    list_merge = pe.Node(niu.Merge(numinputs=2), name="list_merge")

    merge = pe.Node(fsl.Merge(dimension="t"), name="mergeAPPA")

    topup = pe.Node(fsl.TOPUP(), name="topup")
    
    wf.connect(
        [(inputnode, list_merge, [("ap_file", "in1"),
                                  ("pa_file", "in2")]),
         (list_merge, merge, [("out", "in_files")]),
         (merge, topup, [("merged_file", "in_file")]),
         (inputnode, topup, [("acqp", "encoding_file")])
        ]
    )
    
    return wf

In [None]:
def init_fmap_wf():

    wf = pe.Workflow(name="fmap_wf")
    
    inputnode = pe.Node(niu.IdentityInterface(fields=["fmap_file", "mag_file", "b0"]))
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    wf.connect(
        [(inputnode, rad_to_hz, [("fmap_file", "in_file")]),
         (inputnode, mag_flirt, [("mag_file", "in_file"),
                                 ("b0", "reference")]),
         (rad_to_hz, fmap_flirt, [("out_file", "in_file")]),
         (inputnode, fmap_flirt, [("b0", "reference")]),
         (mag_flirt, fmap_flirt, [("out_matrix_file", "in_matrix_file")])
        ]
    )
    
    return wf

In [None]:
def init_eddy_wf():
    
    wf = pe.Workflow(name="eddy_wf")

    inputnode = pe.Node(niu.IdentityInterface(fields=["dwi_file",
                                                      "bvec_file",
                                                      "bval_file",
                                                      "mask_file",
                                                      "acqp",
                                                      "index",
                                                      "fieldcoef",
                                                      "movpar"]))
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    ecc = pe.Node(
        fsl.Eddy(repol=True, cnr_maps=True, residuals=True),
        name="fsl_eddy",
    )

    # if num_threads not specified, do this
    import multiprocessing

    ecc.inputs.num_threads = multiprocessing.cpu_count()

    try:
        if cuda.gpus:
            ecc.inputs.use_cuda = True
    except:
        ecc.inputs.use_cuda = False
    
    wf.connect(
        [(inputnode, ecc, [("dwi_file", "in_file"),
                           ("bval_file", "in_bval"),
                           ("bvec_file", "in_bvec"),
                           ("acqp", "in_acqp"),
                           ("index", "in_index"),
                           ("mask_file", "in_mask"),
                           ("fieldcoef", "in_topup_fieldcoef"),
                           ("movpar", "in_topup_movpar")])
        ]
    )
    
    return wf

In [None]:
def init_dwi_preproc_wf(ignore):
    from nipype.pipeline import engine as pe
    from nipype.interfaces import ants, fsl, utility as niu
    
    wf = pe.Workflow(name="dwi_preproc_wf")
    
    inputnode = pe.Node(niu.IdentityInterface(fields=["dwi_file", 
                                                      "dwi_meta", 
                                                      "bvec_file", 
                                                      "bval_file", 
                                                      "ap_file", 
                                                      "pa_file"]), 
                        name="inputnode")
    
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"]), name="outputnode")
    
    def gen_index(in_file):
        import os
        import numpy as np
        import nibabel as nib
        from nipype.utils import NUMPY_MMAP
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            in_file,
            suffix="_index.txt",
            newpath=os.path.abspath("."),
            use_ext=False)
        
        vols = nib.load(in_file, mmap=NUMPY_MMAP).get_data().shape[-1]
        index_lines = np.ones((vols,))
        index_lines_reshape = index_lines.reshape(1, index_lines.shape[0])
        np.savetxt(out_file, index_lines_reshape, fmt="%i")
        return out_file

    gen_idx = pe.Node(
        niu.Function(
            input_names=["in_file"],
            output_names=["out_file"],
            function=gen_index),
        name="gen_index")
    
    artefact_removal_wf = init_artefact_removal_wf(ignore)

    def b0_average(in_dwi, in_bval, b0_thresh=0, out_file=None):
        """
        A function that averages the *b0* volumes from a DWI dataset.
        As current dMRI data are being acquired with all b-values > 0.0,
        the *lowb* volumes are selected by specifying the parameter b0_thresh.
        .. warning:: *b0* should be already registered (head motion artifact
        should be corrected).
        """
        import os
        import numpy as np
        import nibabel as nib
        from nipype.utils import NUMPY_MMAP
        from nipype.utils.filemanip import fname_presuffix

        if out_file is None:
            out_file = fname_presuffix(
                in_dwi, suffix="_avg_b0", newpath=os.path.abspath(".")
            )

        imgs = np.array(nib.four_to_three(nib.load(in_dwi, mmap=NUMPY_MMAP)))
        bval = np.loadtxt(in_bval)
        index = np.argwhere(bval <= b0_thresh).flatten().tolist()

        b0s = [im.get_data().astype(np.float32) for im in imgs[index]]
        b0 = np.average(np.array(b0s), axis=0)

        hdr = imgs[0].header.copy()
        hdr.set_data_shape(b0.shape)
        hdr.set_xyzt_units("mm")
        hdr.set_data_dtype(np.float32)
        nib.Nifti1Image(b0, imgs[0].affine, hdr).to_filename(out_file)
        return out_file

    avg_b0_0 = pe.Node(
        niu.Function(
            input_names=["in_dwi", "in_bval", "b0_thresh"],
            output_names=["out_file"],
            function=b0_average,
        ),
        name="b0_avg_pre"
    )

    bet_dwi0 = pe.Node(
        fsl.BET(frac=0.6, mask=True, robust=True),
        name="bet_dwi_pre"
    )
    
    sdc_wf = init_sdc_wf()
    
    eddy_wf = init_eddy_wf()
    
    bias_correct = pe.Node(
        ant.N4BiasCorrect(dimension=3),
        name="bias_correct"
    )
    
    wf.connect(
        [(inputnode, artefact_removal_wf, ["dwi_file", "dwi_file"])
         (inputnode, gen_idx, [("dwi_file", "in_file")]),
         (inputnode, avg_b0_0, [("bval_file", "in_bval")]),
         (artefact_removal_wf, avg_b0_0, [("out_file", "in_dwi")]),
         (avg_b0_0, bet_dwi0, [("out_file", "in_file")])
        ]
    )
    
    return wf