Skip to content

Commit

Permalink
[mlir][NFC] Remove unnecessary attr name getters from StructuredOpsUt…
Browse files Browse the repository at this point in the history
…ils.h.

Those methods were added long time ago. Now we get the same methods generated by tablegen, so there is no need for duplicates.

Differential Revision: https://reviews.llvm.org/D137544
  • Loading branch information
olegshyshkov committed Nov 7, 2022
1 parent 39dbfa7 commit bada353
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 54 deletions.
31 changes: 0 additions & 31 deletions mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
Expand Up @@ -48,37 +48,6 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);

/// Attribute name for the AffineArrayAttr which encodes the relationship
/// between a structured op iterators' and its operands.
constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }

/// Attribute name for the StrArrayAttr which encodes the type of a structured
/// op's iterators.
constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }

/// Attribute name for the StrArrayAttr which encodes the distribution type for
/// `linalg.tiled_loop`.
constexpr StringRef getDistributionTypesAttrName() {
return "distribution_types";
}

/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
constexpr StringRef getDocAttrName() { return "doc"; }

/// Attribute name for the StrArrayAttr which encodes the external library
/// function that implements the structured op.
constexpr StringRef getLibraryCallAttrName() { return "library_call"; }

/// Attribute name for the StrArrayAttr which encodes the value of strides.
constexpr StringRef getStridesAttrName() { return "strides"; }

/// Attribute name for the StrArrayAttr which encodes the value of dilations.
constexpr StringRef getDilationsAttrName() { return "dilations"; }

/// Attribute name for the StrArrayAttr which encodes the value of paddings.
constexpr StringRef getPaddingAttrName() { return "padding"; }

/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }

Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Expand Up @@ -244,7 +244,7 @@ def Vector_ContractionOp :
return getOperand(4).getType().cast<VectorType>();
}
Type getResultType() { return getResult().getType(); }
ArrayRef<StringRef> getTraitAttrNames();
SmallVector<StringRef> getTraitAttrNames();
static unsigned getAccOperandIndex() { return 2; }

llvm::SmallVector<::mlir::AffineMap, 4> getIndexingMapsArray() {
Expand All @@ -265,8 +265,6 @@ def Vector_ContractionOp :
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();

static constexpr StringRef getKindAttrStrName() { return "kind"; }

static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
}
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
Expand Up @@ -73,8 +73,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
m = m.compose(permutationMap);
newIndexingMaps.push_back(m);
}
genericOp->setAttr(getIndexingMapsAttrName(),
rewriter.getAffineMapArrayAttr(newIndexingMaps));
genericOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));

// 3. Compute the interchanged iterator types.
ArrayRef<Attribute> itTypes = genericOp.getIteratorTypes().getValue();
Expand All @@ -83,8 +83,7 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
SmallVector<int64_t> permutation(interchangeVector.begin(),
interchangeVector.end());
applyPermutationToVector(itTypesVector, permutation);
genericOp->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
genericOp.setIteratorTypesAttr(rewriter.getArrayAttr(itTypesVector));

// 4. Transform the index operations by applying the permutation map.
if (genericOp.hasIndexSemantics()) {
Expand Down
31 changes: 15 additions & 16 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -508,11 +508,11 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<IteratorType> iteratorTypes) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(::mlir::getIndexingMapsAttrName(),
result.addAttribute(getIndexingMapsAttrName(result.name),
builder.getAffineMapArrayAttr(
AffineMap::inferFromExprList(indexingExprs)));
result.addAttribute(
::mlir::getIteratorTypesAttrName(),
getIteratorTypesAttrName(result.name),
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
return IteratorTypeAttr::get(builder.getContext(), t);
Expand All @@ -533,9 +533,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
ArrayAttr iteratorTypes, CombiningKind kind) {
result.addOperands({lhs, rhs, acc});
result.addTypes(acc.getType());
result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
result.addAttribute(ContractionOp::getKindAttrStrName(),
result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
result.addAttribute(getKindAttrName(result.name),
CombiningKindAttr::get(builder.getContext(), kind));
}

Expand Down Expand Up @@ -570,7 +570,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
ArrayAttr iteratorTypes =
result.attributes.get("iterator_types").cast<ArrayAttr>();
result.attributes.get(getIteratorTypesAttrName(result.name))
.cast<ArrayAttr>();

SmallVector<Attribute> iteratorTypeAttrs;

Expand All @@ -579,15 +580,15 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
if (!maybeIteratorType.has_value())
return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";

iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
parser.getContext(), maybeIteratorType.value()));
iteratorTypeAttrs.push_back(
IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
}
result.attributes.set("iterator_types",
result.attributes.set(getIteratorTypesAttrName(result.name),
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));

if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
if (!result.attributes.get(getKindAttrName(result.name))) {
result.addAttribute(
ContractionOp::getKindAttrStrName(),
getKindAttrName(result.name),
CombiningKindAttr::get(result.getContext(),
ContractionOp::getDefaultKind()));
}
Expand Down Expand Up @@ -822,11 +823,9 @@ LogicalResult ContractionOp::verify() {
return success();
}

ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
::mlir::getIteratorTypesAttrName(),
ContractionOp::getKindAttrStrName()};
return llvm::makeArrayRef(names);
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef>{getIndexingMapsAttrName(),
getIteratorTypesAttrName(), getKindAttrName()};
}

static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
Expand Down

0 comments on commit bada353

Please sign in to comment.