In [13]:
# Import modules and set experiment-specific parameters
import copy
import os
import numpy as np
from os.path import join as opj
from nipype.pipeline.engine import Workflow, Node, MapNode, JoinNode
from nipype.interfaces.io import SelectFiles, DataSink
from nipype.interfaces.utility import IdentityInterface, Merge, Select
from nipype.interfaces.ants import Registration, ApplyTransforms, AverageImages
from nipype import config, logging
from nipype.interfaces.freesurfer import MRIConvert

config.enable_debug_mode()
logging.update_logging(config)

#runs out of current directory 
filepath = os.path.dirname( os.path.realpath( '__file__'))
datadir = os.path.realpath(os.path.join(filepath, ''))
os.chdir(datadir)

subject_list = ['d701', 'd702', 'd703', 'd704', 'd705', 'd706', 'd707', 
                'd708', 'd709', 'd710', 'd711', 'd712', 'd713', 'd714', 
                'd715', 'd716', 'd717', 'd720', 'd722', 'd723', 'd724', 
                'd726', 'd727', 'd728', 'd729', 'd730', 'd731', 'd732', 
                'd734']

# Rigid Reg node 1

antsreg = Registration()
antsreg.inputs.float = True
antsreg.inputs.collapse_output_transforms=True
antsreg.inputs.output_transform_prefix = 'rigid_'
antsreg.inputs.fixed_image=[]
antsreg.inputs.moving_image=[]
antsreg.inputs.initial_moving_transform_com=1
antsreg.inputs.output_warped_image= True
antsreg.inputs.transforms=['Rigid']
antsreg.inputs.terminal_output='file'
antsreg.inputs.winsorize_lower_quantile=0.005
antsreg.inputs.winsorize_upper_quantile=0.995
antsreg.inputs.convergence_threshold=[1e-06]
antsreg.inputs.convergence_window_size=[10]
antsreg.inputs.metric=[['MeanSquares','MI','MI']]
antsreg.inputs.metric_weight=[[0.75,0.125,0.125]]
antsreg.inputs.number_of_iterations=[[1000, 500, 250, 0]]
antsreg.inputs.smoothing_sigmas=[[4, 3, 2, 1]]
antsreg.inputs.sigma_units=['vox']
antsreg.inputs.radius_or_number_of_bins=[[0,32,32]]
antsreg.inputs.sampling_strategy=[['None',
                               'Regular',
                               'Regular']]
antsreg.inputs.sampling_percentage=[[0,0.25,0.25]]
antsreg.inputs.shrink_factors=[[12,8,4,2]]
antsreg.inputs.transform_parameters=[[(0.1)]]
antsreg.inputs.use_histogram_matching=True

antsreg_rigid = Node(antsreg,name='antsreg_rigid')
#antsreg.cmdline

# Apply Rigid Reg node 1

apply_rigid_reg = ApplyTransforms()

apply_rigid = MapNode(apply_rigid_reg, 
                      name = 'apply_rigid', 
                      iterfield=['input_image','reference_image','transforms'], 
                      nested = True
                     )
apply_rigid.inputs.input_image = []
apply_rigid.inputs.reference_image = []
apply_rigid.inputs.transforms = []
apply_rigid.inputs.terminal_output = 'file'
#apply_rigid_reg.cmdline

# Select outputs by image type

# Select labels

sl = Select()
sl = Node(sl, name = 'select_lists')
sl.inputs.inlist= []
sl.inputs.index=[]
sl.iterables = ('index', [0,1,2])

# Merge selected files into list

ml = Merge(1)
ml = JoinNode(ml, 
              name = 'merge_lists',
             joinsource = 'info_source',
             joinfield = 'in1')
ml.inputs.in1 = []
ml.inputs.axis = 'hstack'

# Average rigid-transformed images to construct new template

avg_rigid = AverageImages()
avg_rigid = Node(avg_rigid, name = 'average_rigid')
avg_rigid.inputs.dimension = 3
avg_rigid.inputs.images = []
avg_rigid.inputs.normalize = True
avg_rigid.inputs.terminal_output = 'file'
#avg_rigid.cmdline

#convert average.nii to nii.gz

mc = MRIConvert()
mc.inputs.out_type = 'niigz'
mc = Node(mc, name = 'mri_convert')

