Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tcp] Add lowering from Torch to Tcp for tanh, add, and broadcast ops + Enable e2e tests for Tcp #1695

Merged
merged 6 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
endif()

option(TORCH_MLIR_ENABLE_TCP "Add TCP dialect" OFF)
if (TORCH_MLIR_ENABLE_TCP)
add_definitions(-DTORCH_MLIR_ENABLE_TCP)
endif()

option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)

if(TORCH_MLIR_ENABLE_LTC)
Expand Down
5 changes: 4 additions & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ function build_in_tree() {
-DLLVM_TARGETS_TO_BUILD=host \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_LTC=ON \
-DTORCH_MLIR_DIALECTS_ENABLE_TCP=OFF \
-DTORCH_MLIR_ENABLE_TCP=ON \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_bin" \
-DTORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO} \
-DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \
Expand Down Expand Up @@ -277,6 +277,9 @@ function test_in_tree() {

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v

echo ":::: Run TCP e2e integration tests"
python -m e2e_testing.main --config=tcp -v
}

function setup_venv() {
Expand Down
10 changes: 8 additions & 2 deletions e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
NativeTorchTestConfig,
TcpBackendTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
EagerModeTestConfig,
Expand All @@ -26,16 +27,17 @@

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tcp_backends.linalg_on_tensors import LinalgOnTensorsTcpBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TCP_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tcp', 'tosa', 'eager_mode', 'lazy_tensor_core', 'torchdynamo']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
Expand All @@ -44,6 +46,7 @@ def _get_argparse():
Meaning of options:
"refbackend": run through torch-mlir's RefBackend.
"mhlo": run through torch-mlir's default MHLO backend.
"tcp": run through torch-mlir's default TCP backend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -85,6 +88,9 @@ def main():
if args.config == 'mhlo':
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == 'tcp':
config = TcpBackendTestConfig(LinalgOnTensorsTcpBackend())
xfail_set = all_test_unique_names - TCP_PASS_SET
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
23 changes: 23 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,29 @@
"AtenRoundIntModule_basic",
}

TCP_PASS_SET = {
"AtenRoundIntModule_basic",
"AtenToDeviceModule_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnMixedModule_basic",
"BoolTensorReturnTrueModule_basic",
"BroadcastToModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic",
"ExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TestMultipleTensorReturn_basic",
"TypeAsSameModule_basic",
"UnsafeView1DFoldModule_basic",
"View1DFoldModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ endif()

option(MLIR_ENABLE_BINDINGS_PYTHON "Enables MLIR Python Bindings" OFF)

option(TORCH_MLIR_DIALECTS_ENABLE_TCP "Add TCP dialect" OFF)
if(TORCH_MLIR_DIALECTS_ENABLE_TCP)
if(TORCH_MLIR_ENABLE_TCP)
option(TORCH_MLIR_DIALECTS_ENABLE_TCP "Add TCP dialect" ON)
add_definitions(-DTORCH_MLIR_DIALECTS_ENABLE_TCP)
else()
option(TORCH_MLIR_DIALECTS_ENABLE_TCP "Add TCP dialect" OFF)
endif()

set(TORCH_MLIR_DIALECTS_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ Value createElementwiseLinalgGeneric(
SmallVector<AffineMap> indexingMaps{tensorOperands.size() + 1,
b.getMultiDimIdentityMap(resultRank)};

SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
SmallVector<utils::IteratorType> iteratorTypes(resultRank,
utils::IteratorType::parallel);

Value emptyTensor = b.create<tensor::EmptyOp>(
loc, resultDimSizes, resultTensorType.getElementType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ class ConvertBroadcastOp : public OpConversionPattern<BroadcastOp> {
indexingMaps.push_back(inputIndexingMap);
indexingMaps.push_back(outputIndexingMap);

SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);

Value emptyTensor =
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes),
Expand Down
12 changes: 10 additions & 2 deletions include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_TCP)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO -DTORCH_MLIR_ENABLE_TCP)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
endif()
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
if(TORCH_MLIR_ENABLE_TCP)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_TCP)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
endif()
add_public_tablegen_target(TorchMLIRConversionPassIncGen)

Expand Down
10 changes: 10 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,14 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
}
#endif

