Skip to content

Commit

Permalink
Merge 1727460 into 2ff4102
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Jun 13, 2024
2 parents 2ff4102 + 1727460 commit dddd1e0
Show file tree
Hide file tree
Showing 35 changed files with 125 additions and 123 deletions.
38 changes: 1 addition & 37 deletions compiler/plugins/input/StableHLO/Conversion/LegalizeCHLO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/BroadcastUtils.h"
Expand Down Expand Up @@ -443,38 +442,6 @@ struct ConvertSelectOp final
}
};

struct ConvertDynamicReshapeOp final
: OpRewritePattern<mlir::chlo::DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::chlo::DynamicReshapeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
TypedValue<TensorType> tensor = op.getOperand();
TypedValue<RankedTensorType> shape = op.getOutputShape();

auto shapeTy = cast<ShapedType>(shape.getType());
auto resultTy = cast<ShapedType>(op.getType());

Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
Value cstr =
rewriter.create<mlir::stablehlo::CstrReshapableOp>(loc, numEls, shape);
rewriter.replaceOpWithNewOp<shape::AssumingOp>(
op, cstr, [&](OpBuilder &b, Location l) {
Value computedShape =
b.create<mlir::stablehlo::ComputeReshapeShapeOp>(l, shapeTy,
numEls, shape);
SmallVector<Value> result;
result.push_back(b.create<mlir::stablehlo::DynamicReshapeOp>(
l, resultTy, tensor, computedShape));
return result;
});

return success();
}
};

//===----------------------------------------------------------------------===//
// Decomposition Patterns.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2192,7 +2159,6 @@ struct LegalizeChlo final : impl::LegalizeChloBase<LegalizeChlo> {
ConversionTarget conversionTarget(getContext());
RewritePatternSet conversionPatterns(ctx);
conversionTarget.addIllegalDialect<chlo::ChloDialect>();
conversionTarget.addLegalOp<chlo::MinimumBroadcastShapesOp>();
conversionTarget.addLegalDialect<
mlir::stablehlo::StablehloDialect, mlir::arith::ArithDialect,
mlir::shape::ShapeDialect, mlir::scf::SCFDialect,
Expand Down Expand Up @@ -2239,9 +2205,7 @@ static void populateBroadcastingPatterns(MLIRContext *context,
context, patterns, 10);
populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
context, patterns, 5);
patterns
->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
context);
patterns->add<ConvertConstantLikeOp, ConvertSelectOp>(context);
}

static void populateDecompositionPatterns(MLIRContext *context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ struct ScatterImplicitBatch final

auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdateWindowDims,
dimNumbers.getInsertedWindowDims(),
dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
dimNumbers.getIndexVectorDim() + 1);

Expand Down Expand Up @@ -700,7 +701,8 @@ struct ScatterCollapseBatch final

auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims,
dimNumbers.getInsertedWindowDims(),
dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/1);

Expand Down Expand Up @@ -801,7 +803,8 @@ struct ScatterBatchFirst final : OpRewritePattern<mlir::stablehlo::ScatterOp> {

auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims,
dimNumbers.getInsertedWindowDims(),
dimNumbers.getInsertedWindowDims(), dimNumbers.getInputBatchingDims(),
dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/indexVectorDim);

Expand Down Expand Up @@ -939,6 +942,8 @@ struct ScatterMaterializeInsertedDim final

auto newDimNumbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(
op.getContext(), newUpdatedWindowDims, newInsertedWindowDims,
dimNumbers.getInputBatchingDims(),
dimNumbers.getScatterIndicesBatchingDims(),
dimNumbers.getScatterDimsToOperandDims(),
/*indexVectorDim=*/1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ struct GenericTypeConvert final : ConversionPattern {
return rewriter.notifyMatchFailure(op,
"argument type conversion failed");
}
rewriter.applySignatureConversion(newRegion, result);
rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ struct MapOpToGenericConverter final
}
signatureConverter.addInputs(resultType.getElementType());

