Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions NodeGraphQt/base/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def __init__(self, parent=None, **kwargs):
kwargs.get('viewer') or NodeViewer(undo_stack=self._undo_stack))
self._viewer.set_layout_direction(layout_direction)

# viewer needs a reference to the model port connection constrains
# for the user interaction with the live pipe.
self._viewer.accept_connection_types = self._model.accept_connection_types
self._viewer.reject_connection_types = self._model.reject_connection_types

self._context_menu = {}

self._register_context_menu()
Expand Down Expand Up @@ -1143,6 +1148,39 @@ def create_node(self, node_type, name=None, selected=True, color=None,
node_attrs[node.type_][pname].update(pattrs)
self.model.set_node_common_properties(node_attrs)

accept_types = node.model.__dict__.pop(
'_TEMP_accept_connection_types'
)
for ptype, pdata in accept_types.get(node.type_, {}).items():
for pname, accept_data in pdata.items():
for accept_ntype, accept_ndata in accept_data.items():
for accept_ptype, accept_pnames in accept_ndata.items():
for accept_pname in accept_pnames:
self._model.add_port_accept_connection_type(
port_name=pname,
port_type=ptype,
node_type=node.type_,
accept_pname=accept_pname,
accept_ptype=accept_ptype,
accept_ntype=accept_ntype
)
reject_types = node.model.__dict__.pop(
'_TEMP_reject_connection_types'
)
for ptype, pdata in reject_types.get(node.type_, {}).items():
for pname, reject_data in pdata.items():
for reject_ntype, reject_ndata in reject_data.items():
for reject_ptype, reject_pnames in reject_ndata.items():
for reject_pname in reject_pnames:
self._model.add_port_reject_connection_type(
port_name=pname,
port_type=ptype,
node_type=node.type_,
reject_pname=reject_pname,
reject_ptype=reject_ptype,
reject_ntype=reject_ntype
)

node.NODE_NAME = self.get_unique_name(name or node.NODE_NAME)
node.model.name = node.NODE_NAME
node.model.selected = selected
Expand Down Expand Up @@ -1207,6 +1245,39 @@ def add_node(self, node, pos=None, selected=True, push_undo=True):
node_attrs[node.type_][pname].update(pattrs)
self.model.set_node_common_properties(node_attrs)

accept_types = node.model.__dict__.pop(
'_TEMP_accept_connection_types'
)
for ptype, pdata in accept_types.get(node.type_, {}).items():
for pname, accept_data in pdata.items():
for accept_ntype, accept_ndata in accept_data.items():
for accept_ptype, accept_pnames in accept_ndata.items():
for accept_pname in accept_pnames:
self._model.add_port_accept_connection_type(
port_name=pname,
port_type=ptype,
node_type=node.type_,
accept_pname=accept_pname,
accept_ptype=accept_ptype,
accept_ntype=accept_ntype
)
reject_types = node.model.__dict__.pop(
'_TEMP_reject_connection_types'
)
for ptype, pdata in reject_types.get(node.type_, {}).items():
for pname, reject_data in pdata.items():
for reject_ntype, reject_ndata in reject_data.items():
for reject_ptype, reject_pnames in reject_ndata.items():
for reject_pname in reject_pnames:
self._model.add_port_reject_connection_type(
port_name=pname,
port_type=ptype,
node_type=node.type_,
reject_pname=reject_pname,
reject_ptype=reject_ptype,
reject_ntype=reject_ntype
)

node._graph = self
node.NODE_NAME = self.get_unique_name(node.NODE_NAME)
node.model._graph_model = self.model
Expand Down Expand Up @@ -1554,6 +1625,10 @@ def _serialize(self, nodes):
serial_data['graph']['pipe_collision'] = self.pipe_collision()
serial_data['graph']['pipe_slicing'] = self.pipe_slicing()

# connection constrains.
serial_data['graph']['accept_connection_types'] = self.model.accept_connection_types
serial_data['graph']['reject_connection_types'] = self.model.reject_connection_types

# serialize nodes.
for n in nodes:
# update the node model.
Expand Down Expand Up @@ -1618,6 +1693,12 @@ def _deserialize(self, data, relative_pos=False, pos=None):
elif attr_name == 'pipe_slicing':
self.set_pipe_slicing(attr_value)

# connection constrains.
elif attr_name == 'accept_connection_types':
self.model.accept_connection_types = attr_value
elif attr_name == 'reject_connection_types':
self.model.reject_connection_types = attr_value

# build the nodes.
nodes = {}
for n_id, n_data in data.get('nodes', {}).items():
Expand Down
177 changes: 177 additions & 0 deletions NodeGraphQt/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def __init__(self):
'outputs': NodePropWidgetEnum.HIDDEN.value,
}

# temp store connection constrains.
# (deleted when node is added to the graph)
self._TEMP_accept_connection_types = {}
self._TEMP_reject_connection_types = {}

def __repr__(self):
return '<{}(\'{}\') object at {}>'.format(
self.__class__.__name__, self.name, self.id)
Expand Down Expand Up @@ -223,6 +228,85 @@ def get_tab_name(self, name):
return
return model.get_node_common_properties(self.type_)[name]['tab']

