Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions pydra/engine/auxiliary.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pdb
import inspect, os
# import pdb
import os
import inspect
import logging
logger = logging.getLogger('nipype.workflow')
from nipype import Node

logger = logging.getLogger('nipype.workflow')


# dj: might create a new class or move to State

Expand Down Expand Up @@ -50,32 +52,34 @@ def _ordering(el, i, output_mapper, current_sign=None, other_mappers=None):
output_mapper.append(el)
else:
raise Exception("mapper has to be a string, a tuple or a list")

if i > 0:
output_mapper.append(current_sign)


def _iterate_list(element, sign, other_mappers, output_mapper):
""" Used in the mapper2rpn to get recursion. """
for i, el in enumerate(element):
_ordering(el, i, current_sign=sign, other_mappers=other_mappers, output_mapper=output_mapper)
_ordering(el, i, current_sign=sign, other_mappers=other_mappers,
output_mapper=output_mapper)


# functions used in State to know which element should be used for a specific axis

def mapping_axis(state_inputs, mapper_rpn):
"""Having inputs and mapper (in rpn notation), functions returns the axes of output for every input."""
"""Given inputs and mapper (in rpn notation), return the axes of output of each input."""
axis_for_input = {}
stack = []
current_axis = None
current_shape = None
#pdb.set_trace()
# pdb.set_trace()
for el in mapper_rpn:
if el == ".":
right = stack.pop()
left = stack.pop()
if left == "OUT":
if state_inputs[right].shape == current_shape: #todo:should we allow for one-element array?
# todo:should we allow for one-element array?
if state_inputs[right].shape == current_shape:
axis_for_input[right] = current_axis
else:
raise Exception("arrays for scalar operations should have the same size")
Expand All @@ -94,7 +98,7 @@ def mapping_axis(state_inputs, mapper_rpn):
axis_for_input[right] = current_axis
else:
raise Exception("arrays for scalar operations should have the same size")

stack.append("OUT")

elif el == "*":
Expand All @@ -120,7 +124,7 @@ def mapping_axis(state_inputs, mapper_rpn):
axis_for_input[right] = [i + state_inputs[left].ndim
for i in range(state_inputs[right].ndim)]
current_axis = axis_for_input[left] + axis_for_input[right]
current_shape = tuple([i for i in
current_shape = tuple([i for i in
state_inputs[left].shape + state_inputs[right].shape])
stack.append("OUT")

Expand Down Expand Up @@ -149,12 +153,12 @@ def converting_axis2input(state_inputs, axis_for_input, ndim):
for i in range(ndim):
input_for_axis.append([])
shape.append(0)

for inp, axis in axis_for_input.items():
for (i, ax) in enumerate(axis):
input_for_axis[ax].append(inp)
shape[ax] = state_inputs[inp].shape[i]

return input_for_axis, shape


Expand Down Expand Up @@ -190,8 +194,6 @@ def _add_name(mlist, name):
return mlist


#Function interface

class FunctionInterface(object):
""" A new function interface """
def __init__(self, function, output_nm, out_read=False, input_map=None):
Expand All @@ -209,7 +211,6 @@ def __init__(self, function, output_nm, out_read=False, input_map=None):
# flags if we want to read the txt file to save in node.output
self.out_read = out_read


def run(self, input):
self.output = {}
if self.input_map:
Expand All @@ -226,7 +227,7 @@ def run(self, input):
self.output[self._output_nm[i]] = out
else:
raise Exception("length of output_nm doesnt match length of the function output")
elif len(self._output_nm)==1:
elif len(self._output_nm) == 1:
self.output[self._output_nm[0]] = fun_output
else:
raise Exception("output_nm doesnt match length of the function output")
Expand All @@ -241,8 +242,8 @@ class DotDict(dict):
"""dot.notation access to dictionary attributes"""
def __getattr__(self, attr):
return self.get(attr)
__setattr__= dict.__setitem__
__delattr__= dict.__delitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

def __getstate__(self):
return self
Expand All @@ -262,7 +263,7 @@ def run(self, inputs, base_dir, dir_nm_el):
for key, val in inputs.items():
key = key.split(".")[-1]
setattr(self.nn.inputs, key, val)
#have to set again self._output_dir in case of mapper
# have to set again self._output_dir in case of mapper
self.nn._output_dir = os.path.join(self.nn.base_dir, self.nn.name)
res = self.nn.run()
return res
return res
Loading