Skip to content

Commit

Permalink
[torch] Improve shape inference for dynamic shapes (#3091)
Browse files Browse the repository at this point in the history
Shapes can be processed as tensors to represent the set of dimensions.
As reshapes take a list of scalars this can result in a single dynamic
dimension blocking the adjacent static dimensions.

This pass attempts to de-couple tensor computations related to shapes
and propagate values to better support lowering scalar tensor
computations.
  • Loading branch information
rsuderman committed Apr 2, 2024
1 parent 401869e commit f97cd48
Show file tree
Hide file tree
Showing 8 changed files with 457 additions and 75 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3616,6 +3616,7 @@ def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3666,6 +3667,7 @@ def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -3715,6 +3717,7 @@ def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

Expand Down Expand Up @@ -11065,6 +11068,7 @@ def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}

def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [
Expand Down
35 changes: 24 additions & 11 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1756,19 +1756,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.create<Torch::AtenEqIntOp>(binder.getLoc(), dim, zero);
isZero =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isZero);
Value adjustment = zero;
int64_t inputDimsSize = dataSizes.size();
if (i < inputDimsSize) {
adjustment = rewriter.create<Torch::ConstantIntOp>(

int64_t dataRank = dataSizes.size();
if (i < dataRank) {
auto torchIntTy = rewriter.getType<Torch::IntType>();
auto int64Ty = rewriter.getIntegerType(64, true);
auto dimTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>(), int64Ty);
auto boolTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>(), rewriter.getI1Type());
Value iv = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i));
Value inDim = rewriter.create<Torch::AtenSizeIntOp>(
binder.getLoc(), torchIntTy, data, iv);
isZero = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), boolTy, isZero);
inDim = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), dimTy, inDim);
dim = rewriter.create<Torch::PrimNumToTensorScalarOp>(
binder.getLoc(), dimTy, dim);
Value finalDim = rewriter.create<Torch::AtenWhereSelfOp>(
binder.getLoc(), dimTy, isZero, inDim, dim);
dim = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dataSizes[i]));
finalDim);
}
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isZero, adjustment);
Value finalDim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), dim, finalOffset);
dimList.push_back(finalDim);
dimList.push_back(dim);
}
Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Expand Down
31 changes: 19 additions & 12 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
Location loc = op.getLoc();
Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>();
SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);
int64_t inputRank = inputType.getRank();
const TypeConverter *typeConverter = getTypeConverter();
auto resultType =
Expand Down Expand Up @@ -893,12 +892,6 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1");
}
SmallVector<Value> outputSizeInt = getTypeConvertedValues(
rewriter, loc, typeConverter, outputSizeTorchInt);
if (resultRank != (int64_t)outputSizeInt.size()) {
return rewriter.notifyMatchFailure(
op, "desired size list length mismatches with the result type rank");
}

auto [inputShape, outputShape] =
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt);
Expand Down Expand Up @@ -975,6 +968,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
// For more information, see description of helper functions used in the
// `if-else` cases inside the while loop.
int64_t inputDim = 0, outputDim = 0;
SmallVector<std::pair<int64_t, int64_t>> checkDimPairs;
for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) {
// Used for ensuring that we don't have an ambiguous expansion
bool assumedDynamicDimNotSplit = false;
Expand Down Expand Up @@ -1021,11 +1015,10 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
/// input and output dimensions in the slice statically
/// known to have the same number of elements.
} else if (inputShapeSlice[0] == kUnknownSize) {
// If the input is dynamic, assume it is not split
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
outputSizeInt[outputDim]);
// If output dimension is not dynamic, improve static information of
// input
// Defer the dynamic shape check to avoid DialectConversion assertion:
checkDimPairs.push_back(
std::pair<int64_t, int64_t>(inputDim, outputDim));

inputShape[inputDim] = outputShape[outputDim];
inputSliceIndices.push_back(0);
outputSliceIndices.push_back(0);
Expand Down Expand Up @@ -1073,6 +1066,20 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
}
}

SmallVector<Value> inputSize = getTensorSizes(rewriter, loc, input);

SmallVector<Value> outputSizeInt = getTypeConvertedValues(
rewriter, loc, typeConverter, outputSizeTorchInt);
if (resultRank != (int64_t)outputSizeInt.size()) {
return rewriter.notifyMatchFailure(
op, "desired size list length mismatches with the result type rank");
}

for (auto [inputDim, outputDim] : checkDimPairs) {
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
outputSizeInt[outputDim]);
}

auto cast = [&](Location loc, Type t, Value v) -> Value {
return rewriter.createOrFold<tensor::CastOp>(loc, t, v);
};
Expand Down
17 changes: 10 additions & 7 deletions lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,16 @@ class ConvertPrimNumToTensorScalarOp
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value a = adaptor.getA();
Value outTensor =
rewriter
.create<tensor::EmptyOp>(loc, ArrayRef<OpFoldResult>{}, a.getType())
->getResult(0);
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, a, outTensor);

RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type outElementType = resultType.getElementType();
Value elemVal = adaptor.getA();
Value elemValProm =
convertScalarToDtype(rewriter, loc, elemVal, outElementType);
Value zeroDTensor =
createInitTensor(rewriter, loc, {}, outElementType, elemValProm);
rewriter.replaceOp(op, zeroDTensor);
return success();
}
};
Expand Down
128 changes: 121 additions & 7 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,23 @@ void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenAddScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] + (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
int64_t bits = inputs[0].getBitWidth();
APInt other(bits, inputs[1].getLimitedValue());
APInt alpha(bits, inputs[2].getLimitedValue());
return inputs[0] + (other * alpha);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenSubTensorOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1373,6 +1390,23 @@ void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenSubScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 3);
return inputs[0] - (inputs[1] * inputs[2]);
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 3);
int64_t bits = inputs[0].getBitWidth();
APInt other(bits, inputs[1].getLimitedValue());
APInt alpha(bits, inputs[2].getLimitedValue());
return inputs[0] - (other * alpha);
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenRSubScalarOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1401,7 +1435,9 @@ OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) {

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
int64_t bits = inputs[0].getBitWidth();
APInt other(bits, inputs[1].getLimitedValue());
return inputs[0] * other;
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
Expand Down Expand Up @@ -1604,7 +1640,10 @@ OpFoldResult AtenLeScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs <= rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return unsign ? lhs.ule(rhs) : lhs.sle(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return unsign ? lhsWiden.ule(rhsWiden) : lhsWiden.sle(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand All @@ -1622,7 +1661,10 @@ OpFoldResult AtenLtScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs < rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return unsign ? lhs.ult(rhs) : lhs.slt(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return unsign ? lhsWiden.ult(rhsWiden) : lhsWiden.slt(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand All @@ -1640,7 +1682,10 @@ OpFoldResult AtenGtScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs > rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return unsign ? lhs.ugt(rhs) : lhs.sgt(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return unsign ? lhsWiden.ugt(rhsWiden) : lhsWiden.sgt(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand All @@ -1658,7 +1703,10 @@ OpFoldResult AtenGeScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs >= rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return unsign ? lhs.uge(rhs) : lhs.sge(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return unsign ? lhsWiden.uge(rhsWiden) : lhsWiden.sge(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand All @@ -1676,7 +1724,10 @@ OpFoldResult AtenEqScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs == rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return lhs.eq(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return lhsWiden.eq(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand All @@ -1694,7 +1745,10 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](double lhs, double rhs) -> bool { return lhs != rhs; };

auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool {
return lhs.ne(rhs);
int64_t bits = std::max(lhs.getBitWidth(), rhs.getBitWidth());
APInt lhsWiden(bits, lhs.getLimitedValue());
APInt rhsWiden(bits, rhs.getLimitedValue());
return lhsWiden.ne(rhsWiden);
};

return comparisonScaleFolder(self, other, resultTy, fpFold, intFold);
Expand Down Expand Up @@ -1749,6 +1803,20 @@ void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

OpFoldResult AtenMulScalarOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 2);
return inputs[0] * inputs[1];
};

return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenDivTensorModeOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3808,6 +3876,15 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

if (auto full = getOperand().getDefiningOp<Torch::AtenFullOp>()) {
return full.getFillValue();
}

if (auto numToTensor =
getOperand().getDefiningOp<Torch::PrimNumToTensorScalarOp>()) {
return numToTensor.getA();
}

return nullptr;
}

Expand Down Expand Up @@ -3984,6 +4061,9 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
}

OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) {
if (getSelf() == getOther())
return getSelf();

auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
Expand Down Expand Up @@ -4026,6 +4106,40 @@ OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) {
return getBroadcastedAttr(valueAttr, resultTy);
}

void AtenWhereScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {

patterns.add(+[](AtenWhereScalarOp op, PatternRewriter &rewriter) {
auto cond = op.getCondition();
auto self = op.getSelf();
auto other = op.getOther();

if (self != other)
return rewriter.notifyMatchFailure(op, "differing output");

auto condTy = dyn_cast<BaseTensorType>(cond.getType());
if (!condTy || !condTy.hasSizes())
return rewriter.notifyMatchFailure(op, "output size unknown");

SmallVector<Value> dims;
auto torchIntTy = rewriter.getType<Torch::IntType>();
for (int i = 0, s = condTy.getSizes().size(); i < s; ++i) {
Value iv = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i));
dims.push_back(rewriter.create<Torch::AtenSizeIntOp>(
op.getLoc(), torchIntTy, cond, iv));
}

Value dimsList = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(torchIntTy), dims);

Value none = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
op, op.getType(), dimsList, self, none, none, none, none);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenWhereScalarOtherOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit f97cd48

Please sign in to comment.