Skip to content

Commit 624fccb

Browse files
committed
[mlir] Add linalg.tiled_loop op.
`subtensor_insert` was used instead of `linalg.subtensor_yield` to make this PR smaller. Verification will be added in a follow-up PR. Differential Revision: https://reviews.llvm.org/D96943
1 parent b80357d commit 624fccb

File tree

3 files changed

+310
-0
lines changed

3 files changed

+310
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
1717
include "mlir/Interfaces/ControlFlowInterfaces.td"
18+
include "mlir/Interfaces/LoopLikeInterface.td"
1819
include "mlir/Interfaces/SideEffectInterfaces.td"
1920
include "mlir/Interfaces/ViewLikeInterface.td"
2021

@@ -485,4 +486,58 @@ def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
485486
}];
486487
}
487488

489+
def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
490+
AttrSizedOperandSegments,
491+
DeclareOpInterfaceMethods<LoopLikeOpInterface>,
492+
RecursiveSideEffects,
493+
SingleBlockImplicitTerminator<"linalg::YieldOp">
494+
]> {
495+
let summary = "Linalg tiled loop operation";
496+
let description = [{
497+
This is a loop-like operation with additional properties. The arguments
498+
also include the input and the output tensors and the attributes to specify
499+
the iterator types. The body region of the loop contains `subtensor`
500+
operations applied to every tensor argument of TiledLoopOp.
501+
502+
The body region must contain exactly one block that terminates with
503+
`linalg.yield` with the operands resulting from `subtensor_insert`
504+
operations.
505+
506+
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
507+
to "parallel" type, when it is absent from the custom format.
508+
509+
Example:
510+
511+
```mlir
512+
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
513+
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
514+
outs(%out : tensor<24x64xi8>)
515+
iterators("parallel") {
516+
%lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
517+
: tensor<24x64xi8> to tensor<?x?xi8>
518+
%rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
519+
: tensor<24x64xi8> to tensor<?x?xi8>
520+
%out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1]
521+
: tensor<24x64xi8> to tensor<?x?xi8>
522+
523+
%result_sub = linalg.generic ...
524+
525+
%result = subtensor_insert %result_sub into %out[%i, 0][%c4, %c64][1, 1]
526+
: tensor<?x?xi8> into tensor<24x64xi8>
527+
linalg.yield %result : tensor<24x64xi8>
528+
}
529+
```
530+
}];
531+
532+
let arguments = (ins Variadic<Index>:$lowerBound,
533+
Variadic<Index>:$upperBound,
534+
Variadic<Index>:$step,
535+
Variadic<AnyRankedTensor>:$inputs,
536+
Variadic<AnyRankedTensor>:$outputs,
537+
ArrayAttr:$iterator_types);
538+
let results = (outs Variadic<AnyRankedTensor>:$results);
539+
let regions = (region SizedRegion<1>:$region);
540+
}
541+
542+
488543
#endif // LINALG_OPS

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,9 +1704,157 @@ static LogicalResult verify(linalg::YieldOp op) {
17041704
return success();
17051705
}
17061706

1707+
if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
1708+
return success();
1709+
}
17071710
return op.emitOpError("expected parent op with LinalgOp interface");
17081711
}
17091712

