@@ -1704,9 +1704,157 @@ static LogicalResult verify(linalg::YieldOp op) {
1704
1704
return success ();
1705
1705
}
1706
1706
1707
+ if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
1708
+ return success ();
1709
+ }
1707
1710
return op.emitOpError (" expected parent op with LinalgOp interface" );
1708
1711
}
1709
1712
1713
+ // ===----------------------------------------------------------------------===//
1714
+ // TiledLoopOp
1715
+ // ===----------------------------------------------------------------------===//
1716
+
1717
+ static void print (OpAsmPrinter &p, TiledLoopOp op) {
1718
+ p << op.getOperationName () << " (" << op.getBody ()->getArguments () << " ) = ("
1719
+ << op.lowerBound () << " ) to (" << op.upperBound () << " ) step (" << op.step ()
1720
+ << " )" ;
1721
+
1722
+ if (!op.inputs ().empty ())
1723
+ p << " ins (" << op.inputs () << " )" ;
1724
+ if (!op.outputs ().empty ())
1725
+ p << " outs (" << op.outputs () << " )" ;
1726
+
1727
+ if (llvm::any_of (op.iterator_types (), [](Attribute attr) {
1728
+ return attr.cast <StringAttr>().getValue () !=
1729
+ getParallelIteratorTypeName ();
1730
+ })) {
1731
+ p << " iterators(" << op.iterator_types () << " )" ;
1732
+ }
1733
+
1734
+ p.printRegion (op.region (), /* printEntryBlockArgs=*/ false );
1735
+ p.printOptionalAttrDict (
1736
+ op.getAttrs (), /* elidedAttrs=*/ {TiledLoopOp::getOperandSegmentSizeAttr (),
1737
+ getIteratorTypesAttrName ()});
1738
+ }
1739
+
1740
+ static ParseResult parseTiledLoopOp (OpAsmParser &parser,
1741
+ OperationState &result) {
1742
+ auto &builder = parser.getBuilder ();
1743
+ // Parse an opening `(` followed by induction variables followed by `)`
1744
+ SmallVector<OpAsmParser::OperandType, 4 > ivs;
1745
+ if (parser.parseRegionArgumentList (ivs, /* requiredOperandCount=*/ -1 ,
1746
+ OpAsmParser::Delimiter::Paren))
1747
+ return failure ();
1748
+
1749
+ // Parse loop bounds.
1750
+ SmallVector<OpAsmParser::OperandType, 4 > lower;
1751
+ if (parser.parseEqual () ||
1752
+ parser.parseOperandList (lower, ivs.size (),
1753
+ OpAsmParser::Delimiter::Paren) ||
1754
+ parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
1755
+ return failure ();
1756
+
1757
+ SmallVector<OpAsmParser::OperandType, 4 > upper;
1758
+ if (parser.parseKeyword (" to" ) ||
1759
+ parser.parseOperandList (upper, ivs.size (),
1760
+ OpAsmParser::Delimiter::Paren) ||
1761
+ parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
1762
+ return failure ();
1763
+
1764
+ // Parse step values.
1765
+ SmallVector<OpAsmParser::OperandType, 4 > steps;
1766
+ if (parser.parseKeyword (" step" ) ||
1767
+ parser.parseOperandList (steps, ivs.size (),
1768
+ OpAsmParser::Delimiter::Paren) ||
1769
+ parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
1770
+ return failure ();
1771
+
1772
+ // Parse input tensors.
1773
+ SmallVector<OpAsmParser::OperandType, 4 > inputs;
1774
+ if (succeeded (parser.parseOptionalKeyword (" ins" ))) {
1775
+ SmallVector<Type, 4 > inputTypes;
1776
+ llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation ();
1777
+
1778
+ if (parser.parseLParen () || parser.parseOperandList (inputs) ||
1779
+ parser.parseColonTypeList (inputTypes) || parser.parseRParen ())
1780
+ return failure ();
1781
+
1782
+ if (parser.resolveOperands (inputs, inputTypes, inputsOperandsLoc,
1783
+ result.operands ))
1784
+ return failure ();
1785
+ }
1786
+
1787
+ // Parse output tensors.
1788
+ SmallVector<OpAsmParser::OperandType, 4 > outputs;
1789
+ if (succeeded (parser.parseOptionalKeyword (" outs" ))) {
1790
+ SmallVector<Type, 4 > outputTypes;
1791
+ llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation ();
1792
+
1793
+ if (parser.parseLParen () || parser.parseOperandList (outputs) ||
1794
+ parser.parseColonTypeList (outputTypes) || parser.parseRParen ())
1795
+ return failure ();
1796
+
1797
+ if (parser.resolveOperands (outputs, outputTypes, outputsOperandsLoc,
1798
+ result.operands ))
1799
+ return failure ();
1800
+ result.addTypes (outputTypes);
1801
+ }
1802
+
1803
+ // Parse attributes.
1804
+ SmallVector<Attribute, 4 > iterTypes;
1805
+ if (succeeded (parser.parseOptionalKeyword (" iterators" ))) {
1806
+ StringAttr iterType;
1807
+
1808
+ if (parser.parseLParen () || parser.parseAttribute (iterType))
1809
+ return failure ();
1810
+ iterTypes.push_back (iterType);
1811
+ for (int i = 1 , e = ivs.size (); i < e; ++i) {
1812
+ if (parser.parseComma () || parser.parseAttribute (iterType))
1813
+ return failure ();
1814
+ iterTypes.push_back (iterType);
1815
+ }
1816
+ if (parser.parseRParen ())
1817
+ return failure ();
1818
+ } else {
1819
+ auto parallelIter = builder.getStringAttr (getParallelIteratorTypeName ());
1820
+ iterTypes = SmallVector<Attribute, 4 >(ivs.size (), parallelIter);
1821
+ }
1822
+ result.addAttribute (getIteratorTypesAttrName (),
1823
+ builder.getArrayAttr (iterTypes));
1824
+ result.addAttribute (
1825
+ TiledLoopOp::getOperandSegmentSizeAttr (),
1826
+ builder.getI32VectorAttr ({static_cast <int32_t >(lower.size ()),
1827
+ static_cast <int32_t >(upper.size ()),
1828
+ static_cast <int32_t >(steps.size ()),
1829
+ static_cast <int32_t >(inputs.size ()),
1830
+ static_cast <int32_t >(outputs.size ())}));
1831
+
1832
+ // Parse the body.
1833
+ Region *body = result.addRegion ();
1834
+ SmallVector<Type, 4 > types (ivs.size (), builder.getIndexType ());
1835
+ if (parser.parseRegion (*body, ivs, types))
1836
+ return failure ();
1837
+
1838
+ // Parse optional attributes.
1839
+ parser.parseOptionalAttrDict (result.attributes );
1840
+
1841
+ return success ();
1842
+ }
1843
+
1844
+ Region &TiledLoopOp::getLoopBody () { return region (); }
1845
+
1846
+ LogicalResult TiledLoopOp::moveOutOfLoop (ArrayRef<Operation *> ops) {
1847
+ for (auto *op : ops)
1848
+ op->moveBefore (*this );
1849
+ return success ();
1850
+ }
1851
+
1852
+ bool TiledLoopOp::isDefinedOutsideOfLoop (Value value) {
1853
+ return !region ().isAncestor (value.getParentRegion ());
1854
+ }
1855
+
1856
+ static LogicalResult verify (TiledLoopOp op) { return success (); }
1857
+
1710
1858
// ///// Operations corresponding to library calls defined with Tablegen ////////
1711
1859
1712
1860
template <typename LinalgPoolingOp>
0 commit comments