CAMH studies with single-shell or multi-shell scans acquired in separate images include:
- COGBDO
- COGBDY
- DBDC
- DTI15T
- DTI3T
- PACTMD
- PASD01
- RTMSWM

What combinations of scans do we expect from each study? Write doctests to cover them.
```
['.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-multishelldir30b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-multishelldir30b3000_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-multishelldir30b4500_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir60b1000_dwi.nii.gz']
['.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir21b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir22b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir23b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-02/dwi/sub-01_ses-02_acq-singleshelldir21b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-02/dwi/sub-01_ses-02_acq-singleshelldir22b1000_dwi.nii.gz',
 '.../ds001/sub-01/ses-02/dwi/sub-01_ses-02_acq-singleshelldir23b1000_dwi.nii.gz']
['.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir20b1000_run-1_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir20b1000_run-2_dwi.nii.gz',
 '.../ds001/sub-01/ses-01/dwi/sub-01_ses-01_acq-singleshelldir20b1000_run-3_dwi.nii.gz']
```

Group scans by:
1. session_label
2. _acq-<>
3. _run-<>

Some scans have `NumberofAverages` stored in the header.

In [89]:
import os

from bids import BIDSLayout
from nipype.pipeline import engine as pe
from nipype.interfaces import fsl, utility as niu
from nipype.utils.filemanip import fname_presuffix

In [90]:
class BIDSError(ValueError):
    def __init__(self, message, bids_root):
        indent = 10
        header = '{sep} BIDS root folder: "{bids_root}" {sep}'.format(
            bids_root=bids_root, sep="".join(["-"] * indent)
        )
        self.msg = "\n{header}\n{indent}{message}\n{footer}".format(
            header=header,
            indent="".join([" "] * (indent + 1)),
            message=message,
            footer="".join(["-"] * len(header)),
        )
        super(BIDSError, self).__init__(self.msg)
        self.bids_root = bids_root


class BIDSWarning(RuntimeWarning):
    pass

In [91]:
def collect_participants(bids_dir, participant_label=None, strict=False, bids_validate=True):

    if isinstance(bids_dir, BIDSLayout):
        layout = bids_dir
    else:
        layout = BIDSLayout(str(bids_dir), validate=bids_validate)

    all_participants = set(layout.get_subjects())

    # Error: bids_dir does not contain subjects
    if not all_participants:
        raise BIDSError(
            "Could not find participants. Please make sure the BIDS data "
            "structure is present and correct. Datasets can be validated online "
            "using the BIDS Validator (http://bids-standard.github.io/bids-validator/).\n"
            "If you are using Docker for Mac or Docker for Windows, you "
            'may need to adjust your "File sharing" preferences.',
            bids_dir,
        )

    # No --participant-label was set, return all
    if not participant_label:
        return sorted(all_participants)

    if isinstance(participant_label, str):
        participant_label = [participant_label]

    # Drop sub- prefixes
    participant_label = [
        sub[4:] if sub.startswith("sub-") else sub for sub in participant_label
    ]
    # Remove duplicates
    participant_label = sorted(set(participant_label))

    # Remove labels not found
    found_label = sorted(set(participant_label) & all_participants)
    if not found_label:
        raise BIDSError(
            "Could not find participants [{}]".format(", ".join(participant_label)),
            bids_dir,
        )

    # Warn if some IDs were not found
    notfound_label = sorted(set(participant_label) - all_participants)
    if notfound_label:
        exc = BIDSError(
            "Some participants were not found: {}".format(", ".join(notfound_label)),
            bids_dir,
        )
        if strict:
            raise exc
        warnings.warn(exc.msg, BIDSWarning)

    return all_participants, found_label

