Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
Value input, Value dim);

// In Dynamo import paths, we can assume that dynamic dimensions are strictly
// quantities and are not ambiguous with '1' symbols that can be interpreted
// to signal an expansion in various broadcasting scenarios. In the
// torch.compile eager path, this precondition is assured by guards on 0/1
// dimension values, and on the torch.export graph-capture path, the shape
// solver guarantees this.
//
// We let lowerings assume this on a per-scope basis if the
// torch.assume_strict_symbolic_shapes unit attribute is present on any parent
// of the block.
bool isAssumingStrictSymbolicShapes(Block *scope);

// Helper that uses the block from an OpBuilder for determining whether we
// are assuming strict symbolic shapes.
inline bool isAssumingStrictSymbolicShapes(OpBuilder &builder) {
return isAssumingStrictSymbolicShapes(builder.getBlock());
}

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
21 changes: 12 additions & 9 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,20 +619,23 @@ class ConvertAtenSqueezeOp : public OpConversionPattern<AtenSqueezeOp> {
reassociation[0].push_back(headOnesCount++);
}

// TODO: Add support for size-1 dynamic dimensions.
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
int64_t j = -1;
bool elideDynamicBroadcastDimCheck =
isAssumingStrictSymbolicShapes(rewriter);
for (auto i : llvm::seq<int64_t>(headOnesCount, inputRank)) {
if (inputType.isDynamicDim(i)) {
// Make sure that size-1 dynamic dimension does not exist.
Value dimSize = getDimOp(rewriter, loc, input, i);
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, dimSize, one);
rewriter.create<cf::AssertOp>(
loc, dimSizeNotOne,
rewriter.getStringAttr(
"unimplemented: size 1 dynamic dimension is not supported"));
if (!elideDynamicBroadcastDimCheck) {
// Make sure that size-1 dynamic dimension does not exist.
Value dimSize = getDimOp(rewriter, loc, input, i);
Value dimSizeNotOne = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, dimSize, one);
rewriter.create<cf::AssertOp>(
loc, dimSizeNotOne,
rewriter.getStringAttr(
"unimplemented: size 1 dynamic dimension is not supported"));
}
++j;
} else if (inputType.getDimSize(i) != 1) {
++j;
Expand Down
18 changes: 10 additions & 8 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,16 @@ class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern<AtenIndexT
return rewriter.notifyMatchFailure(
op,
"unimplemented: index tensors with overlapping dynamic dims");
if (staticDimSize > 1) {
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
rewriter.getIndexType());
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
dynamicDims[0]);
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
if (!isAssumingStrictSymbolicShapes(rewriter)) {
if (staticDimSize > 1) {
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
rewriter.getIndexType());
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
dynamicDims[0]);
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
}
broadcastedIndexShape.push_back(dynamicDims[0]);
} else {
Expand Down
19 changes: 11 additions & 8 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,18 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
}

Value lhsDim0 = rewriter.create<tensor::DimOp>(loc, lhs, 0);
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value rhsDim1 = rewriter.create<tensor::DimOp>(loc, rhs, 1);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for torch.aten.mm"));

if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value lhsDim1 = rewriter.create<tensor::DimOp>(loc, lhs, 1);
Value rhsDim0 = rewriter.create<tensor::DimOp>(loc, rhs, 0);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0);
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for torch.aten.mm"));
}

Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class ConvertAtenSizeIntOp : public OpConversionPattern<AtenSizeIntOp> {
Value inputRank = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(type.getRank()));
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
if (!isAssumingStrictSymbolicShapes(rewriter)) {
assertIsValidDim(rewriter, loc, dimPositive, inputRank);
}
Value size = rewriter.create<tensor::DimOp>(
loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive));
rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size));
Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1477,10 +1477,12 @@ class ConvertAtenBatchNormOp : public OpConversionPattern<AtenBatchNormOp> {
rewriter.getStringAttr(
"expect the size of dim 0 equal to the number of features"));
};
contractingDim0EqualsNumFeatures(weight);
contractingDim0EqualsNumFeatures(bias);
contractingDim0EqualsNumFeatures(runningMean);
contractingDim0EqualsNumFeatures(runningVar);
if (!isAssumingStrictSymbolicShapes(rewriter)) {
contractingDim0EqualsNumFeatures(weight);
contractingDim0EqualsNumFeatures(bias);
contractingDim0EqualsNumFeatures(runningMean);
contractingDim0EqualsNumFeatures(runningVar);
}

