Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions nipype/interfaces/spm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,10 @@ def _list_outputs(self):

def _format_arg(self, opt, spec, val):
"""Convert input to appropriate format for SPM."""

return val
if spec.is_trait_type(traits.Bool):
return int(val)
else:
return val

def _parse_inputs(self, skip=()):
spmdict = {}
Expand Down
8 changes: 3 additions & 5 deletions nipype/interfaces/spm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _format_arg(self, opt, spec, val):
return [val]
else:
return val
return val
return super(Level1Design, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm realign options if set to None ignore
Expand Down Expand Up @@ -200,7 +200,7 @@ def _format_arg(self, opt, spec, val):
return {'%s' % val: 1}
else:
return val
return val
return super(EstimateModel, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm realign options if set to None ignore
Expand Down Expand Up @@ -747,7 +747,7 @@ def _format_arg(self, opt, spec, val):
outdict[mapping[key]] = keyval
outlist.append(outdict)
return outlist
return val
return super(FactorialDesign, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm realign options if set to None ignore
Expand Down Expand Up @@ -893,8 +893,6 @@ def _format_arg(self, opt, spec, val):
"""
if opt in ['in_files']:
return np.array(val, dtype=object)
if opt in ['include_intercept']:
return int(val)
if opt in ['user_covariates']:
outlist = []
mapping = {'name': 'cname', 'vector': 'c',
Expand Down
26 changes: 9 additions & 17 deletions nipype/interfaces/spm/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _format_arg(self, opt, spec, val):
return scans_for_fnames(filename_to_list(val),
keep4d=False,
separate_sessions=True)
return val
return super(SliceTiming, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down Expand Up @@ -177,9 +177,7 @@ def _format_arg(self, opt, spec, val):
return scans_for_fnames(val,
keep4d=True,
separate_sessions=True)
if opt == 'register_to_mean': # XX check if this is necessary
return int(val)
return val
return super(Realign, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm realign options if set to None ignore
Expand Down Expand Up @@ -296,7 +294,7 @@ def _format_arg(self, opt, spec, val):
return scans_for_fnames(val+self.inputs.apply_to_files)
else:
return scans_for_fnames(val)
return val
return super(Coregister, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm coregister options if set to None ignore
Expand Down Expand Up @@ -411,7 +409,7 @@ def _format_arg(self, opt, spec, val):
if opt in ['write_wrap']:
if len(val) != 3:
raise ValueError('%s must have 3 elements' % opt)
return val
return super(Normalize, self)._format_arg(opt, spec, val)

def _parse_inputs(self):
"""validate spm realign options if set to None ignore
Expand Down Expand Up @@ -560,13 +558,11 @@ def _format_arg(self, opt, spec, val):
return scans_for_fname(val)
if 'output_type' in opt:
return [int(v) for v in val]
if opt == 'save_bias_corrected':
return int(val)
if opt == 'mask_image':
return scans_for_fname(val)
if opt == 'clean_masks':
return clean_masks_dict[val]
return val
return super(Segment, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down Expand Up @@ -785,10 +781,8 @@ def _format_arg(self, opt, spec, val):
return [val[0], val[0], val[0]]
else:
return val
if opt == 'implicit_masking':
return int(val)

return val
return super(Smooth, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down Expand Up @@ -879,7 +873,7 @@ def _format_arg(self, opt, spec, val):
new_param['its'] = val[2]
return [new_param]
else:
return val
return super(DARTEL, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down Expand Up @@ -965,10 +959,8 @@ def _format_arg(self, opt, spec, val):
return val
else:
return [val, val, val]
elif opt == 'modulate':
return int(val)
else:
return val
return super(DARTELNorm2MNI, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down Expand Up @@ -1040,7 +1032,7 @@ def _format_arg(self, opt, spec, val):
if opt in ['flowfield_files']:
return scans_for_fnames(val, keep4d=True)
else:
return val
return super(CreateWarped, self)._format_arg(opt, spec, val)

def _list_outputs(self):
outputs = self._outputs().get()
Expand Down
16 changes: 15 additions & 1 deletion nipype/interfaces/spm/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import nipype.interfaces.spm.base as spm
from nipype.interfaces.spm import no_spm
import nipype.interfaces.matlab as mlab
from nipype.interfaces.spm.base import SPMCommandInputSpec
from nipype.interfaces.base import traits

try:
matlab_cmd = os.environ['MATLABCMD']
Expand Down Expand Up @@ -121,7 +123,19 @@ class TestClass(spm.SPMCommand):
out = dc._generate_job(prefix='test', contents=contents)
yield assert_equal, out, 'test.onsets = {...\n[1, 2, 3, 4];...\n};\n'


def test_bool():
class TestClassInputSpec(SPMCommandInputSpec):
test_in = include_intercept = traits.Bool(field='testfield')

class TestClass(spm.SPMCommand):
input_spec = TestClassInputSpec
_jobtype = 'jobtype'
_jobname = 'jobname'
dc = TestClass() # dc = derived_class
dc.inputs.test_in = True
out = dc._make_matlab_command(dc._parse_inputs())
yield assert_equal, out.find('jobs{1}.jobtype{1}.jobname{1}.testfield = 1;') > 0, 1

def test_make_matlab_command():
class TestClass(spm.SPMCommand):
_jobtype = 'jobtype'
Expand Down