Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Init Torch to MHLO conversion. #1025

Closed
wants to merge 1 commit into from
Closed

Conversation

ZihengJiang
Copy link
Collaborator

@ZihengJiang ZihengJiang commented Jul 7, 2022

See RFC: #999

TODO:

  • Remove dependency for FuncTorch in examples
  • Update BERT and ResNet example
  • Configurable Decomposition
  • Clean and format

Co-authored-by: @byronyi @Vremold Xuanrun Zhang

Copy link
Collaborator

@powderluv powderluv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Drive by comments. I know it is only a draft PR but wanted to get comments early.

@@ -0,0 +1,33 @@
# in-tree build
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably the scripts in scripts/ are meant to be local scripts and not checked in ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I left it here as the example to build the project. Will remove it in the end.

f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}",
f"-DLLVM_EXTERNAL_TORCH_MLIR_DIALECTS_SOURCE_DIR={src_dir}/externals/llvm-external-projects/torch-mlir-dialects",
f"-DLLVM_EXTERNAL_MLIR_HLO_SOURCE_DIR={src_dir}/externals/mlir-hlo",
f"-DMLIR_PDLL_TABLEGEN_EXE=mlir-pdll", # FIXME: set MLIR_PDLL_TABLEGEN_EXE since mlir-hlo doesn't
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we submit a PR upstream in mhlo for setting mlir-pdll ?

@@ -95,6 +110,9 @@ else()
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include)
set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
# since mhlo didn't set INTERFACE_DIRECTORIES for their target, we need include mhlo directories globally
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably good to file and issue upstream so they are aware of this and maybe if it is easy add a PR upstream.

EXCLUDE_FROM_ALL)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the whole CMAKE_CURRENT_BINARY_DIR required since we are adding it globally ?

@@ -1,3 +1,6 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably should land the llvm update and even mhlo as submodules first.

@@ -41,7 +41,13 @@ torch_mlir_add_llvm_external_project(
TORCH_MLIR_DIALECTS
${CMAKE_CURRENT_SOURCE_DIR}/externals/llvm-external-projects/torch-mlir-dialects)

torch_mlir_add_llvm_external_project(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a top level TORCH_MLIR_MHLO (name can be anything @silvasean any suggestions ?) CMake flag that can enable / disable the MHLO backend ? This can be the big hammer in case we have to disable it for any reason (broken on macOS etc) .

@@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrap in same TORCH_MLIR_MHLO (?)

@silvasean
Copy link
Contributor

One thing I would like is if we can add this to our e2e test suite, even if only a few tests pass. We can model it on the Tosa support

Add a case here:

if args.config == 'tosa':

Other related files:
python/torch_mlir_e2e_test/torchscript/configs/tosa_backend.py
python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py -- will require using MHLOToLinalg lowerings
e2e_testing/torchscript/xfail_sets.py - add list of expected tests to pass

@@ -81,7 +87,16 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(TORCH-MLIR_BUILT_STANDALONE 1)
set(BACKEND_PACKAGE_STRING "LLVM ${LLVM_PACKAGE_VERSION}")
add_subdirectory(externals/llvm-external-projects/torch-mlir-dialects)

set(MHLO_BUILD_EMBEDDED ON)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is MHLO_BUILD_EMBEDDED meant for? It seems not used.

fout.write(str(module))
print("MHLO module has been save to {}".format(fname))
print("MHLO execution is not support yet. Stopped.")
exit(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The codes following will not be reached, and should be removed.

MLIRContext *context = patterns.getContext();

target.addIllegalOp<AtenViewOp>();
patterns.add<ConvertAtenViewOp>(typeConverter, context);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think operators like aten::view, aten::expand, aten::flatten, aten::[un]squeeze can be split into another source file, such as ViewLikeOps.cpp

target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);

target.addIllegalOp<AtenSliceTensorOp>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: Those slices like operators can be moved to SliceLikeOps.cpp

return op->emitError("unimplemented: dim is not constant");
uint64_t batchDims = 0;

rewriter.replaceOpWithNewOp<mhlo::TorchIndexSelectOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's wired that MHLO has this operator. I think mhlo::DynamicGatherOp is sufficient to represent the semantic. Is anyone knows the difference?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mhlo::TorchIndexSelect was added to simplify lowering tf.GatherV2 to MHLO in this commit 2.5 years ago. It is modelled after the client HLO API TorchIndexSelect.

I agree that it doesn't look particularly essential for the MHLO dialect. We've been recently thinking about moving it to CHLO, along with a few other ops which are modelled after client HLO APIs which don't correspond to dedicated HLO ops. (More specifically, these other ops are: mhlo.broadcast - keeping broadcast_in_dim in MHLO, mhlo.create_token, mhlo.cross_replica_sum, mhlo.dot - keeping dot_general in MHLO, mhlo.einsum, mhlo.unary_einsum).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, we will change this part's implementation in later commits.

patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto: I think operations like aten::gather, aten::index_select and so on can be moved to MemOps.cpp

Comment on lines +947 to +951
target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);