# Rigid Reg node 2

antsreg2 = Registration()
antsreg2.inputs.float = True
antsreg2.inputs.collapse_output_transforms=True
antsreg2.inputs.output_transform_prefix = 'rigid_'
antsreg2.inputs.fixed_image=[]
antsreg2.inputs.moving_image=[]
antsreg2.inputs.initial_moving_transform_com=1
antsreg2.inputs.output_warped_image= True
antsreg2.inputs.transforms=['Rigid']
antsreg2.inputs.terminal_output='file'
antsreg2.inputs.winsorize_lower_quantile=0.005
antsreg2.inputs.winsorize_upper_quantile=0.995
antsreg2.inputs.convergence_threshold=[1e-06]
antsreg2.inputs.convergence_window_size=[10]
antsreg2.inputs.metric=[['MeanSquares','MI','MI']]
antsreg2.inputs.metric_weight=[[0.75,0.125,0.125]]
antsreg2.inputs.number_of_iterations=[[1000, 500, 250, 0]]
antsreg2.inputs.smoothing_sigmas=[[4, 3, 2, 1]]
antsreg2.inputs.sigma_units=['vox']
antsreg2.inputs.radius_or_number_of_bins=[[0,32,32]]
antsreg2.inputs.sampling_strategy=[['None',
                               'Regular',
                               'Regular']]
antsreg2.inputs.sampling_percentage=[[0,0.25,0.25]]
antsreg2.inputs.shrink_factors=[[12,8,4,2]]
antsreg2.inputs.transform_parameters=[[(0.1)]]
antsreg2.inputs.use_histogram_matching=True

antsreg_rigid2 = Node(antsreg2, name='antsreg_rigid_2')

# Apply Rigid Reg node 2

apply_rigid_reg2 = ApplyTransforms()
apply_rigid2 = MapNode(apply_rigid_reg2, 
                      name = 'apply_rigid_2', 
                      iterfield=['input_image','reference_image','transforms'], 
                      nested = True
                     )
apply_rigid2.inputs.input_image = []
apply_rigid2.inputs.reference_image = []
apply_rigid2.inputs.transforms = []
apply_rigid2.inputs.terminal_output = 'file'
#apply_rigid_reg.cmdline

# Select outputs by image type

# Select labels

sl2 = Select()
sl2 = Node(sl2, name = 'select_lists_2')
sl2.inputs.inlist= []
sl2.inputs.index=[]
sl2.iterables = ('index', [0,1,2])

# Merge selected files into list

ml3 = Merge(1)
ml3 = JoinNode(ml3, 
               name = 'merge_lists_3',
               joinsource = 'info_source_2',
               joinfield = 'in1')
ml3.inputs.in1 = []
ml3.inputs.axis = 'hstack'

# Average rigid-transformed images to construct new template

avg_rigid2 = AverageImages()
avg_rigid2 = Node(avg_rigid2, name = 'average_rigid_2')
avg_rigid2.inputs.dimension = 3
avg_rigid2.inputs.images = []
avg_rigid2.inputs.normalize = True
avg_rigid2.inputs.terminal_output = 'file'

#convert average.nii to nii.gz

mc2 = MRIConvert()
mc2.inputs.out_type = 'niigz'
mc2 = Node(mc2, name = 'mri_convert_2')

# Rigid Reg node 2

antsreg3 = Registration()
antsreg3.inputs.float = True
antsreg3.inputs.collapse_output_transforms=True
antsreg3.inputs.output_transform_prefix = 'rigid_'
antsreg3.inputs.fixed_image=[]
antsreg3.inputs.moving_image=[]
antsreg3.inputs.initial_moving_transform_com=1
antsreg3.inputs.output_warped_image= True
antsreg3.inputs.transforms=['Rigid']
antsreg3.inputs.terminal_output='file'
antsreg3.inputs.winsorize_lower_quantile=0.005
antsreg3.inputs.winsorize_upper_quantile=0.995
antsreg3.inputs.convergence_threshold=[1e-06]
antsreg3.inputs.convergence_window_size=[10]
antsreg3.inputs.metric=[['MeanSquares','MI','MI']]
antsreg3.inputs.metric_weight=[[0.75,0.125,0.125]]
antsreg3.inputs.number_of_iterations=[[1000, 500, 250, 0]]
antsreg3.inputs.smoothing_sigmas=[[4, 3, 2, 1]]
antsreg3.inputs.sigma_units=['vox']
antsreg3.inputs.radius_or_number_of_bins=[[0,32,32]]
antsreg3.inputs.sampling_strategy=[['None',
                               'Regular',
                               'Regular']]