1713+
//===----------------------------------------------------------------------===//
1714+
// TiledLoopOp
1715+
//===----------------------------------------------------------------------===//
1716+
1717+
static void print(OpAsmPrinter &p, TiledLoopOp op) {
1718+
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
1719+
<< op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
1720+
<< ")";
1721+
1722+
if (!op.inputs().empty())
1723+
p << " ins (" << op.inputs() << ")";
1724+
if (!op.outputs().empty())
1725+
p << " outs (" << op.outputs() << ")";
1726+
1727+
if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
1728+
return attr.cast<StringAttr>().getValue() !=
1729+
getParallelIteratorTypeName();
1730+
})) {
1731+
p << " iterators(" << op.iterator_types() << ")";
1732+
}
1733+
1734+
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1735+
p.printOptionalAttrDict(
1736+
op.getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
1737+
getIteratorTypesAttrName()});
1738+
}
1739+
1740+
static ParseResult parseTiledLoopOp(OpAsmParser &parser,
1741+
OperationState &result) {
1742+
auto &builder = parser.getBuilder();
1743+
// Parse an opening `(` followed by induction variables followed by `)`
1744+
SmallVector<OpAsmParser::OperandType, 4> ivs;
1745+
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1746+
OpAsmParser::Delimiter::Paren))
1747+
return failure();
1748+
1749+
// Parse loop bounds.
1750+
SmallVector<OpAsmParser::OperandType, 4> lower;
1751+
if (parser.parseEqual() ||
1752+
parser.parseOperandList(lower, ivs.size(),
1753+
OpAsmParser::Delimiter::Paren) ||
1754+
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
1755+
return failure();
1756+
1757+
SmallVector<OpAsmParser::OperandType, 4> upper;
1758+
if (parser.parseKeyword("to") ||
1759+
parser.parseOperandList(upper, ivs.size(),
1760+
OpAsmParser::Delimiter::Paren) ||
1761+
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
1762+
return failure();
1763+
1764+
// Parse step values.
1765+
SmallVector<OpAsmParser::OperandType, 4> steps;
1766+
if (parser.parseKeyword("step") ||
1767+
parser.parseOperandList(steps, ivs.size(),
1768+
OpAsmParser::Delimiter::Paren) ||
1769+
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
1770+
return failure();
1771+
1772+
// Parse input tensors.
1773+
SmallVector<OpAsmParser::OperandType, 4> inputs;
1774+
if (succeeded(parser.parseOptionalKeyword("ins"))) {
1775+
SmallVector<Type, 4> inputTypes;
1776+
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
1777+
1778+
if (parser.parseLParen() || parser.parseOperandList(inputs) ||
1779+
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
1780+
return failure();
1781+
1782+
if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
1783+
result.operands))
1784+
return failure();
1785+
}
1786+
1787+
// Parse output tensors.
1788+
SmallVector<OpAsmParser::OperandType, 4> outputs;
1789+
if (succeeded(parser.parseOptionalKeyword("outs"))) {
1790+
SmallVector<Type, 4> outputTypes;
1791+
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
1792+
1793+
if (parser.parseLParen() || parser.parseOperandList(outputs) ||
1794+
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
1795+
return failure();
1796+
1797+
if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
1798+
result.operands))
1799+
return failure();
1800+
result.addTypes(outputTypes);
1801+
}
1802+
1803+
// Parse attributes.
1804+
SmallVector<Attribute, 4> iterTypes;
1805+
if (succeeded(parser.parseOptionalKeyword("iterators"))) {
1806+
StringAttr iterType;
1807+
1808+
if (parser.parseLParen() || parser.parseAttribute(iterType))
1809+
return failure();
1810+
iterTypes.push_back(iterType);
1811+
for (int i = 1, e = ivs.size(); i < e; ++i) {
1812+
if (parser.parseComma() || parser.parseAttribute(iterType))
1813+
return failure();
1814+
iterTypes.push_back(iterType);
1815+
}
1816+
if (parser.parseRParen())
1817+
return failure();
1818+
} else {
1819+
auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
1820+
iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
1821+
}
1822+
result.addAttribute(getIteratorTypesAttrName(),
1823+
builder.getArrayAttr(iterTypes));
1824+
result.addAttribute(
1825+
TiledLoopOp::getOperandSegmentSizeAttr(),
1826+
builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
1827+
static_cast<int32_t>(upper.size()),
1828+
static_cast<int32_t>(steps.size()),
1829+
static_cast<int32_t>(inputs.size()),
1830+
static_cast<int32_t>(outputs.size())}));
1831+
1832+
// Parse the body.
1833+
Region *body = result.addRegion();
1834+
SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
1835+
if (parser.parseRegion(*body, ivs, types))
1836+
return failure();
1837+
1838+
// Parse optional attributes.
1839+
parser.parseOptionalAttrDict(result.attributes);
1840+
1841+
return success();
1842+
}
1843+
1844+
Region &TiledLoopOp::getLoopBody() { return region(); }
1845+
1846+
LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1847+
for (auto *op : ops)
1848+
op->moveBefore(*this);
1849+
return success();
1850+
}
1851+
1852+
bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
1853+
return !region().isAncestor(value.getParentRegion());
1854+
}
1855+
1856+
static LogicalResult verify(TiledLoopOp op) { return success(); }
1857+
17101858
/////// Operations corresponding to library calls defined with Tablegen ////////
17111859

17121860
template <typename LinalgPoolingOp>

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,3 +794,110 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32>
794794
return %1 : tensor<?x?xf32>
795795
}
796796
// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
797+
798+
// -----
799+
800+
#accesses = [
801+
affine_map<(i, j) -> (i, j)>,
802+
affine_map<(i, j) -> (i, j)>,
803+
affine_map<(i, j) -> (i, j)>
804+
]
805+
806+
#trait = {
807+
indexing_maps = #accesses,
808+
iterator_types = ["parallel", "parallel"]
809+
}
810+
811+
func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>,
812+
%out: tensor<24x64xi8>) -> tensor<24x64xi8> {
813+
%c0 = constant 0 : index
814+
%c1 = constant 1 : index
815+
%c4 = constant 4 : index
816+
%c24 = constant 24 : index
817+
%c64 = constant 64 : index
818+
%prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
819+
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
820+
outs(%out : tensor<24x64xi8>) {
821+
%lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
822+
: tensor<24x64xi8> to tensor<?x?xi8>
823+
%rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
824+
: tensor<24x64xi8> to tensor<?x?xi8>
825+
%out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1]
826+
: tensor<24x64xi8> to tensor<?x?xi8>
827+
828+
%sum = linalg.generic #trait
829+
ins(%lhs_sub, %rhs_sub : tensor<?x?xi8>, tensor<?x?xi8>)
830+
outs(%out_sub : tensor<?x?xi8>) {
831+
^bb(%l: i8, %r: i8, %o: i8) :
832+
%s = addi %l, %r : i8
833+
linalg.yield %s : i8
834+
} -> tensor<?x?xi8>
835+
836+
%sum_sub = subtensor_insert %sum into %out[%i, 0][%c4, %c64][1, 1]
837+
: tensor<?x?xi8> into tensor<24x64xi8>
838+
linalg.yield %sum_sub : tensor<24x64xi8>
839+
}
840+
return %prod : tensor<24x64xi8>
841+
}
842+
// CHECK-LABEL: func @tiled_loop
843+
// CHECK-NOT: iterators(
844+
845+
// -----
846+
847+
#id_3d = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
848+
#id_2d = affine_map<(d0, d1, d2) -> (d0, d2)>
849+
#id_1d = affine_map<(d0, d1, d2) -> (d1)>
850+
851+
#trait = {
852+
indexing_maps = [
853+
#id_3d,
854+
#id_2d,
855+
#id_1d,
856+
#id_1d
857+
],
858+
iterator_types = ["reduction", "parallel", "reduction"]
859+
}
860+
861+
func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>,
862+
%input_2d: tensor<16x32xf32>,
863+
%input_1d: tensor<24xf32>,
864+
%output: tensor<24xf32>) -> tensor<24xf32> {
865+
%c0 = constant 0 : index
866+
%c1 = constant 1 : index
867+
%c2 = constant 2 : index
868+
%c4 = constant 4 : index
869+
%c8 = constant 8 : index
870+
%X = dim %input_3d, %c0 : tensor<16x24x32xf32>
871+
%Y = dim %input_3d, %c1 : tensor<16x24x32xf32>
872+
%Z = dim %input_3d, %c2 : tensor<16x24x32xf32>
873+
%result = linalg.tiled_loop (%i, %j, %k)
874+
= (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8)
875+
ins(%input_3d, %input_2d: tensor<16x24x32xf32>, tensor<16x32xf32>)
876+
outs( %output: tensor<24xf32>)
877+
iterators("reduction", "parallel", "reduction") {
878+
%sub_3d = subtensor %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1]
879+
: tensor<16x24x32xf32> to tensor<2x4x8xf32>
880+
%sub_2d = subtensor %input_2d[%i, %k][2, 8][1, 1]
881+
: tensor<16x32xf32> to tensor<2x8xf32>
882+
%sub_1d = subtensor %input_1d[%j] [4] [1]
883+
: tensor<24xf32> to tensor<4xf32>
884+
%sub_out = subtensor %output[%j] [4] [1]
885+
: tensor<24xf32> to tensor<4xf32>
886+
%acc = linalg.generic #trait
887+
ins(%sub_3d, %sub_2d, %sub_1d
888+
: tensor<2x4x8xf32>, tensor<2x8xf32>, tensor<4xf32>)
889+
outs(%sub_out : tensor<4xf32>) {
890+
^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32):
891+
%0 = addf %i3d, %i2d : f32
892+
%1 = addf %0, %i1d : f32
893+
linalg.yield %1 : f32
894+
} -> tensor<4xf32>
895+
896+
%sum_sub = subtensor_insert %acc into %output[%j][%c4][1]
897+
: tensor<4xf32> into tensor<24xf32>
898+
linalg.yield %sum_sub : tensor<24xf32>
899+
}
900+
return %result : tensor<24xf32>
901+
}
902+
// CHECK-LABEL: func @tiled_loop_reduction
903+
// CHECK: iterators(

0 commit comments

Comments
 (0)