diff --git a/NodeGraphQt/base/factory.py b/NodeGraphQt/base/factory.py index 7f2df8a5..9db79785 100644 --- a/NodeGraphQt/base/factory.py +++ b/NodeGraphQt/base/factory.py @@ -42,7 +42,7 @@ def nodes(self): """ return self.__nodes - def create_node_instance(self, node_type=None): + def create_node_instance(self, node_type=None, node_layout_direction=None): """ create node object by the node type identifier or alias. @@ -58,7 +58,7 @@ def create_node_instance(self, node_type=None): _NodeClass = self.__nodes.get(node_type) if not _NodeClass: print('can\'t find node type {}'.format(node_type)) - return _NodeClass() + return _NodeClass(node_layout_direction=node_layout_direction) def register_node(self, node, alias=None): """ diff --git a/NodeGraphQt/base/graph.py b/NodeGraphQt/base/graph.py index caa27ad6..e24ddf79 100644 --- a/NodeGraphQt/base/graph.py +++ b/NodeGraphQt/base/graph.py @@ -17,7 +17,7 @@ from NodeGraphQt.base.node import NodeObject from NodeGraphQt.base.port import Port from NodeGraphQt.constants import ( - NODE_LAYOUT_DIRECTION, NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL, + NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL, PipeLayoutEnum, URI_SCHEME, URN_SCHEME, PortTypeEnum, @@ -127,7 +127,8 @@ def __init__(self, parent=None, **kwargs): kwargs.get('model') or NodeGraphModel()) self._node_factory = ( kwargs.get('node_factory') or NodeFactory()) - + self._node_layout_direction = kwargs.get('node_layout_direction', + None) self._undo_view = None self._undo_stack = ( kwargs.get('undo_stack') or QtWidgets.QUndoStack(self)) @@ -137,7 +138,7 @@ def __init__(self, parent=None, **kwargs): self._sub_graphs = {} self._viewer = ( - kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack)) + kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack, node_layout_direction=self._node_layout_direction)) self._build_context_menu() self._register_builtin_nodes() @@ -889,7 +890,7 @@ def create_node(self, node_type, name=None, selected=True, color=None, Returns: BaseNode: the created instance of the node. """ - node = self._node_factory.create_node_instance(node_type) + node = self._node_factory.create_node_instance(node_type, node_layout_direction=self._node_layout_direction) if node: node._graph = self node.model._graph_model = self.model @@ -1672,7 +1673,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None): else: rank_map[rank] = [node] - if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL: + if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL: current_x = 0 node_height = 120 for rank in sorted(range(len(rank_map)), reverse=not down_stream): @@ -1687,7 +1688,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None): current_y += dy * 0.5 + 10 current_x += max_width * 0.5 + 100 - elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL: + elif self._node_layout_direction is NODE_LAYOUT_VERTICAL: current_y = 0 node_width = 250 for rank in sorted(range(len(rank_map)), reverse=not down_stream): @@ -1965,9 +1966,9 @@ def _build_port_nodes(self): input_nodes[port.name()] = input_node self.add_node(input_node, selected=False, push_undo=False) x, y = input_node.pos() - if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL: + if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL: x -= 100 - elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL: + elif self._node_layout_direction is NODE_LAYOUT_VERTICAL: y -= 100 input_node.set_property('pos', [x, y], push_undo=False) @@ -1983,9 +1984,9 @@ def _build_port_nodes(self): output_nodes[port.name()] = output_node self.add_node(output_node, selected=False, push_undo=False) x, y = output_node.pos() - if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL: + if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL: x += 100 - elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL: + elif self._node_layout_direction is NODE_LAYOUT_VERTICAL: y += 100 output_node.set_property('pos', [x, y], push_undo=False) diff --git a/NodeGraphQt/base/node.py b/NodeGraphQt/base/node.py index 57fb4340..bd8011b8 100644 --- a/NodeGraphQt/base/node.py +++ b/NodeGraphQt/base/node.py @@ -51,16 +51,17 @@ def __init__(self, qgraphics_views=None): # Base node name. NODE_NAME = None - def __init__(self, qgraphics_views=None): + def __init__(self, qgraphics_views=None, node_layout_direction=None): self._graph = None self._model = NodeModel() self._model.type_ = self.type_ self._model.name = self.NODE_NAME - _NodeItem = None - if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL: + if node_layout_direction is None: + node_layout_direction = NODE_LAYOUT_DIRECTION + if node_layout_direction is NODE_LAYOUT_VERTICAL: _NodeItem = qgraphics_views.get(NODE_LAYOUT_VERTICAL) - elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL: + elif node_layout_direction is NODE_LAYOUT_HORIZONTAL: _NodeItem = qgraphics_views.get(NODE_LAYOUT_HORIZONTAL) if _NodeItem is None: @@ -68,7 +69,7 @@ def __init__(self, qgraphics_views=None): 'qgraphics item for the {} node layout can\'t be None!'.format({ NODE_LAYOUT_VERTICAL: 'vertical', NODE_LAYOUT_HORIZONTAL: 'horizontal' - }[NODE_LAYOUT_DIRECTION])) + }[node_layout_direction])) self._view = _NodeItem() self._view.type_ = self.type_ diff --git a/NodeGraphQt/nodes/base_node.py b/NodeGraphQt/nodes/base_node.py index 8ad5a84e..f6302276 100644 --- a/NodeGraphQt/nodes/base_node.py +++ b/NodeGraphQt/nodes/base_node.py @@ -57,12 +57,13 @@ def __init__(self): NODE_NAME = 'Node' - def __init__(self, qgraphics_views=None): + def __init__(self, qgraphics_views=None, node_layout_direction=None): qgraphics_views = qgraphics_views or { NODE_LAYOUT_HORIZONTAL: NodeItem, NODE_LAYOUT_VERTICAL: NodeItemVertical } - super(BaseNode, self).__init__(qgraphics_views) + super(BaseNode, self).__init__(qgraphics_views=qgraphics_views, + node_layout_direction=node_layout_direction) self._inputs = [] self._outputs = [] diff --git a/NodeGraphQt/nodes/group_node.py b/NodeGraphQt/nodes/group_node.py index 595c7891..0eb7f56a 100644 --- a/NodeGraphQt/nodes/group_node.py +++ b/NodeGraphQt/nodes/group_node.py @@ -24,12 +24,13 @@ class GroupNode(BaseNode): NODE_NAME = 'Group' - def __init__(self, qgraphics_views=None): + def __init__(self, qgraphics_views=None, node_layout_direction=None): qgraphics_views = qgraphics_views or { NODE_LAYOUT_HORIZONTAL: GroupNodeItem, NODE_LAYOUT_VERTICAL: GroupNodeVerticalItem } - super(GroupNode, self).__init__(qgraphics_views) + super(GroupNode, self).__init__(qgraphics_views=qgraphics_views, + node_layout_direction=node_layout_direction) self._input_port_nodes = {} self._output_port_nodes = {} diff --git a/NodeGraphQt/nodes/port_node.py b/NodeGraphQt/nodes/port_node.py index f7ddc2f2..2dae4509 100644 --- a/NodeGraphQt/nodes/port_node.py +++ b/NodeGraphQt/nodes/port_node.py @@ -26,12 +26,14 @@ class PortInputNode(BaseNode): NODE_NAME = 'InputPort' - def __init__(self, qgraphics_views=None, parent_port=None): + def __init__(self, qgraphics_views=None, parent_port=None, + node_layout_direction=None): qgraphics_views = qgraphics_views or { NODE_LAYOUT_HORIZONTAL: PortInputNodeItem, NODE_LAYOUT_VERTICAL: PortInputNodeVerticalItem } - super(PortInputNode, self).__init__(qgraphics_views) + super(PortInputNode, self).__init__(qgraphics_views=qgraphics_views, + node_layout_direction=node_layout_direction) self._parent_port = parent_port @property @@ -87,12 +89,14 @@ class PortOutputNode(BaseNode): NODE_NAME = 'OutputPort' - def __init__(self, qgraphics_views=None, parent_port=None): + def __init__(self, qgraphics_views=None, parent_port=None, + node_layout_direction=None): qgraphics_views = qgraphics_views or { NODE_LAYOUT_HORIZONTAL: PortOutputNodeItem, NODE_LAYOUT_VERTICAL: PortOutputNodeVerticalItem } - super(PortOutputNode, self).__init__(qgraphics_views) + super(PortOutputNode, self).__init__(qgraphics_views=qgraphics_views, + node_layout_direction=node_layout_direction) self._parent_port = parent_port @property diff --git a/NodeGraphQt/qgraphics/pipe.py b/NodeGraphQt/qgraphics/pipe.py index 29beada0..9e6176ce 100644 --- a/NodeGraphQt/qgraphics/pipe.py +++ b/NodeGraphQt/qgraphics/pipe.py @@ -6,9 +6,8 @@ from NodeGraphQt.constants import ( PipeEnum, PipeLayoutEnum, PortTypeEnum, Z_VAL_PIPE, Z_VAL_NODE_WIDGET, - ITEM_CACHE_MODE, - NODE_LAYOUT_VERTICAL, NODE_LAYOUT_HORIZONTAL, - NODE_LAYOUT_DIRECTION + ITEM_CACHE_MODE, NODE_LAYOUT_DIRECTION, + NODE_LAYOUT_VERTICAL, NODE_LAYOUT_HORIZONTAL ) from NodeGraphQt.qgraphics.port import PortItem @@ -24,13 +23,15 @@ class PipeItem(QtWidgets.QGraphicsPathItem): Base Pipe item used for drawing node connections. """ - def __init__(self, input_port=None, output_port=None): + def __init__(self, input_port=None, output_port=None, + node_layout_direction=None): super(PipeItem, self).__init__() self.setZValue(Z_VAL_PIPE) self.setAcceptHoverEvents(True) self.setFlag(QtWidgets.QGraphicsItem.ItemIsSelectable) self._color = PipeEnum.COLOR.value self._style = PipeEnum.DRAW_TYPE_DEFAULT.value + self._active = False self._highlight = False self._input_port = input_port @@ -42,6 +43,10 @@ def __init__(self, input_port=None, output_port=None): self._arrow.append(QtCore.QPointF(size, size)) self.setCacheMode(ITEM_CACHE_MODE) + if node_layout_direction is None: + node_layout_direction = NODE_LAYOUT_DIRECTION + self._node_layout_direction = node_layout_direction + def __repr__(self): in_name = self._input_port.name if self._input_port else '' out_name = self._output_port.name if self._output_port else '' @@ -262,9 +267,9 @@ def draw_path(self, start_port, end_port=None, cursor_pos=None): self.setPath(path) return else: - if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL: + if self._node_layout_direction is NODE_LAYOUT_VERTICAL: self.__draw_path_vertical(start_port, pos1, pos2, path) - elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL: + elif self._node_layout_direction is NODE_LAYOUT_HORIZONTAL: self.__draw_path_horizontal(start_port, pos1, pos2, path) def reset_path(self): diff --git a/NodeGraphQt/widgets/viewer.py b/NodeGraphQt/widgets/viewer.py index 6fc3e744..d8ded679 100644 --- a/NodeGraphQt/widgets/viewer.py +++ b/NodeGraphQt/widgets/viewer.py @@ -45,7 +45,7 @@ class NodeViewer(QtWidgets.QGraphicsView): node_double_clicked = QtCore.Signal(str) data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint) - def __init__(self, parent=None, undo_stack=None): + def __init__(self, parent=None, undo_stack=None, node_layout_direction=None): """ Args: parent: @@ -65,6 +65,7 @@ def __init__(self, parent=None, undo_stack=None): self.setAcceptDrops(True) self.resize(850, 800) + self._node_layout_direction = node_layout_direction self._scene_range = QtCore.QRectF( 0, 0, self.size().width(), self.size().height()) @@ -853,7 +854,7 @@ def establish_connection(self, start_port, end_port): establish a new pipe connection. (adds a new pipe item to draw between 2 ports) """ - pipe = PipeItem() + pipe = PipeItem(node_layout_direction=self._node_layout_direction) self.scene().addItem(pipe) pipe.set_connections(start_port, end_port) pipe.draw_path(pipe.input_port, pipe.output_port)