Skip to content

Commit

Permalink
[Fixbug] Fix a bug in the minimum and maximum operator (#102)
Browse files Browse the repository at this point in the history
* fix a typo in min/max; change signature

* add test
  • Loading branch information
yaoyaoding committed Feb 13, 2023
1 parent 24c1260 commit 33cfcbb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/hidet/graph/ops/definitions/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def __init__(self, cond: Tensor, x: Tensor, y: Tensor):


class MaxOp(Operator):
def __init__(self, tensors: List[Tensor]):
def __init__(self, *tensors: Tensor):
def scalar_max(args: List[expr.Expr]):
if len(args) == 1:
return args[0]
Expand All @@ -434,12 +434,12 @@ def scalar_max(args: List[expr.Expr]):


class MinOp(Operator):
def __init__(self, tensors: List[Tensor]):
def __init__(self, *tensors: Tensor):
def scalar_min(args: List[expr.Expr]):
if len(args) == 1:
return args[0]
else:
return primitives.max(args[0], scalar_min(args[1:]))
return primitives.min(args[0], scalar_min(args[1:]))

super().__init__(
inputs=list(tensors),
Expand Down Expand Up @@ -671,12 +671,12 @@ def where(cond: Tensor, x: Tensor, y: Tensor) -> Tensor:

def maximum(a: Tensor, b: Tensor, *others: Tensor) -> Tensor:
args = [a, b] + list(others)
return MaxOp(args).get_output(0)
return MaxOp(*args).get_output(0)


def minimum(a: Tensor, b: Tensor, *others: Tensor) -> Tensor:
args = [a, b] + list(others)
return MinOp(args).get_output(0)
return MinOp(*args).get_output(0)


def mod(x: Tensor, y: Tensor) -> Tensor:
Expand Down
10 changes: 10 additions & 0 deletions tests/graph/operators/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@ def test_bitwise_xor(a_shape, b_shape):
check_binary(a_shape, b_shape, np.int32, np.bitwise_xor, ops.bitwise_xor)


@pytest.mark.parametrize("a_shape, b_shape", binary_op_shapes)
def test_minimum(a_shape, b_shape):
check_binary(a_shape, b_shape, np.int32, np.minimum, ops.minimum)


@pytest.mark.parametrize("a_shape, b_shape", binary_op_shapes)
def test_maximum(a_shape, b_shape):
check_binary(a_shape, b_shape, np.int32, np.maximum, ops.maximum)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_ceil(a_shape):
check_unary(a_shape, np.float32, np.ceil, ops.ceil)
Expand Down

0 comments on commit 33cfcbb

Please sign in to comment.