-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JSON serialization / deserialization functions for Nodes and Edges #824
Conversation
jacksonwb
commented
Sep 11, 2020
- Add dict (de)serialization helper functions to Node/edge classes
- Add np array (de)serialization methods to np backend
- Update backend tests
- Add (de)serialization tests
- Format changes (sneaky yapf bomb)
and Edges * Add dict (de)serialization helper functions to Node/edge classes * Add np array (de)serialization methods to np backend * Update backend tests * Add (de)serialization tests
@@ -11,6 +11,7 @@ | |||
_remove_trace_edge, _remove_edges) | |||
import tensornetwork as tn | |||
from tensornetwork.backends.abstract_backend import AbstractBackend | |||
from typing import Dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note - Relevant Change
def to_serial_dict(self) -> Dict: | ||
return {} | ||
|
||
@classmethod | ||
def from_serial_dict(cls, serial_dict) -> "TestNode": | ||
return cls() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note - Relevant Change
@@ -16,12 +16,15 @@ | |||
from typing import Union | |||
from tensornetwork.backends import abstract_backend | |||
from tensornetwork.backends.numpy import decompositions | |||
import io |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note - Relevant Change
@Thenerdstation |
Codecov Report
@@ Coverage Diff @@
## master #824 +/- ##
==========================================
- Coverage 98.41% 98.39% -0.03%
==========================================
Files 128 129 +1
Lines 21639 21856 +217
==========================================
+ Hits 21297 21506 +209
- Misses 342 350 +8
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small things.
* Remove CopyNode to_serial from_serial method definitions * Add mock remote contraction test * Minor changes
functions * Added binding tests
Create a JSON string representing the Tensor Network made up of the given | ||
nodes. Nodes and their attributes, edges and their attributes and tensor | ||
values are included. | ||
|
||
Tensors are serialized according the the format used by each tensors backend. | ||
|
||
For edges spanning included nodes and excluded nodes the edge attributes are | ||
preserved in the serialization but the connection to the excluded node is | ||
dropped. The original edge is not modified. | ||
|
||
Args: | ||
nodes: A list of nodes making up a tensor network. | ||
edge_binding: A dictionary containing {str->edge} bindings. Edges that are | ||
not included in the serialized network are ommited from the dictionary. | ||
|
||
Returns: | ||
A string representing the JSON serialized tensor network. | ||
|
||
Raises: | ||
TypeError: If an edge_binding dict is passed with non string keys, or non | ||
Edge values. | ||
""" | ||
network_dict = { | ||
'nodes': [], | ||
'edges': [], | ||
} | ||
node_id_dict = {} | ||
edge_id_dict = {} | ||
|
||
# Build serialized Nodes | ||
for i, node in enumerate(nodes): | ||
node_id_dict[node] = i | ||
network_dict['nodes'].append({ | ||
'id': i, | ||
'attributes': node.to_serial_dict(), | ||
}) | ||
edges = get_all_edges(nodes) | ||
|
||
# Build serialized edges | ||
for i, edge in enumerate(edges): | ||
edge_id_dict[edge] = i | ||
node_ids = [node_id_dict.get(n) for n in edge.get_nodes()] | ||
attributes = edge.to_serial_dict() | ||
attributes['axes'] = [ | ||
a if node_ids[j] is not None else None | ||
for j, a in enumerate(attributes['axes']) | ||
] | ||
edge_dict = { | ||
'id': i, | ||
'node_ids': node_ids, | ||
'attributes': attributes, | ||
} | ||
network_dict['edges'].append(edge_dict) | ||
|
||
serial_edge_binding = _build_serial_binding(edge_binding, edge_id_dict) | ||
if serial_edge_binding: | ||
network_dict['edge_binding'] = serial_edge_binding | ||
return json.dumps(network_dict) | ||
|
||
|
||
def nodes_from_json(json_str: str) -> Tuple[List[AbstractNode], | ||
Dict[str, Tuple[Edge]]]: | ||
""" | ||
Create a tensor network from a JSON string representation of a tensor network. | ||
|
||
Args: | ||
json_str: A string representing a JSON serialized tensor network. | ||
|
||
Returns: | ||
A list of nodes making up the tensor network. | ||
A dictionary of {str -> (edge,)} bindings. All dictionary values are tuples | ||
of Edges. | ||
|
||
""" | ||
network_dict = json.loads(json_str) | ||
nodes = [] | ||
node_ids = {} | ||
edge_lookup = {} | ||
edge_binding = {} | ||
for n in network_dict['nodes']: | ||
node = Node.from_serial_dict(n['attributes']) | ||
nodes.append(node) | ||
node_ids[n['id']] = node | ||
for e in network_dict['edges']: | ||
e_nodes = [node_ids.get(n_id) for n_id in e['node_ids']] | ||
axes = e['attributes']['axes'] | ||
edge = Edge(node1=e_nodes[0], | ||
axis1=axes[0], | ||
node2=e_nodes[1], | ||
axis2=axes[1], | ||
name=e['attributes']['name']) | ||
edge_lookup[e['id']] = edge | ||
for node, axis in zip(e_nodes, axes): | ||
if node is not None: | ||
node.add_edge(edge, axis, override=True) | ||
for k, v in network_dict.get('edge_binding', {}).items(): | ||
for e_id in v: | ||
edge_binding[k] = edge_binding.get(k, ()) + (edge_lookup[e_id],) | ||
|
||
return nodes, edge_binding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added binding_dict support. @Thenerdstation take a look.