Skip to content

Commit

Permalink
[Fixbug] Binary arthmatic ops raise error when one is scalar on GPU (#…
Browse files Browse the repository at this point in the history
…109)

fix a bug in binary arthmatic ops
  • Loading branch information
yaoyaoding committed Feb 17, 2023
1 parent 7ec22cc commit 7820672
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/hidet/graph/ops/definitions/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,15 +475,15 @@ def binary_arithmetic(
elif isinstance(x, float):
x = dtypes.float32(x)
elif isinstance(x, Tensor) and len(x.shape) == 0:
if x.trace is None:
if x.trace is None and x.storage is not None:
x = x.dtype(x.item())

if isinstance(y, int):
y = dtypes.int32(y)
elif isinstance(y, float):
y = dtypes.float32(y)
elif isinstance(y, Tensor) and len(y.shape) == 0:
if y.trace is None:
if y.trace is None and y.storage is not None:
y = y.dtype(y.item())

if isinstance(x, Tensor) and isinstance(y, Tensor):
Expand Down

0 comments on commit 7820672

Please sign in to comment.