antsreg3.inputs.sampling_percentage=[[0,0.25,0.25]]
antsreg3.inputs.shrink_factors=[[12,8,4,2]]
antsreg3.inputs.transform_parameters=[[(0.1)]]
antsreg3.inputs.use_histogram_matching=True

antsreg_rigid3 = Node(antsreg3, name='antsreg_rigid_3')

# Apply Rigid Reg node 2

apply_rigid_reg3 = ApplyTransforms()
apply_rigid3 = MapNode(apply_rigid_reg3, 
                      name = 'apply_rigid_3', 
                      iterfield=['input_image','reference_image','transforms'], 
                      nested = True
                     )
apply_rigid3.inputs.input_image = []
apply_rigid3.inputs.reference_image = []
apply_rigid3.inputs.transforms = []
apply_rigid3.inputs.terminal_output = 'file'
#apply_rigid_reg.cmdline

# Select outputs by image type

# Select labels

sl3 = Select()
sl3 = Node(sl3, name = 'select_lists_3')
sl3.inputs.inlist= []
sl3.inputs.index=[]
sl3.iterables = ('index', [0,1,2])

# Merge selected files into list

ml4 = Merge(1)
ml4 = JoinNode(ml4, 
               name = 'merge_lists_4',
               joinsource = 'info_source_3',
               joinfield = 'in1')
ml3.inputs.in1 = []
ml3.inputs.axis = 'hstack'

# Average rigid-transformed images to construct new template

avg_rigid3 = AverageImages()
avg_rigid3 = Node(avg_rigid3, name = 'average_rigid_3')
avg_rigid3.inputs.dimension = 3
avg_rigid3.inputs.images = []
avg_rigid3.inputs.normalize = True
avg_rigid3.inputs.terminal_output = 'file'

#convert average.nii to nii.gz

mc3 = MRIConvert()
mc3.inputs.out_type = 'niigz'
mc3 = Node(mc3, name = 'mri_convert_3')

In [None]:
# Establish input/output stream

#create subject ID iterable
infosource = Node(IdentityInterface(fields=['subject_id']), name = "info_source")
infosource.iterables = [('subject_id', subject_list)]

infosource2 = Node(IdentityInterface(fields=['subject_id']), name = "info_source_2")
infosource2.iterables = [('subject_id', subject_list)]

infosource3 = Node(IdentityInterface(fields=['subject_id']), name = "info_source_3")
infosource3.iterables = [('subject_id', subject_list)]

#create template
lhtemplate_files = opj(datadir,'lhtemplate[0, 1, 2].nii.gz')
mi_files = opj(datadir,'{subject_id}-*.nii.gz')
new_template_files = opj(datadir,'normflow/_index_[0,1,2]/mri_convert/average_out.nii.gz')
new_template_files_2 = opj(datadir,'normflow2/_index_[0,1,2]/mri_convert_2/average_out.nii.gz')

templates = {'lhtemplate': lhtemplate_files,
            'mi': mi_files,
            }

templates2 = {'mi': mi_files,
             'new_template': new_template_files,
            }

templates3 = {'mi': mi_files,
            'new_template_2': new_template_files_2,
            }

#select images organized by subject
selectfiles = Node(SelectFiles(templates, force_lists=['lhtemplate','mi'], 
                               sort_filelist = True, 
                               base_directory=datadir), 
                               name = "select_files")

selectfiles2 = Node(SelectFiles(templates2, force_lists=['mi', 'new_template'],
                               sort_filelist = True,
                               base_directory=datadir),
                               name = "select_files_2")

selectfiles3 = Node(SelectFiles(templates3, force_lists=['mi', 'new_template_2'],
                               sort_filelist = True,
                               base_directory=datadir),
                               name = "select_files_3")

#datasink = Node(DataSink(base_directory= datadir, container = 'output_dir'), name = "datasink")
#substitutions = [('_subject_id_',''),
#                ]

