New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MHLO] Init Torch to MHLO conversion. #1025
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Drive by comments. I know it is only a draft PR but wanted to get comments early.
@@ -0,0 +1,33 @@ | |||
# in-tree build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably the scripts in scripts/ are meant to be local scripts and not checked in ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I left it here as the example to build the project. Will remove it in the end.
f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", | ||
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects", | ||
f"-DLLVM_EXTERNAL_MLIR_HLO_SOURCE_DIR={src_dir}/externals/mlir-hlo", | ||
f"-DMLIR_PDLL_TABLEGEN_EXE=mlir-pdll", # FIXME: set MLIR_PDLL_TABLEGEN_EXE since mlir-hlo doesn't |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we submit a PR upstream in mhlo for setting mlir-pdll ?
@@ -95,6 +110,9 @@ else() | |||
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) | |||
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) | |||
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") | |||
# since mhlo didn't set INTERFACE_DIRECTORIES for their target, we need include mhlo directories globally |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably good to file and issue upstream so they are aware of this and maybe if it is easy add a PR upstream.
EXCLUDE_FROM_ALL) | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include) | ||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include) | ||
include_directories(${CMAKE_CURRENT_BINARY_DIR}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the whole CMAKE_CURRENT_BINARY_DIR required since we are adding it globally ?
@@ -1,3 +1,6 @@ | |||
[submodule "external/llvm-project"] | |||
path = externals/llvm-project | |||
url = https://github.com/llvm/llvm-project.git |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably should land the llvm update and even mhlo as submodules first.
@@ -41,7 +41,13 @@ torch_mlir_add_llvm_external_project( | |||
TORCH_MLIR_DIALECTS | |||
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects) | |||
|
|||
torch_mlir_add_llvm_external_project( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a top level TORCH_MLIR_MHLO (name can be anything @silvasean any suggestions ?) CMake flag that can enable / disable the MHLO backend ? This can be the big hammer in case we have to disable it for any reason (broken on macOS etc) .
@@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) | |||
set(TORCH-MLIR_BUILT_STANDALONE 1) | |||
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") | |||
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrap in same TORCH_MLIR_MHLO (?)
One thing I would like is if we can add this to our e2e test suite, even if only a few tests pass. We can model it on the Tosa support Add a case here:
Other related files: |
@@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) | |||
set(TORCH-MLIR_BUILT_STANDALONE 1) | |||
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}") | |||
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects) | |||
|
|||
set(MHLO_BUILD_EMBEDDED ON) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is MHLO_BUILD_EMBEDDED
meant for? It seems not used.
fout.write(str(module)) | ||
print("MHLO module has been save to {}".format(fname)) | ||
print("MHLO execution is not support yet. Stopped.") | ||
exit(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The codes following will not be reached, and should be removed.
MLIRContext *context = patterns.getContext(); | ||
|
||
target.addIllegalOp<AtenViewOp>(); | ||
patterns.add<ConvertAtenViewOp>(typeConverter, context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think operators like aten::view
, aten::expand
, aten::flatten
, aten::[un]squeeze
can be split into another source file, such as ViewLikeOps.cpp
target.addIllegalOp<AtenBroadcastToOp>(); | ||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context); | ||
|
||
target.addIllegalOp<AtenSliceTensorOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: Those slices like operators can be moved to SliceLikeOps.cpp
return op->emitError("unimplemented: dim is not constant"); | ||
uint64_t batchDims = 0; | ||
|
||
rewriter.replaceOpWithNewOp<mhlo::TorchIndexSelectOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's wired that MHLO has this operator. I think mhlo::DynamicGatherOp
is sufficient to represent the semantic. Is anyone knows the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mhlo::TorchIndexSelect
was added to simplify lowering tf.GatherV2
to MHLO in this commit 2.5 years ago. It is modelled after the client HLO API TorchIndexSelect
.
I agree that it doesn't look particularly essential for the MHLO dialect. We've been recently thinking about moving it to CHLO, along with a few other ops which are modelled after client HLO APIs which don't correspond to dedicated HLO ops. (More specifically, these other ops are: mhlo.broadcast
- keeping broadcast_in_dim in MHLO, mhlo.create_token
, mhlo.cross_replica_sum
, mhlo.dot
- keeping dot_general in MHLO, mhlo.einsum
, mhlo.unary_einsum
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your advice, we will change this part's implementation in later commits.
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context); | ||
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); | ||
INSERT_ATENOP_PATTERN(AtenTanhOp); | ||
INSERT_ATENOP_PATTERN(AtenIndexSelectOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto: I think operations like aten::gather
, aten::index_select
and so on can be moved to MemOps.cpp
target.addIllegalOp<AtenTransposeIntOp>(); | ||
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context); | ||
|
||
target.addIllegalOp<AtenPermuteOp>(); | ||
patterns.add<ConvertAtenPermuteOp>(typeConverter, context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: Can be move to PermutationsOps.cpp
int64_t lhsSize = 1; | ||
for (auto &en : llvm::enumerate(lhsTy.getShape())) { | ||
lhsSize *= en.value(); | ||
} | ||
auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy); | ||
DenseElementsAttr constAttr; | ||
if (lhsElemTy.isa<mlir::FloatType>()) { | ||
std::vector<APFloat> constVec( | ||
lhsSize, | ||
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(), | ||
/*negative=*/false)); | ||
constAttr = DenseElementsAttr::get(constTy, constVec); | ||
} else if (lhsElemTy.isa<mlir::IntegerType>()) { | ||
std::vector<APInt> constVec( | ||
lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth())); | ||
constAttr = DenseElementsAttr::get(constTy, constVec); | ||
} | ||
Value rhs = | ||
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constTy, constAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Value rhs = chlo::getConstantLike(rewriter, loc, 0.0, input);
size_t inPos = inShape.size() - 1 - i; | ||
int64_t outDim = outShape[outPos]; | ||
int64_t inDim = inShape[inPos]; | ||
if (inDim == outDim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation doesn't support implicit broadcast when inDim == outDim == -1
both unknown?
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); | ||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); | ||
|
||
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, lhsTensor, rhsTensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use chlo::BroadcastMulOp to support implicit/dynamic broadcast
lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); | ||
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); | ||
|
||
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outType, lhsTensor, rhsTensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \ | ||
target.addIllegalOp<AtenOp>(); \ | ||
patterns.add<ConvertAtenAddSubOp<AtenOp, MhloOp>>(typeConverter, context); | ||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use chlo::Broadcast[Add]SubOp to support implicit/dynamic broadcast
op->getContext(), mhlo::ComparisonDirection::NE); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<mhlo::CompareOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: chlo::BroadcastCompareOp
SEQ_LEN = 128 | ||
data = { | ||
'input_ids': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)), | ||
# 'labels': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the comment code?
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { | ||
let summary = "Convert Torch ops to MHLO ops"; | ||
let description = [{ | ||
Convert ATen ops to mhlo ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert ATen ops to mhlo ops.
Convert Torch ops ... ? Seems not only aten ops in this pass.
}; | ||
|
||
template<> | ||
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is too long, does torch-mlir community follow any code style like Google Code Style?
cc @silvasean
}; | ||
|
||
template<> | ||
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment to verify descript the lowering logic, e.g. Reciprocal(x) = Div(1, x)
?
AtenBatchNormOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
Value input = adaptor.input(); | ||
// shape = [N C H W] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[N C H W] => [N, C, H, W]
Value bias = adaptor.bias(); | ||
Value runningMean = adaptor.running_mean(); | ||
Value runningVar = adaptor.running_var(); | ||
// momentum is ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove momentum
related code?
|
||
DenseElementsAttr valueAttr = | ||
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { | ||
return APInt(bitWidth, v.getSExtValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about unsigned int?
// ----- | ||
|
||
// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( | ||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format the code?
@@ -0,0 +1 @@ | |||
./build/bin/torch-mlir-opt < test/Conversion/TorchToMhlo/custom.mlir -convert-torch-to-mhlo -split-input-file -verify-diagnostics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about other FileCheck uts?
@@ -0,0 +1 @@ | |||
./build/bin/torch-mlir-opt < test/Conversion/TorchToMhlo/custom.mlir -convert-torch-to-mhlo -split-input-file -verify-diagnostics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can not find custom.mlir
in this pull request.
I just found a better reference on coding style for us: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
@silvasean @powderluv @fortianyou @Yancey1989 Thanks to everyone who reviews here. According to the offline meeting between Bytedance and Alibaba. We decide to break the proposal into several PRs:
We will address the issue mentioned above in the following PRs and track the progress in the RFC #999 |
…lvm#1025) * Fix missleading output when bad path to onnx model is given Signed-off-by: Yasushi Negishi <negishi@jp.ibm.com> * Simply the fix by avoiding to open the input file twice. Signed-off-by: Yasushi Negishi <negishi@jp.ibm.com>
See RFC: #999
TODO:
Co-authored-by: @byronyi @Vremold Xuanrun Zhang