-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][memref] Bufferize memref.tensor_store op
This change adds the BufferizableOpInterface implementation for memref.tensor_store. Differential Revision: https://reviews.llvm.org/D144080
- Loading branch information
1 parent
01581e2
commit c645eb0
Showing
7 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
mlir/include/mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H | ||
#define MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H | ||
|
||
namespace mlir { | ||
|
||
class DialectRegistry; | ||
|
||
namespace memref { | ||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); | ||
} // namespace memref | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_MEMREF_BUFFERIZABLEOPINTERFACEIMPL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
63 changes: 63 additions & 0 deletions
63
mlir/lib/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" | ||
|
||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/IR/Operation.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::bufferization; | ||
|
||
namespace { | ||
/// Bufferization of memref.tensor_store. Replace with memref.copy. | ||
struct TensorStoreOpInterface | ||
: public BufferizableOpInterface::ExternalModel<TensorStoreOpInterface, | ||
memref::TensorStoreOp> { | ||
AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, | ||
const AnalysisState &state) const { | ||
return {}; | ||
} | ||
|
||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, | ||
const AnalysisState &state) const { | ||
assert(opOperand.getOperandNumber() == 0 && "expected src operand"); | ||
return true; | ||
} | ||
|
||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, | ||
const AnalysisState &state) const { | ||
// The memref operand is written but not the tensor operand. | ||
assert(opOperand.getOperandNumber() == 0 && "expected src operand"); | ||
return false; | ||
} | ||
|
||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, | ||
const BufferizationOptions &options) const { | ||
auto tensorStoreOp = cast<memref::TensorStoreOp>(op); | ||
auto srcBuffer = getBuffer(rewriter, tensorStoreOp.getTensor(), options); | ||
if (failed(srcBuffer)) | ||
return failure(); | ||
if (failed(options.createMemCpy(rewriter, op->getLoc(), *srcBuffer, | ||
tensorStoreOp.getMemref()))) | ||
return failure(); | ||
rewriter.eraseOp(tensorStoreOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::memref::registerBufferizableOpInterfaceExternalModels( | ||
DialectRegistry ®istry) { | ||
registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) { | ||
TensorStoreOp::attachInterface<TensorStoreOpInterface>(*ctx); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// RUN: mlir-opt -one-shot-bufferize %s | FileCheck %s | ||
|
||
// CHECK-LABEL: func @tensor_store( | ||
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>, %[[m:.*]]: memref<?xf32> | ||
// CHECK: %[[src:.*]] = bufferization.to_memref %[[t]] | ||
// CHECK: memref.copy %[[src]], %[[m]] | ||
// CHECK: return | ||
func.func @tensor_store(%t: tensor<?xf32>, %m: memref<?xf32>) { | ||
memref.tensor_store %t, %m : memref<?xf32> | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters