Skip to content

Commit

Permalink
TN interface pack/unpack: allow raw arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Aug 31, 2023
1 parent 1643140 commit 965df75
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions quimb/tensor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ def pack(obj):
skeleton : Tensor, TensorNetwork, or similar
A copy of ``obj`` with all references to the original data removed.
"""
skeleton = obj.copy()
params = skeleton.get_params()
placeholders = tree_map(Placeholder, params)
skeleton.set_params(placeholders)
try:
skeleton = obj.copy()
params = skeleton.get_params()
placeholders = tree_map(Placeholder, params)
skeleton.set_params(placeholders)
except AttributeError:
# assume it's a raw array
params = obj
skeleton = Placeholder(obj)
return params, skeleton


Expand All @@ -62,8 +67,12 @@ def unpack(params, skeleton):
obj : Tensor, TensorNetwork, or similar
A copy of ``skeleton`` with parameters inserted.
"""
obj = skeleton.copy()
obj.set_params(params)
try:
obj = skeleton.copy()
obj.set_params(params)
except AttributeError:
# assume it's a raw array
obj = params
return obj


Expand Down

0 comments on commit 965df75

Please sign in to comment.