diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 7ee5d5f4dd744..12a8d80c72fcc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -116,14 +116,14 @@ static VectorType extractVectorTypeFromShapedValue(Value v) { /// Linalg. This limitation is motivated by the fact that e.g. /// min(max(X)) != max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. -static Operation *getSingleBinaryOpAssumedReduction(OpOperand &outputOperand) { - auto linalgOp = cast(outputOperand.getOwner()); +static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) { + auto linalgOp = cast(outputOperand->getOwner()); auto yieldOp = cast(linalgOp->getRegion(0).front().getTerminator()); unsigned yieldNum = - outputOperand.getOperandNumber() - linalgOp.getNumInputs(); + outputOperand->getOperandNumber() - linalgOp.getNumInputs(); llvm::SetVector backwardSlice, forwardSlice; BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument( - outputOperand.getOperandNumber()); + outputOperand->getOperandNumber()); Value yieldVal = yieldOp->getOperand(yieldNum); getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) { return op->getParentOp() == linalgOp; @@ -186,16 +186,15 @@ getKindForOp(Operation *reductionOp) { /// return a new vector.broadcast to `shape`. /// Otherwise, just return value. static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType, - Value value, OpOperand &outputOperand) { - assert(targetVectorType.getShape() == - outputOperand.get().getType().cast().getShape()); + Value value, OpOperand *outputOperand) { + auto linalgOp = cast(outputOperand->getOwner()); + assert(targetVectorType.getShape() == linalgOp.getShape(outputOperand)); auto vecType = value.getType().dyn_cast(); if (!vecType || vecType.getShape() == targetVectorType.getShape()) return value; // At this point, we know we need to reduce. Detect the reduction operator. // TODO: Use the generic reduction detection util. Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand); - auto linalgOp = cast(outputOperand.getOwner()); unsigned pos = 0; MLIRContext *ctx = b.getContext(); SmallVector exprs; @@ -235,23 +234,22 @@ static Value buildVectorRead(OpBuilder &b, Value source, VectorType vectorType, /// currently being vectorized. If `dest` has null rank, build an memref.store. /// Return the produced value or null if no value is produced. static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand &outputOperand) { + OpOperand *outputOperand) { Operation *write; Location loc = value.getLoc(); - auto shapedType = outputOperand.get().getType().cast(); if (VectorType vectorType = - extractVectorTypeFromShapedValue(outputOperand.get())) { - auto linalgOp = cast(outputOperand.getOwner()); - AffineMap map = reindexIndexingMap( - linalgOp.getIndexingMap(outputOperand.getOperandNumber())); - SmallVector indices(shapedType.getRank(), + extractVectorTypeFromShapedValue(outputOperand->get())) { + auto linalgOp = cast(outputOperand->getOwner()); + AffineMap map = + reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)); + SmallVector indices(linalgOp.getRank(outputOperand), b.create(loc, 0)); value = broadcastIfNeeded(b, value, vectorType.getShape()); value = reduceIfNeeded(b, vectorType, value, outputOperand); - write = b.create(loc, value, outputOperand.get(), + write = b.create(loc, value, outputOperand->get(), indices, map); } else { - write = b.create(loc, value, outputOperand.get()); + write = b.create(loc, value, outputOperand->get()); } LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: vectorized op: " << *write); if (!write->getResults().empty()) @@ -284,7 +282,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op, // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getOutputOpOperands()[outputs.index()]); + b, vectorValue, linalgOp.getOutputOperand(outputs.index())); if (newResult) newResults.push_back(newResult); } @@ -422,8 +420,8 @@ static bool isElementwise(Operation *op) { if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) return false; // TODO: relax the restrictions on indexing map. - for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) { - if (!linalgOp.getOutputIndexingMap(i).isIdentity()) + for (OpOperand *opOperand : linalgOp.getOutputOperands()) { + if (!linalgOp.getTiedIndexingMap(opOperand).isIdentity()) return false; } if (linalgOp->getNumRegions() != 1) @@ -479,36 +477,37 @@ LogicalResult vectorizeAsLinalgGeneric( // 3. Turn all BBArgs into vector.transfer_read / load. SmallVector indexings; - for (auto bbarg : block.getArguments()) { - Value shapedArg = linalgOp.getShapedOperand(bbarg.getArgNumber()); - ShapedType shapedType = shapedArg.getType().cast(); + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + BlockArgument bbarg = block.getArgument(opOperand->getOperandNumber()); // TODO: 0-d vectors. - if (shapedType.getShape().empty()) { - Value loaded = b.create(linalgOp.getLoc(), shapedArg); + if (linalgOp.getShape(opOperand).empty()) { + Value loaded = + b.create(linalgOp.getLoc(), opOperand->get()); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" << bbarg.getArgNumber() << "): " << loaded); bvm.map(bbarg, loaded); - bvm.map(shapedArg, loaded); + bvm.map(opOperand->get(), loaded); continue; } AffineMap map; VectorType vectorType; if (broadcastToMaximalCommonShape) { map = inverseAndBroadcastProjectedPermuation( - linalgOp.getIndexingMap(bbarg.getArgNumber())); - vectorType = - VectorType::get(commonVectorShape, shapedType.getElementType()); + linalgOp.getTiedIndexingMap(opOperand)); + vectorType = VectorType::get( + commonVectorShape, getElementTypeOrSelf(opOperand->get().getType())); } else { map = inversePermutation( - reindexIndexingMap(linalgOp.getIndexingMap(bbarg.getArgNumber()))); - vectorType = VectorType::get(map.compose(shapedType.getShape()), - shapedType.getElementType()); + reindexIndexingMap(linalgOp.getTiedIndexingMap(opOperand))); + vectorType = + VectorType::get(map.compose(linalgOp.getShape(opOperand)), + getElementTypeOrSelf(opOperand->get().getType())); } - Value vectorRead = buildVectorRead(b, shapedArg, vectorType, map); + Value vectorRead = buildVectorRead(b, opOperand->get(), vectorType, map); LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vectorized bbarg(" << bbarg.getArgNumber() << "): " << vectorRead); bvm.map(bbarg, vectorRead); - bvm.map(shapedArg, vectorRead); + bvm.map(opOperand->get(), vectorRead); } auto hooks = llvm::to_vector<4>(customVectorizationHooks); @@ -562,7 +561,8 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, const BlockAndValueMapping &bvm) -> VectorizationResult { if (!isa(op)) return VectorizationResult{VectorizationStatus::Failure, nullptr}; - auto outShape = linalgOp.getOutputShapedType(0).getShape(); + ArrayRef outShape = + linalgOp.getShape(linalgOp.getOutputOperand(0)); auto vType = outShape.empty() ? op->getResult(0).getType() : VectorType::get(outShape, op->getResult(0).getType()); @@ -574,13 +574,14 @@ static LogicalResult vectorizeContraction(OpBuilder &b, LinalgOp linalgOp, // TODO: consider dropping contraction special casing altogether, this will // require more advanced canonicalizations involving vector.multi_reduction // that are not yet available. - SmallVector indexingMaps{ - inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(0))) - .compose(linalgOp.getIndexingMap(0)), - inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(1))) - .compose(linalgOp.getIndexingMap(1)), - inversePermutation(reindexIndexingMap(linalgOp.getIndexingMap(2))) - .compose(linalgOp.getIndexingMap(2))}; + SmallVector indexingMaps; + indexingMaps.reserve(linalgOp.getNumInputsAndOutputs()); + llvm::transform(linalgOp.getIndexingMaps(), + std::back_inserter(indexingMaps), + [](AffineMap indexingMap) { + return inversePermutation(reindexIndexingMap(indexingMap)) + .compose(indexingMap); + }); Operation *contract = b.create( loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero, b.getAffineMapArrayAttr(indexingMaps), linalgOp.iterator_types()); @@ -601,8 +602,8 @@ static bool allIndexingsAreProjectedPermutation(LinalgOp op) { static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.iterator_types(), isReductionIteratorType)) return failure(); - for (auto &operand : op.getOutputOpOperands()) { - Operation *reductionOp = getSingleBinaryOpAssumedReduction(operand); + for (OpOperand *opOperand : op.getOutputOperands()) { + Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand); if (!getKindForOp(reductionOp)) return failure(); } @@ -612,12 +613,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) { LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { auto linalgOp = cast(op); // All types must be static shape to go to vector. - for (Value operand : linalgOp.getShapedOperands()) - if (!operand.getType().cast().hasStaticShape()) - return failure(); - for (Type outputTensorType : linalgOp.getOutputTensorTypes()) - if (!outputTensorType.cast().hasStaticShape()) - return failure(); + if (linalgOp.hasDynamicShape()) + return failure(); if (isElementwise(op)) return success(); if (isaContractionOpInterface(linalgOp)) @@ -722,13 +719,14 @@ LogicalResult ConvOpVectorization::matchAndRewrite( Location loc = op.getLoc(); MLIRContext *context = op.getContext(); - ShapedType inShapeType = op.getInputShapedType(0); - ShapedType kShapeType = op.getInputShapedType(1); - - ArrayRef inShape = inShapeType.getShape(); - ArrayRef kShape = kShapeType.getShape(); + OpOperand *input = op.getInputOperand(0); + OpOperand *kernel = op.getInputOperand(1); + OpOperand *output = op.getOutputOperand(0); + ArrayRef inShape = op.getShape(input); + ArrayRef kShape = op.getShape(kernel); - if (!inShapeType.hasStaticShape() || !kShapeType.hasStaticShape()) + if (llvm::any_of(inShape, ShapedType::isDynamic) || + llvm::any_of(kShape, ShapedType::isDynamic)) return failure(); SmallVector mapping; @@ -747,22 +745,18 @@ LogicalResult ConvOpVectorization::matchAndRewrite( } } - Value input = op.getInput(0); - Value kernel = op.getInput(1); - Value output = op.getOutputBuffer(0); - - unsigned rank = inShapeType.getRank(); - unsigned numDims = mapping.size(); - Type elemType = inShapeType.getElementType(); + int64_t rank = op.getRank(input); + int64_t numDims = mapping.size(); + Type elemType = getElementTypeOrSelf(input->get().getType()); auto map = AffineMap::get(rank, 0, mapping, context); SmallVector zeros(rank, rewriter.create(loc, 0)); auto vecType = VectorType::get(vectorDims, elemType); - auto inputVec = - rewriter.create(loc, vecType, input, zeros, map); - auto kernelVec = - rewriter.create(loc, vecType, kernel, zeros, map); + auto inputVec = rewriter.create( + loc, vecType, input->get(), zeros, map); + auto kernelVec = rewriter.create( + loc, vecType, kernel->get(), zeros, map); auto acc = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); @@ -779,7 +773,8 @@ LogicalResult ConvOpVectorization::matchAndRewrite( rewriter.getAffineMapArrayAttr(indexingMaps), rewriter.getStrArrayAttr(iteratorTypes)); - rewriter.create(loc, result, output, ValueRange(zeros)); + rewriter.create(loc, result, output->get(), + ValueRange(zeros)); rewriter.eraseOp(op); return success(); } @@ -939,7 +934,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( CopyOp copyOp; for (auto &u : subView.getUses()) { if (auto newCopyOp = dyn_cast(u.getOwner())) { - if (newCopyOp.getOutputBuffer(0) != subView) + assert(newCopyOp.output().getType().isa()); + if (newCopyOp.output() != subView) continue; LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << "copy candidate " << *newCopyOp); @@ -958,7 +954,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( FillOp maybeFillOp; for (auto &u : viewOrAlloc.getUses()) { if (auto newFillOp = dyn_cast(u.getOwner())) { - if (newFillOp.getOutputBuffer(0) != viewOrAlloc) + assert(newFillOp.output().getType().isa()); + if (newFillOp.output() != viewOrAlloc) continue; LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << "fill candidate " << *newFillOp); @@ -976,7 +973,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( << "with maybeFillOp " << *maybeFillOp); // `in` is the subview that linalg.copy reads. Replace it. - Value in = copyOp.getInput(0); + Value in = copyOp.input(); // linalg.copy + linalg.fill can be used to create a padded local buffer. // The `masked` attribute is only valid on this padded buffer. @@ -1014,7 +1011,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( CopyOp copyOp; for (auto &u : subViewOp.getResult().getUses()) { if (auto newCopyOp = dyn_cast(u.getOwner())) { - if (newCopyOp.getInput(0) != subView) + if (newCopyOp.getInputOperand(0)->get() != subView) continue; if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) continue; @@ -1026,7 +1023,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( return failure(); // `out` is the subview copied into that we replace. - Value out = copyOp.getOutputBuffer(0); + assert(copyOp.output().getType().isa()); + Value out = copyOp.output(); // Forward vector.transfer into copy. // linalg.copy + linalg.fill can be used to create a padded local buffer.