diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8a0071d6f7c3..815a1892d161 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -761,6 +761,7 @@ "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", @@ -1434,6 +1435,7 @@ "UniformStaticShapeModule_basic", "AtenEmbeddingBagStaticModule_basic", "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index f913e70345f4..1b536f435f46 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -104,6 +104,12 @@ inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { return isAssumingStrictSymbolicShapes(builder.getBlock()); } +// Helper function for AtenEmptyStrided and friends that checks if the stride +// values are default or not. Throws a runtime assert if not. +LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, + Value opSize, Value opStride, + Location loc); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2cf3cc742f92..f362a661b15f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -332,7 +332,8 @@ class DecomposeAtenNarrowOp : public OpRewritePattern { rewriter.create(loc, one.getType(), start, length); rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, /*start=*/start, + op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, + /*start=*/start, /*end=*/startPlusLength, /*step=*/one); return success(); @@ -404,16 +405,15 @@ class DecomposeAtenGluOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenZeroOp - : public OpRewritePattern { +class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { Value zero = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - zero); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), zero); return success(); } }; @@ -1139,14 +1139,21 @@ class DecomposeAtenEluOp : public OpRewritePattern { Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); - Value maxZeroX = rewriter.create(loc, resType, zeroTensor, input); - Value positiveOutput = rewriter.create(loc, resType, maxZeroX, scale); - Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); - Value scaledMinZeroX = rewriter.create(loc, resType, minZeroX, inputScale); + Value maxZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value positiveOutput = + rewriter.create(loc, resType, maxZeroX, scale); + Value minZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value scaledMinZeroX = + rewriter.create(loc, resType, minZeroX, inputScale); Value expX = rewriter.create(loc, resType, scaledMinZeroX); - Value expXM1 = rewriter.create(loc, resType, expX, constantOne, constantOne); - Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, scale); - Value negativeOutput = rewriter.create(loc, resType, scaledExpXM1, alpha); + Value expXM1 = rewriter.create(loc, resType, expX, + constantOne, constantOne); + Value scaledExpXM1 = + rewriter.create(loc, resType, expXM1, scale); + Value negativeOutput = + rewriter.create(loc, resType, scaledExpXM1, alpha); Value eluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); @@ -1419,8 +1426,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); - auto reshaped = rewriter.create(loc, unsqueezedType, op.getSelf(), - unsqueezedDims); + auto reshaped = rewriter.create(loc, unsqueezedType, + op.getSelf(), unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); @@ -1507,8 +1514,8 @@ class DecomposeAtenExpandOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "unimplemented: requires implicit to be false"); } - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getSize()); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), op.getSize()); return success(); } }; @@ -1527,7 +1534,8 @@ class DecomposeAtenWhereScalarOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value selfTensor = createRank0Tensor(rewriter, loc, resType, op.getSelf()); - Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); + Value otherTensor = + createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), selfTensor, otherTensor); return success(); @@ -1548,7 +1556,8 @@ class DecomposeAtenWhereScalarOtherOp if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - Value otherTensor = createRank0Tensor(rewriter, loc, resType, op.getOther()); + Value otherTensor = + createRank0Tensor(rewriter, loc, resType, op.getOther()); rewriter.replaceOpWithNewOp(op, resType, op.getCondition(), op.getSelf(), otherTensor); return success(); @@ -1592,8 +1601,8 @@ class DecomposeAtenMaskedFillScalarOp } Value mask = op.getMask(); Value value = createRank0Tensor(rewriter, loc, resType, op.getValue()); - rewriter.replaceOpWithNewOp(op, resType, mask, - value, op.getSelf()); + rewriter.replaceOpWithNewOp(op, resType, mask, value, + op.getSelf()); return success(); } }; @@ -1653,8 +1662,8 @@ class DecomposeAtenConvTranspose2dOp Value cstTrue = rewriter.create(op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), - op.getStride(), op.getPadding(), op.getDilation(), /*transposed=*/cstTrue, - op.getOutputPadding(), op.getGroups()); + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); return success(); } }; @@ -2406,9 +2415,9 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { op, "aten.std.dim expects input tensor of floating-point type"); } - Value varDim = - rewriter.create(op->getLoc(), op.getType(), self, - op.getDim(), op.getUnbiased(), op.getKeepdim()); + Value varDim = rewriter.create( + op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(), + op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varDim); return success(); } @@ -2532,8 +2541,8 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { Value one = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value emptyTensor = rewriter.create( - loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), op.getMemoryFormat()); + loc, resultType, input, zero, op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, /*generator=*/none); @@ -2735,7 +2744,8 @@ class DecomposeAtenNativeLayerNormOp SmallVector normalizedShapeSizesTorchInt; getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt); int64_t axis = inputRank - normalizedShapeSizesTorchInt.size(); - auto reduceDimInts = llvm::to_vector<4>(llvm::seq(axis, inputRank)); + auto reduceDimInts = + llvm::to_vector<4>(llvm::seq(axis, inputRank)); auto reducedTy = op.getResult(1).getType(); auto sizeListType = ListType::get(IntType::get(context)); @@ -2811,8 +2821,8 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern { Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( - op, op.getType(), sizeList, op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), op.getMemoryFormat()); + op, op.getType(), sizeList, op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); return success(); } }; @@ -2833,8 +2843,8 @@ class DecomposeAtenArangeOp : public OpRewritePattern { step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( - op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); + op, op.getType(), start, op.getEnd(), step, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; @@ -2853,8 +2863,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { step = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); + op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); } }; @@ -2941,7 +2951,8 @@ class DecomposeAtenNativeBatchNormOp loc, ListType::get(IntType::get(context)), runningStatsShape); SmallVector runningStatsShapeInt(inputRank, 1); - runningStatsShapeInt[1] = runningMean.getType().cast().getSizes()[0]; + runningStatsShapeInt[1] = + runningMean.getType().cast().getSizes()[0]; Type dtype = input.getType().cast().getOptionalDtype(); Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); @@ -3226,11 +3237,10 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); } rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), op.getFillValue(), dtype, op.getLayout(), op.getDevice(), - op.getPinMemory()); + op, op.getType(), op.getSize(), op.getFillValue(), dtype, + op.getLayout(), op.getDevice(), op.getPinMemory()); return success(); - } }; } // namespace @@ -3244,7 +3254,8 @@ class DecomposeAtenIndexPutOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } @@ -3261,8 +3272,8 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern { Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getOther()); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - sizeList); + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSelf(), sizeList); return success(); } }; @@ -3284,8 +3295,9 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern { Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); Value emptyTensor = rewriter.create( - op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); + op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory(), + op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, op.getType(), emptyTensor, op.getSelf(), op.getNonBlocking()); return success(); @@ -3356,7 +3368,8 @@ class DecomposeAtenIndexPutHackedTwinOp PatternRewriter &rewriter) const override { Value cstFalse = rewriter.create(op.getLoc(), false); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), op.getAccumulate(), + op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), + op.getAccumulate(), /*unsafe=*/cstFalse); return success(); } @@ -3445,9 +3458,9 @@ class DecomposeAtenToDtypeLayoutOp op, "unimplemented: layout is expected to be strided"); } - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getDtype(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), + op.getCopy(), op.getMemoryFormat()); return success(); } }; @@ -3463,9 +3476,9 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern { // Device information isn't relevant to torch-mlir, so we can drop that info // here. - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getDtype(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), + op.getCopy(), op.getMemoryFormat()); return success(); } @@ -3704,8 +3717,8 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenBaddbmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value bmm = - rewriter.create(loc, op.getType(), op.getBatch1(), op.getBatch2()); + Value bmm = rewriter.create(loc, op.getType(), op.getBatch1(), + op.getBatch2()); Value alphaTimesBmm = rewriter.create(loc, op.getType(), bmm, op.getAlpha()); Value input = op.getSelf(); @@ -4066,7 +4079,8 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { resultType.getOptionalDtype()) .cast(); - Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); + Value sub = + createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); Value result = rewriter.create(loc, subType, sub); if (reductionType == torch_upstream::Reduction::None) { rewriter.replaceOp(op, result); @@ -4148,7 +4162,8 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { rewriter.getF32Type()) .cast(); Value emptyTensor = rewriter.create( - loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), + loc, floatResultType, op.getSize(), /*dtype=*/none, + /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*memoryFormat=*/none); @@ -4178,11 +4193,11 @@ class DecomposeAtenRandintOp : public OpRewritePattern { Value low = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - + rewriter.replaceOpWithNewOp( - op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory()); - + op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), + op.getLayout(), op.getDevice(), op.getPinMemory()); + return success(); } }; @@ -4200,10 +4215,11 @@ class DecomposeAtenVarMeanCorrectionOp Location loc = op.getLoc(); Value noneVal = rewriter.create(loc); Value var = rewriter.create( - loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), op.getKeepdim()); - Value mean = - rewriter.create(loc, op.getType(0), op.getSelf(), op.getDim(), - op.getKeepdim(), /*dtype=*/noneVal); + loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), + op.getKeepdim()); + Value mean = rewriter.create( + loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), + /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } @@ -4300,14 +4316,16 @@ class DecomposeAtenRandnGeneratorOp /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); - Value uOne = rewriter.create(loc, resultType, emptyTensorA, - /*from=*/low, - /*to=*/high, - /*generator=*/op.getGenerator()); - Value uTwo = rewriter.create(loc, resultType, emptyTensorB, - /*from=*/low, - /*to=*/high, - /*generator=*/op.getGenerator()); + Value uOne = + rewriter.create(loc, resultType, emptyTensorA, + /*from=*/low, + /*to=*/high, + /*generator=*/op.getGenerator()); + Value uTwo = + rewriter.create(loc, resultType, emptyTensorB, + /*from=*/low, + /*to=*/high, + /*generator=*/op.getGenerator()); Value logUOne = rewriter.create(loc, resultType, uOne); Value minusTwoLogUOne = @@ -4432,37 +4450,18 @@ class DecomposeAtenNewEmptyStridedOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyStridedOp op, PatternRewriter &rewriter) const override { - SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); - - // We only support the cases with default stride values. - // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) - // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and - // stride[2] == 1. - bool isDefaultStride = true; - for (unsigned i = 0; i < strideListInts.size(); i++) { - int64_t defaultStride = 1; - for (unsigned j = i + 1; j < sizeListInts.size(); j++) - defaultStride *= sizeListInts[j]; - if (defaultStride != strideListInts[i]) { - isDefaultStride = false; - break; - } - } + Location loc = op.getLoc(); + Value opSize = op.getSize(); + Value opStride = op.getStride(); - if (!isDefaultStride) + if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( - op, "only default strides supported for new_empty_strided op"); + op, "Unable to determine if stride is default"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); + return success(); } }; @@ -4475,42 +4474,20 @@ class DecomposeAtenEmptyStridedOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenEmptyStridedOp op, PatternRewriter &rewriter) const override { - SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); - - // We only support the cases with default stride values. - // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) - // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and - // stride[2] == 1. - bool isDefaultStride = true; - for (unsigned i = 0; i < strideListInts.size(); i++) { - int64_t defaultStride = 1; - for (unsigned j = i + 1; j < sizeListInts.size(); j++) - defaultStride *= sizeListInts[j]; - if (defaultStride != strideListInts[i]) { - isDefaultStride = false; - break; - } - } - if (!isDefaultStride) + Location loc = op.getLoc(); + Value opSize = op.getSize(); + Value opStride = op.getStride(); + + if (failed(checkDefaultStrideHelper(op, rewriter, opSize, opStride, loc))) return rewriter.notifyMatchFailure( - op, "only default strides supported for new_empty_strided op"); + op, "Unable to determine if stride is default"); Value noneVal = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), - op.getPinMemory(), /*memoryFormat=*/noneVal); - + op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); - - } }; } // namespace @@ -4899,8 +4876,8 @@ class DecomposeAtenSignOp : public OpRewritePattern { auto selectGreater = rewriter.create(loc, outType, greater, one, zero); - rewriter.replaceOpWithNewOp(op, outType, greaterEqual, - selectGreater, minusOne); + rewriter.replaceOpWithNewOp( + op, outType, greaterEqual, selectGreater, minusOne); return success(); } }; @@ -5312,7 +5289,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index ddc95bd4b2fd..14a264ada342 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -333,3 +333,61 @@ bool Torch::isAssumingStrictSymbolicShapes(Block *block) { } return false; } + +LogicalResult Torch::checkDefaultStrideHelper(Operation *op, + PatternRewriter &rewriter, + Value opSize, Value opStride, + Location loc) { + + SmallVector sizeListInts, strideListInts; + if (matchPattern(opSize, m_TorchListOfConstantInts(sizeListInts)) && + matchPattern(opStride, m_TorchListOfConstantInts(strideListInts))) { + + // We only support the cases with default stride values. + // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) + // Here the stride[0] == size[1] * size[2], stride[1] == size[2], and + // stride[2] == 1. + bool isDefaultStride = true; + for (unsigned i = 0; i < strideListInts.size(); i++) { + int64_t defaultStride = 1; + for (unsigned j = i + 1; j < sizeListInts.size(); j++) + defaultStride *= sizeListInts[j]; + if (defaultStride != strideListInts[i]) { + isDefaultStride = false; + break; + } + } + if (!isDefaultStride) + return rewriter.notifyMatchFailure( + op, "only default strides supported for empty_strided op"); + + return success(); + + } else { + SmallVector sizeListValues; + if (!getListConstructElements(opSize, sizeListValues)) + return rewriter.notifyMatchFailure(op, "couldn't get size list values"); + SmallVector strideListValues; + if (!getListConstructElements(opStride, strideListValues)) + return rewriter.notifyMatchFailure(op, + "couldn't get stride list values."); + SmallVector boolVector; + for (unsigned i = 0; i < strideListValues.size(); i++) { + Value defaultStride = rewriter.createOrFold( + loc, rewriter.getI64IntegerAttr(1)); + for (unsigned j = i + 1; j < sizeListValues.size(); j++) { + defaultStride = rewriter.createOrFold( + loc, defaultStride, sizeListValues[j]); + } + boolVector.push_back(rewriter.createOrFold( + loc, defaultStride, strideListValues[i])); + } + Value allBoolOpList = rewriter.createOrFold( + loc, Torch::ListType::get(rewriter.getType()), + boolVector); + Value cmp = rewriter.createOrFold(loc, allBoolOpList); + rewriter.createOrFold( + loc, cmp, "not all strides are default"); + return success(); + } +} diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 27cf2eb4a0d9..675e327572d2 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1629,7 +1629,6 @@ def forward(self, a): def NewEmptyStridedModuleDefaultDtype_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) - # ============================================================================== @@ -1651,4 +1650,27 @@ def forward(self, a): @register_test_case(module_factory=lambda: EmptyStridedModule()) def EmptyStridedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, 4)) + +# ============================================================================== + + +class EmptyStridedSizeIntStrideModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, -1], torch.float32, True), + ]) + def forward(self, a): + x = torch.ops.aten.empty_strided(a.size(), stride=[12, a.size(2), 1]) + y = x.copy_(a) + return y + + +@register_test_case(module_factory=lambda: EmptyStridedSizeIntStrideModule()) +def EmptyStridedSizeIntStrideModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4))