From 5267a5ef1d4f55dc53c151a8170bae99f25fc004 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 29 Sep 2023 14:33:41 -0700 Subject: [PATCH 1/2] Elide dynamic broadcast checks when in strict symbolic shapes mode. When importing dynamic shaped programs from Dynamo, via torch.compile or torch.export, we can assume that strict symbolic shape checks have been done prior to generating torch IR. Among other shape checking, this eliminates the case where an unknown dimension can be dynamically '1' in a way that signals a broadcast. Adds a `isAssumingStrictSymbolicShapes` utility which consults a `torch.assume_strict_symbolic_shapes` attribute on an enclosing scope and returns true if present. In the linalg pipeline, many runtime checks are elided when this returns true. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 18 ++++++++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 21 +++++---- .../TorchToLinalg/IndirectDataMovement.cpp | 18 ++++---- lib/Conversion/TorchToLinalg/Linear.cpp | 19 ++++---- .../TorchToLinalg/TensorScalarInterop.cpp | 4 +- .../TorchToLinalg/Uncategorized.cpp | 10 +++-- lib/Conversion/TorchToLinalg/Utils.cpp | 31 +++++++------ .../Torch/Transforms/DecomposeComplexOps.cpp | 43 +++++++++++-------- lib/Dialect/Torch/Utils/Utils.cpp | 9 ++++ 9 files changed, 112 insertions(+), 61 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 37aaed9cd704..f913e70345f4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -86,6 +86,24 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim); +// In Dynamo import paths, we can assume that dynamic dimensions are strictly +// quantities and are not ambiguous with '1' symbols that can be interpreted +// to signal an expansion in various broadcasting scenarios. In the +// torch.compile eager path, this precondition is assured by guards on 0/1 +// dimension values, and on the torch.export graph-capture path, the shape +// solver guarantees this. +// +// We let lowerings assume this on a per-scope basis if the +// torch.assume_strict_symbolic_shapes unit attribute is present on any parent +// of the block. +bool isAssumingStrictSymbolicShapes(Block *scope); + +// Helper that uses the block from an OpBuilder for determining whether we +// are assuming strict symbolic shapes. +inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { + return isAssumingStrictSymbolicShapes(builder.getBlock()); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 9ec6a6006be7..6ed9d369e8e5 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -619,20 +619,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { reassociation[0].push_back(headOnesCount++); } - // TODO: Add support for size-1 dynamic dimensions. Value one = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t j = -1; + bool elideDynamicBroadcastDimCheck = + isAssumingStrictSymbolicShapes(rewriter); for (auto i : llvm::seq(headOnesCount, inputRank)) { if (inputType.isDynamicDim(i)) { - // Make sure that size-1 dynamic dimension does not exist. - Value dimSize = getDimOp(rewriter, loc, input, i); - Value dimSizeNotOne = rewriter.create( - loc, arith::CmpIPredicate::ne, dimSize, one); - rewriter.create( - loc, dimSizeNotOne, - rewriter.getStringAttr( - "unimplemented: size 1 dynamic dimension is not supported")); + if (!elideDynamicBroadcastDimCheck) { + // Make sure that size-1 dynamic dimension does not exist. + Value dimSize = getDimOp(rewriter, loc, input, i); + Value dimSizeNotOne = rewriter.create( + loc, arith::CmpIPredicate::ne, dimSize, one); + rewriter.create( + loc, dimSizeNotOne, + rewriter.getStringAttr( + "unimplemented: size 1 dynamic dimension is not supported")); + } ++j; } else if (inputType.getDimSize(i) != 1) { ++j; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632a28..0e89d822669f 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern 1) { - Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, - rewriter.getIndexType()); - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, cstStaticDimSize, - dynamicDims[0]); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } broadcastedIndexShape.push_back(dynamicDims[0]); } else { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 23528bb01f80..66380dea9a89 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern { } Value lhsDim0 = rewriter.create(loc, lhs, 0); - Value lhsDim1 = rewriter.create(loc, lhs, 1); - Value rhsDim0 = rewriter.create(loc, rhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for torch.aten.mm")); + + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value lhsDim1 = rewriter.create(loc, lhs, 1); + Value rhsDim0 = rewriter.create(loc, rhs, 0); + Value contractingDimEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); + rewriter.create( + loc, contractingDimEqual, + rewriter.getStringAttr( + "mismatching contracting dimension for torch.aten.mm")); + } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62e54..a1e8e5fb72d9 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); - assertIsValidDim(rewriter, loc, dimPositive, inputRank); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + assertIsValidDim(rewriter, loc, dimPositive, inputRank); + } Value size = rewriter.create( loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1d25d22720d2..3934e2f649ec 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1477,10 +1477,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; - contractingDim0EqualsNumFeatures(weight); - contractingDim0EqualsNumFeatures(bias); - contractingDim0EqualsNumFeatures(runningMean); - contractingDim0EqualsNumFeatures(runningVar); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + contractingDim0EqualsNumFeatures(weight); + contractingDim0EqualsNumFeatures(bias); + contractingDim0EqualsNumFeatures(runningMean); + contractingDim0EqualsNumFeatures(runningVar); + } auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 42c5d0b441cc..99b86027b8e5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) + // **non-broadcasted** traversal unless if + // isAssumingStrictSymbolicShapes!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` @@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); @@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. - auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!elideDynamicBroadcastCheck) { + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); @@ -337,6 +341,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); SmallVector outShape; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); // Create affine map and shapes for tensor initialization. SmallVector outExpr; @@ -351,12 +356,14 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value shapeValue = broadcastToShape[i]; size_t j = i - diff; if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + } outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6136db09221d..0bdfca26ddc1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3484,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } kernelSize.push_back(constantOne); } @@ -3586,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, rewriter.getI64IntegerAttr( inputShape[rank - 2 + i]))); } else { - Value cond = rewriter.create(loc, inputHW[i], - outputShapeSizesTorchInt[i]); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + rewriter.create(loc, cond, + "unimplemented: only support cases " + "where input and output size are " + "equal for non-unit output size"); + } Value outMinusOne = rewriter.create( loc, outputShapeSizesTorchInt[i], constantOne); kernelSize.push_back( @@ -3822,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); - Value cond = - rewriter.create(loc, productDimSizePlusOne, cstCorrection); - rewriter.create( - loc, cond, - "correction value should be less than or equal to productDimSize + 1"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value productDimSizePlusOne = rewriter.create( + loc, productDimSize.getType(), productDimSize, constantOne); + Value cond = rewriter.create(loc, productDimSizePlusOne, + cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + } Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 10c4bea67dc0..5de777763ea5 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -324,3 +324,12 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, op->getLoc(), unsqueezedType, input, dim); return unsqueezed; } + +bool Torch::isAssumingStrictSymbolicShapes(Block *block) { + for (Operation *parentOp = block->getParentOp(); parentOp; + parentOp = parentOp->getParentOp()) { + if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes")) + return true; + } + return false; +} From 1f52562d18ba40b8a93146d249615e075bdf8791 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 29 Sep 2023 15:18:03 -0700 Subject: [PATCH 2/2] Update test --- test/Conversion/TorchToLinalg/basic.mlir | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index d95b7e1d87cf..470962e2494d 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -8,11 +8,11 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index // CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm" // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor @@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( +// CHECK-NOT: assert +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> + return %0 : !torch.vtensor<[?,2],f32> +} + +// ----- + // If the operands are missing dtype, we cannot lower it. func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} @@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> -} \ No newline at end of file +}