def add_port_accept_connection_type(
self,
port_name, port_type, node_type,
accept_pname, accept_ptype, accept_ntype
):
"""
Convenience function for adding to the "accept_connection_types" dict.
If the node graph model is unavailable yet then we store it to a
temp var that gets deleted.

Args:
port_name (str): current port name.
port_type (str): current port type.
node_type (str): current port node type.
accept_pname (str):port name to accept.
accept_ptype (str): port type accept.
accept_ntype (str):port node type to accept.
"""
model = self._graph_model
if model:
model.add_port_accept_connection_type(
port_name, port_type, node_type,
accept_pname, accept_ptype, accept_ntype
)
return

connection_data = self._TEMP_accept_connection_types
keys = [node_type, port_type, port_name, accept_ntype]
for key in keys:
if key not in connection_data.keys():
connection_data[key] = {}
connection_data = connection_data[key]

if accept_ptype not in connection_data:
connection_data[accept_ptype] = set([accept_pname])
else:
connection_data[accept_ptype].add(accept_pname)

def add_port_reject_connection_type(
self,
port_name, port_type, node_type,
reject_pname, reject_ptype, reject_ntype
):
"""
Convenience function for adding to the "reject_connection_types" dict.
If the node graph model is unavailable yet then we store it to a
temp var that gets deleted.

Args:
port_name (str): current port name.
port_type (str): current port type.
node_type (str): current port node type.
reject_pname:
reject_ptype:
reject_ntype:

Returns:

"""
model = self._graph_model
if model:
model.add_port_reject_connection_type(
port_name, port_type, node_type,
reject_pname, reject_ptype, reject_ntype
)
return

connection_data = self._TEMP_reject_connection_types
keys = [node_type, port_type, port_name, reject_ntype]
for key in keys:
if key not in connection_data.keys():
connection_data[key] = {}
connection_data = connection_data[key]

if reject_ptype not in connection_data:
connection_data[reject_ptype] = set([reject_pname])
else:
connection_data[reject_ptype].add(reject_pname)

@property
def properties(self):
"""
Expand Down Expand Up @@ -352,6 +436,9 @@ class NodeGraphModel(object):
def __init__(self):
self.__common_node_props = {}

self.accept_connection_types = {}
self.reject_connection_types = {}

self.nodes = {}
self.session = ''
self.acyclic = True
Expand Down Expand Up @@ -421,6 +508,96 @@ def get_node_common_properties(self, node_type):
"""
return self.__common_node_props.get(node_type)

def add_port_accept_connection_type(
self,
port_name, port_type, node_type,
accept_pname, accept_ptype, accept_ntype
):
"""
Convenience function for adding to the "accept_connection_types" dict.

Args:
port_name (str): current port name.
port_type (str): current port type.
node_type (str): current port node type.
accept_pname (str):port name to accept.
accept_ptype (str): port type accept.
accept_ntype (str):port node type to accept.
"""
connection_data = self.accept_connection_types
keys = [node_type, port_type, port_name, accept_ntype]
for key in keys:
if key not in connection_data.keys():
connection_data[key] = {}
connection_data = connection_data[key]

if accept_ptype not in connection_data:
connection_data[accept_ptype] = set([accept_pname])
else:
connection_data[accept_ptype].add(accept_pname)

def port_accept_connection_types(self, node_type, port_type, port_name):
"""
Convenience function for getting the accepted port types from the
"accept_connection_types" dict.

Args:
node_type (str):
port_type (str):
port_name (str):

Returns:
dict: {<node_type>: {<port_type>: [<port_name>]}}
"""
data = self.accept_connection_types.get(node_type) or {}
accepted_types = data.get(port_type) or {}
return accepted_types.get(port_name) or {}

def add_port_reject_connection_type(
self,
port_name, port_type, node_type,
reject_pname, reject_ptype, reject_ntype
):
"""
Convenience function for adding to the "reject_connection_types" dict.

Args:
port_name (str): current port name.
port_type (str): current port type.
node_type (str): current port node type.
reject_pname (str): port name to reject.
reject_ptype (str): port type to reject.
reject_ntype (str): port node type to reject.
"""
connection_data = self.reject_connection_types
keys = [node_type, port_type, port_name, reject_ntype]
for key in keys:
if key not in connection_data.keys():
connection_data[key] = {}
connection_data = connection_data[key]

if reject_ptype not in connection_data:
connection_data[reject_ptype] = set([reject_pname])
else:
connection_data[reject_ptype].add(reject_pname)

def port_reject_connection_types(self, node_type, port_type, port_name):
"""
Convenience function for getting the accepted port types from the
"reject_connection_types" dict.

Args:
node_type (str):
port_type (str):
port_name (str):

Returns:
dict: {<node_type>: {<port_type>: [<port_name>]}}
"""
data = self.reject_connection_types.get(node_type) or {}
rejected_types = data.get(port_type) or {}
return rejected_types.get(port_name) or {}


if __name__ == '__main__':
p = PortModel(None)
Expand Down
Loading