rewriter.applySignatureConversion(&region, signatureConverter,
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
Expand Down Expand Up @@ -1706,7 +1706,7 @@ struct MapOpToMapConverter final : OpConversionPattern<mlir::stablehlo::MapOp> {
signatureConverter.addInputs(idx, convertedTy);
}

rewriter.applySignatureConversion(&region, signatureConverter,
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType,
linalgOp.getResults());
Expand Down Expand Up @@ -2073,8 +2073,8 @@ struct SelectAndScatterNoOverlapConverter final
reduceSignConverter.addInputs(srcETy);
reduceSignConverter.addInputs(1, destETy);
reduceSignConverter.addInputs(indexETy);
rewriter.applySignatureConversion(&reduceRegion, reduceSignConverter,
getTypeConverter());
rewriter.applySignatureConversion(&reduceRegion.front(),
reduceSignConverter, getTypeConverter());

// Grab the terminator and use the turned value to now select the
// correct index and value.
Expand Down Expand Up @@ -2179,8 +2179,8 @@ struct SelectAndScatterNoOverlapConverter final
scatterSignConverter.addInputs(indexETy);
scatterSignConverter.addInputs(0, sourceTy.getElementType());
scatterSignConverter.addInputs(1, sourceTy.getElementType());
rewriter.applySignatureConversion(&scatterRegion, scatterSignConverter,
getTypeConverter());
rewriter.applySignatureConversion(&scatterRegion.front(),
scatterSignConverter, getTypeConverter());

auto &scatterBlock = scatterRegion.front();
auto scatterTerminator = scatterBlock.getTerminator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ struct SortOpConversion final : OpConversionPattern<mlir::stablehlo::SortOp> {
idx, getTypeConverter()->convertType(
getElementTypeOrSelf(argument.getType())));
}
rewriter.applySignatureConversion(&region, signature_converter);
rewriter.applySignatureConversion(&region.front(), signature_converter);

rewriter.replaceOp(op, sortOp->getResults());
return success();
Expand Down Expand Up @@ -281,7 +281,7 @@ struct ScatterOpConversion final
// where output[O] maps to block args #1 in linalg_ext.scatter ops.
signatureConverter.addInputs(1, argType);
signatureConverter.addInputs(0, argType);
rewriter.applySignatureConversion(&region, signatureConverter);
rewriter.applySignatureConversion(&region.front(), signatureConverter);

rewriter.replaceOp(op, scatterOp->getResults());
return success();
Expand Down Expand Up @@ -598,7 +598,8 @@ struct ScanOpConversion final
TypeConverter::SignatureConversion signatureConverter(2);
signatureConverter.addInputs(0, input0Ty.getElementType());
signatureConverter.addInputs(1, init0Ty.getElementType());
rewriter.applySignatureConversion(&scanOp.getRegion(), signatureConverter);
rewriter.applySignatureConversion(&scanOp.getRegion().front(),
signatureConverter);

rewriter.replaceOp(op, scanOp.getResult(0));
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ struct ReduceOpToGenericConverter final
cast<ShapedType>(val.getType()).getElementType()));
}

rewriter.applySignatureConversion(&region, signatureConverter,
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
Expand Down Expand Up @@ -301,7 +301,7 @@ struct ReduceOpToReduceConverter final
// type for new operand number 'idx' + linalgOp.getNumInputs()
typeConverter->convertType(val.getElementType()));
}
rewriter.applySignatureConversion(&region, signatureConverter,
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());

// Cast the result to the correct type.
Expand Down Expand Up @@ -470,7 +470,7 @@ struct ReduceWindowOpOnTensorsGenericConversion final
i, cast<ShapedType>(input.getType()).getElementType());
}

