Skip to content

Commit

Permalink
Revert "[mlir][linalg] Replace "string" iterator_types attr with enum…
Browse files Browse the repository at this point in the history
…s in LinalgInterface."

Breaks linalg python tests. Would need to also update python/mlir/dialects/linalg/opdsl.

This reverts commit b809d73.
  • Loading branch information
olegshyshkov committed Nov 9, 2022
1 parent a209744 commit 4128090
Show file tree
Hide file tree
Showing 33 changed files with 329 additions and 385 deletions.
7 changes: 0 additions & 7 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#ifndef LINALG_BASE
#define LINALG_BASE

include "mlir/Dialect/Utils/StructuredOpsUtils.td"
include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
Expand Down Expand Up @@ -72,10 +71,4 @@ def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}

def IteratorTypeEnum : EnumAttr<Linalg_Dialect, IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}
def IteratorTypeArrayAttr : TypedArrayAttrBase<IteratorTypeEnum,
"Iterator type should be an enum.">;

#endif // LINALG_BASE
1 change: 0 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;

namespace detail {
Expand Down
52 changes: 40 additions & 12 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::count($_op.getIteratorTypesArray(),
utils::IteratorType::parallel);
return getNumIterators(getParallelIteratorTypeName(),
$_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
Expand All @@ -207,7 +207,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return findPositionsOfType($_op.getIteratorTypesArray(),
utils::IteratorType::parallel, res);
getParallelIteratorTypeName(), res);
}]
>,
InterfaceMethod<
Expand All @@ -219,8 +219,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::count($_op.getIteratorTypesArray(),
utils::IteratorType::reduction);
return getNumIterators(getReductionIteratorTypeName(),
$_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
Expand All @@ -233,7 +233,33 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return findPositionsOfType($_op.getIteratorTypesArray(),
utils::IteratorType::reduction, res);
getReductionIteratorTypeName(), res);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the number of window loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumWindowLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumIterators(getWindowIteratorTypeName(),
$_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the dims that are window loops.
}],
/*retTy=*/"void",
/*methodName=*/"getWindowDims",
/*args=*/(ins "SmallVectorImpl<unsigned> &":$res),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return findPositionsOfType($_op.getIteratorTypesArray(),
getWindowIteratorTypeName(), res);
}]
>,
InterfaceMethod<
Expand All @@ -245,7 +271,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getIteratorTypesArray().size();
return getNumIterators($_op.getIteratorTypesArray());
}]
>,
InterfaceMethod<
Expand All @@ -260,7 +286,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*defaultImplementation=*/[{
auto iters = $_op.getIteratorTypesArray();
return iters.size() == 1 &&
llvm::count(iters, utils::IteratorType::reduction) == 1;
getNumIterators(getReductionIteratorTypeName(), iters) == 1;
}]>,
//===------------------------------------------------------------------===//
// Input and Init arguments handling.
Expand Down Expand Up @@ -480,14 +506,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
can be infered from other parameters and in such cases default
getIteratorTypesArray should be overriden.
}],
/*retTy=*/"SmallVector<utils::IteratorType>",
/*retTy=*/"SmallVector<StringRef>",
/*methodName=*/"getIteratorTypesArray",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto range = $_op.getIteratorTypes()
.template getAsValueRange<IteratorTypeAttr,
utils::IteratorType>();
auto range = $_op.getIteratorTypes().template getAsValueRange<StringAttr>();
return {range.begin(), range.end()};
}]
>,
Expand Down Expand Up @@ -743,6 +767,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);

SmallVector<StringRef> getIteratorTypeNames() {
return getIteratorTypesArray();
}

//========================================================================//
// Forwarding functions to access interface methods from the
// DestinationStyleOpInterface.
Expand Down
18 changes: 9 additions & 9 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
IteratorTypeArrayAttr:$iterator_types,
ArrayAttr:$iterator_types,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<StrAttr>:$library_call);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
Expand All @@ -178,22 +178,22 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<utils::IteratorType>":$iteratorTypes, "StringRef":$doc,
"ArrayRef<StringRef>":$iteratorTypes, "StringRef":$doc,
"StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
"StringRef":$doc, "StringRef":$libraryCall,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "ArrayRef<AffineMap>":$indexingMaps,
"ArrayRef<utils::IteratorType>":$iteratorTypes,
"ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputBuffers,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<utils::IteratorType>":$iteratorTypes,
"ArrayRef<AffineMap>":$indexingMaps, "ArrayRef<StringRef>":$iteratorTypes,
CArg<"function_ref<void(OpBuilder &, Location, ValueRange)>", "nullptr">,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
Expand Down Expand Up @@ -275,7 +275,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Implement functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down Expand Up @@ -356,7 +356,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down Expand Up @@ -426,7 +426,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down Expand Up @@ -502,7 +502,7 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
SmallVector<StringRef> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
Expand Down
7 changes: 3 additions & 4 deletions mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ bool hasOnlyScalarElementwiseOp(Region &r);
bool isElementwise(LinalgOp op);

/// Check if iterator type has "parallel" semantics.
bool isParallelIterator(utils::IteratorType iteratorType);
bool isParallelIterator(StringRef iteratorType);

