Skip to content

Commit c5ffd81

Browse files
committed
[Torch Dialect] Decompose AtenTriuOp
1 parent 7b94189 commit c5ffd81

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,66 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
246246
};
247247
} // end namespace
248248

249+
namespace {
250+
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
251+
public:
252+
using OpRewritePattern::OpRewritePattern;
253+
LogicalResult matchAndRewrite(AtenTriuOp op,
254+
PatternRewriter &rewriter) const override {
255+
MLIRContext *context = op.getContext();
256+
Location loc = op.getLoc();
257+
Value input = op.getSelf();
258+
auto inputType = input.getType().cast<BaseTensorType>();
259+
if (!inputType.hasSizes() || !inputType.hasDtype()) {
260+
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
261+
}
262+
if (inputType.getSizes().size() < 2) {
263+
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
264+
}
265+
266+
Value rowDim = rewriter.create<Torch::ConstantIntOp>(
267+
loc, rewriter.getI64IntegerAttr(-2));
268+
Value colDim = rewriter.create<Torch::ConstantIntOp>(
269+
loc, rewriter.getI64IntegerAttr(-1));
270+
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
271+
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
272+
273+
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
274+
275+
Value none = rewriter.create<ConstantNoneOp>(loc);
276+
Value rowArange = rewriter.create<AtenArangeOp>(
277+
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
278+
/*device=*/none, /*pin_memory=*/none);
279+
Value colArange = rewriter.create<AtenArangeOp>(
280+
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
281+
/*device=*/none, /*pin_memory=*/none);
282+
283+
Value unsqueezeRowArange = rewriter.create<AtenUnsqueezeOp>(
284+
loc, baseType, rowArange,
285+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)));
286+
287+
Value unsqueezeColArange = rewriter.create<AtenUnsqueezeOp>(
288+
loc, baseType, colArange,
289+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0)));
290+
291+
Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
292+
loc, baseType, unsqueezeRowArange,
293+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)),
294+
op.getDiagonal());
295+
296+
Value condTensor = rewriter.create<AtenGeTensorOp>(
297+
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
298+
299+
auto others =
300+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
301+
302+
rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
303+
op, op.getResult().getType(), condTensor, input, others);
304+
return success();
305+
}
306+
};
307+
} // namespace
308+
249309
namespace {
250310
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
251311
public:
@@ -5819,6 +5879,7 @@ class DecomposeComplexOpsPass
58195879
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
58205880
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
58215881
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
5882+
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
58225883

58235884
GreedyRewriteConfig config;
58245885
config.useTopDownTraversal = true;

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
500500
target.addIllegalOp<AtenTypeAsOp>();
501501
target.addIllegalOp<AtenTileOp>();
502502
target.addIllegalOp<AtenReshapeAsOp>();
503+
target.addIllegalOp<AtenTriuOp>();
503504
for (auto &opName : backendLegalOpsSet) {
504505
target.addLegalOp(
505506
OperationName(kTorchOpPrefix + opName.first().str(), context));

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,8 @@
13081308
"MeanModule_basic",
13091309
"ArangeStartOutModule_basic",
13101310
"ArangeStartOutViewModule_basic",
1311+
"TriuModule_basic",
1312+
"TriuBroadcastModule_basic",
13111313
}
13121314

13131315
MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,6 +3251,53 @@ def AtenTriuWithPosDiagonalModule_basic(module, tu: TestUtils):
32513251
# ==============================================================================
32523252

32533253

3254+
class TriuModule(torch.nn.Module):
3255+
def __init__(self):
3256+
super().__init__()
3257+
3258+
@export
3259+
@annotate_args([
3260+
None,
3261+
([4,5], torch.float32, True),
3262+
])
3263+
def forward(self, x):
3264+
return torch.ops.aten.triu(x, 1)
3265+
3266+
3267+
@register_test_case(module_factory=lambda: TriuModule())
3268+
def TriuModule_basic(module, tu: TestUtils):
3269+
x=torch.tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2],
3270+
[-0.2447, 0.9556, -1.2919, 1.3378, 0.3],
3271+
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.4],
3272+
[-0.9888, torch.nan, torch.inf, -torch.inf, 0.5]])
3273+
module.forward(x)
3274+
3275+
3276+
# ==============================================================================
3277+
3278+
3279+
class TriuBroadcastModule(torch.nn.Module):
3280+
def __init__(self):
3281+
super().__init__()
3282+
3283+
@export
3284+
@annotate_args([
3285+
None,
3286+
([3,4,5,6], torch.float32, True),
3287+
([], torch.int32, True),
3288+
])
3289+
def forward(self, x, diagonal):
3290+
return torch.ops.aten.triu(x, diagonal)
3291+
3292+
3293+
@register_test_case(module_factory=lambda: TriuBroadcastModule())
3294+
def TriuBroadcastModule_basic(module, tu: TestUtils):
3295+
module.forward(tu.rand(3,4,5,6), tu.randint())
3296+
3297+
3298+
# ==============================================================================
3299+
3300+
32543301
class AtenTriuWithNegDiagonalModule(torch.nn.Module):
32553302

32563303
def __init__(self):

0 commit comments

Comments
 (0)