target.addIllegalOp<AtenPermuteOp>();
patterns.add<ConvertAtenPermuteOp>(typeConverter, context);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: Can be move to PermutationsOps.cpp

Comment on lines +867 to +885
int64_t lhsSize = 1;
for (auto &en : llvm::enumerate(lhsTy.getShape())) {
lhsSize *= en.value();
}
auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy);
DenseElementsAttr constAttr;
if (lhsElemTy.isa<mlir::FloatType>()) {
std::vector<APFloat> constVec(
lhsSize,
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
/*negative=*/false));
constAttr = DenseElementsAttr::get(constTy, constVec);
} else if (lhsElemTy.isa<mlir::IntegerType>()) {
std::vector<APInt> constVec(
lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth()));
constAttr = DenseElementsAttr::get(constTy, constVec);
}
Value rhs =
rewriter.create<mhlo::ConstantOp>(op.getLoc(), constTy, constAttr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Value rhs = chlo::getConstantLike(rewriter, loc, 0.0, input);

size_t inPos = inShape.size() - 1 - i;
int64_t outDim = outShape[outPos];
int64_t inDim = inShape[inPos];
if (inDim == outDim) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation doesn't support implicit broadcast when inDim == outDim == -1 both unknown?

lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);

rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, lhsTensor, rhsTensor);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use chlo::BroadcastMulOp to support implicit/dynamic broadcast

lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType);
rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType);

rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, outType, lhsTensor, rhsTensor);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAddSubOp<AtenOp, MhloOp>>(typeConverter, context);
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use chlo::Broadcast[Add]SubOp to support implicit/dynamic broadcast

op->getContext(), mhlo::ComparisonDirection::NE);
}

rewriter.replaceOpWithNewOp<mhlo::CompareOp>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: chlo::BroadcastCompareOp

SEQ_LEN = 128
data = {
'input_ids': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)),
# 'labels': torch.randint(30522, (BATCH_SIZE, SEQ_LEN)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comment code?

def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
let description = [{
Convert ATen ops to mhlo ops.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Convert ATen ops to mhlo ops.

Convert Torch ops ... ? Seems not only aten ops in this pass.

};

template<>
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too long, does torch-mlir community follow any code style like Google Code Style?

cc @silvasean

};

template<>
LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(AtenReciprocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment to verify descript the lowering logic, e.g. Reciprocal(x) = Div(1, x)?

AtenBatchNormOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = adaptor.input();
// shape = [N C H W]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[N C H W] => [N, C, H, W]

Value bias = adaptor.bias();
Value runningMean = adaptor.running_mean();
Value runningVar = adaptor.running_var();
// momentum is ignored
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove momentum related code?


DenseElementsAttr valueAttr =
elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about unsigned int?

// -----

// CHECK-LABEL: func.func @torch.aten.addtensor$alpha(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format the code?

@@ -0,0 +1 @@
./build/bin/torch-mlir-opt < test/Conversion/TorchToMhlo/custom.mlir -convert-torch-to-mhlo -split-input-file -verify-diagnostics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about other FileCheck uts?

@@ -0,0 +1 @@
./build/bin/torch-mlir-opt < test/Conversion/TorchToMhlo/custom.mlir -convert-torch-to-mhlo -split-input-file -verify-diagnostics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can not find custom.mlir in this pull request.

@silvasean
Copy link
Contributor

I just found a better reference on coding style for us: https://mlir.llvm.org/getting_started/DeveloperGuide/

@ZihengJiang
Copy link
Collaborator Author

ZihengJiang commented Jul 13, 2022

@silvasean @powderluv @fortianyou @Yancey1989

Thanks to everyone who reviews here. According to the offline meeting between Bytedance and Alibaba. We decide to break the proposal into several PRs:

  • Integrate MHLO into TorchMLIR Repo
  • Basic conversion pipeline
  • Operator conversions
  • ResNet and BERT examples

We will address the issue mentioned above in the following PRs and track the progress in the RFC #999

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this pull request Oct 3, 2022
…lvm#1025)

* Fix missleading output when bad path to onnx model is given

Signed-off-by: Yasushi Negishi <negishi@jp.ibm.com>

* Simply the fix by avoiding to open the input file twice.

Signed-off-by: Yasushi Negishi <negishi@jp.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants