Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions NodeGraphQt/base/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def nodes(self):
"""
return self.__nodes

def create_node_instance(self, node_type=None):
def create_node_instance(self, node_type=None, node_layout_direction=None):
"""
create node object by the node type identifier or alias.

Expand All @@ -58,7 +58,7 @@ def create_node_instance(self, node_type=None):
_NodeClass = self.__nodes.get(node_type)
if not _NodeClass:
print('can\'t find node type {}'.format(node_type))
return _NodeClass()
return _NodeClass(node_layout_direction=node_layout_direction)

def register_node(self, node, alias=None):
"""
Expand Down
21 changes: 11 additions & 10 deletions NodeGraphQt/base/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from NodeGraphQt.base.node import NodeObject
from NodeGraphQt.base.port import Port
from NodeGraphQt.constants import (
NODE_LAYOUT_DIRECTION, NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL,
NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL,
PipeLayoutEnum,
URI_SCHEME, URN_SCHEME,
PortTypeEnum,
Expand Down Expand Up @@ -127,7 +127,8 @@ def __init__(self, parent=None, **kwargs):
kwargs.get('model') or NodeGraphModel())
self._node_factory = (
kwargs.get('node_factory') or NodeFactory())

self._node_layout_direction = kwargs.get('node_layout_direction',
None)
self._undo_view = None
self._undo_stack = (
kwargs.get('undo_stack') or QtWidgets.QUndoStack(self))
Expand All @@ -137,7 +138,7 @@ def __init__(self, parent=None, **kwargs):
self._sub_graphs = {}

self._viewer = (
kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack))
kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack, node_layout_direction=self._node_layout_direction))

self._build_context_menu()
self._register_builtin_nodes()
Expand Down Expand Up @@ -889,7 +890,7 @@ def create_node(self, node_type, name=None, selected=True, color=None,
Returns:
BaseNode: the created instance of the node.
"""
node = self._node_factory.create_node_instance(node_type)
node = self._node_factory.create_node_instance(node_type, node_layout_direction=self._node_layout_direction)
if node:
node._graph = self
node.model._graph_model = self.model
Expand Down Expand Up @@ -1672,7 +1673,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
else:
rank_map[rank] = [node]

if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL:
current_x = 0
node_height = 120
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
Expand All @@ -1687,7 +1688,7 @@ def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
current_y += dy * 0.5 + 10

current_x += max_width * 0.5 + 100
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
elif self._node_layout_direction is NODE_LAYOUT_VERTICAL:
current_y = 0
node_width = 250
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
Expand Down Expand Up @@ -1965,9 +1966,9 @@ def _build_port_nodes(self):
input_nodes[port.name()] = input_node
self.add_node(input_node, selected=False, push_undo=False)
x, y = input_node.pos()
if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL:
x -= 100
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
elif self._node_layout_direction is NODE_LAYOUT_VERTICAL:
y -= 100
input_node.set_property('pos', [x, y], push_undo=False)

Expand All @@ -1983,9 +1984,9 @@ def _build_port_nodes(self):
output_nodes[port.name()] = output_node
self.add_node(output_node, selected=False, push_undo=False)
x, y = output_node.pos()
if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
if self._node_layout_direction is NODE_LAYOUT_HORIZONTAL:
x += 100
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
elif self._node_layout_direction is NODE_LAYOUT_VERTICAL:
y += 100
output_node.set_property('pos', [x, y], push_undo=False)

Expand Down
11 changes: 6 additions & 5 deletions NodeGraphQt/base/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,25 @@ def __init__(self, qgraphics_views=None):
# Base node name.
NODE_NAME = None

def __init__(self, qgraphics_views=None):
def __init__(self, qgraphics_views=None, node_layout_direction=None):
self._graph = None
self._model = NodeModel()
self._model.type_ = self.type_
self._model.name = self.NODE_NAME

_NodeItem = None
if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
if node_layout_direction is None:
node_layout_direction = NODE_LAYOUT_DIRECTION
if node_layout_direction is NODE_LAYOUT_VERTICAL:
_NodeItem = qgraphics_views.get(NODE_LAYOUT_VERTICAL)
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
elif node_layout_direction is NODE_LAYOUT_HORIZONTAL:
_NodeItem = qgraphics_views.get(NODE_LAYOUT_HORIZONTAL)

if _NodeItem is None:
raise ValueError(
'qgraphics item for the {} node layout can\'t be None!'.format({
NODE_LAYOUT_VERTICAL: 'vertical',
NODE_LAYOUT_HORIZONTAL: 'horizontal'
}[NODE_LAYOUT_DIRECTION]))
}[node_layout_direction]))

self._view = _NodeItem()
self._view.type_ = self.type_
Expand Down
5 changes: 3 additions & 2 deletions NodeGraphQt/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def __init__(self):

NODE_NAME = 'Node'

def __init__(self, qgraphics_views=None):
def __init__(self, qgraphics_views=None, node_layout_direction=None):
qgraphics_views = qgraphics_views or {
NODE_LAYOUT_HORIZONTAL: NodeItem,
NODE_LAYOUT_VERTICAL: NodeItemVertical
}
super(BaseNode, self).__init__(qgraphics_views)
super(BaseNode, self).__init__(qgraphics_views=qgraphics_views,
node_layout_direction=node_layout_direction)
self._inputs = []
self._outputs = []

Expand Down
5 changes: 3 additions & 2 deletions NodeGraphQt/nodes/group_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ class GroupNode(BaseNode):

NODE_NAME = 'Group'

