Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 28 additions & 75 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,59 +149,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter,
return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy);
}

template <typename OpTy>
static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtTensorOp>() ||
std::is_same<OpTy, AtenLeTensorOp>() ||
std::is_same<OpTy, AtenGtTensorOp>() ||
std::is_same<OpTy, AtenGeTensorOp>() ||
std::is_same<OpTy, AtenEqTensorOp>() ||
std::is_same<OpTy, AtenNeTensorOp>(),
"unimplemented: op type not supported");

Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();

// TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs
// to be handled.
if (lhsDtype != rhsDtype) {
op.emitError("unimplemented: lhs and rhs dtype must be same");
return nullptr;
}

Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
}
template <class T, class... Ts>
struct is_any_same : std::disjunction<std::is_same<T, Ts>...> {};

template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");
static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs,
Value rhs) {
static_assert(
is_any_same<OpTy, AtenLtScalarOp, AtenLeScalarOp, AtenEqScalarOp,
AtenNeScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenLtTensorOp, AtenLeTensorOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp>(),
"unimplemented: op type not supported");

Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
Expand Down Expand Up @@ -229,22 +188,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
return nullptr;
}

if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLtScalarOp, AtenLtTensorOp>()) {
return createLessThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenLeScalarOp, AtenLeTensorOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGtScalarOp, AtenGtTensorOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenGeScalarOp, AtenGeTensorOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenEqScalarOp, AtenEqTensorOp>()) {
return createEqual(b, loc, elementalType, lhs, rhs);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
if constexpr (is_any_same<OpTy, AtenNeScalarOp, AtenNeTensorOp>()) {
return createNotEqual(b, loc, elementalType, lhs, rhs);
}
llvm_unreachable("unimplemented: op type not supported");
Expand Down Expand Up @@ -892,28 +851,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto ltTensor = dyn_cast<AtenLtTensorOp>(op)) {
return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto leTensor = dyn_cast<AtenLeTensorOp>(op)) {
return createCompareTensorOp(b, loc, leTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto geTensor = dyn_cast<AtenGeTensorOp>(op)) {
return createCompareTensorOp(b, loc, geTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto neTensor = dyn_cast<AtenNeTensorOp>(op)) {
return createCompareTensorOp(b, loc, neTensor, payloadArgs[0],
payloadArgs[1]);
return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]);
}
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
AtenDivTensorOp::Adaptor adaptor(operands);
Expand Down Expand Up @@ -996,27 +949,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}

if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}

if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}

if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}

if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}

if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}

if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}

if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"ElementwiseFloatTensorGtIntTensorModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -2706,6 +2707,7 @@
"ElementwiseTanIntModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
Expand Down Expand Up @@ -3785,6 +3787,7 @@
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10))


class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.int64, True),
([-1], torch.float64, True),
]
)
def forward(self, x, y):
return torch.lt(x, y)


@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64))


class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
([-1], torch.int32, True),
]
)
def forward(self, x, y):
return torch.gt(x, y)


@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule())
def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(3, 5, high=10).to(torch.float32),
tu.randint(5, high=10, dtype=torch.int32),
)


# ==============================================================================


Expand Down