@@ -246,6 +246,66 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
246246};
247247} // end namespace
248248
249+ namespace {
250+ class DecomposeAtenTriuOp : public OpRewritePattern <AtenTriuOp> {
251+ public:
252+ using OpRewritePattern::OpRewritePattern;
253+ LogicalResult matchAndRewrite (AtenTriuOp op,
254+ PatternRewriter &rewriter) const override {
255+ MLIRContext *context = op.getContext ();
256+ Location loc = op.getLoc ();
257+ Value input = op.getSelf ();
258+ auto inputType = input.getType ().cast <BaseTensorType>();
259+ if (!inputType.hasSizes () || !inputType.hasDtype ()) {
260+ return rewriter.notifyMatchFailure (op, " should have shape and dtype" );
261+ }
262+ if (inputType.getSizes ().size () < 2 ) {
263+ return rewriter.notifyMatchFailure (op, " the rank of tensor should >= 2" );
264+ }
265+
266+ Value rowDim = rewriter.create <Torch::ConstantIntOp>(
267+ loc, rewriter.getI64IntegerAttr (-2 ));
268+ Value colDim = rewriter.create <Torch::ConstantIntOp>(
269+ loc, rewriter.getI64IntegerAttr (-1 ));
270+ Value rowSize = rewriter.create <AtenSizeIntOp>(loc, input, rowDim);
271+ Value colSize = rewriter.create <AtenSizeIntOp>(loc, input, colDim);
272+
273+ auto baseType = ValueTensorType::getWithLeastStaticInformation (context);
274+
275+ Value none = rewriter.create <ConstantNoneOp>(loc);
276+ Value rowArange = rewriter.create <AtenArangeOp>(
277+ loc, baseType, rowSize, /* dtype=*/ none, /* layout=*/ none,
278+ /* device=*/ none, /* pin_memory=*/ none);
279+ Value colArange = rewriter.create <AtenArangeOp>(
280+ loc, baseType, colSize, /* dtype=*/ none, /* layout=*/ none,
281+ /* device=*/ none, /* pin_memory=*/ none);
282+
283+ Value unsqueezeRowArange = rewriter.create <AtenUnsqueezeOp>(
284+ loc, baseType, rowArange,
285+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 )));
286+
287+ Value unsqueezeColArange = rewriter.create <AtenUnsqueezeOp>(
288+ loc, baseType, colArange,
289+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 )));
290+
291+ Value unsqueezeRowArangePlusDiagonal = rewriter.create <AtenAddScalarOp>(
292+ loc, baseType, unsqueezeRowArange,
293+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 )),
294+ op.getDiagonal ());
295+
296+ Value condTensor = rewriter.create <AtenGeTensorOp>(
297+ loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
298+
299+ auto others =
300+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
301+
302+ rewriter.replaceOpWithNewOp <AtenWhereScalarOtherOp>(
303+ op, op.getResult ().getType (), condTensor, input, others);
304+ return success ();
305+ }
306+ };
307+ } // namespace
308+
249309namespace {
250310class DecomposeAtenSizeOp : public OpRewritePattern <AtenSizeOp> {
251311public:
@@ -5819,6 +5879,7 @@ class DecomposeComplexOpsPass
58195879 addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
58205880 addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
58215881 addPatternIfTargetOpIsIllegal<DecomposeAtenIndexTensorOp>(patterns);
5882+ addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
58225883
58235884 GreedyRewriteConfig config;
58245885 config.useTopDownTraversal = true ;
0 commit comments