diff --git a/NodeGraphQt/base/graph.py b/NodeGraphQt/base/graph.py index 6dc22011..842ded98 100644 --- a/NodeGraphQt/base/graph.py +++ b/NodeGraphQt/base/graph.py @@ -160,6 +160,11 @@ def __init__(self, parent=None, **kwargs): kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack)) self._viewer.set_layout_direction(layout_direction) + # viewer needs a reference to the model port connection constrains + # for the user interaction with the live pipe. + self._viewer.accept_connection_types = self._model.accept_connection_types + self._viewer.reject_connection_types = self._model.reject_connection_types + self._context_menu = {} self._register_context_menu() @@ -1143,6 +1148,39 @@ def create_node(self, node_type, name=None, selected=True, color=None, node_attrs[node.type_][pname].update(pattrs) self.model.set_node_common_properties(node_attrs) + accept_types = node.model.__dict__.pop( + '_TEMP_accept_connection_types' + ) + for ptype, pdata in accept_types.get(node.type_, {}).items(): + for pname, accept_data in pdata.items(): + for accept_ntype, accept_ndata in accept_data.items(): + for accept_ptype, accept_pnames in accept_ndata.items(): + for accept_pname in accept_pnames: + self._model.add_port_accept_connection_type( + port_name=pname, + port_type=ptype, + node_type=node.type_, + accept_pname=accept_pname, + accept_ptype=accept_ptype, + accept_ntype=accept_ntype + ) + reject_types = node.model.__dict__.pop( + '_TEMP_reject_connection_types' + ) + for ptype, pdata in reject_types.get(node.type_, {}).items(): + for pname, reject_data in pdata.items(): + for reject_ntype, reject_ndata in reject_data.items(): + for reject_ptype, reject_pnames in reject_ndata.items(): + for reject_pname in reject_pnames: + self._model.add_port_reject_connection_type( + port_name=pname, + port_type=ptype, + node_type=node.type_, + reject_pname=reject_pname, + reject_ptype=reject_ptype, + reject_ntype=reject_ntype + ) + node.NODE_NAME = self.get_unique_name(name or node.NODE_NAME) node.model.name = node.NODE_NAME node.model.selected = selected @@ -1207,6 +1245,39 @@ def add_node(self, node, pos=None, selected=True, push_undo=True): node_attrs[node.type_][pname].update(pattrs) self.model.set_node_common_properties(node_attrs) + accept_types = node.model.__dict__.pop( + '_TEMP_accept_connection_types' + ) + for ptype, pdata in accept_types.get(node.type_, {}).items(): + for pname, accept_data in pdata.items(): + for accept_ntype, accept_ndata in accept_data.items(): + for accept_ptype, accept_pnames in accept_ndata.items(): + for accept_pname in accept_pnames: + self._model.add_port_accept_connection_type( + port_name=pname, + port_type=ptype, + node_type=node.type_, + accept_pname=accept_pname, + accept_ptype=accept_ptype, + accept_ntype=accept_ntype + ) + reject_types = node.model.__dict__.pop( + '_TEMP_reject_connection_types' + ) + for ptype, pdata in reject_types.get(node.type_, {}).items(): + for pname, reject_data in pdata.items(): + for reject_ntype, reject_ndata in reject_data.items(): + for reject_ptype, reject_pnames in reject_ndata.items(): + for reject_pname in reject_pnames: + self._model.add_port_reject_connection_type( + port_name=pname, + port_type=ptype, + node_type=node.type_, + reject_pname=reject_pname, + reject_ptype=reject_ptype, + reject_ntype=reject_ntype + ) + node._graph = self node.NODE_NAME = self.get_unique_name(node.NODE_NAME) node.model._graph_model = self.model @@ -1554,6 +1625,10 @@ def _serialize(self, nodes): serial_data['graph']['pipe_collision'] = self.pipe_collision() serial_data['graph']['pipe_slicing'] = self.pipe_slicing() + # connection constrains. + serial_data['graph']['accept_connection_types'] = self.model.accept_connection_types + serial_data['graph']['reject_connection_types'] = self.model.reject_connection_types + # serialize nodes. for n in nodes: # update the node model. @@ -1618,6 +1693,12 @@ def _deserialize(self, data, relative_pos=False, pos=None): elif attr_name == 'pipe_slicing': self.set_pipe_slicing(attr_value) + # connection constrains. + elif attr_name == 'accept_connection_types': + self.model.accept_connection_types = attr_value + elif attr_name == 'reject_connection_types': + self.model.reject_connection_types = attr_value + # build the nodes. nodes = {} for n_id, n_data in data.get('nodes', {}).items(): diff --git a/NodeGraphQt/base/model.py b/NodeGraphQt/base/model.py index a599cdab..c31b45d6 100644 --- a/NodeGraphQt/base/model.py +++ b/NodeGraphQt/base/model.py @@ -107,6 +107,11 @@ def __init__(self): 'outputs': NodePropWidgetEnum.HIDDEN.value, } + # temp store connection constrains. + # (deleted when node is added to the graph) + self._TEMP_accept_connection_types = {} + self._TEMP_reject_connection_types = {} + def __repr__(self): return '<{}(\'{}\') object at {}>'.format( self.__class__.__name__, self.name, self.id) @@ -223,6 +228,85 @@ def get_tab_name(self, name): return return model.get_node_common_properties(self.type_)[name]['tab'] + def add_port_accept_connection_type( + self, + port_name, port_type, node_type, + accept_pname, accept_ptype, accept_ntype + ): + """ + Convenience function for adding to the "accept_connection_types" dict. + If the node graph model is unavailable yet then we store it to a + temp var that gets deleted. + + Args: + port_name (str): current port name. + port_type (str): current port type. + node_type (str): current port node type. + accept_pname (str):port name to accept. + accept_ptype (str): port type accept. + accept_ntype (str):port node type to accept. + """ + model = self._graph_model + if model: + model.add_port_accept_connection_type( + port_name, port_type, node_type, + accept_pname, accept_ptype, accept_ntype + ) + return + + connection_data = self._TEMP_accept_connection_types + keys = [node_type, port_type, port_name, accept_ntype] + for key in keys: + if key not in connection_data.keys(): + connection_data[key] = {} + connection_data = connection_data[key] + + if accept_ptype not in connection_data: + connection_data[accept_ptype] = set([accept_pname]) + else: + connection_data[accept_ptype].add(accept_pname) + + def add_port_reject_connection_type( + self, + port_name, port_type, node_type, + reject_pname, reject_ptype, reject_ntype + ): + """ + Convenience function for adding to the "reject_connection_types" dict. + If the node graph model is unavailable yet then we store it to a + temp var that gets deleted. + + Args: + port_name (str): current port name. + port_type (str): current port type. + node_type (str): current port node type. + reject_pname: + reject_ptype: + reject_ntype: + + Returns: + + """ + model = self._graph_model + if model: + model.add_port_reject_connection_type( + port_name, port_type, node_type, + reject_pname, reject_ptype, reject_ntype + ) + return + + connection_data = self._TEMP_reject_connection_types + keys = [node_type, port_type, port_name, reject_ntype] + for key in keys: + if key not in connection_data.keys(): + connection_data[key] = {} + connection_data = connection_data[key] + + if reject_ptype not in connection_data: + connection_data[reject_ptype] = set([reject_pname]) + else: + connection_data[reject_ptype].add(reject_pname) + @property def properties(self): """ @@ -352,6 +436,9 @@ class NodeGraphModel(object): def __init__(self): self.__common_node_props = {} + self.accept_connection_types = {} + self.reject_connection_types = {} + self.nodes = {} self.session = '' self.acyclic = True @@ -421,6 +508,96 @@ def get_node_common_properties(self, node_type): """ return self.__common_node_props.get(node_type) + def add_port_accept_connection_type( + self, + port_name, port_type, node_type, + accept_pname, accept_ptype, accept_ntype + ): + """ + Convenience function for adding to the "accept_connection_types" dict. + + Args: + port_name (str): current port name. + port_type (str): current port type. + node_type (str): current port node type. + accept_pname (str):port name to accept. + accept_ptype (str): port type accept. + accept_ntype (str):port node type to accept. + """ + connection_data = self.accept_connection_types + keys = [node_type, port_type, port_name, accept_ntype] + for key in keys: + if key not in connection_data.keys(): + connection_data[key] = {} + connection_data = connection_data[key] + + if accept_ptype not in connection_data: + connection_data[accept_ptype] = set([accept_pname]) + else: + connection_data[accept_ptype].add(accept_pname) + + def port_accept_connection_types(self, node_type, port_type, port_name): + """ + Convenience function for getting the accepted port types from the + "accept_connection_types" dict. + + Args: + node_type (str): + port_type (str): + port_name (str): + + Returns: + dict: {: {: []}} + """ + data = self.accept_connection_types.get(node_type) or {} + accepted_types = data.get(port_type) or {} + return accepted_types.get(port_name) or {} + + def add_port_reject_connection_type( + self, + port_name, port_type, node_type, + reject_pname, reject_ptype, reject_ntype + ): + """ + Convenience function for adding to the "reject_connection_types" dict. + + Args: + port_name (str): current port name. + port_type (str): current port type. + node_type (str): current port node type. + reject_pname (str): port name to reject. + reject_ptype (str): port type to reject. + reject_ntype (str): port node type to reject. + """ + connection_data = self.reject_connection_types + keys = [node_type, port_type, port_name, reject_ntype] + for key in keys: + if key not in connection_data.keys(): + connection_data[key] = {} + connection_data = connection_data[key] + + if reject_ptype not in connection_data: + connection_data[reject_ptype] = set([reject_pname]) + else: + connection_data[reject_ptype].add(reject_pname) + + def port_reject_connection_types(self, node_type, port_type, port_name): + """ + Convenience function for getting the accepted port types from the + "reject_connection_types" dict. + + Args: + node_type (str): + port_type (str): + port_name (str): + + Returns: + dict: {: {: []}} + """ + data = self.reject_connection_types.get(node_type) or {} + rejected_types = data.get(port_type) or {} + return rejected_types.get(port_name) or {} + if __name__ == '__main__': p = PortModel(None) diff --git a/NodeGraphQt/base/port.py b/NodeGraphQt/base/port.py index f2b0798f..9ff38d36 100644 --- a/NodeGraphQt/base/port.py +++ b/NodeGraphQt/base/port.py @@ -221,6 +221,35 @@ def connect_to(self, port=None, push_undo=True): raise PortError( 'Can\'t connect port because "{}" is locked.'.format(name)) + # validate accept connection. + node_type = self.node().type_ + accepted_types = port.accepted_port_types().get(node_type) + if accepted_types: + accepted_pnames = accepted_types.get(self.type_()) or set([]) + if self.name() not in accepted_pnames: + return + node_type = port.node().type_ + accepted_types = self.accepted_port_types().get(node_type) + if accepted_types: + accepted_pnames = accepted_types.get(port.type_()) or set([]) + if port.name() not in accepted_pnames: + return + + # validate reject connection. + node_type = self.node().type_ + rejected_types = port.rejected_port_types().get(node_type) + if rejected_types: + rejected_pnames = rejected_types.get(self.type_()) or set([]) + if self.name() in rejected_pnames: + return + node_type = port.node().type_ + rejected_types = self.rejected_port_types().get(node_type) + if rejected_types: + rejected_pnames = rejected_types.get(port.type_()) or set([]) + if port.name() in rejected_pnames: + return + + # make the connection from here. graph = self.node().graph viewer = graph.viewer() @@ -349,6 +378,92 @@ def clear_connections(self, push_undo=True): for cp in self.connected_ports(): self.disconnect_from(cp, push_undo=False) + def add_accept_port_type(self, port_name, port_type, node_type): + """ + Add a constrain to "accept" a pipe connection. + + Once a constrain has been added only ports of that type specified will + be allowed a pipe connection. + + `Implemented in` ``v0.6.0`` + + See Also: + :meth:`NodeGraphQt.Port.add_reject_ports_type`, + :meth:`NodeGraphQt.BaseNode.add_accept_port_type` + + Args: + port_name (str): name of the port. + port_type (str): port type. + node_type (str): port node type. + """ + # storing the connection constrain at the graph level instead of the + # port level so we don't serialize the same data for every port + # instance. + self.node().add_accept_port_type( + port=self, + port_type_data={ + 'port_name': port_name, + 'port_type': port_type, + 'node_type': node_type, + } + ) + + def accepted_port_types(self): + """ + Returns a dictionary of connection constrains of the port types + that allow for a pipe connection to this node. + + See Also: + :meth:`NodeGraphQt.BaseNode.accepted_port_types` + + Returns: + dict: {: {: []}} + """ + return self.node().accepted_port_types(self) + + def add_reject_port_type(self, port_name, port_type, node_type): + """ + Add a constrain to "reject" a pipe connection. + + Once a constrain has been added only ports of that type specified will + be rejected a pipe connection. + + `Implemented in` ``v0.6.0`` + + See Also: + :meth:`NodeGraphQt.Port.add_accept_ports_type`, + :meth:`NodeGraphQt.BaseNode.add_reject_port_type` + + Args: + port_name (str): name of the port. + port_type (str): port type. + node_type (str): port node type. + """ + # storing the connection constrain at the graph level instead of the + # port level so we don't serialize the same data for every port + # instance. + self.node().add_reject_port_type( + port=self, + port_type_data={ + 'port_name': port_name, + 'port_type': port_type, + 'node_type': node_type, + } + ) + + def rejected_port_types(self): + """ + Returns a dictionary of connection constrains of the port types + that are NOT allowed for a pipe connection to this node. + + See Also: + :meth:`NodeGraphQt.BaseNode.rejected_port_types` + + Returns: + dict: {: {: []}} + """ + return self.node().rejected_port_types(self) + @property def color(self): return self.__view.color diff --git a/NodeGraphQt/nodes/base_node.py b/NodeGraphQt/nodes/base_node.py index e80fb89d..3894d589 100644 --- a/NodeGraphQt/nodes/base_node.py +++ b/NodeGraphQt/nodes/base_node.py @@ -657,6 +657,124 @@ def connected_output_nodes(self): nodes[p] = [cp.node() for cp in p.connected_ports()] return nodes + def add_accept_port_type(self, port, port_type_data): + """ + Add a accept constrain to a specified node port. + + Once a constrain has been added only ports of that type specified will + be allowed a pipe connection. + + port type data example + + .. highlight:: python + .. code-block:: python + { + 'port_name': 'foo' + 'port_type': PortTypeEnum.IN.value + 'node_type': 'io.github.jchanvfx.NodeClass' + } + + See Also: + :meth:`NodeGraphQt.BaseNode.accepted_port_types` + + Args: + port (NodeGraphQt.Port): port to assign constrain to. + port_type_data (dict): port type data to accept a connection + """ + node_ports = self._inputs + self._outputs + if port not in node_ports: + raise PortError('Node does not contain port: "{}"'.format(port)) + + self._model.add_port_accept_connection_type( + port_name=port.name(), + port_type=port.type_(), + node_type=self.type_, + accept_pname=port_type_data['port_name'], + accept_ptype=port_type_data['port_type'], + accept_ntype=port_type_data['node_type'] + ) + + def accepted_port_types(self, port): + """ + Returns a dictionary of connection constrains of the port types + that allow for a pipe connection to this node. + + Args: + port (NodeGraphQt.Port): port object. + + Returns: + dict: {: {: []}} + """ + ports = self._inputs + self._outputs + if port not in ports: + raise PortError('Node does not contain port "{}"'.format(port)) + + accepted_types = self.graph.model.port_accept_connection_types( + node_type=self.type_, + port_type=port.type_(), + port_name=port.name() + ) + return accepted_types + + def add_reject_port_type(self, port, port_type_data): + """ + Add a reject constrain to a specified node port. + + Once a constrain has been added only ports of that type specified will + NOT be allowed a pipe connection. + + port type data example + + .. highlight:: python + .. code-block:: python + { + 'port_name': 'foo' + 'port_type': PortTypeEnum.IN.value + 'node_type': 'io.github.jchanvfx.NodeClass' + } + + See Also: + :meth:`NodeGraphQt.Port.rejected_port_types` + + Args: + port (NodeGraphQt.Port): port to assign constrain to. + port_type_data (dict): port type data to reject a connection + """ + node_ports = self._inputs + self._outputs + if port not in node_ports: + raise PortError('Node does not contain port: "{}"'.format(port)) + + self._model.add_port_reject_connection_type( + port_name=port.name(), + port_type=port.type_(), + node_type=self.type_, + reject_pname=port_type_data['port_name'], + reject_ptype=port_type_data['port_type'], + reject_ntype=port_type_data['node_type'] + ) + + def rejected_port_types(self, port): + """ + Returns a dictionary of connection constrains of the port types + that are NOT allowed for a pipe connection to this node. + + Args: + port (NodeGraphQt.Port): port object. + + Returns: + dict: {: {: []}} + """ + ports = self._inputs + self._outputs + if port not in ports: + raise PortError('Node does not contain port "{}"'.format(port)) + + rejected_types = self.graph.model.port_reject_connection_types( + node_type=self.type_, + port_type=port.type_(), + port_name=port.name() + ) + return rejected_types + def on_input_connected(self, in_port, out_port): """ Callback triggered when a new pipe connection is made. diff --git a/NodeGraphQt/pkg_info.py b/NodeGraphQt/pkg_info.py index 3f5f4766..555fbffa 100644 --- a/NodeGraphQt/pkg_info.py +++ b/NodeGraphQt/pkg_info.py @@ -1,6 +1,6 @@ #!/usr/bin/python # -*- coding: utf-8 -*- -__version__ = '0.5.12' +__version__ = '0.6.0' __status__ = 'Work in Progress' __license__ = 'MIT' diff --git a/NodeGraphQt/qgraphics/node_base.py b/NodeGraphQt/qgraphics/node_base.py index 9886f67f..15b91389 100644 --- a/NodeGraphQt/qgraphics/node_base.py +++ b/NodeGraphQt/qgraphics/node_base.py @@ -761,15 +761,23 @@ def set_proxy_mode(self, mode): for w in self._widgets.values(): w.widget().setVisible(visible) + # port text is not visible in vertical layout. + if self.layout_direction is LayoutDirectionEnum.VERTICAL.value: + port_text_visible = False + else: + port_text_visible = visible + # input port text visibility. for port, text in self._input_items.items(): if port.display_name: - text.setVisible(visible) + text.setVisible(port_text_visible) # output port text visibility. for port, text in self._output_items.items(): if port.display_name: - text.setVisible(visible) + text.setVisible(port_text_visible) + + self._text_item.setVisible(visible) self._icon_item.setVisible(visible) diff --git a/NodeGraphQt/qgraphics/pipe.py b/NodeGraphQt/qgraphics/pipe.py index f9edcabc..a91462a2 100644 --- a/NodeGraphQt/qgraphics/pipe.py +++ b/NodeGraphQt/qgraphics/pipe.py @@ -583,8 +583,7 @@ def hoverEnterEvent(self, event): """ QtWidgets.QGraphicsPathItem.hoverEnterEvent(self, event) - def draw_path(self, start_port, end_port=None, cursor_pos=None, - color_mode=None): + def draw_path(self, start_port, end_port=None, cursor_pos=None, color=None): """ re-implemented to also update the index pointer arrow position. @@ -593,13 +592,12 @@ def draw_path(self, start_port, end_port=None, cursor_pos=None, end_port (PortItem): port used to draw the end point. cursor_pos (QtCore.QPointF): cursor position if specified this will be the draw end point. - color_mode (str): arrow index pointer color mode - ('accept', 'reject' or None). + color (list[int]): override arrow index pointer color. (r, g, b) """ super(LivePipeItem, self).draw_path(start_port, end_port, cursor_pos) - self.draw_index_pointer(start_port, cursor_pos, color_mode) + self.draw_index_pointer(start_port, cursor_pos, color) - def draw_index_pointer(self, start_port, cursor_pos, color_mode=None): + def draw_index_pointer(self, start_port, cursor_pos, color=None): """ Update the index pointer arrow position and direction when the live pipe path is redrawn. @@ -607,8 +605,7 @@ def draw_index_pointer(self, start_port, cursor_pos, color_mode=None): Args: start_port (PortItem): start port item. cursor_pos (QtCore.QPoint): cursor scene position. - color_mode (str): arrow index pointer color mode - ('accept', 'reject' or None). + color (list[int]): override arrow index pointer color. (r, g, b). """ text_rect = self._idx_text.boundingRect() @@ -635,16 +632,13 @@ def draw_index_pointer(self, start_port, cursor_pos, color_mode=None): self._idx_pointer.setPolygon(transform.map(self._poly)) - if color_mode == 'accept': - color = QtGui.QColor(*PipeEnum.HIGHLIGHT_COLOR.value) - elif color_mode == 'reject': - color = QtGui.QColor(*PipeEnum.DISABLED_COLOR.value) - else: - color = QtGui.QColor(*PipeEnum.ACTIVE_COLOR.value) + pen_color = QtGui.QColor(*PipeEnum.HIGHLIGHT_COLOR.value) + if isinstance(color, (list, tuple)): + pen_color = QtGui.QColor(*color) pen = self._idx_pointer.pen() - pen.setColor(color) - self._idx_pointer.setBrush(color.darker(300)) + pen.setColor(pen_color) + self._idx_pointer.setBrush(pen_color.darker(300)) self._idx_pointer.setPen(pen) diff --git a/NodeGraphQt/widgets/viewer.py b/NodeGraphQt/widgets/viewer.py index 1b09018a..84c62f33 100644 --- a/NodeGraphQt/widgets/viewer.py +++ b/NodeGraphQt/widgets/viewer.py @@ -9,6 +9,7 @@ from NodeGraphQt.constants import ( LayoutDirectionEnum, PortTypeEnum, + PipeEnum, PipeLayoutEnum, ViewerEnum, Z_VAL_PIPE, @@ -88,6 +89,7 @@ def __init__(self, parent=None, undo_stack=None): self._prev_selection_nodes = [] self._prev_selection_pipes = [] self._node_positions = {} + self._rubber_band = QtWidgets.QRubberBand( QtWidgets.QRubberBand.Rectangle, self ) @@ -150,6 +152,11 @@ def __init__(self, parent=None, undo_stack=None): self.SHIFT_state = False self.COLLIDING_state = False + # connection constrains. + # TODO: maybe this should be a reference to the graph model instead? + self.accept_connection_types = None + self.reject_connection_types = None + def __repr__(self): return '<{}() object at {}>'.format( self.__class__.__name__, hex(id(self))) @@ -750,26 +757,37 @@ def sceneMouseMoveEvent(self, event): return pos = event.scenePos() - color_mode = None + pointer_color = None for item in self.scene().items(pos): - if isinstance(item, PortItem): - x = item.boundingRect().width() / 2 - y = item.boundingRect().height() / 2 - pos = item.scenePos() - pos.setX(pos.x() + x) - pos.setY(pos.y() + y) - if item == self._start_port: - break - color_mode = 'accept' - if self.acyclic: - if item.node == self._start_port.node: - color_mode = 'reject' - elif item.port_type == self._start_port.port_type: - color_mode = 'reject' + if not isinstance(item, PortItem): + continue + + x = item.boundingRect().width() / 2 + y = item.boundingRect().height() / 2 + pos = item.scenePos() + pos.setX(pos.x() + x) + pos.setY(pos.y() + y) + if item == self._start_port: + break + pointer_color = PipeEnum.HIGHLIGHT_COLOR.value + accept = self._validate_accept_connection(self._start_port, item) + if not accept: + pointer_color = [150, 60, 255] break + reject = self._validate_reject_connection(self._start_port, item) + if reject: + pointer_color = [150, 60, 255] + break + + if self.acyclic: + if item.node == self._start_port.node: + pointer_color = PipeEnum.DISABLED_COLOR.value + elif item.port_type == self._start_port.port_type: + pointer_color = PipeEnum.DISABLED_COLOR.value + break self._LIVE_PIPE.draw_path( - self._start_port, cursor_pos=pos, color_mode=color_mode + self._start_port, cursor_pos=pos, color=pointer_color ) def sceneMousePressEvent(self, event): @@ -873,6 +891,78 @@ def sceneMouseReleaseEvent(self, event): # --- port connections --- + def _validate_accept_connection(self, from_port, to_port): + """ + Check if a pipe connection is allowed if there are a constrains set + on the ports. + + Args: + from_port (PortItem): + to_port (PortItem): + + Returns: + bool: true to allow connection. + """ + to_ptype = to_port.port_type + from_ptype = from_port.port_type + + to_data = self.accept_connection_types.get(to_port.node.type_) or {} + constraints = to_data.get(to_ptype, {}).get(to_port.name, {}) + accept_data = constraints.get(from_port.node.type_, {}) + + accepted_pnames = accept_data.get(from_ptype) + if accepted_pnames: + if from_port.name in accepted_pnames: + return True + return False + + from_data = self.accept_connection_types.get(from_port.node.type_) or {} + constraints = from_data.get(from_ptype, {}).get(from_port.name, {}) + accept_data = constraints.get(to_port.node.type_, {}) + + accepted_pnames = accept_data.get(to_ptype) + if accepted_pnames: + if from_port.name in accepted_pnames: + return True + return False + return True + + def _validate_reject_connection(self, from_port, to_port): + """ + Check if a pipe connection is NOT allowed if there are a constrains set + on the ports. + + Args: + from_port (PortItem): + to_port (PortItem): + + Returns: + bool: true to reject connection. + """ + to_ptype = to_port.port_type + from_ptype = from_port.port_type + + to_data = self.reject_connection_types.get(to_port.node.type_) or {} + constraints = to_data.get(to_ptype, {}).get(to_port.name, {}) + reject_data = constraints.get(from_port.node.type_, {}) + + rejected_pnames = reject_data.get(from_ptype) + if rejected_pnames: + if from_port.name in rejected_pnames: + return True + return False + + from_data = self.reject_connection_types.get(from_port.node.type_) or {} + constraints = from_data.get(from_ptype, {}).get(from_port.name, {}) + reject_data = constraints.get(to_port.node.type_, {}) + + rejected_pnames = reject_data.get(to_ptype) + if rejected_pnames: + if to_port.name in rejected_pnames: + return True + return False + return False + def apply_live_connection(self, event): """ triggered mouse press/release event for the scene. @@ -926,6 +1016,14 @@ def apply_live_connection(self, event): # allow a node cycle connection. same_node_connection = False + # constrain check + accept_connection = self._validate_accept_connection( + self._start_port, end_port + ) + reject_connection = self._validate_reject_connection( + self._start_port, end_port + ) + # restore connection check. restore_connection = any([ # if the end port is locked. @@ -937,7 +1035,11 @@ def apply_live_connection(self, event): # if end port is the start port. end_port == self._start_port, # if detached port is the end port. - self._detached_port == end_port + self._detached_port == end_port, + # if a port has a accept port type constrain. + not accept_connection, + # if a port has a reject port type constrain. + reject_connection ]) if restore_connection: if self._detached_port: diff --git a/docs/examples/ex_port.rst b/docs/examples/ex_port.rst index 966a5f42..958c8eaf 100644 --- a/docs/examples/ex_port.rst +++ b/docs/examples/ex_port.rst @@ -141,3 +141,68 @@ And here's another example function for drawing a Square port. painter.drawRect(rect) painter.restore() + + +Connection Constrains +********************* + +From version ``v0.6.0`` port object can now have pipe connection constraints the functions implemented are: + +:meth:`NodeGraphQt.Port.add_accept_ports_type` and :meth:`NodeGraphQt.Port.add_reject_ports_type` + +this can also be set on the ``BaseNode`` level as well with: +:meth:`NodeGraphQt.BaseNode.add_accept_port_type`, :meth:`NodeGraphQt.BaseNode.add_accept_port_type` + +Here's an example snippet to add pipe connection constraints to a port. + +.. code-block:: python + :linenos: + + from NodeGraphQt import BaseNode + from NodeGraphQt.constants import PortTypeEnum + + + class BasicNodeA(BaseNode): + + # unique node identifier. + __identifier__ = 'io.github.jchanvfx' + + # initial default node name. + NODE_NAME = 'node A' + + def __init__(self): + super(BasicNode, self).__init__() + + # create node output ports. + self.add_output('output 1') + self.add_output('output 2') + + + class BasicNodeB(BaseNode): + + # unique node identifier. + __identifier__ = 'io.github.jchanvfx' + + # initial default node name. + NODE_NAME = 'node B' + + def __init__(self): + super(BasicNode, self).__init__() + + # create node inputs. + + # port "in A" will only accept pipe connections from port "output 1" under the node "BasicNodeA". + in_port_a = self.add_input('in A') + in_port_a.add_accept_port_type( + port_name='output 1', + port_type=PortTypeEnum.OUT.value, + node_type='io.github.jchanvfx.BasicNodeA' + ) + + # port "in A" will reject pipe connections from port "output 1" under the node "BasicNodeA". + in_port_b = self.add_input('in B') + in_port_b.add_reject_port_type( + port_name='output 1', + port_type=PortTypeEnum.OUT.value, + node_type='io.github.jchanvfx.BasicNodeA' + )