Skip to content

Commit

Permalink
[StableHlo] Fix AtenWhereSelfOp convert rule (#2093)
Browse files Browse the repository at this point in the history
* fix whereself convert rule

* use int to test promotion

* add dynamo failing test

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
  • Loading branch information
zhekunz2 and zhekunz2 committed May 5, 2023
1 parent eaaaeb6 commit fc62b8e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
# %7 = torch.operator "aten._index_put_impl_.hacked_twin"(%1, %6, %5, %true, %false) : (!torch.tensor<*,f32>, !torch.list<tensor>, !torch.tensor<*,f32>, !torch.bool, !torch.bool) -> !torch.tensor
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
Expand Down Expand Up @@ -267,6 +269,8 @@
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseNotInt32Module_basic",
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
Value cond = adaptor.getCondition();
Value other = adaptor.getOther();

auto outType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
// promote self and other types
self = hlo::promoteType(rewriter, self, outType);
other = hlo::promoteType(rewriter, other, outType);

if (failed(
broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits)))
return op.emitError("failed broadcast self and condition ranks");
Expand Down
45 changes: 45 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,29 @@ def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseWhereScalarOtherStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
([4, 5], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, b, 8)


@register_test_case(module_factory=lambda: ElementwiseWhereScalarOtherStaticModule())
def ElementwiseWhereScalarOtherStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())


# ==============================================================================


class ElementwiseWhereScalarSelfModule(torch.nn.Module):

def __init__(self):
Expand All @@ -246,6 +269,28 @@ def forward(self, a, b):
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())

# ==============================================================================


class ElementwiseWhereScalarSelfStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 5], torch.float64, True),
([4, 5], torch.float64, True),
])
def forward(self, a, b):
return torch.where(a > 0.5, 4.0, b)


@register_test_case(module_factory=lambda: ElementwiseWhereScalarSelfStaticModule())
def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())


# ==============================================================================

Expand Down

0 comments on commit fc62b8e

Please sign in to comment.