In [92]:
def group_dwi(dwi_files, session_list, concat_dwis):

    all_dwis = []

    if session_list:
        for session in session_list:
            session_groups = []
            session_dwis = [img for img in dwi_files if 'ses-%s' % session in img]
            for f in session_dwis:
                if any(acq in f for acq in concat_dwis):
                    session_groups.append(f)
                else:
                    all_dwis.append(f)
            all_dwis.append(session_groups)
    else:
        session_groups = []
        for f in dwi_files:
            if any(acq in f for acq in concat_dwis):
                session_groups.append(f)
            else:
                all_dwis.append(f)
        all_dwis.append(session_groups)

    return all_dwis

In [93]:
def collect_data(bids_dir, participant_label, concat_dwis, session_label=None):

    if isinstance(bids_dir, BIDSLayout):
        layout = bids_dir
    else:
        layout = BIDSLayout(str(bids_dir), validate=bids_validate)

    queries = {
        'fmap': {'datatype': 'fmap'},
        'dwi': {'datatype': 'dwi', 'suffix': 'dwi'},
        't1w': {'datatype': 'anat', 'suffix': 'T1w'}
    }
    
    if not session_label:
        session_label = layout.get_sessions()
        
    subj_data = {
        dtype: sorted(layout.get(return_type='file',
                                 subject=participant_label, session=session_label,
                                 extension=['nii', 'nii.gz'], **query))
        for dtype, query in queries.items()}

    subj_data['dwi'] = group_dwi(subj_data['dwi'], session_label, concat_dwis)

    return subj_data

In [173]:
def init_dwi_concat_wf(ref_file):
    """

    """

    wf = pe.Workflow(name='dwi_concat_wf')

    inputnode = pe.Node(niu.IdentityInterface(fields=['ref_file',
                                                      'dwi_list',
                                                      'bvec_list',
                                                      'bval_list']),
                       name='inputnode')

    outputnode = pe.Node(niu.IdentityInterface(fields=['dwi_file',
                                                       'bvec_file',
                                                       'bval_file']),
                         name='outputnode')

    def concat_dwis(ref_file, dwi_list):
        import os
        import numpy as np
        from nipype.utils.filemanip import fname_presuffix
        import nibabel as nib
        from nilearn.image import concat_imgs

        out_file = fname_presuffix(
            ref_file,
            newpath=os.path.abspath('.')
        )
        
        dwi_data = [nib.load(dwi) for dwi in dwi_list]
        
        new_nii = concat_imgs(dwi_data)
        
        hdr = dwi_data[0].header.copy()
        hdr.set_data_shape(new_nii.shape)
        hdr.set_xyzt_units('mm')
        hdr.set_data_dtype(np.float32)
        nib.Nifti1Image(new_nii.get_data(), dwi_data[0].affine, hdr).to_filename(out_file)
        return out_file

    concat_dwis = pe.Node(
        niu.Function(
            input_names=['ref_file', 'dwi_list'],
            output_names=['out_file'],
            function=concat_dwis
        ),
        name='concat_dwis')

    def concat_bvecs(ref_file, bvec_list):
        """

        """

        import os
        import numpy as np
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            ref_file,
            suffix='.bvec',
            newpath=os.path.abspath('.'),
            use_ext=False
        )

        bvec_vals = []
        for bvec in bvec_list:
            bvec_vals.append(np.genfromtxt(bvec))
        np.savetxt(out_file,
                   np.concatenate((bvec_vals), axis=1),
                   fmt='%.4f',
                   delimiter=' ')
        return out_file

    concat_bvecs = pe.Node(
        niu.Function(
            input_names=['ref_file', 'bvec_list'],
            output_names=['out_file'],
            function=concat_bvecs
        ),
        name='concat_bvecs')

    def concat_bvals(ref_file, bval_list):
        """

        """

        import os
        import numpy as np
        from nipype.utils.filemanip import fname_presuffix

        out_file = fname_presuffix(
            ref_file,
            suffix='.bval',
            newpath=os.path.abspath('.'),
            use_ext=False
        )

        bval_vals = []
        for bval in bval_list:
            bval_vals.append(np.genfromtxt(bval))
        np.savetxt(out_file,
                   np.concatenate((bval_vals), axis=0),
                   fmt='%i',
                   delimiter=' ',
                   newline=' ')
        return out_file

    concat_bvals = pe.Node(
        niu.Function(
            input_names=['ref_file', 'bval_list'],
            output_names=['out_file'],
            function=concat_bvals
        ),
        name='concat_bvals')

    wf.connect([
        (inputnode, concat_dwis, [('ref_file', 'ref_file'),
                                  ('dwi_list', 'dwi_list')]),
        (inputnode, concat_bvecs, [('ref_file', 'ref_file'),
                                   ('bvec_list', 'bvec_list')]),
        (inputnode, concat_bvals, [('ref_file', 'ref_file'),
                                   ('bval_list', 'bval_list')]),
        (concat_dwis, outputnode, [('merged_file', 'dwi_file')]),
        (concat_bvecs, outputnode, [('out_file', 'bvec_file')]),
        (concat_bvals, outputnode, [('out_file', 'bval_file')])
    ])

    return wf