#ifdef TORCH_MLIR_ENABLE_TCP
def ConvertTorchToTcp : Pass<"convert-torch-to-tcp", "func::FuncOp"> {
let summary = "Convert Torch ops to Tcp ops";
let description = [{
Convert Torch ops to Tcp ops.
}];
let constructor = "mlir::torch::createConvertTorchToTcpPass()";
}
#endif // TORCH_MLIR_ENABLE_TCP

#endif // TORCHMLIR_CONVERSION_PASSES
24 changes: 24 additions & 0 deletions include/torch-mlir/Conversion/TorchToTcp/TorchToTcp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHTOTCP_TORCHTOTCP_H
#define TORCHMLIR_CONVERSION_TORCHTOTCP_TORCHTOTCP_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace torch {

std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTcpPass();

} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOTCP_TORCHTOTCP_H
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_TCP)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO -DTORCH_MLIR_ENABLE_TCP)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
endif()
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
if(TORCH_MLIR_ENABLE_TCP)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_TCP)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
endif()
add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ void createTorchBackendToMhloBackendPipeline(
std::unique_ptr<OperationPass<ModuleOp>> createVerifyMhloBackendContractPass();
#endif

#ifdef TORCH_MLIR_ENABLE_TCP
// Creates a pipeline that lowers from the torch backend contract to the TCP
// backend contract.
void createTorchBackendToTcpBackendPipeline(OpPassManager &pm);
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTcpBackendContractPass();
#endif // TORCH_MLIR_ENABLE_TCP

std::unique_ptr<OperationPass<ModuleOp>> createFuncBackendTypeConversionPass();

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,12 @@ def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "Modu
let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_MHLO

#ifdef TORCH_MLIR_ENABLE_TCP
def VerifyTcpBackendContract : Pass<"torch-verify-tcp-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the tcp backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyTcpBackendContractPass()";
}
#endif // TORCH_MLIR_ENABLE_TCP

#endif // TORCHMLIR_TORCHCONVERSION_PASSES
7 changes: 7 additions & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ add_subdirectory(TorchToTosa)
if(TORCH_MLIR_ENABLE_MHLO)
add_subdirectory(TorchToMhlo)
endif()
if (TORCH_MLIR_ENABLE_TCP)
add_subdirectory(TorchToTcp)
endif()
add_subdirectory(TorchToTMTensor)
add_subdirectory(TorchConversionToMLProgram)
add_subdirectory(Utils)
Expand All @@ -20,6 +23,10 @@ set(linked_libs TorchMLIRTorchToLinalg
if(TORCH_MLIR_ENABLE_MHLO)
list(APPEND linked_libs TorchMLIRTorchToMhlo)
endif()
if(TORCH_MLIR_ENABLE_TCP)
list(APPEND linked_libs TorchMLIRTorchToTcp)
list(APPEND linked_libs TorchMLIRTcpToLinalg)
endif()

add_mlir_library(TorchMLIRConversionPasses
Passes.cpp
Expand Down
9 changes: 9 additions & 0 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h"
#include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h"
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
#ifdef TORCH_MLIR_ENABLE_TCP
#include "torch-mlir-dialects/Conversion/TcpToLinalg/TcpToLinalg.h"
#include "torch-mlir/Conversion/TorchToTcp/TorchToTcp.h"
#endif // TORCH_MLIR_ENABLE_TCP

//===----------------------------------------------------------------------===//
// Pass registration
Expand All @@ -40,4 +44,9 @@ void mlir::torch::registerConversionPasses() {
return mlir::createSymbolicShapeOptimizationPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
#if TORCH_MLIR_ENABLE_TCP
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::tcp::createConvertTcpToLinalgPass();
});
#endif // TORCH_MLIR_ENABLE_TCP
}
23 changes: 23 additions & 0 deletions lib/Conversion/TorchToTcp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
add_mlir_conversion_library(TorchMLIRTorchToTcp
Elementwise.cpp
Misc.cpp
TorchToTcp.cpp
Utils.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTcp

DEPENDS
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
TorchMLIRTcpDialect
TorchMLIRTorchDialect
)

torch_mlir_target_includes(TorchMLIRTorchToTcp)
Loading