Skip to content

Commit

Permalink
fix: unsqueeze axes became an input in onnx op set v13 (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlshriver committed Jul 6, 2022
1 parent c9c5899 commit 7d6d643
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions dnnv/nn/operations/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,19 @@ def __init__(self, x, axes, *, name: Optional[str] = None):
@classmethod
def from_onnx(cls, onnx_node, *inputs):
attributes = {a.name: as_numpy(a) for a in onnx_node.attribute}
axes = attributes.get("axes")
if isinstance(inputs[0], np.ndarray):
a = inputs[0]
for axis in axes:
a = np.expand_dims(a, axis)
return a
if len(inputs) == 2:
axes = inputs[1]
else:
axes = attributes["axes"]
if not isinstance(axes, Operation):
a = inputs[0]
for axis in axes:
a = np.expand_dims(a, axis)
return a
if len(inputs) == 2:
return cls(*inputs, name=onnx_node.name)
axes = attributes["axes"]
return cls(*inputs, axes=axes, name=onnx_node.name)


Expand Down

0 comments on commit 7d6d643

Please sign in to comment.