Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm-update] update llvm 6127f15e5b4834411e8f2e700e25c40490deec35 #347

Merged
merged 6 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/cmake/mhlo.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
add_subdirectory(${BYTEIR_SRC_DIR}/../external/mlir-hlo ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo EXCLUDE_FROM_ALL)

# FIXME: remove this
target_link_libraries(ChloPasses PUBLIC StablehloPasses)

include_directories(${BYTEIR_SRC_DIR}/../external/mlir-hlo)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo)
include_directories(${BYTEIR_SRC_DIR}/../external/mlir-hlo/stablehlo)
Expand Down
4 changes: 2 additions & 2 deletions compiler/dialects/include/byteir/Dialect/Ccl/IR/CclOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ class Ccl_ReplicaGroupsOp<string mnemonic, list<Trait> traits = []> :
SmallVector<ReplicaGroupsIndices, 4> replicaGroupsIndices;
for (auto attr : *maybeReplicaGroups)
replicaGroupsIndices.push_back(llvm::to_vector(
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
return indexAttr.cast<IntegerAttr>().getInt();
llvm::map_range(cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
return cast<IntegerAttr>(indexAttr).getInt();
})));
return replicaGroupsIndices;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#define BYTEIR_DIALECT_CCL_TRANSFORMOPS_CCLTRANSFORMOPS_H

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringRef.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/byteir/Dialect/GPU/TransformOps/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"

Expand Down
4 changes: 2 additions & 2 deletions compiler/include/byteir/Dialect/Lccl/LcclOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class Lccl_ReplicaGroupsOp<string mnemonic, list<Trait> traits = []> :
SmallVector<ReplicaGroupsIndices, 4> replicaGroupsIndices;
for (auto attr : *maybeReplicaGroups)
replicaGroupsIndices.push_back(llvm::to_vector(
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
return indexAttr.cast<IntegerAttr>().getInt();
llvm::map_range(cast<ArrayAttr>(attr), [&](Attribute indexAttr) {
return cast<IntegerAttr>(indexAttr).getInt();
})));
return replicaGroupsIndices;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ namespace detail {
LogicalResult verifyLinalgExtOpInterface(Operation *op);
}

} // namespace linalg_ext
} // namespace mlir

#include "byteir/Dialect/Linalg/IR/LinalgExtOps.h.inc" // IWYU pragma: export

/// Include the generated interface declarations.
#include "byteir/Dialect/Linalg/IR/LinalgExtOpInterfaces.h.inc" // IWYU pragma: export

} // namespace linalg_ext
} // namespace mlir

#endif // BYTEIR_DIALECT_LINALG_IR_LINALGEXTINTERFACES_H
10 changes: 6 additions & 4 deletions compiler/include/byteir/Dialect/Linalg/IR/LinalgExtInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
llvm::transform(getOutputBufferOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<MemRefType>();
return cast<MemRefType>(opOperands->get().getType());
});
return result;
}]
Expand All @@ -286,7 +286,7 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
llvm::transform(getOutputTensorOperands(),
std::back_inserter(result),
[](OpOperand *opOperands) {
return opOperands->get().getType().cast<RankedTensorType>();
return cast<RankedTensorType>(opOperands->get().getType());
});
return result;
}]
Expand Down Expand Up @@ -544,6 +544,8 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {
>
];

let cppNamespace = "::mlir::linalg_ext";

