diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 37aaed9cd704..f913e70345f4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -86,6 +86,24 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value input, Value dim); +// In Dynamo import paths, we can assume that dynamic dimensions are strictly +// quantities and are not ambiguous with '1' symbols that can be interpreted +// to signal an expansion in various broadcasting scenarios. In the +// torch.compile eager path, this precondition is assured by guards on 0/1 +// dimension values, and on the torch.export graph-capture path, the shape +// solver guarantees this. +// +// We let lowerings assume this on a per-scope basis if the +// torch.assume_strict_symbolic_shapes unit attribute is present on any parent +// of the block. +bool isAssumingStrictSymbolicShapes(Block *scope); + +// Helper that uses the block from an OpBuilder for determining whether we +// are assuming strict symbolic shapes. +inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) { + return isAssumingStrictSymbolicShapes(builder.getBlock()); +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 9ec6a6006be7..6ed9d369e8e5 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -619,20 +619,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { reassociation[0].push_back(headOnesCount++); } - // TODO: Add support for size-1 dynamic dimensions. Value one = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t j = -1; + bool elideDynamicBroadcastDimCheck = + isAssumingStrictSymbolicShapes(rewriter); for (auto i : llvm::seq(headOnesCount, inputRank)) { if (inputType.isDynamicDim(i)) { - // Make sure that size-1 dynamic dimension does not exist. - Value dimSize = getDimOp(rewriter, loc, input, i); - Value dimSizeNotOne = rewriter.create( - loc, arith::CmpIPredicate::ne, dimSize, one); - rewriter.create( - loc, dimSizeNotOne, - rewriter.getStringAttr( - "unimplemented: size 1 dynamic dimension is not supported")); + if (!elideDynamicBroadcastDimCheck) { + // Make sure that size-1 dynamic dimension does not exist. + Value dimSize = getDimOp(rewriter, loc, input, i); + Value dimSizeNotOne = rewriter.create( + loc, arith::CmpIPredicate::ne, dimSize, one); + rewriter.create( + loc, dimSizeNotOne, + rewriter.getStringAttr( + "unimplemented: size 1 dynamic dimension is not supported")); + } ++j; } else if (inputType.getDimSize(i) != 1) { ++j; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index cfbac2632a28..0e89d822669f 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern 1) { - Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, - rewriter.getIndexType()); - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, cstStaticDimSize, - dynamicDims[0]); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + if (staticDimSize > 1) { + Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, + rewriter.getIndexType()); + auto equalToRunning = rewriter.create( + loc, arith::CmpIPredicate::eq, cstStaticDimSize, + dynamicDims[0]); + rewriter.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } broadcastedIndexShape.push_back(dynamicDims[0]); } else { diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 23528bb01f80..66380dea9a89 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern { } Value lhsDim0 = rewriter.create(loc, lhs, 0); - Value lhsDim1 = rewriter.create(loc, lhs, 1); - Value rhsDim0 = rewriter.create(loc, rhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); - rewriter.create( - loc, contractingDimEqual, - rewriter.getStringAttr( - "mismatching contracting dimension for torch.aten.mm")); + + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value lhsDim1 = rewriter.create(loc, lhs, 1); + Value rhsDim0 = rewriter.create(loc, rhs, 0); + Value contractingDimEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); + rewriter.create( + loc, contractingDimEqual, + rewriter.getStringAttr( + "mismatching contracting dimension for torch.aten.mm")); + } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 262d3cf62e54..a1e8e5fb72d9 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Value inputRank = rewriter.create( loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); - assertIsValidDim(rewriter, loc, dimPositive, inputRank); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + assertIsValidDim(rewriter, loc, dimPositive, inputRank); + } Value size = rewriter.create( loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 1d25d22720d2..3934e2f649ec 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1477,10 +1477,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; - contractingDim0EqualsNumFeatures(weight); - contractingDim0EqualsNumFeatures(bias); - contractingDim0EqualsNumFeatures(runningMean); - contractingDim0EqualsNumFeatures(runningVar); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + contractingDim0EqualsNumFeatures(weight); + contractingDim0EqualsNumFeatures(bias); + contractingDim0EqualsNumFeatures(runningMean); + contractingDim0EqualsNumFeatures(runningVar); + } auto indexingMap = AffineMap::get( /*dimCount=*/inputRank, diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 42c5d0b441cc..99b86027b8e5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // if this is the first tensor operand that didn't continue above: // take its dimension size as the size of the non-broadcasted // traversal along this dimension (this may include a dynamic size-1, - // **non-broadcasted** traversal!) + // **non-broadcasted** traversal unless if + // isAssumingStrictSymbolicShapes!) // emit error check "if the size does not match the non-broadcasted // traversal size along this dimension, error" // ``` @@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( auto c1 = b.create(loc, /*value=*/1); SmallVector resultShape(resultRank, c1); SmallVector indexingMaps; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b); for (Value tensorOperand : tensorOperands) { SmallVector exprs; auto type = tensorOperand.getType().cast(); @@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // This is the check which protects against the undefined behavior of // the generated linalg op in the case of iterating two operands with // dimensions sizes that are expected to match. - auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); + if (!elideDynamicBroadcastCheck) { + auto equalToRunning = + b.create(loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + b.create(loc, equalToRunning, + "mismatched size for broadcast"); + } } indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext())); @@ -337,6 +341,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Type elementType = inputType.getElementType(); Location loc = op->getLoc(); SmallVector outShape; + bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter); // Create affine map and shapes for tensor initialization. SmallVector outExpr; @@ -351,12 +356,14 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( Value shapeValue = broadcastToShape[i]; size_t j = i - diff; if (i < diff) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, - rewriter.getStringAttr( - "negative values not allowed in new dimensions")); + if (!elideDynamicBroadcastCheck) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + } outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); continue; } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6136db09221d..0bdfca26ddc1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3484,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp : rewriter.create( loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); } else { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } kernelSize.push_back(constantOne); } @@ -3586,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, rewriter.getI64IntegerAttr( inputShape[rank - 2 + i]))); } else { - Value cond = rewriter.create(loc, inputHW[i], - outputShapeSizesTorchInt[i]); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + rewriter.create(loc, cond, + "unimplemented: only support cases " + "where input and output size are " + "equal for non-unit output size"); + } Value outMinusOne = rewriter.create( loc, outputShapeSizesTorchInt[i], constantOne); kernelSize.push_back( @@ -3822,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); - Value cond = - rewriter.create(loc, productDimSizePlusOne, cstCorrection); - rewriter.create( - loc, cond, - "correction value should be less than or equal to productDimSize + 1"); + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value productDimSizePlusOne = rewriter.create( + loc, productDimSize.getType(), productDimSize, constantOne); + Value cond = rewriter.create(loc, productDimSizePlusOne, + cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + } Value productDimSizeSubCorrection = rewriter.create(loc, productDimSize, cstCorrection); Value result = rewriter.create(loc, newOutputType, squareSum, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 10c4bea67dc0..5de777763ea5 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -324,3 +324,12 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, op->getLoc(), unsqueezedType, input, dim); return unsqueezed; } + +bool Torch::isAssumingStrictSymbolicShapes(Block *block) { + for (Operation *parentOp = block->getParentOp(); parentOp; + parentOp = parentOp->getParentOp()) { + if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes")) + return true; + } + return false; +} diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index d95b7e1d87cf..470962e2494d 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -8,11 +8,11 @@ // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor +// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index // CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm" // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor @@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.mm$basic_strict( +// CHECK-NOT: assert +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> + attributes {torch.assume_strict_symbolic_shapes} +{ + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> + return %0 : !torch.vtensor<[?,2],f32> +} + +// ----- + // If the operands are missing dtype, we cannot lower it. func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} @@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> { %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> -} \ No newline at end of file +}