def __init__(self, qgraphics_views=None):
def __init__(self, qgraphics_views=None, node_layout_direction=None):
qgraphics_views = qgraphics_views or {
NODE_LAYOUT_HORIZONTAL: GroupNodeItem,
NODE_LAYOUT_VERTICAL: GroupNodeVerticalItem
}
super(GroupNode, self).__init__(qgraphics_views)
super(GroupNode, self).__init__(qgraphics_views=qgraphics_views,
node_layout_direction=node_layout_direction)
self._input_port_nodes = {}
self._output_port_nodes = {}

Expand Down
12 changes: 8 additions & 4 deletions NodeGraphQt/nodes/port_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ class PortInputNode(BaseNode):

NODE_NAME = 'InputPort'

def __init__(self, qgraphics_views=None, parent_port=None):
def __init__(self, qgraphics_views=None, parent_port=None,
node_layout_direction=None):
qgraphics_views = qgraphics_views or {
NODE_LAYOUT_HORIZONTAL: PortInputNodeItem,
NODE_LAYOUT_VERTICAL: PortInputNodeVerticalItem
}
super(PortInputNode, self).__init__(qgraphics_views)
super(PortInputNode, self).__init__(qgraphics_views=qgraphics_views,
node_layout_direction=node_layout_direction)
self._parent_port = parent_port

@property
Expand Down Expand Up @@ -87,12 +89,14 @@ class PortOutputNode(BaseNode):

NODE_NAME = 'OutputPort'

def __init__(self, qgraphics_views=None, parent_port=None):
def __init__(self, qgraphics_views=None, parent_port=None,
node_layout_direction=None):
qgraphics_views = qgraphics_views or {
NODE_LAYOUT_HORIZONTAL: PortOutputNodeItem,
NODE_LAYOUT_VERTICAL: PortOutputNodeVerticalItem
}
super(PortOutputNode, self).__init__(qgraphics_views)
super(PortOutputNode, self).__init__(qgraphics_views=qgraphics_views,
node_layout_direction=node_layout_direction)
self._parent_port = parent_port

@property
Expand Down
17 changes: 11 additions & 6 deletions NodeGraphQt/qgraphics/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from NodeGraphQt.constants import (
PipeEnum, PipeLayoutEnum, PortTypeEnum, Z_VAL_PIPE,
Z_VAL_NODE_WIDGET,
ITEM_CACHE_MODE,
NODE_LAYOUT_VERTICAL, NODE_LAYOUT_HORIZONTAL,
NODE_LAYOUT_DIRECTION
ITEM_CACHE_MODE, NODE_LAYOUT_DIRECTION,
NODE_LAYOUT_VERTICAL, NODE_LAYOUT_HORIZONTAL
)
from NodeGraphQt.qgraphics.port import PortItem

Expand All @@ -24,13 +23,15 @@ class PipeItem(QtWidgets.QGraphicsPathItem):
Base Pipe item used for drawing node connections.
"""

def __init__(self, input_port=None, output_port=None):
def __init__(self, input_port=None, output_port=None,
node_layout_direction=None):
super(PipeItem, self).__init__()
self.setZValue(Z_VAL_PIPE)
self.setAcceptHoverEvents(True)
self.setFlag(QtWidgets.QGraphicsItem.ItemIsSelectable)
self._color = PipeEnum.COLOR.value
self._style = PipeEnum.DRAW_TYPE_DEFAULT.value

self._active = False
self._highlight = False
self._input_port = input_port
Expand All @@ -42,6 +43,10 @@ def __init__(self, input_port=None, output_port=None):
self._arrow.append(QtCore.QPointF(size, size))
self.setCacheMode(ITEM_CACHE_MODE)

if node_layout_direction is None:
node_layout_direction = NODE_LAYOUT_DIRECTION
self._node_layout_direction = node_layout_direction

def __repr__(self):
in_name = self._input_port.name if self._input_port else ''
out_name = self._output_port.name if self._output_port else ''
Expand Down Expand Up @@ -262,9 +267,9 @@ def draw_path(self, start_port, end_port=None, cursor_pos=None):
self.setPath(path)
return
else:
if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
if self._node_layout_direction is NODE_LAYOUT_VERTICAL:
self.__draw_path_vertical(start_port, pos1, pos2, path)
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
elif self._node_layout_direction is NODE_LAYOUT_HORIZONTAL:
self.__draw_path_horizontal(start_port, pos1, pos2, path)

def reset_path(self):
Expand Down
5 changes: 3 additions & 2 deletions NodeGraphQt/widgets/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class NodeViewer(QtWidgets.QGraphicsView):
node_double_clicked = QtCore.Signal(str)
data_dropped = QtCore.Signal(QtCore.QMimeData, QtCore.QPoint)

def __init__(self, parent=None, undo_stack=None):
def __init__(self, parent=None, undo_stack=None, node_layout_direction=None):
"""
Args:
parent:
Expand All @@ -65,6 +65,7 @@ def __init__(self, parent=None, undo_stack=None):

self.setAcceptDrops(True)
self.resize(850, 800)
self._node_layout_direction = node_layout_direction

self._scene_range = QtCore.QRectF(
0, 0, self.size().width(), self.size().height())
Expand Down Expand Up @@ -853,7 +854,7 @@ def establish_connection(self, start_port, end_port):
establish a new pipe connection.
(adds a new pipe item to draw between 2 ports)
"""
pipe = PipeItem()
pipe = PipeItem(node_layout_direction=self._node_layout_direction)
self.scene().addItem(pipe)
pipe.set_connections(start_port, end_port)
pipe.draw_path(pipe.input_port, pipe.output_port)
Expand Down