In [23]:
bids_dir = '/archive/data/DTI3T/data/bids'
layout = BIDSLayout(bids_dir)

In [38]:
participant_label = 'CMHH166'
session_label = ['01']
concat_dwis = ["multishelldir30b1000", "multishelldir30b3000", "multishelldir30b4500"]

In [25]:
all_subjects, subject_list = collect_participants(layout, participant_label)

In [39]:
for subject_id in subject_list:
    subject_data = collect_data(layout, subject_id, concat_dwis, session_label)

In [174]:
for dwi_file in subject_data['dwi']:
    multiple_dwis = isinstance(dwi_file, list)
    
    if multiple_dwis:
        ref_file = dwi_file[0]
        bval_list = []
        bvec_list = []
        for f in dwi_file:
            bvec_list.append(layout.get_bvec(f))
            bval_list.append(layout.get_bval(f))

        test_wf = init_dwi_concat_wf(ref_file)
        test_wf.base_dir = os.getcwd()

        inputspec = test_wf.get_node("inputnode")
        inputspec.inputs.ref_file = ref_file
        inputspec.inputs.dwi_list = dwi_file
        inputspec.inputs.bvec_list = bvec_list
        inputspec.inputs.bval_list = bval_list

        test_wf.write_graph(graph2use="colored")
        test_wf.config["execution"]["remove_unnecessary_outputs"] = False
        test_wf.config["execution"]["keep_inputs"] = True
        test_wf.config["execution"]["crashfile_format"] = "txt"

        test_wf.run()

190903-14:01:45,761 nipype.workflow INFO:
	 Generated workflow graph: /mnt/tigrlab/projects/mjoseph/pipelines/dmriprep-notebooks/notebooks/dwi_concat_wf/graph.png (graph2use=colored, simple_form=True).
190903-14:01:45,765 nipype.workflow INFO:
	 Workflow dwi_concat_wf settings: ['check', 'execution', 'logging', 'monitoring']
190903-14:01:45,798 nipype.workflow INFO:
	 Running serially.
190903-14:01:45,799 nipype.workflow INFO:
	 [Node] Setting-up "dwi_concat_wf.concat_bvals" in "/mnt/tigrlab/projects/mjoseph/pipelines/dmriprep-notebooks/notebooks/dwi_concat_wf/concat_bvals".
190903-14:01:45,809 nipype.workflow INFO:
	 [Node] Cached "dwi_concat_wf.concat_bvals" - collecting precomputed outputs
190903-14:01:45,811 nipype.workflow INFO:
	 [Node] "dwi_concat_wf.concat_bvals" found cached.
190903-14:01:45,812 nipype.workflow INFO:
	 [Node] Setting-up "dwi_concat_wf.concat_bvecs" in "/mnt/tigrlab/projects/mjoseph/pipelines/dmriprep-notebooks/notebooks/dwi_concat_wf/concat_bvecs".
190903-14:0