rewriter.applySignatureConversion(&region, signatureConverter,
rewriter.applySignatureConversion(&region.front(), signatureConverter,
getTypeConverter());
rewriter.replaceOp(op, linalgOp.getResults());
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class GenericTypeConvert : public ConversionPattern {
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class GenericTypeConvert : public ConversionPattern {
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
rewriter.applySignatureConversion(&newRegion->front(), result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
Expand Down
4 changes: 2 additions & 2 deletions compiler/plugins/input/Torch/torch-mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ iree_tablegen_library(
TD_FILE
"${TORCH_MLIR_ROOT_DIR}/include/torch-mlir/Dialect/Torch/IR/TorchOps.td"
OUTS
-gen-dialect-decls Dialect/Torch/IR/TorchDialect.h.inc
-gen-dialect-defs Dialect/Torch/IR/TorchDialect.cpp.inc
-gen-dialect-decls -dialect=torch Dialect/Torch/IR/TorchDialect.h.inc
-gen-dialect-defs -dialect=torch Dialect/Torch/IR/TorchDialect.cpp.inc
-gen-op-decls Dialect/Torch/IR/TorchOps.h.inc
-gen-op-defs Dialect/Torch/IR/TorchOps.cpp.inc
)
Expand Down
24 changes: 24 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ extern void mlirAttributeIsADictionary();
extern void mlirAttributeIsAElements();
extern void mlirAttributeIsAFlatSymbolRef();
extern void mlirAttributeIsAFloat();
extern void mlirAttributeIsAGPUObjectAttr();
extern void mlirAttributeIsAInteger();
extern void mlirAttributeIsAIntegerSet();
extern void mlirAttributeIsALocation();
Expand All @@ -181,6 +182,7 @@ extern void mlirBlockCreate();
extern void mlirBlockDestroy();
extern void mlirBlockDetach();
extern void mlirBlockEqual();
extern void mlirBlockEraseArgument();
extern void mlirBlockGetArgument();
extern void mlirBlockGetFirstOperation();
extern void mlirBlockGetNextInRegion();
Expand Down Expand Up @@ -355,6 +357,12 @@ extern void mlirFunctionTypeGetNumInputs();
extern void mlirFunctionTypeGetNumResults();
extern void mlirFunctionTypeGetResult();
extern void mlirFunctionTypeGetTypeID();
extern void mlirGPUObjectAttrGet();
extern void mlirGPUObjectAttrGetFormat();
extern void mlirGPUObjectAttrGetObject();
extern void mlirGPUObjectAttrGetProperties();
extern void mlirGPUObjectAttrGetTarget();
extern void mlirGPUObjectAttrHasProperties();
extern void mlirGetDialectHandle__iree_input__();
extern void mlirGetDialectHandle__transform__();
extern void mlirIREELinalgTransformRegisterPasses();
Expand Down Expand Up @@ -399,6 +407,7 @@ extern void mlirIntegerTypeIsSignless();
extern void mlirIntegerTypeIsUnsigned();
extern void mlirIntegerTypeSignedGet();
extern void mlirIntegerTypeUnsignedGet();
extern void mlirIsCurrentDebugType();
extern void mlirIsGlobalDebugEnabled();
extern void mlirLinalgFillBuiltinNamedOpRegion();
extern void mlirLlvmThreadPoolCreate();
Expand All @@ -422,6 +431,7 @@ extern void mlirMemRefTypeGetLayout();
extern void mlirMemRefTypeGetMemorySpace();
extern void mlirMemRefTypeGetStridesAndOffset();
extern void mlirMemRefTypeGetTypeID();
extern void mlirMergeSymbolsIntoFromClone();
extern void mlirModuleCreateEmpty();
extern void mlirModuleCreateParse();
extern void mlirModuleDestroy();
Expand Down Expand Up @@ -547,6 +557,8 @@ extern void mlirRegionInsertOwnedBlockBefore();
extern void mlirRegionTakeBody();
extern void mlirRegisterGPUPasses();
extern void mlirRegisterLinalgPasses();
extern void mlirSetGlobalDebugType();
extern void mlirSetGlobalDebugTypes();
extern void mlirShapedTypeGetDimSize();
extern void mlirShapedTypeGetDynamicSize();
extern void mlirShapedTypeGetDynamicStrideOrOffset();
Expand Down Expand Up @@ -854,6 +866,7 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirAttributeIsAElements;
x += (uintptr_t)&mlirAttributeIsAFlatSymbolRef;
x += (uintptr_t)&mlirAttributeIsAFloat;
x += (uintptr_t)&mlirAttributeIsAGPUObjectAttr;
x += (uintptr_t)&mlirAttributeIsAInteger;
x += (uintptr_t)&mlirAttributeIsAIntegerSet;
x += (uintptr_t)&mlirAttributeIsALocation;
Expand All @@ -877,6 +890,7 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirBlockDestroy;
x += (uintptr_t)&mlirBlockDetach;
x += (uintptr_t)&mlirBlockEqual;
x += (uintptr_t)&mlirBlockEraseArgument;
x += (uintptr_t)&mlirBlockGetArgument;
x += (uintptr_t)&mlirBlockGetFirstOperation;
x += (uintptr_t)&mlirBlockGetNextInRegion;
Expand Down Expand Up @@ -1051,6 +1065,12 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirFunctionTypeGetNumResults;
x += (uintptr_t)&mlirFunctionTypeGetResult;
x += (uintptr_t)&mlirFunctionTypeGetTypeID;
x += (uintptr_t)&mlirGPUObjectAttrGet;
x += (uintptr_t)&mlirGPUObjectAttrGetFormat;
x += (uintptr_t)&mlirGPUObjectAttrGetObject;
x += (uintptr_t)&mlirGPUObjectAttrGetProperties;
x += (uintptr_t)&mlirGPUObjectAttrGetTarget;
x += (uintptr_t)&mlirGPUObjectAttrHasProperties;
x += (uintptr_t)&mlirGetDialectHandle__iree_input__;
x += (uintptr_t)&mlirGetDialectHandle__transform__;
x += (uintptr_t)&mlirIREELinalgTransformRegisterPasses;
Expand Down Expand Up @@ -1095,6 +1115,7 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirIntegerTypeIsUnsigned;
x += (uintptr_t)&mlirIntegerTypeSignedGet;
x += (uintptr_t)&mlirIntegerTypeUnsignedGet;
x += (uintptr_t)&mlirIsCurrentDebugType;
x += (uintptr_t)&mlirIsGlobalDebugEnabled;
x += (uintptr_t)&mlirLinalgFillBuiltinNamedOpRegion;
x += (uintptr_t)&mlirLlvmThreadPoolCreate;
Expand All @@ -1118,6 +1139,7 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirMemRefTypeGetMemorySpace;
x += (uintptr_t)&mlirMemRefTypeGetStridesAndOffset;
x += (uintptr_t)&mlirMemRefTypeGetTypeID;
x += (uintptr_t)&mlirMergeSymbolsIntoFromClone;
x += (uintptr_t)&mlirModuleCreateEmpty;
x += (uintptr_t)&mlirModuleCreateParse;
x += (uintptr_t)&mlirModuleDestroy;
Expand Down Expand Up @@ -1243,6 +1265,8 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&mlirRegionTakeBody;
x += (uintptr_t)&mlirRegisterGPUPasses;
x += (uintptr_t)&mlirRegisterLinalgPasses;
x += (uintptr_t)&mlirSetGlobalDebugType;
x += (uintptr_t)&mlirSetGlobalDebugTypes;
x += (uintptr_t)&mlirShapedTypeGetDimSize;
x += (uintptr_t)&mlirShapedTypeGetDynamicSize;
x += (uintptr_t)&mlirShapedTypeGetDynamicStrideOrOffset;
Expand Down
12 changes: 12 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ EXPORTS
mlirAttributeIsAElements
mlirAttributeIsAFlatSymbolRef
mlirAttributeIsAFloat
mlirAttributeIsAGPUObjectAttr
mlirAttributeIsAInteger
mlirAttributeIsAIntegerSet
mlirAttributeIsALocation
Expand All @@ -173,6 +174,7 @@ EXPORTS
mlirBlockDestroy
mlirBlockDetach
mlirBlockEqual
mlirBlockEraseArgument
mlirBlockGetArgument
mlirBlockGetFirstOperation
mlirBlockGetNextInRegion
Expand Down Expand Up @@ -347,6 +349,12 @@ EXPORTS
mlirFunctionTypeGetNumResults
mlirFunctionTypeGetResult
mlirFunctionTypeGetTypeID
mlirGPUObjectAttrGet
mlirGPUObjectAttrGetFormat
mlirGPUObjectAttrGetObject
mlirGPUObjectAttrGetProperties
mlirGPUObjectAttrGetTarget
mlirGPUObjectAttrHasProperties
mlirGetDialectHandle__iree_input__
mlirGetDialectHandle__transform__
mlirIREELinalgTransformRegisterPasses
Expand Down Expand Up @@ -391,6 +399,7 @@ EXPORTS
mlirIntegerTypeIsUnsigned
mlirIntegerTypeSignedGet
mlirIntegerTypeUnsignedGet
mlirIsCurrentDebugType
mlirIsGlobalDebugEnabled
mlirLinalgFillBuiltinNamedOpRegion
mlirLlvmThreadPoolCreate
Expand All @@ -414,6 +423,7 @@ EXPORTS
mlirMemRefTypeGetMemorySpace
mlirMemRefTypeGetStridesAndOffset
mlirMemRefTypeGetTypeID
mlirMergeSymbolsIntoFromClone
mlirModuleCreateEmpty
mlirModuleCreateParse
mlirModuleDestroy
Expand Down Expand Up @@ -539,6 +549,8 @@ EXPORTS
mlirRegionTakeBody
mlirRegisterGPUPasses
mlirRegisterLinalgPasses
mlirSetGlobalDebugType
mlirSetGlobalDebugTypes
mlirShapedTypeGetDimSize
mlirShapedTypeGetDynamicSize
mlirShapedTypeGetDynamicStrideOrOffset
Expand Down
Loading

0 comments on commit dddd1e0

Please sign in to comment.