Skip to content

Commit

Permalink
[NFC] Updated for StructuredOpInterface refactor (#6061)
Browse files Browse the repository at this point in the history
The StructuredOpInterface was updated for more limited operations. Updated
uses within iree for the more limited operand set.
  • Loading branch information
rsuderman committed Jun 8, 2021
1 parent 5b958e7 commit 96e2a29
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 52 deletions.
50 changes: 27 additions & 23 deletions iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,34 +328,35 @@ static LogicalResult analysePadTensorOp(linalg::PadTensorOp padTensorOp,
static SmallVector<Value> getTiedOperandsForLinalgOps(
linalg::LinalgOp linalgOp) {
SmallVector<Value> tiedOperands(linalgOp.getOperation()->getNumResults());
for (auto outTensor : llvm::enumerate(linalgOp.getOutputs())) {
if (linalgOp.payloadUsesValueFromOutputOperandIndex(outTensor.index())) {
auto outputOperands = linalgOp.getOutputOperands();
for (auto outTensor : llvm::enumerate(outputOperands)) {
if (linalgOp.payloadUsesValueFromOperand(outTensor.value())) {
// If the `outs` tensor has a single use (this op) and is not from a
// read-only buffer, the `outs` tensor can be tied to the result.
if (outTensor.value().hasOneUse() &&
!isFromReadOnlyTensor(outTensor.value())) {
tiedOperands[outTensor.index()] = outTensor.value();
if (outTensor.value()->get().hasOneUse() &&
!isFromReadOnlyTensor(outTensor.value()->get())) {
tiedOperands[outTensor.index()] = outTensor.value()->get();
}
}
}
for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
for (auto result : llvm::enumerate(outputOperands)) {
// If the output tensor is not actually used (for initialization) by this
// op, we can reuse the result tensor's buffer for some operands.
// TODO(#5040): A better way to handle this case is to allocate a buffer and
// then vectorization + load-store forwarding to remove the intermediate
// buffer. This requires vectorization to handle all cases downstream. This
// is a WAR for current use cases.
if (linalgOp.payloadUsesValueFromOutputOperandIndex(result.index())) {
if (linalgOp.payloadUsesValueFromOperand(result.value())) {
continue;
}
for (auto input : llvm::enumerate(linalgOp.getInputTensors())) {
auto producerOp = input.value().getDefiningOp<linalg::LinalgOp>();
if (producerOp && input.value().hasOneUse() &&
input.value().getType() == result.value().getType() &&
linalgOp.getInputIndexingMap(input.index()) ==
linalgOp.getOutputIndexingMap(result.index())) {
for (auto input : linalgOp.getInputTensorsOpOperands()) {
auto producerOp = input->get().getDefiningOp<linalg::LinalgOp>();
if (producerOp && input->get().hasOneUse() &&
input->get().getType() == result.value()->get().getType() &&
linalgOp.getTiedIndexingMap(input) ==
linalgOp.getTiedIndexingMap(result.value())) {
assert(!tiedOperands[result.index()]);
tiedOperands[result.index()] = input.value();
tiedOperands[result.index()] = input->get();
break;
}
}
Expand All @@ -368,15 +369,15 @@ static SmallVector<Value> getTiedOperandsForLinalgOps(
static LogicalResult analyseLinalgOps(linalg::LinalgOp linalgOp,
BufferizationPlan &plan) {
if (!linalgOp.hasTensorSemantics()) return success();
auto results = linalgOp->getResults();
auto tiedOperands = getTiedOperandsForLinalgOps(linalgOp);
for (auto it :
llvm::enumerate(llvm::zip(linalgOp->getResults(), tiedOperands))) {
for (auto it : llvm::enumerate(llvm::zip(results, tiedOperands))) {
Value resultTensor = std::get<0>(it.value());
Value tiedOperand = std::get<1>(it.value());
if (tiedOperand) {
plan.unionSets(resultTensor, tiedOperand);
}
plan.insert(linalgOp.getOutput(it.index()));
plan.insert(linalgOp.getOutputOperand(it.index())->get());
plan.insert(resultTensor);
}
return success();
Expand Down Expand Up @@ -995,7 +996,8 @@ static LogicalResult getOrAllocateResultBuffers(
BufferizationPlan &plan, WorkgroupMemoryAllocationFn allocationFn) {
assert(tiedOperands.size() == op->getNumResults());
assert(aliasingBuffers.size() == op->getNumResults());
for (auto result : llvm::enumerate(op->getResults())) {
auto results = op->getResults();
for (auto result : llvm::enumerate(results)) {
if (!result.value().getType().isa<RankedTensorType>() ||
bvm.contains(result.value())) {
continue;
Expand Down Expand Up @@ -1064,15 +1066,17 @@ static LogicalResult convertAnyLinalgOp(
newInputBuffers.push_back(inputBuffer);
}
SmallVector<Value, 2> newOutputBuffers;
for (auto it : llvm::enumerate(
llvm::zip(op.getOperation()->getResults(), op.getOutputs()))) {
Value resultTensor = std::get<0>(it.value());
auto results = op.getOperation()->getResults();
auto outputs = op.getOutputOperands();
for (auto it : llvm::zip(results, outputs)) {
Value resultTensor = std::get<0>(it);
Value resultBuffer = bvm.lookup(resultTensor);

Value outTensor = std::get<1>(it.value());
OpOperand *outOperand = std::get<1>(it);
Value outTensor = outOperand->get();
Value outBuffer = bvm.lookupOrNull(outTensor);
if (outBuffer && !plan.isEquivalent(outTensor, resultTensor) &&
op.payloadUsesValueFromOutputOperandIndex(it.index())) {
op.payloadUsesValueFromOperand(outOperand)) {
b.create<linalg::CopyOp>(loc, outBuffer, resultBuffer);
}
newOutputBuffers.push_back(resultBuffer);
Expand Down
15 changes: 9 additions & 6 deletions iree/compiler/Conversion/LinalgToLinalg/Conv2D1x1ToMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class Convert1x1ConvolutionMatmulOp

LogicalResult matchAndRewrite(linalg::ConvInputNHWCFilterHWCFOp convOp,
PatternRewriter &rewriter) const override {
ShapedType inputShapeType = convOp.getInputShapedType(0);
ShapedType filterShapeType = convOp.getInputShapedType(1);
ShapedType outputShapeType = convOp.getOutputShapedType(0);
ShapedType inputShapeType =
convOp.getInputOperand(0)->get().getType().cast<ShapedType>();
ShapedType filterShapeType =
convOp.getInputOperand(1)->get().getType().cast<ShapedType>();
ShapedType outputShapeType =
convOp.getOutputOperand(0)->get().getType().cast<ShapedType>();

auto inputShape = inputShapeType.getShape();
auto filterShape = filterShapeType.getShape();
Expand Down Expand Up @@ -57,9 +60,9 @@ class Convert1x1ConvolutionMatmulOp
RankedTensorType::get({outputShape[1] * outputShape[2], outputShape[3]},
outputShapeType.getElementType());

Value input = convOp.getInput(0);
Value filter = convOp.getInput(1);
Value output = convOp.getOutput(0);
Value input = convOp.getInputOperand(0)->get();
Value filter = convOp.getInputOperand(1)->get();
Value output = convOp.getOutputOperand(0)->get();
auto loc = convOp.getLoc();

Value reshapedInput = rewriter.create<linalg::TensorCollapseShapeOp>(
Expand Down
15 changes: 9 additions & 6 deletions iree/compiler/Conversion/LinalgToLinalg/Conv2DToImg2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,20 @@ class Conv2DImg2ColMatmulConversion

LogicalResult matchAndRewrite(linalg::ConvInputNHWCFilterHWCFOp convOp,
PatternRewriter &rewriter) const override {
ShapedType inputShapeType = convOp.getInputShapedType(0);
ShapedType filterShapeType = convOp.getInputShapedType(1);
ShapedType outputShapeType = convOp.getOutputShapedType(0);
ShapedType inputShapeType =
convOp.getInputOperand(0)->get().getType().cast<ShapedType>();
ShapedType filterShapeType =
convOp.getInputOperand(1)->get().getType().cast<ShapedType>();
ShapedType outputShapeType =
convOp.getOutputOperand(0)->get().getType().cast<ShapedType>();

if (!filterShapeType || !inputShapeType) return failure();
if (!filterShapeType.hasStaticShape() || !inputShapeType.hasStaticShape())
return failure();

Value input = convOp.getInput(0);
Value filter = convOp.getInput(1);
Value output = convOp.getOutput(0);
Value input = convOp.getInputOperand(0)->get();
Value filter = convOp.getInputOperand(1)->get();
Value output = convOp.getOutputOperand(0)->get();
auto filterShape = filterShapeType.getShape();
auto outputShape = outputShapeType.getShape();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ static void removeOneTripTiledLoops(MLIRContext *context, FuncOp funcOp,
unsigned numTiledDims =
std::min<size_t>(numParallelDims, kMaxWorkgroupDimCount);

ArrayRef<int64_t> outputShape = getUntiledShape(rootLinalgOp.getOutput(0));
ArrayRef<int64_t> outputShape =
getUntiledShape(rootLinalgOp.getOutputOperand(0)->get());
if (outputShape.size() < numParallelDims) return;

// TODO(ravishankarm, antiagainst): Its pure co-incidence that the
Expand Down
13 changes: 8 additions & 5 deletions iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,15 @@ static std::tuple<SmallVector<ShapedType>, SmallVector<ShapedType>>
getInputOutputTypes(linalg::LinalgOp op) {
SmallVector<ShapedType> inputTypes(op.getNumInputs()),
outputTypes(op.getNumOutputs());
for (auto operand : enumerate(op.getInputOpOperands())) {
auto inputOperands = op.getInputOperands();
for (auto operand : enumerate(inputOperands)) {
inputTypes[operand.index()] =
getUntiledType(operand.value().get()).dyn_cast<ShapedType>();
getUntiledType(operand.value()->get()).dyn_cast<ShapedType>();
}
for (auto operand : enumerate(op.getOutputOpOperands())) {
auto outputOperands = op.getOutputOperands();
for (auto operand : enumerate(outputOperands)) {
outputTypes[operand.index()] =
getUntiledType(operand.value().get()).dyn_cast<ShapedType>();
getUntiledType(operand.value()->get()).dyn_cast<ShapedType>();
}
return std::make_tuple(std::move(inputTypes), std::move(outputTypes));
}
Expand Down Expand Up @@ -380,7 +382,8 @@ LogicalResult getGenericOpLaunchConfig(linalg::LinalgOp linalgOp,
config.workgroupSize[0] = subgroupSize;
config.workgroupSize[1] = 1;
config.workgroupSize[2] = 1;
ShapedType outputShape = linalgOp.getOutputShapedType(0);
ShapedType outputShape =
linalgOp.getOutputOperand(0)->get().getType().cast<ShapedType>();

SmallVector<int64_t, 4> candidateTileSizes;
// When Vectororization is not enabled we skil the second level of tiling and
Expand Down
15 changes: 9 additions & 6 deletions iree/compiler/Conversion/LinalgToVector/VectorizeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ struct VectorizeLinalgConv
}

auto inputViewOp =
convOp.getInputBuffer(0).getDefiningOp<memref::SubViewOp>();
convOp.getInputOperand(0)->get().getDefiningOp<memref::SubViewOp>();
auto filterViewOp =
convOp.getInputBuffer(1).getDefiningOp<memref::SubViewOp>();
convOp.getInputOperand(1)->get().getDefiningOp<memref::SubViewOp>();
auto outputViewOp =
convOp.getOutputBuffer(0).getDefiningOp<memref::SubViewOp>();
convOp.getOutputOperand(0)->get().getDefiningOp<memref::SubViewOp>();
if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();

// The filter/input/output view should have static sizes to vectorize.
Expand Down Expand Up @@ -234,9 +234,12 @@ struct VectorizeLinalgDepthwiseConv
PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "inspecting " << convOp << "\n");

auto inputViewOp = convOp.getInput(0).getDefiningOp<memref::SubViewOp>();
auto filterViewOp = convOp.getInput(1).getDefiningOp<memref::SubViewOp>();
auto outputViewOp = convOp.getOutput(0).getDefiningOp<memref::SubViewOp>();
auto inputViewOp =
convOp.getInputOperand(0)->get().getDefiningOp<memref::SubViewOp>();
auto filterViewOp =
convOp.getInputOperand(1)->get().getDefiningOp<memref::SubViewOp>();
auto outputViewOp =
convOp.getOutputOperand(0)->get().getDefiningOp<memref::SubViewOp>();
if (!filterViewOp || !inputViewOp || !outputViewOp) return failure();

// The filter/input/output view should have static sizes to vectorize.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ static void pullInProducersInSameGroup(
OpResult opResult = en.value().cast<OpResult>();
auto maybeFusionInfo = linalg::fuseProducerOfTensor(
rewriter, clonedOpToFuse->getResult(opResult.getResultNumber()),
tiledOp.getShapedOpOperand(en.index()));
*tiledOp.getInputAndOutputOperands()[en.index()]);
if (!maybeFusionInfo.hasValue()) {
DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
<< "failed to fuse with tensor\n");
Expand Down Expand Up @@ -543,7 +543,8 @@ static void tryToTieOperandsAndResults(
return loadOp.source().cast<BlockArgument>();
} else if (auto linalgOp = dyn_cast_or_null<linalg::LinalgOp>(tieOp)) {
unsigned resultIndex = storeOp.value().cast<OpResult>().getResultNumber();
auto loadOp = linalgOp.getOutputTensors()[resultIndex]
auto loadOp = linalgOp.getOutputTensorOperands()[resultIndex]
->get()
.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!loadOp) return nullptr;
return loadOp.source().cast<BlockArgument>();
Expand Down Expand Up @@ -1056,9 +1057,9 @@ static unsigned decideFusableLinalgOps(FuncOp funcOp) {
if (!consumer ||
consumer.getNumLoops() != consumer.getNumParallelLoops())
continue;
AffineMap consumerIndexingMap =
consumer.getInputIndexingMap(use.getOperandNumber());
AffineMap producerIndexingMap = linalgOp.getOutputIndexingMap(0);
AffineMap consumerIndexingMap = consumer.getTiedIndexingMap(&use);
AffineMap producerIndexingMap =
linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(0));
if (!consumerIndexingMap.isIdentity() ||
producerIndexingMap.getResults() !=
consumerIndexingMap.getResults()) {
Expand Down

0 comments on commit 96e2a29

Please sign in to comment.