#Define function to replicate fwd transforms to match iterfield length
def reptrans(forward_transforms):
    import numpy as np
    nested_list = np.ndarray.tolist(np.tile(forward_transforms,[1,3]))
    transforms = [val for sublist in nested_list for val in sublist]
    return transforms

def flatten_nlist(out):
    return [val for sublist in out for val in sublist]

# Create pipeline and connect nodes
workflow = Workflow(name='normflow')
workflow.base_dir = datadir
workflow.connect([
            #rigid 1st pass
                (infosource, selectfiles, [('subject_id', 'subject_id')]),
                (selectfiles, antsreg_rigid, [('lhtemplate','fixed_image'),('mi','moving_image')]),
                (selectfiles, apply_rigid, [('lhtemplate','reference_image'),('mi','input_image')]),
                (antsreg_rigid, apply_rigid, [(('forward_transforms',reptrans),'transforms')]),
                (apply_rigid, sl, [('output_image', 'inlist')]),
                (sl, ml, [('out','in1')]),
                (ml, avg_rigid, [(('out', flatten_nlist),'images')]),
                (avg_rigid, mc, [('output_average_image','in_file')]),
                 ])

workflow2 = Workflow(name='normflow2')
workflow2.base_dir = datadir
workflow2.connect([
                (infosource2, selectfiles2, [('subject_id', 'subject_id')]),
                (selectfiles2, antsreg_rigid2, [('mi','moving_image'),('new_template','fixed_image')]),
                (selectfiles2, apply_rigid2, [('mi','input_image'),('new_template','reference_image')]),
                (antsreg_rigid2, apply_rigid2, [(('forward_transforms',reptrans,),'transforms')]),
                (apply_rigid2, sl2, [('output_image', 'inlist')]),
                (sl2, ml3, [('out','in1')]),
                (ml3, avg_rigid2, [(('out', flatten_nlist),'images')]),
                (avg_rigid2, mc2, [('output_average_image','in_file')]),
                ])

workflow3 = Workflow(name='normflow3')
workflow3.base_dir = datadir
workflow3.connect([
                (infosource3, selectfiles3, [('subject_id', 'subject_id')]),
                (selectfiles3, antsreg_rigid3, [('mi','moving_image'),('new_template_2','fixed_image')]),
                (selectfiles3, apply_rigid3, [('mi','input_image'),('new_template_2','reference_image')]),
                (antsreg_rigid3, apply_rigid3, [(('forward_transforms',reptrans,),'transforms')]),
                (apply_rigid3, sl3, [('output_image', 'inlist')]),
                (sl3, ml4, [('out','in1')]),
                (ml4, avg_rigid3, [(('out', flatten_nlist),'images')]),
                (avg_rigid3, mc3, [('output_average_image','in_file')]),
                ])

#visualize workflow; makes graph with everything and simplified one
import pydotplus
workflow.write_graph(graph2use='exec',format='png')
workflow.write_graph(graph2use='colored',format='png')

workflow2.write_graph(graph2use='exec',format='png')
workflow2.write_graph(graph2use='colored',format='png')

workflow3.write_graph(graph2use='exec',format='png')
workflow3.write_graph(graph2use='colored',format='png')

# Run the workflow

workflow.run(plugin='SLURM', plugin_args={'jobid_re': '([0-9]*)',
                                          'sbatch_args': '-t 4 -g 4 --partition nimh --time=00:30:00'
                                         })

workflow2.run(plugin='SLURM', plugin_args={'jobid_re': '([0-9]*)',
                                          'sbatch_args': '-t 4 -g 4 --partition nimh --time=00:30:00'
                                        })

workflow3.run(plugin='SLURM', plugin_args={'jobid_re': '([0-9]*)',
                                          'sbatch_args': '-t 4 -g 4 --partition nimh --time=00:30:00'
                                        })

In [11]:
# Scratch pad

import numpy as np
a = [1,2,3]
a = np.ndarray.tolist(np.tile(a,[1,3]))
flattened = [val for sublist in a for val in sublist]
print(flattened)

l = list(range(102))[::3]
l2 = [x+1 for x in l]
l3 = [x+2 for x in l2]
print(l2)
print(l3)

l4 = list(range(29))
print(l4)

[1, 2, 3, 1, 2, 3, 1, 2, 3]
[1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100]
[3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]