auto indexingMap = AffineMap::get(
/*dimCount=*/inputRank,
Expand Down
31 changes: 19 additions & 12 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
// if this is the first tensor operand that didn't continue above:
// take its dimension size as the size of the non-broadcasted
// traversal along this dimension (this may include a dynamic size-1,
// **non-broadcasted** traversal!)
// **non-broadcasted** traversal unless if
// isAssumingStrictSymbolicShapes!)
// emit error check "if the size does not match the non-broadcasted
// traversal size along this dimension, error"
// ```
Expand All @@ -251,6 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
SmallVector<Value> resultShape(resultRank, c1);
SmallVector<AffineMap> indexingMaps;
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
for (Value tensorOperand : tensorOperands) {
SmallVector<AffineExpr> exprs;
auto type = tensorOperand.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -294,11 +296,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
// This is the check which protects against the undefined behavior of
// the generated linalg op in the case of iterating two operands with
// dimensions sizes that are expected to match.
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
if (!elideDynamicBroadcastCheck) {
auto equalToRunning =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
resultShape[resultDim], currentDimSize);
b.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
}
indexingMaps.push_back(AffineMap::get(
/*dimCount=*/resultRank, /*symbolCount=*/0, exprs, b.getContext()));
Expand Down Expand Up @@ -337,6 +341,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
Type elementType = inputType.getElementType();
Location loc = op->getLoc();
SmallVector<Value> outShape;
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(rewriter);

// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Expand All @@ -351,12 +356,14 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
Value shapeValue = broadcastToShape[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
if (!elideDynamicBroadcastCheck) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<cf::AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
}
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
Expand Down
43 changes: 24 additions & 19 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3484,11 +3484,13 @@ class DecomposeAtenAdaptiveAvgPool1dOp
: rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
} else {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");
}
kernelSize.push_back(constantOne);
}

Expand Down Expand Up @@ -3586,13 +3588,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
loc, rewriter.getI64IntegerAttr(
inputShape[rank - 2 + i])));
} else {
Value cond = rewriter.create<AtenEqIntOp>(loc, inputHW[i],
outputShapeSizesTorchInt[i]);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"unimplemented: only support cases where input and output size are "
"equal for non-unit output size");

if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value cond = rewriter.create<AtenEqIntOp>(
loc, inputHW[i], outputShapeSizesTorchInt[i]);
rewriter.create<RuntimeAssertOp>(loc, cond,
"unimplemented: only support cases "
"where input and output size are "
"equal for non-unit output size");
}
Value outMinusOne = rewriter.create<AtenSubIntOp>(
loc, outputShapeSizesTorchInt[i], constantOne);
kernelSize.push_back(
Expand Down Expand Up @@ -3822,13 +3825,15 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
loc, rewriter.getF64FloatAttr(correction));
// The `correction` value should be less than or equal to `productDimSize +
// 1`.
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
loc, productDimSize.getType(), productDimSize, constantOne);
Value cond =
rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne, cstCorrection);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"correction value should be less than or equal to productDimSize + 1");
if (!isAssumingStrictSymbolicShapes(rewriter)) {
Value productDimSizePlusOne = rewriter.create<AtenAddOp>(
loc, productDimSize.getType(), productDimSize, constantOne);
Value cond = rewriter.create<AtenGeFloatOp>(loc, productDimSizePlusOne,
cstCorrection);
rewriter.create<RuntimeAssertOp>(
loc, cond,
"correction value should be less than or equal to productDimSize + 1");
}
Value productDimSizeSubCorrection =
rewriter.create<AtenSubFloatOp>(loc, productDimSize, cstCorrection);
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, squareSum,
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,12 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
op->getLoc(), unsqueezedType, input, dim);
return unsqueezed;
}

bool Torch::isAssumingStrictSymbolicShapes(Block *block) {
for (Operation *parentOp = block->getParentOp(); parentOp;
parentOp = parentOp->getParentOp()) {
if (parentOp->hasAttr("torch.assume_strict_symbolic_shapes"))
return true;
}
return false;
}
17 changes: 14 additions & 3 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
// CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm"
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[RHS_DIM_1]]) : tensor<?x?xf32>
Expand All @@ -29,6 +29,17 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v

// -----

// CHECK-LABEL: func.func @torch.aten.mm$basic_strict(
// CHECK-NOT: assert
func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32>
attributes {torch.assume_strict_symbolic_shapes}
{
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32>
return %0 : !torch.vtensor<[?,2],f32>
}

// -----

// If the operands are missing dtype, we cannot lower it.
func.func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor {
// expected-error@+1 {{failed to legalize}}
Expand Down Expand Up @@ -264,4 +275,4 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten
func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}
}