let extraClassDeclaration = [{
/// Returns the value that expresses the shape of the output in terms of
/// shape of the input operands where possible.
Expand All @@ -559,8 +561,8 @@ def LinalgExtInterface : OpInterface<"LinalgExtOp"> {

private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>();
auto attr = cast<DenseIntElementsAttr>(
(*this)->getAttr("operand_segment_sizes"));
unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
Expand Down
18 changes: 9 additions & 9 deletions compiler/include/byteir/Dialect/Linalg/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def LinalgExt_SoftmaxOp : LinalgExt_Op<"softmax",
return getOutputOperand(3)->get();
}
ShapedType getOperandType() {
return output().getType().cast<ShapedType>();
return cast<ShapedType>(output().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
Expand Down Expand Up @@ -292,7 +292,7 @@ def LinalgExt_UnnormalizedSoftmaxOp : LinalgExt_Op<"unnorm_softmax",
return getOutputOperand(3)->get();
}
ShapedType getOperandType() {
return output().getType().cast<ShapedType>();
return cast<ShapedType>(output().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
Expand Down Expand Up @@ -352,7 +352,7 @@ def LinalgExt_DiagOp : LinalgExt_Op<"diag",
return getOutputOperand(0)->get();
}
ShapedType getOperandType() {
return output().getType().cast<ShapedType>();
return cast<ShapedType>(output().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
Expand Down Expand Up @@ -413,7 +413,7 @@ def LinalgExt_ScanOp : LinalgExt_Op<"scan",
return getOutputOperand(0)->get();
}
ShapedType getOperandType() {
return input().getType().cast<ShapedType>();
return cast<ShapedType>(input().getType());
}
int64_t getOperandRank() {
return getOperandType().getRank();
Expand Down Expand Up @@ -498,13 +498,13 @@ def LinalgExt_ScatterOp : LinalgExt_Op<"scatter",
return getOutputOperand(0)->get();
}
ShapedType getUpdateType() {
return update().getType().cast<ShapedType>();
return cast<ShapedType>(update().getType());
}
ShapedType getIndicesType() {
return indices().getType().cast<ShapedType>();
return cast<ShapedType>(indices().getType());
}
ShapedType getSrcType() {
return src().getType().cast<ShapedType>();
return cast<ShapedType>(src().getType());
}
int64_t getUpdateRank() {
return getUpdateType().getRank();
Expand Down Expand Up @@ -573,7 +573,7 @@ def LinalgExt_TopkOp : LinalgExt_Op<"topk",
return getOutputOperand(1)->get();
}
ShapedType getInputType() {
return values().getType().cast<ShapedType>();
return cast<ShapedType>(values().getType());
}
int64_t getInputRank() {
return getInputType().getRank();
Expand Down Expand Up @@ -679,7 +679,7 @@ def LinalgExt_LayerNormOp : LinalgExt_Op<"layer_norm",
return getOutputOperand(2)->get();
}
ShapedType getOperandType(int64_t idx) {
return getInputOperand(idx)->get().getType().cast<ShapedType>();
return cast<ShapedType>(getInputOperand(idx)->get().getType());
}
int64_t getOperandRank(int64_t idx) {
return getOperandType(idx).getRank();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#define BYTEIR_DIALECT_LINALG_TRANSFORMOPS_LINALGEXTTRANSFORMOPS_H

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/StringRef.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
Expand Down
7 changes: 7 additions & 0 deletions compiler/include/byteir/Dialect/Linalg/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
}
template <typename DstOpTy>
static SmallVector<DstOpTy>
castToTypedOperations(ArrayRef<LoopLikeOpInterface> ifaces) {
return llvm::to_vector(llvm::map_range(ifaces, [](LoopLikeOpInterface iface) {
return cast<DstOpTy>(iface.getOperation());
}));
}
template <typename DstOpTy>
static SmallVector<DstOpTy>
castToTypedOperations(const SmallVector<Operation *> &ops) {
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#define BYTEIR_DIALECT_TRANSFORM_IR_TRANSFORM_EXT_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

Expand Down
90 changes: 2 additions & 88 deletions compiler/include/byteir/Dialect/mhlo/Analysis/ShapeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ class MhloShapeAnalysisBase : public ShapeAnalysis<ShapeKnowledgeType> {
wrapperShapeValueKnowledges);

return inferFunc(op->getContext(), op->getLoc(), range,
op->getAttrDictionary(), op->getRegions(), results);
op->getAttrDictionary(), op->getPropertiesStorage(),
op->getRegions(), results);
}
};

Expand Down Expand Up @@ -404,93 +405,6 @@ class MhloShapeValueAnalysisBase
lattice->join(mlir::dataflow::ConstantValue(
outAttr, op->getDialect())));
})
.template Case<mhlo::ComputeReshapeShapeOp>([&](Operation *op) {
mhlo::ComputeReshapeShapeOp computeReshapeShapeOp =
dyn_cast<mhlo::ComputeReshapeShapeOp>(op);
Value dynamicShapeV = computeReshapeShapeOp.getDynamicShape();
auto *boundedShapeValue =
this->template getOrCreate<BoundedValueLattice>(dynamicShapeV);
boundedShapeValue->useDefSubscribe(this);

const ShapeValueLattice *product = operands[0];
const ShapeValueLattice *shapeValue = operands[1];
if (product->getValue().isUninitialized()) {
return;
}
if (shapeValue->getValue().isUninitialized()) {
return;
}
if (boundedShapeValue->getValue().isUninitialized()) {
return;
}

if (!product->getValue().getConstantValue()) {
return this->setAllToEntryStates(results);
}
if (!shapeValue->getValue().getConstantValue() &&
boundedShapeValue->getValue().isUnknown()) {
return this->setAllToEntryStates(results);
}
Attribute productAttr = product->getValue().getConstantValue();
Attribute constShapeAttr = shapeValue->getValue().getConstantValue();
Attribute upperShapeAttr = boundedShapeValue->getValue().upper();
if (constShapeAttr) {
assert(!upperShapeAttr);
} else {
assert(upperShapeAttr);
constShapeAttr = upperShapeAttr;
}

Attribute resAttr = constShapeAttr;
ShapeValueLattice *lattice = results[0];
// in some cases, the shape in computeReshapeShapeOp is dense<[-1, x,
// ....]>, we need calculate firstly
do {
auto denseInt =
dyn_cast_or_null<DenseIntElementsAttr>(constShapeAttr);
if (denseInt == nullptr) {
break;
}
auto dataType = dyn_cast<IntegerType>(denseInt.getElementType());
// is int32
if (dataType == nullptr || dataType.isUnsigned() ||
dataType.getWidth() != 32) {
break;
}
llvm::SmallVector<int32_t> shape =
llvm::to_vector(denseInt.getValues<int32_t>());

// check whether has dimSize < 0, aka dynamic in mhlo
int cntDynamic = llvm::count_if(
shape, [](int32_t dimSize) { return dimSize < 0; });

if (cntDynamic == 1) {
if (auto num = dyn_cast_or_null<IntegerAttr>(productAttr)) {
int64_t number = num.getInt();
if (number < 0) {
break;
}

int32_t index = K_INITIAL;
for (auto elem : llvm::enumerate(shape)) {
if (elem.value() < 0) {
index = elem.index();
} else {
number /= elem.value();
}
}
assert(index != K_INITIAL);
shape[index] = number;
resAttr = DenseIntElementsAttr::get(denseInt.getType(), shape);
}
}
} while (0);

LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << resAttr << "\n");
this->propagateIfChanged(lattice,
lattice->join(mlir::dataflow::ConstantValue(
resAttr, op->getDialect())));
})
.Default([&](Operation *op) {
ShapeValueAnalysis<ShapeKnowledgeType>::visitOperation(op, operands,
results);
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/byteir/Dialect/mhlo/Util/ShapeInferUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ inferBoundedReturnTypeComponents(llvm::StringRef name);

using InferReturnTypeComponents = std::function<LogicalResult(
MLIRContext *, std::optional<Location>, ValueShapeRange operands,
DictionaryAttr, RegionRange,
DictionaryAttr, OpaqueProperties, RegionRange,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnTypes)>;

struct InferReturnTypeComponentsRegistration {
Expand Down
5 changes: 3 additions & 2 deletions compiler/lib/CAPI/PDLValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ bool mlirRegisterPDLConstraintFn(MlirContext ctx, MlirStringRef name, void *pfn,
return registerPDLConstraintFunction(
unwrap(ctx), unwrap(name),
[fn = *reinterpret_cast<std::function<bool(std::vector<MlirPDLValue>)> *>(
pfn)](PatternRewriter &,
pfn)](PatternRewriter &, PDLResultList &,
ArrayRef<PDLValue> pdlValues) -> LogicalResult {
std::vector<MlirPDLValue> wrapped;
wrapped.reserve(pdlValues.size());
Expand Down Expand Up @@ -167,7 +167,8 @@ bool mlirRegisterPDLRewriteFn(MlirContext ctx, MlirStringRef name, void *pfn,
insertionPoint = wrap(&*rewriter.getInsertionPoint());

auto onOperationInserted = [&](MlirOperation op) {
rewriter.getListener()->notifyOperationInserted(unwrap(op));
rewriter.getListener()->notifyOperationInserted(unwrap(op),
/*previous=*/{});
};

if (!fn(insertionPoint, wrap(resultList), wrapped, onOperationInserted))
Expand Down
Loading
Loading