Skip to content

Commit

Permalink
[mlir][linalg] Add support for scalar input operands.
Browse files Browse the repository at this point in the history
Up to now all structured op operands are assumed to be shaped. The patch relaxes this assumption and allows scalar input operands. In contrast to shaped operands scalar operands are not indexed and directly forwarded to the body of the operation. As all other operands, scalar operands are associated to an indexing map that in case of a scalar or a 0D-operand has an empty range.

We will use scalar operands as a replacement for the capture mechanism. In contrast to captures, the approach ensures we can generate the function signature from the operand list and it prevents outdated capture values in case a transformation updates only the capture operand but not the hidden body of a named operation.

Removing captures and updating existing operations such as linalg.fill is left for a later patch.

The patch depends on https://reviews.llvm.org/D103891 and https://reviews.llvm.org/D103890.

Differential Revision: https://reviews.llvm.org/D104109
  • Loading branch information
Tobias Gysi committed Jun 14, 2021
1 parent ddda52c commit 046922e
Show file tree
Hide file tree
Showing 24 changed files with 366 additions and 107 deletions.
2 changes: 0 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Expand Up @@ -15,8 +15,6 @@

include "mlir/IR/OpBase.td"

def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;

def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{
Expand Down
32 changes: 26 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Expand Up @@ -584,6 +584,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return {};
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if the `opOperand` is a scalar value.
}],
/*retTy=*/"bool",
/*methodName=*/"isScalar",
/*args=*/(ins "OpOperand*":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opOperand->getOwner() == this->getOperation());
return !opOperand->get().getType().template isa<ShapedType>();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the input or output indexing map for `opOperand`.
Expand Down Expand Up @@ -694,10 +707,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
return this->getOperation()->getNumResults() == 0 &&
llvm::all_of(getInputAndOutputOperands(),
[](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
});
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<MemRefType>();
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<MemRefType>();
});
}]
>,
InterfaceMethod<
Expand All @@ -709,8 +725,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return llvm::all_of(getInputAndOutputOperands(),
[](OpOperand *opOperand) {
return
llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) {
return isScalar(opOperand) ||
opOperand->get().getType().template isa<RankedTensorType>();
}) &&
llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) {
return opOperand->get().getType().template isa<RankedTensorType>();
});
}]
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Expand Up @@ -640,8 +640,8 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let arguments = (ins Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
Variadic<LinalgOperand>:$inputs,
Variadic<LinalgOperand>:$outputs,
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
ArrayAttr:$iterator_types,
OptionalAttr<ArrayAttr>:$distribution_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
Expand Down
7 changes: 1 addition & 6 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Expand Up @@ -517,17 +517,12 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
//===----------------------------------------------------------------------===//
// Generic Linalg ops.
//===----------------------------------------------------------------------===//
class LinalgOperandOfRank<int rank>: Type<
And<[
LinalgOperand.predicate,
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
>>;

class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SingleBlockImplicitTerminator<"YieldOp">]> {
let arguments = (ins Variadic<AnyShaped>:$inputs,
let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
Expand Down
23 changes: 12 additions & 11 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Expand Up @@ -338,7 +338,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
return failure();

// All shaped operands must be indexed.
// All input/output operands must be indexed.
if (static_cast<int64_t>(linalgOp.indexing_maps().size()) !=
linalgOp.getNumInputsAndOutputs())
return op->emitOpError("expected the number of indexing_map (")
Expand All @@ -363,7 +363,7 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {

int64_t rank = linalgOp.getRank(opOperand);
if (indexingMap.getNumResults() != rank)
return op->emitOpError("expected shaped value rank (")
return op->emitOpError("expected operand rank (")
<< rank << ") to match the result rank of indexing_map #"
<< opOperand->getOperandNumber() << " ("
<< indexingMap.getNumResults() << ")";
Expand Down Expand Up @@ -444,22 +444,22 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {

if (linalgOp.getNumInputsAndOutputs() + numBBIvs != block.getNumArguments())
return op->emitOpError("expected as many non-induction variable region "
"arguments as the number of shaped operands");
"arguments as the number of input/output operands");

// Note: the number and type of yield values are checked in the YieldOp.
for (unsigned i = 0; i < numBBIvs; ++i)
if (!block.getArgument(i).getType().isIndex())
return op->emitOpError("expected index block argument #") << i;

for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
Type elementType = getElementTypeOrSelf(opOperand->get().getType());
Type elementType = getElementTypeOrSelf(opOperand->get());
Type argType =
block.getArgument(numBBIvs + opOperand->getOperandNumber()).getType();
if (elementType != argType)
return op->emitOpError("expected type of bb argument #")
<< numBBIvs + opOperand->getOperandNumber() << " (" << argType
<< ")"
<< " to match element type of corresponding shaped operand ("
<< " to match element or self type of the corresponding operand ("
<< elementType << ")";
}

Expand Down Expand Up @@ -489,10 +489,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {

// The first index or last index should be the maximum or the minimum in
// the inferred index ranges since the range is increasing or
// decreasing. The size of dimensions of shaped operands and the maximum
// value + 1 in the inferred range should be the same. But, for now we
// check if the inferred ranges are in boundary of shaped operands' size
// or not in case that Affine Expressions are complicated such as d0 * 3
// decreasing. The size of dimensions of input/output operands and the
// maximum value + 1 in the inferred range should be the same. But, for
// now we check if the inferred ranges are in boundary of input/output
// operands' size or not in case that Affine Expressions are complicated
// such as d0 * 3
// + d1 since it is not easy to handle the issues.
// Found the case that this solution can't check, for example, (d0, d1)
// -> (d1 - d0)
Expand All @@ -510,14 +511,14 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
}
if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
if (inferredDimSize != shape[dim]) {
return op->emitOpError("inferred shaped operand #")
return op->emitOpError("inferred input/output operand #")
<< opOperand->getOperandNumber()
<< " has shape's dimension #" << dim << " to be "
<< inferredDimSize << ", but found " << shape[dim];
}
} else {
if (inferredDimSize > shape[dim]) {
return op->emitOpError("inferred shaped operand #")
return op->emitOpError("inferred input/output operand #")
<< opOperand->getOperandNumber()
<< " has shape's dimension #" << dim
<< " to be greater than or equal to " << inferredDimSize
Expand Down
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -377,8 +377,7 @@ void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
OpOperand *output = op.getOutputOperand(0);
OpOperand *input = op.getInputOperand(0);
if (getElementTypeOrSelf(input->get().getType()) !=
getElementTypeOrSelf(output->get().getType()))
if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
return op.emitOpError("expects views of the same type");
if (op.getRank(input) != op.getRank(output))
return op.emitOpError("expects views of the same rank");
Expand Down Expand Up @@ -452,7 +451,7 @@ void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
OpOperand *output = op.getOutputOperand(0);
Type fillType = op.value().getType();
if (getElementTypeOrSelf(output->get().getType()) != fillType)
if (getElementTypeOrSelf(output->get()) != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !output->get().getType().isa<MemRefType>()) {
return op.emitOpError(
Expand Down Expand Up @@ -489,7 +488,7 @@ void GenericOp::build(
SmallVector<Type, 4> blockArgTypes;
for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
blockArgTypes.push_back(getElementTypeOrSelf(v));

OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
Expand Down Expand Up @@ -545,7 +544,7 @@ void IndexedGenericOp::build(
SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType());
for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType());
blockArgTypes.push_back(getElementTypeOrSelf(v));

OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
Expand Down Expand Up @@ -2949,7 +2948,6 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures,
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));

// TODO: atm all operands go through getElementTypeOrSelf,
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Expand Up @@ -484,18 +484,21 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,

b.setInsertionPoint(op);
Location loc = op.getLoc();
SmallVector<Value, 2> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
SmallVector<Value> newInputs;
newInputs.reserve(op.getNumInputs());
for (OpOperand *opOperand : op.getInputOperands()) {
Value v = lookup(bvm, opOperand->get());
if (!v)
if (op.isScalar(opOperand)) {
newInputs.push_back(opOperand->get());
continue;
}
newInputs.push_back(lookup(bvm, opOperand->get()));
if (!newInputs.back())
return failure();
newInputBuffers.push_back(v);
}
SmallVector<Value, 2> newOutputBuffers;
SmallVector<Value> newOutputBuffers;
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
return failure();
finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm);
finalizeBufferAllocation(b, op, newInputs, newOutputBuffers, bvm);
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Expand Up @@ -301,7 +301,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
++dim;
}
// Compute the tensor or scalar replacement type.
Type elementType = getElementTypeOrSelf(opOperand->get().getType());
Type elementType = getElementTypeOrSelf(opOperand->get());
Type replacementType = elementType == opOperand->get().getType()
? elementType
: RankedTensorType::get(newShape, elementType);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Expand Up @@ -129,14 +129,14 @@ static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
assert(producer.hasTensorSemantics() &&
"only fusion on tensors is currently supported for TiledLinalgOp");

for (OpOperand *producerInput : producer.getInputTensorOperands()) {
for (OpOperand *producerInput : producer.getInputOperands()) {
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
if (addedInput == nullptr)
addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
tiledOperands.push_back(addedBlockArg);
}
for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
for (OpOperand *producerOutput : producer.getOutputOperands()) {
OpResult result = producer.getTiedOpResult(producerOutput);
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
Expand Down
6 changes: 5 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Expand Up @@ -126,8 +126,12 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,

// TODO: Avoid the loads if the corresponding argument of the
// region has no uses.
// 1.a. Emit load from input views.
// 1.a. Emit load from input operand or for scalars access the operand itself.
for (OpOperand *inputOperand : linalgOp.getInputOperands()) {
if (linalgOp.isScalar(inputOperand)) {
indexedValues.push_back(inputOperand->get());
continue;
}
auto indexing = makeCanonicalAffineApplies(
b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims);
indexedValues.push_back(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Expand Up @@ -149,7 +149,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
}
Value pad = options.paddingValueComputationFunction(rewriter, *opOperand);
auto staticTensorType = RankedTensorType::get(
staticSizes, getElementTypeOrSelf(opOperand->get().getType()));
staticSizes, getElementTypeOrSelf(opOperand->get()));
result = linalg::PadTensorOp::createPadHighOp(
staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
return success();
Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Expand Up @@ -479,6 +479,10 @@ LogicalResult vectorizeAsLinalgGeneric(
SmallVector<AffineMap> indexings;
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber());
if (linalgOp.isScalar(opOperand)) {
bvm.map(bbarg, opOperand->get());
continue;
}
// TODO: 0-d vectors.
if (linalgOp.getShape(opOperand).empty()) {
Value loaded =
Expand All @@ -494,14 +498,13 @@ LogicalResult vectorizeAsLinalgGeneric(
if (broadcastToMaximalCommonShape) {
map = inverseAndBroadcastProjectedPermuation(
linalgOp.getTiedIndexingMap(opOperand));
vectorType = VectorType::get(
commonVectorShape, getElementTypeOrSelf(opOperand->get().getType()));
vectorType = VectorType::get(commonVectorShape,
getElementTypeOrSelf(opOperand->get()));
} else {
map = inversePermutation(
reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand)));
vectorType =
VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get().getType()));
vectorType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
getElementTypeOrSelf(opOperand->get()));
}
Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map);
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg("
Expand Down Expand Up @@ -1157,7 +1160,7 @@ LogicalResult ConvOpVectorization<ConvOp, N>::matchAndRewrite(

int64_t rank = op.getRank(input);
int64_t numDims = mapping.size();
Type elemType = getElementTypeOrSelf(input->get().getType());
Type elemType = getElementTypeOrSelf(input->get());

auto map = AffineMap::get(rank, 0, mapping, context);
SmallVector<Value, 4> zeros(rank, rewriter.create<ConstantIndexOp>(loc, 0));
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Expand Up @@ -1372,6 +1372,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel.
assert(op.getNumOutputs() == 1);
assert(llvm::none_of(op.getInputAndOutputOperands(),
[&](OpOperand *t) { return op.isScalar(t); }));
unsigned numTensors = op.getNumInputsAndOutputs();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
Expand Down
14 changes: 8 additions & 6 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Expand Up @@ -2,6 +2,7 @@

#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
affine_map<(i, j, k, l, m) -> ()>,
affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
]

Expand All @@ -11,21 +12,22 @@
library_call = "some_external_func"
}

func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
%0 = linalg.generic #trait
ins(%arg0 : tensor<?x1x?xf32>)
ins(%arg0, %arg1 : tensor<?x1x?xf32>, f32)
outs(%shape : tensor<?x1x?x1x?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32) :
linalg.yield %arg2 : f32
^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
linalg.yield %arg3 : f32
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: linalg.tensor_expand_shape %{{.*}} {{\[}}[0, 1], [2, 3], [4]]

Expand Down

0 comments on commit 046922e

Please sign in to comment.