/// Check if iterator type has "reduction" semantics.
bool isReductionIterator(utils::IteratorType iteratorType);
bool isReductionIterator(StringRef iteratorType);

/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Expand Down Expand Up @@ -480,8 +480,7 @@ struct RegionMatcher {
template <typename LoopTy>
struct GenerateLoopNest {
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
LinalgOp linalgOp,
ArrayRef<utils::IteratorType> iteratorTypes,
LinalgOp linalgOp, ArrayRef<StringRef> iteratorTypes,
function_ref<scf::ValueVector(OpBuilder &, Location,
ValueRange, ValueRange)>
bodyBuilderFn,
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ namespace mlir {
namespace tosa {

// Creates a SmallVector of Stringrefs for N parallel loops
SmallVector<utils::IteratorType>
getNParallelLoopsAttrs(unsigned nParallelLoops);
SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops);

// Takes a vector of values and condenses them to a vector with no gaps.
SmallVector<Value> condenseValues(const SmallVector<Value> &values);
Expand Down
59 changes: 47 additions & 12 deletions mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc"
Expand All @@ -47,9 +48,42 @@ bool isColumnMajorMatmul(ArrayAttr indexingMaps);
/// the reduction.
bool isRowMajorBatchMatmul(ArrayAttr indexingMaps);

/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }

/// Use to encode that a particular iterator type has reduction semantics.
constexpr StringRef getReductionIteratorTypeName() { return "reduction"; }

/// Use to encode that a particular iterator type has window semantics.
constexpr StringRef getWindowIteratorTypeName() { return "window"; }

/// Use to encode that a particular iterator type has window semantics.
inline ArrayRef<StringRef> getAllIteratorTypeNames() {
static constexpr StringRef names[3] = {getParallelIteratorTypeName(),
getReductionIteratorTypeName(),
getWindowIteratorTypeName()};
return llvm::makeArrayRef(names);
}

/// Returns the iterator of a certain type.
inline unsigned getNumIterators(StringRef name,
ArrayRef<StringRef> iteratorTypes) {
auto names = getAllIteratorTypeNames();
(void)names;
assert(llvm::is_contained(names, name));
return llvm::count(iteratorTypes, name);
}

inline unsigned getNumIterators(ArrayRef<StringRef> iteratorTypes) {
unsigned res = 0;
for (auto n : getAllIteratorTypeNames())
res += getNumIterators(n, iteratorTypes);
return res;
}

/// Return positions in `iteratorTypes` that match `iteratorTypeName`.
inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
utils::IteratorType iteratorTypeName,
inline void findPositionsOfType(ArrayRef<StringRef> iteratorTypes,
StringRef iteratorTypeName,
SmallVectorImpl<unsigned> &res) {
for (const auto &en : llvm::enumerate(iteratorTypes)) {
if (en.value() == iteratorTypeName)
Expand All @@ -60,28 +94,29 @@ inline void findPositionsOfType(ArrayRef<utils::IteratorType> iteratorTypes,
/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
template <typename StructuredOpInterface, typename IteratorTypeT>
template <typename StructuredOpInterface>
class StructuredGenerator {
public:
using MapList = ArrayRef<ArrayRef<AffineExpr>>;

struct IteratorType {
IteratorType(IteratorTypeT iter) : iter(iter) {}
bool isOfType(IteratorTypeT expectedIter) const {
return expectedIter == iter;
}
IteratorTypeT iter;
IteratorType(StringRef strRef) : strRef(strRef) {}
bool isOfType(StringRef typeName) const { return typeName == strRef; }
StringRef strRef;
};
struct Par : public IteratorType {
Par() : IteratorType(IteratorTypeT::parallel) {}
Par() : IteratorType(getParallelIteratorTypeName()) {}
};
struct Red : public IteratorType {
Red() : IteratorType(IteratorTypeT::reduction) {}
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
struct Win : public IteratorType {
Win() : IteratorType(getWindowIteratorTypeName()) {}
};

StructuredGenerator(OpBuilder &builder, StructuredOpInterface op)
: builder(builder), ctx(op.getContext()), loc(op.getLoc()),
iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()),
iterators(op.getIteratorTypeNames()), maps(op.getIndexingMapsArray()),
op(op) {}

bool iters(ArrayRef<IteratorType> its) {
Expand All @@ -103,7 +138,7 @@ class StructuredGenerator {
OpBuilder &builder;
MLIRContext *ctx;
Location loc;
SmallVector<IteratorTypeT> iterators;
SmallVector<StringRef> iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};
Expand Down
11 changes: 6 additions & 5 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,12 @@ def Vector_ContractionOp :
return CombiningKind::ADD;
}

SmallVector<IteratorType> getIteratorTypesArray() {
auto range =
getIteratorTypes()
.template getAsValueRange<IteratorTypeAttr, IteratorType>();
return {range.begin(), range.end()};
// Returns iterator types in string format.
SmallVector<StringRef> getIteratorTypeNames() {
return llvm::to_vector(
llvm::map_range(getIteratorTypes(), [](Attribute a) {
return stringifyIteratorType(a.cast<IteratorTypeAttr>().getValue());
}));
}
}];

Expand Down

0 comments on commit 4128090

Please sign in to comment.