Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 152 additions & 6 deletions NodeGraphQt/base/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from .model import NodeGraphModel
from .node import NodeObject, BaseNode, BackdropNode
from .port import Port
from ..constants import (URI_SCHEME, URN_SCHEME,
PIPE_LAYOUT_CURVED,
PIPE_LAYOUT_STRAIGHT,
PIPE_LAYOUT_ANGLE,
IN_PORT, OUT_PORT,
VIEWER_GRID_LINES)
from ..constants import (
URI_SCHEME, URN_SCHEME,
NODE_LAYOUT_DIRECTION, NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL,
PIPE_LAYOUT_CURVED, PIPE_LAYOUT_STRAIGHT, PIPE_LAYOUT_ANGLE,
IN_PORT, OUT_PORT,
VIEWER_GRID_LINES
)
from ..widgets.node_space_bar import node_space_bar
from ..widgets.viewer import NodeViewer

Expand Down Expand Up @@ -1558,6 +1559,149 @@ def disable_nodes(self, nodes, mode=None):
return
nodes[0].set_disabled(mode)

# auto layout node functions.

@staticmethod
def _update_node_rank(node, nodes_rank, down_stream=True):
"""
Recursive function for updating the node ranking.

Args:
node (NodeGraphQt.BaseNode): node to start from.
nodes_rank (dict): node ranking object to be updated.
down_stream (bool): true to rank down stram.
"""
if down_stream:
node_values = node.connected_output_nodes().values()
else:
node_values = node.connected_input_nodes().values()

connected_nodes = set()
for nodes in node_values:
connected_nodes.update(nodes)

rank = nodes_rank[node] + 1
for n in connected_nodes:
if n in nodes_rank:
nodes_rank[n] = max(nodes_rank[n], rank)
else:
nodes_rank[n] = rank
NodeGraph._update_node_rank(n, nodes_rank, down_stream)

@staticmethod
def _compute_node_rank(nodes, down_stream=True):
"""
Compute the ranking of nodes.

Args:
nodes (list[NodeGraphQt.BaseNode]): nodes to start ranking from.
down_stream (bool): true to compute down stream.

Returns:
dict: {NodeGraphQt.BaseNode: node_rank, ...}
"""
nodes_rank = {}
for node in nodes:
nodes_rank[node] = 0
NodeGraph._update_node_rank(node, nodes_rank, down_stream)
return nodes_rank

def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
"""
Auto layout the nodes in the node graph.

Note:
If the node graph is acyclic then the ``start_nodes`` will need
to be specified.

Args:
nodes (list[NodeGraphQt.BaseNode]): list of nodes to auto layout
if nodes is None then all nodes is layed out.
down_stream (bool): false to layout up stream.
start_nodes (list[NodeGraphQt.BaseNode]):
list of nodes to start the auto layout from (Optional).
"""
self.begin_undo('Auto Layout Nodes')

nodes = nodes or self.all_nodes()

# filter out the backdrops.
backdrops = {
n: n.nodes() for n in nodes if isinstance(n, BackdropNode)
}
filtered_nodes = [n for n in nodes if not isinstance(n, BackdropNode)]

start_nodes = start_nodes or []
if down_stream:
start_nodes += [
n for n in filtered_nodes
if not any(n.connected_input_nodes().values())
]
else:
start_nodes += [
n for n in filtered_nodes
if not any(n.connected_output_nodes().values())
]

if not start_nodes:
return

node_views = [n.view for n in nodes]
nodes_center_0 = self.viewer().nodes_rect_center(node_views)

nodes_rank = NodeGraph._compute_node_rank(start_nodes, down_stream)

rank_map = {}
for node, rank in nodes_rank.items():
if rank in rank_map:
rank_map[rank].append(node)
else:
rank_map[rank] = [node]

if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
current_x = 0
node_height = 120
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
ranked_nodes = rank_map[rank]
max_width = max([node.view.width for node in ranked_nodes])
current_x += max_width
current_y = 0
for idx, node in enumerate(ranked_nodes):
dy = max(node_height, node.view.height)
current_y += 0 if idx == 0 else dy
node.set_pos(current_x, current_y)
current_y += dy * 0.5 + 10

current_x += max_width * 0.5 + 100
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
current_y = 0
node_width = 250
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
ranked_nodes = rank_map[rank]
max_height = max([node.view.height for node in ranked_nodes])
current_y += max_height
current_x = 0
for idx, node in enumerate(ranked_nodes):
dx = max(node_width, node.view.width)
current_x += 0 if idx == 0 else dx
node.set_pos(current_x, current_y)
current_x += dx * 0.5 + 10

current_y += max_height * 0.5 + 100

nodes_center_1 = self.viewer().nodes_rect_center(node_views)
dx = nodes_center_0[0] - nodes_center_1[0]
dy = nodes_center_0[1] - nodes_center_1[1]
[n.set_pos(n.x_pos() + dx, n.y_pos() + dy) for n in nodes]

# wrap the backdrop nodes.
for backdrop, contained_nodes in backdrops.items():
backdrop.wrap_nodes(contained_nodes)

self.end_undo()

# prompt dialog functions.

def question_dialog(self, text, title='Node Graph'):
"""
Prompts a question open dialog with ``"Yes"`` and ``"No"`` buttons in
Expand Down Expand Up @@ -1624,6 +1768,8 @@ def save_dialog(self, current_dir=None, ext=None):
"""
return self._viewer.save_dialog(current_dir, ext)

### ---

def use_OpenGL(self):
"""
Use OpenGL to draw the graph.
Expand Down
Loading