diff --git a/CMakeLists.txt b/CMakeLists.txt index 94e41f3ab45c..229e7bd5ab2a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -372,14 +372,14 @@ cmake_dependent_option(IREE_TARGET_BACKEND_WEBGPU "Enables the 'webgpu' compiler # Compiler Input Dialects #------------------------------------------------------------------------------- -cmake_dependent_option(IREE_INPUT_MHLO "Builds support for compiling MHLO programs" ON ${IREE_BUILD_COMPILER} OFF) +cmake_dependent_option(IREE_INPUT_STABLEHLO "Builds support for compiling StableHLO programs" ON ${IREE_BUILD_COMPILER} OFF) cmake_dependent_option(IREE_INPUT_TORCH "Builds support for compiling Torch MLIR programs" ON ${IREE_BUILD_COMPILER} OFF) cmake_dependent_option(IREE_INPUT_TOSA "Builds support for compiling TOSA programs" ON ${IREE_BUILD_COMPILER} OFF) if(IREE_BUILD_COMPILER) message(STATUS "IREE compiler input dialects:") - if(IREE_INPUT_MHLO) - message(STATUS " - MHLO") + if(IREE_INPUT_STABLEHLO) + message(STATUS " - StableHLO") endif() if(IREE_INPUT_TORCH) message(STATUS " - Torch MLIR") diff --git a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py index 36d668f657c8..7c70e30d3291 100644 --- a/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py +++ b/build_tools/bazel_to_cmake/bazel_to_cmake_targets.py @@ -77,47 +77,9 @@ def __init__(self, repo_map: Dict[str, str]): "@llvm-project//mlir:MlirOptLib": ["MLIROptLib"], "@llvm-project//mlir:VectorOps": ["MLIRVector"], - # MHLO. - # TODO: Rework this upstream so that Bazel and CMake rules match up - # better. - # All of these have to depend on tensorflow::external_mhlo_includes to - # ensure that include directories are inherited. - "@mlir-hlo//:chlo_legalize_to_hlo": [ - "tensorflow::external_mhlo_includes", - "ChloPasses", - ], - "@mlir-hlo//:mlir_hlo": [ - "tensorflow::external_mhlo_includes", - "MhloDialect", - "MLIRMhloUtils", - ], - "@mlir-hlo//:map_chlo_to_hlo_op": [ - "ChloOps", - "MhloDialect", - ], - "@mlir-hlo//:map_mhlo_to_scalar_op": [ - "tensorflow::external_mhlo_includes", - "MhloDialect", - ], - "@mlir-hlo//:mhlo_passes": [ - "tensorflow::external_mhlo_includes", - "MhloPasses", - "MhloShapeOpsToStandard", - "MhloToLinalg", - "MhloToStablehlo", - "MhloToStandard", - "StablehloToMhlo", - # Note: We deliberately omit some passes that we do not use in IREE, - # e.g.: MhloToArithmeticConversion, MhloToLhloConversion, or - # MhloToMemrefConversion. - ], - "@mlir-hlo//:unfuse_batch_norm": [ - "tensorflow::external_mhlo_includes", - "MhloPasses", - ], + # StableHLO. "@mlir-hlo//stablehlo:chlo_ops": ["ChloOps",], "@mlir-hlo//stablehlo:stablehlo_ops": ["StablehloOps",], - "@mlir-hlo//:stablehlo_legalize_to_hlo_pass": ["StablehloToMhlo",], "@mlir-hlo//stablehlo:broadcast_utils": ["StablehloBroadcastUtils",], # NCCL diff --git a/build_tools/cmake/test_riscv.sh b/build_tools/cmake/test_riscv.sh index 0f995de906dd..163ff4296bc9 100755 --- a/build_tools/cmake/test_riscv.sh +++ b/build_tools/cmake/test_riscv.sh @@ -80,11 +80,10 @@ echo "******** Running tools CTest ********" ctest ${tools_ctest_args[@]} if [[ "${RISCV_PLATFORM}-${RISCV_ARCH}" == "linux-riscv_32" ]]; then - # mhlo.power is also disabled because musl math library is not compiled for + # stablehlo.power is also disabled because musl math library is not compiled for # 32-bit. test_exclude_args+=( "stablehlo.*llvm-cpu.*pow" - "xla.*llvm-cpu.*pow" ) fi @@ -96,7 +95,6 @@ test_exclude_args+=( "iree/tests/e2e/tensor_ops/check_llvm-cpu_local-task_pack_dynamic_inner_tiles.mlir" # TODO(#13421): Enable the tests "iree/tests/e2e/stablehlo_ops/check_llvm-cpu_local-task_dot.mlir" - "iree/tests/e2e/xla_ops/check_llvm-cpu_local-task_dot.mlir" "iree/tests/e2e/matmul/e2e_matmul_direct_i8_small_llvm-cpu_local-task" "iree/tests/e2e/matmul/e2e_matmul_direct_f32_small_llvm-cpu_local-task" "iree/tests/e2e/matmul/e2e_matmul_direct_f32_small_no_padding_llvm-cpu_local-task" diff --git a/compiler/bindings/python/iree/compiler/tools/core.py b/compiler/bindings/python/iree/compiler/tools/core.py index 90814d5c9ab0..ae3f5c48a556 100644 --- a/compiler/bindings/python/iree/compiler/tools/core.py +++ b/compiler/bindings/python/iree/compiler/tools/core.py @@ -46,8 +46,6 @@ class InputType(Enum): STABLEHLO_XLA = "stablehlo_xla" TOSA = "tosa" TM_TENSOR = "tm_tensor" - MHLO_LEGACY = "mhlo_legacy" - XLA_LEGACY = "xla_legacy" @staticmethod def parse(spec: Union[str, InputType]) -> InputType: diff --git a/compiler/bindings/python/iree/compiler/tools/tf.py b/compiler/bindings/python/iree/compiler/tools/tf.py index f3b957dd0c10..00c157349b39 100644 --- a/compiler/bindings/python/iree/compiler/tools/tf.py +++ b/compiler/bindings/python/iree/compiler/tools/tf.py @@ -97,7 +97,7 @@ class ImportOptions(CompilerOptions): exported_names: Sequence[str] = () import_only: bool = False import_type: ImportType = ImportType.OBJECT_GRAPH - input_type: Union[InputType, str] = InputType.XLA_LEGACY + input_type: Union[InputType, str] = InputType.STABLEHLO_XLA saved_model_tags: Set[str] = field(default_factory=set) save_temp_iree_input: Optional[str] = None diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp index 2fcf59b8590d..07429ad29626 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/VerifyInputLegality.cpp @@ -26,8 +26,9 @@ class VerifyInputLegalityPass target.addLegalOp(); // We're already depending on the Tosa Dialect target.addIllegalDialect(); - // Avoid MHLO dependency - target.addIllegalDialect("mhlo"); + // Avoid StableHLO dependency + target.addIllegalDialect("chlo"); + target.addIllegalDialect("stablehlo"); target.addIllegalOp(); if (failed(iree_compiler::verifyAllOperationsAreLegal(getOperation(), diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir index 9044969a3d89..82822c6e6e58 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/verify_input_ir.mlir @@ -1,10 +1,12 @@ // RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-verify-input-legality))" --verify-diagnostics %s -split-input-file // expected-error@below {{illegal operations still remain}} -func.func @check_no_mhlo(%arg0: tensor, %arg1 : tensor) -> tensor { +func.func @check_no_stablehlo(%arg0: tensor, %arg1 : tensor) -> tensor { // expected-error@+1 {{illegal op still exists}} - %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - return %0 : tensor + %0 = stablehlo.add %arg0, %arg1 : tensor + // expected-error@+1 {{illegal op still exists}} + %1 = chlo.broadcast_add %0, %arg1 : (tensor, tensor) -> tensor + return %1 : tensor } // ----- diff --git a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt index 474215f2d60d..6e578b36bd61 100644 --- a/compiler/src/iree/compiler/InputConversion/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/CMakeLists.txt @@ -6,8 +6,7 @@ add_subdirectory(Common) -if(IREE_INPUT_MHLO) - add_subdirectory(MHLO) +if(IREE_INPUT_STABLEHLO) add_subdirectory(StableHLO) endif() if(IREE_INPUT_TORCH) diff --git a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp index 42a36c98f935..be7a9b76d428 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/AutoInputConversionPipeline.cpp @@ -14,12 +14,10 @@ #include "mlir/Pass/PassManager.h" // Dialect specific -#ifdef IREE_HAVE_MHLO_INPUT -#include "iree/compiler/InputConversion/MHLO/Passes.h" +#ifdef IREE_HAVE_STABLEHLO_INPUT #include "iree/compiler/InputConversion/StableHLO/Passes.h" -#include "mhlo/IR/hlo_ops.h" #include "stablehlo/dialect/StablehloOps.h" -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TOSA_INPUT #include "iree/compiler/InputConversion/TOSA/Passes.h" #endif // IREE_HAVE_TOSA_INPUT @@ -45,7 +43,6 @@ struct AutoInputConversionPipelinePass final struct InputFeatures { // HLO features. bool hasStableHLO = false; - bool hasMHLO = false; // - XLA import features. bool hasTuples = false; @@ -93,7 +90,6 @@ static void populateHloFeatures(Operation* op, InputFeatures& features) { } static void populateFeatures(Operation* op, const Dialect* stablehloDialect, - const Dialect* mhloDialect, const Dialect* tmTensorDialect, const Dialect* tosaDialect, InputFeatures& features) { @@ -102,10 +98,6 @@ static void populateFeatures(Operation* op, const Dialect* stablehloDialect, features.hasStableHLO = true; return populateHloFeatures(op, features); } - if (d == mhloDialect) { - features.hasMHLO = true; - return populateHloFeatures(op, features); - } if (d == tosaDialect) { features.hasTOSA = true; return; @@ -122,22 +114,20 @@ void AutoInputConversionPipelinePass::runOnOperation() { InputFeatures features; const Dialect* stablehloDialect = ctxt->getLoadedDialect("stablehlo"); - const Dialect* mhloDialect = ctxt->getLoadedDialect("mhlo"); const Dialect* tosaDialect = ctxt->getLoadedDialect("tosa"); const Dialect* tmTensorDialect = ctxt->getLoadedDialect("tm_tensor"); - if (!stablehloDialect && !mhloDialect && !tosaDialect && !tmTensorDialect) { + if (!stablehloDialect && !tosaDialect && !tmTensorDialect) { return; } auto res = module.walk([&](Operation* op) { - populateFeatures(op, stablehloDialect, mhloDialect, tmTensorDialect, - tosaDialect, features); - bool hasAnyHLO = features.hasStableHLO || features.hasMHLO; - if (hasAnyHLO && features.hasTOSA) { + populateFeatures(op, stablehloDialect, tmTensorDialect, tosaDialect, + features); + if (features.hasStableHLO && features.hasTOSA) { module.emitError("not yet implemented mixture of *HLO and TOSA"); return WalkResult::interrupt(); } - if (hasAnyHLO && features.hasTmTensor) { + if (features.hasStableHLO && features.hasTmTensor) { module.emitError("not yet implemented mixture of *HLO and TM Tensor"); return WalkResult::interrupt(); } @@ -150,15 +140,14 @@ void AutoInputConversionPipelinePass::runOnOperation() { if (res.wasInterrupted()) { return signalPassFailure(); } - if (!features.hasStableHLO && !features.hasMHLO && !features.hasTOSA && - !features.hasTmTensor) { + if (!features.hasStableHLO && !features.hasTOSA && !features.hasTmTensor) { return; } OpPassManager pm(ModuleOp::getOperationName(), OpPassManager::Nesting::Explicit); -#ifdef IREE_HAVE_MHLO_INPUT - if (features.hasStableHLO && !features.hasMHLO) { +#ifdef IREE_HAVE_STABLEHLO_INPUT + if (features.hasStableHLO) { stablehlo::StableHloOptions options; options.demoteI64ToI32 = demoteI64ToI32; options.demoteF64ToF32 = demoteF64ToF32; @@ -169,14 +158,7 @@ void AutoInputConversionPipelinePass::runOnOperation() { stablehlo::buildStableHLOInputConversionPassPipeline(pm, options); } } - if (features.hasMHLO) { - if (features.hasTuples) { - MHLO::buildXLAInputConversionPassPipeline(pm); - } else { - MHLO::buildMHLOInputConversionPassPipeline(pm); - } - } -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TOSA_INPUT if (features.hasTOSA) { buildTOSAInputConversionPassPipeline(pm); @@ -209,7 +191,7 @@ void AutoInputConversionPipelinePass::getDependentDialects( pm.getDependentDialects(registry); }; -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT auto appendStablehloPipelineDialects = [®istry](function_ref @@ -224,10 +206,7 @@ void AutoInputConversionPipelinePass::getDependentDialects( stablehlo::buildStableHLOInputConversionPassPipeline); appendStablehloPipelineDialects( stablehlo::buildStableHLOXLAInputConversionPassPipeline); - - appendPipelineDialects(MHLO::buildMHLOInputConversionPassPipeline); - appendPipelineDialects(MHLO::buildXLAInputConversionPassPipeline); -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TOSA_INPUT appendPipelineDialects(buildTOSAInputConversionPassPipeline); diff --git a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel index 8f61d1d27a0e..3f944a582667 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/Common/BUILD.bazel @@ -95,7 +95,6 @@ iree_compiler_cc_library( deps = [ ":PassHeaders", ":PassesIncGen", - "//compiler/src/iree/compiler/InputConversion/MHLO", "//compiler/src/iree/compiler/InputConversion/StableHLO", "//compiler/src/iree/compiler/InputConversion/TMTensor", "//compiler/src/iree/compiler/InputConversion/TOSA", @@ -106,7 +105,6 @@ iree_compiler_cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:TosaDialect", "@llvm-project//mlir:Transforms", - "@mlir-hlo//:mlir_hlo", "@mlir-hlo//stablehlo:stablehlo_ops", "@torch-mlir-dialects//:TorchMLIRTMTensorDialect", ], diff --git a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt index e446adf33e94..6864b3cb4d7d 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/Common/CMakeLists.txt @@ -6,8 +6,7 @@ # Enable input dialects based on options. set(IREE_INPUT_DEPS "") -if(IREE_INPUT_MHLO) - list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::MHLO) +if(IREE_INPUT_STABLEHLO) list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::StableHLO) endif() if(IREE_INPUT_TORCH) diff --git a/compiler/src/iree/compiler/InputConversion/Common/test/auto_input_conversion_pipeline.mlir b/compiler/src/iree/compiler/InputConversion/Common/test/auto_input_conversion_pipeline.mlir index 11ac3449ffcf..b2ff749f8ef4 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/test/auto_input_conversion_pipeline.mlir +++ b/compiler/src/iree/compiler/InputConversion/Common/test/auto_input_conversion_pipeline.mlir @@ -8,12 +8,3 @@ func.func @simple_add_stablehlo(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) %0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } - -// ----- - -// CHECK-LABEL: func.func @simple_add_mhlo -// CHECK: arith.addi -func.func @simple_add_mhlo(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - return %0 : tensor<2x2xi32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/MHLO/BUILD.bazel deleted file mode 100644 index e0fc50675a8a..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/BUILD.bazel +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_library", "iree_gentbl_cc_library") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_gentbl_cc_library( - name = "PassesIncGen", - tbl_outs = [ - ( - ["--gen-pass-decls"], - "Passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "Passes.td", - deps = [ - "@llvm-project//mlir:PassBaseTdFiles", - ], -) - -iree_compiler_cc_library( - name = "PassHeaders", - hdrs = [ - "PassDetail.h", - "Passes.h", - "Passes.h.inc", - "Rewriters.h", - ], - deps = [ - ":PassesIncGen", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - ], -) - -iree_compiler_cc_library( - name = "MHLO", - srcs = [ - "BroadcastingToLinalgPatterns.cpp", - "ConvertCollectiveOps.cpp", - "ConvertComplexToReal.cpp", - "ConvertMHLOToFlow.cpp", - "ConvertMHLOToFlow.h", - "ConvertMHLOToLinalgExt.cpp", - "ConvertMHLOToStableHLO.cpp", - "FlattenTuplesInCFG.cpp", - "MHLOToLinalgOnTensors.cpp", - "MHLOToMHLOPreprocessing.cpp", - "Passes.cpp", - "VerifyCompilerMHLOInputLegality.cpp", - ], - hdrs = [ - "Passes.h", - ], - defines = [ - "IREE_HAVE_MHLO_INPUT", - ], - deps = [ - ":PassHeaders", - ":PassesIncGen", - "//compiler/src/iree/compiler/Dialect/Flow/IR", - "//compiler/src/iree/compiler/Dialect/Util/IR", - "//compiler/src/iree/compiler/Dialect/Util/Transforms", - "//compiler/src/iree/compiler/InputConversion/Common", - "//compiler/src/iree/compiler/Utils", - "//llvm-external-projects/iree-dialects:IREELinalgExtDialect", - "//llvm-external-projects/iree-dialects:IREELinalgExtPasses", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:ControlFlowDialect", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:MLProgramDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:SCFTransforms", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:ShapeToStandard", - "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:Transforms", - "@mlir-hlo//:chlo_legalize_to_hlo", - "@mlir-hlo//:map_chlo_to_hlo_op", - "@mlir-hlo//:map_mhlo_to_scalar_op", - "@mlir-hlo//:mhlo_passes", - "@mlir-hlo//:mlir_hlo", - "@mlir-hlo//stablehlo:broadcast_utils", - "@mlir-hlo//stablehlo:chlo_ops", - "@mlir-hlo//stablehlo:stablehlo_ops", - ], -) diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp deleted file mode 100644 index e8b54e4b2d91..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/BroadcastingToLinalgPatterns.cpp +++ /dev/null @@ -1,822 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Patterns for lowering from dynamic-shape sensitive CHLO/MHLO ops. This -// primarily involves broadcasting ops but also includes other ops that have -// an impact on dynamic shape conversions. - -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/InputConversion/MHLO/Rewriters.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/map_chlo_to_hlo_op.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "stablehlo/dialect/BroadcastUtils.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -// ----------------------------------------------------------------------------- -// Broadcasting utilities -// ----------------------------------------------------------------------------- - -/// Whether an element type is legal for codegen via linalg on IREE. -bool isElementTypeLegalForCodegen(Type t) { return !llvm::isa(t); } - -/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes -/// are "parallel" except the last `nReduction` elements, where are "reduction" -/// attributes. -SmallVector getParallelAndReductionIterators( - int nLoops, int nReduction) { - SmallVector res(nLoops - nReduction, - utils::IteratorType::parallel); - res.append(nReduction, utils::IteratorType::reduction); - return res; -} - -SmallVector getNParallelLoopsAttrs(int nParallelLoops) { - return getParallelAndReductionIterators(nParallelLoops, 0); -} - -// Holds a static extent or Value for dynamic extents. -class Extent { - public: - Extent() {} - Extent(int64_t extent) : extent(extent) {} - Extent(Value value) : value(value) {} - - bool isStatic() const { return !value; } - bool isUnitExtent() const { return isStatic() && getStatic() == 1; } - int64_t getStatic() const { - assert(isStatic()); - return extent; - } - Value getValue() const { - assert(!isStatic()); - return value; - } - - Value convertToValue(OpBuilder &builder, Location loc) { - if (!isStatic()) return getValue(); - return builder.create(loc, getStatic()); - } - - private: - int64_t extent; - Value value; -}; - -inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const Extent &extent) { - if (extent.isStatic()) { - os << "DIM[" << extent.getStatic() << "]"; - } else { - os << "DIM[" << extent.getValue() << "]"; - } - return os; -} - -Value broadcast(OpBuilder &builder, Location loc, Value operand, - SmallVectorImpl &resultExtents, - SmallVectorImpl &isExpansion) { - auto operandType = llvm::cast(operand.getType()); - SmallVector resultShape; - SmallVector dynDims; - for (Extent &dim : resultExtents) { - if (dim.isStatic()) { - resultShape.push_back(dim.getStatic()); - } else { - resultShape.push_back(ShapedType::kDynamic); - dynDims.push_back(dim.getValue()); - } - } - - // Traverse the right aligned operand dimensions and form expressions. - // We keep 1-dims in place instead of reshaping them away, relying on the - // DropUnitDims pass to run later. - SmallVector dimExprs; - dimExprs.reserve(operandType.getRank()); - for (int i = resultExtents.size() - operandType.getRank(); - i < resultExtents.size(); ++i) { - if (isExpansion[i]) { - dimExprs.push_back(builder.getAffineConstantExpr(0)); - } else { - dimExprs.push_back(builder.getAffineDimExpr(i)); - } - } - - int nloops = resultExtents.size(); - Value init = builder.create( - loc, resultShape, operandType.getElementType(), dynDims); - auto generic = builder.create( - loc, TypeRange{init.getType()}, ValueRange{operand}, - /*outputBuffers=*/ValueRange{init}, - llvm::ArrayRef({ - AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dimExprs, - builder.getContext()), - builder.getMultiDimIdentityMap(nloops), - }), - getNParallelLoopsAttrs(nloops), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(loc, *args.begin()); - }); - return generic.getResult(0); -} - -Value broadcastScalar(OpBuilder &builder, Location loc, Value scalarValue, - SmallVectorImpl &resultExtents) { - SmallVector isExpansion(resultExtents.size()); - for (int i = 0, e = resultExtents.size(); i < e; ++i) { - isExpansion[i] = true; - } - return broadcast(builder, loc, scalarValue, resultExtents, isExpansion); -} - -std::optional computeBinaryResultExtent(OpBuilder &builder, - Location loc, Extent &lhsDim, - Extent &rhsDim, - bool &isLhsExpansion, - bool &isRhsExpansion) { - if (lhsDim.isStatic() && rhsDim.isStatic()) { - // Both are static. Just check. - if (lhsDim.getStatic() != rhsDim.getStatic() && - !(lhsDim.getStatic() == 1 || rhsDim.getStatic() == 1)) { - // Statically illegal. - emitError(loc) << "cannot broadcast extents of differing size unless " - "if one of them is 1 (got " - << lhsDim.getStatic() << ", " << rhsDim.getStatic() << ")"; - return std::nullopt; - } - - // Static expansions. - if (lhsDim.isUnitExtent() && rhsDim.isUnitExtent()) { - // For the fully static case, we can trivially check the 1-equality, - // and know we are not expanding. - isLhsExpansion = false; - isRhsExpansion = false; - } else { - // Otherwise, mark the dim as expanding if it is 1. - isLhsExpansion = lhsDim.isUnitExtent(); - isRhsExpansion = rhsDim.isUnitExtent(); - } - return Extent(std::max(lhsDim.getStatic(), rhsDim.getStatic())); - } - - // At least one of them is dynamic. - // Branch on whether one of them is a static-1, which is the only case - // we allow for dynamic expansion. - if (lhsDim.isUnitExtent() || rhsDim.isUnitExtent()) { - if (lhsDim.isUnitExtent()) { - isLhsExpansion = true; - isRhsExpansion = false; - return rhsDim; - } else { - isLhsExpansion = false; - isRhsExpansion = true; - return lhsDim; - } - } - - // At least one is dynamic and neither are a static 1. - // In this case, we do not allow either to be an expanding dim and - // error if this is the case at runtime. - isLhsExpansion = false; - isRhsExpansion = false; - Value lhsExtentValue = lhsDim.convertToValue(builder, loc); - Value rhsExtentValue = rhsDim.convertToValue(builder, loc); - - Value isEqual = builder.create(loc, arith::CmpIPredicate::eq, - lhsExtentValue, rhsExtentValue); - builder.create( - loc, isEqual, - builder.getStringAttr("mismatched dynamic broadcast extents")); - - // Here, if one of them is static, that has to be the result extent - // (because we checked the error condition above). - if (lhsDim.isStatic()) { - return Extent(lhsDim.getStatic()); - } else if (rhsDim.isStatic()) { - return Extent(rhsDim.getStatic()); - } - - // Both are dynamic. Compute the max. - return Extent(lhsExtentValue); -} - -std::optional computeTernaryResultExtent(OpBuilder &builder, - Location loc, Extent &aValue, - Extent &bValue, Extent &cValue, - bool &isAExpansion, - bool &isBExpansion, - bool &isCExpansion) { - // Collect non unit extents (which includes, implicitly, dynamic dims). - SmallVector nonUnitExtents; - if (!aValue.isUnitExtent()) nonUnitExtents.push_back(aValue); - if (!bValue.isUnitExtent()) nonUnitExtents.push_back(bValue); - if (!cValue.isUnitExtent()) nonUnitExtents.push_back(cValue); - - // Early exit if all unit extents. - if (nonUnitExtents.empty()) { - isAExpansion = false; - isBExpansion = false; - isCExpansion = false; - return aValue; - } - - // Are any a unit? - bool hasUnitExtent = false; - if (aValue.isUnitExtent()) hasUnitExtent = true; - if (bValue.isUnitExtent()) hasUnitExtent = true; - if (cValue.isUnitExtent()) hasUnitExtent = true; - - // Mark expansion for any unit. - if (hasUnitExtent) { - if (aValue.isUnitExtent()) isAExpansion = true; - if (bValue.isUnitExtent()) isBExpansion = true; - if (cValue.isUnitExtent()) isCExpansion = true; - } - - // By default, compare against the first non unit extent; however, prefer - // a static extent if present. - int nonUnitCompareExtentIndex = 0; - for (int i = 0, e = nonUnitExtents.size(); i < e; i++) { - if (nonUnitExtents[i].isStatic()) nonUnitCompareExtentIndex = i; - } - - // Generate checks for each non unit extent. - for (int i = 0, e = nonUnitExtents.size(); i < e; i++) { - if (i == nonUnitCompareExtentIndex) continue; - Extent &cmpLhs = nonUnitExtents[nonUnitCompareExtentIndex]; - Extent &cmpRhs = nonUnitExtents[i]; - // Static check. - if (cmpLhs.isStatic() && cmpRhs.isStatic()) { - if (cmpLhs.getStatic() != cmpRhs.getStatic()) { - // Statically illegal. - emitError(loc) << "cannot broadcast extents of differing size unless " - "if one of them is 1 (got " - << cmpLhs.getStatic() << ", " << cmpRhs.getStatic() - << ")"; - return std::nullopt; - } - continue; - } - // Dynamic check. - Value cmpLhsValue = cmpLhs.convertToValue(builder, loc); - Value cmpRhsValue = cmpRhs.convertToValue(builder, loc); - Value isEqual = builder.create(loc, arith::CmpIPredicate::eq, - cmpLhsValue, cmpRhsValue); - builder.create( - loc, isEqual, - builder.getStringAttr("mismatched dynamic broadcast extents")); - } - - // The result must be one of the non unit extents. Just take the one - // used for comparison. - return nonUnitExtents[nonUnitCompareExtentIndex]; -} - -void padExtents(SmallVectorImpl &extents, int size) { - for (int i = 0; i < size; ++i) { - extents.push_back({1}); - } -} - -void appendExtents(OpBuilder &builder, Location loc, - SmallVectorImpl &extents, Value v, - RankedTensorType t) { - for (int i = 0; i < t.getRank(); ++i) { - if (t.isDynamicDim(i)) { - // Emit a dim op. - Value dim = builder.create(loc, v, i); - extents.push_back(dim); - } else { - // Static dim. - extents.push_back({t.getDimSize(i)}); - } - } -} - -// ----------------------------------------------------------------------------- -// Structural op conversions -// ----------------------------------------------------------------------------- - -struct ConvertConstantLikeOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - chlo::ConstantLikeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultTy = llvm::cast(op.getType()); - if (!resultTy.hasRank()) - return rewriter.notifyMatchFailure(op, "only supports ranked"); - // Lower to MHLO constant if statically shaped. - if (resultTy.hasStaticShape()) { - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(resultTy, op.getValue())); - return success(); - } - - Location loc = op.getLoc(); - - int resultRank = resultTy.getRank(); - SmallVector resultExtents; - resultExtents.reserve(resultRank); - appendExtents(rewriter, loc, resultExtents, adaptor.getOperand(), resultTy); - - auto resultTy0D = RankedTensorType::get({}, resultTy.getElementType()); - Value scalarConst = rewriter.create( - loc, DenseElementsAttr::get(resultTy0D, op.getValue())); - Value broadcasted = - broadcastScalar(rewriter, loc, scalarConst, resultExtents); - rewriter.replaceOp(op, {broadcasted}); - return success(); - } -}; - -// ----------------------------------------------------------------------------- -// Binary broadcasting op conversions -// ----------------------------------------------------------------------------- - -// Adapter base class for adapting binary elementwise broadcasting ops -// via generic patterns. Implemented as a virtual class in order to reduce -// high fanout template instantiations. -struct BinaryBroadcastingAdaptor { - using BroadcastValues = std::pair; - virtual ~BinaryBroadcastingAdaptor() = default; - virtual StringRef getFromOperationName() = 0; - virtual LogicalResult verifyBroadcastCompatibility( - Operation *op, ArrayRef operands) = 0; - virtual BroadcastValues getFromBroadcastValues(Operation *op, - ArrayRef operands) = 0; - virtual Operation *createTargetOperation(Location loc, Operation *op, - Type resultType, - ArrayRef operands, - BroadcastValues broadcastValues, - OpBuilder &builder) = 0; -}; - -// Adaptor for simple binary elementwise operations which have exactly two -// operands and are matched from src -> target by name. -template -struct SimpleBinaryBroadcastingAdaptor : public BinaryBroadcastingAdaptor { - static BinaryBroadcastingAdaptor &getInstance() { - static SimpleBinaryBroadcastingAdaptor instance; - return instance; - } - StringRef getFromOperationName() override { - return FromOpTy::getOperationName(); - } - LogicalResult verifyBroadcastCompatibility( - Operation *op, ArrayRef operands) override { - auto broadcastDimensions = - llvm::cast(op).getBroadcastDimensions(); - if (broadcastDimensions && - !hlo::isLegalNumpyRankedBroadcast(operands[0], operands[1], - *broadcastDimensions)) { - return failure(); - } - return success(); - } - BroadcastValues getFromBroadcastValues(Operation *op, - ArrayRef operands) override { - assert(operands.size() == 2); - return std::make_pair(operands[0], operands[1]); - } - Operation *createTargetOperation(Location loc, Operation *op, Type resultType, - ArrayRef operands, - BroadcastValues broadcastValues, - OpBuilder &builder) override { - return builder.create(loc, resultType, broadcastValues.first, - broadcastValues.second); - } -}; - -struct CompareBinaryBroadcastingAdaptor : public BinaryBroadcastingAdaptor { - static BinaryBroadcastingAdaptor &getInstance() { - static CompareBinaryBroadcastingAdaptor instance; - return instance; - } - StringRef getFromOperationName() override { - return chlo::BroadcastCompareOp::getOperationName(); - } - LogicalResult verifyBroadcastCompatibility( - Operation *op, ArrayRef operands) override { - auto broadcastDimensions = - llvm::cast(op).getBroadcastDimensions(); - if (broadcastDimensions && - !hlo::isLegalNumpyRankedBroadcast(operands[0], operands[1], - *broadcastDimensions)) { - return failure(); - } - return success(); - } - BroadcastValues getFromBroadcastValues(Operation *op, - ArrayRef operands) override { - chlo::BroadcastCompareOpAdaptor adaptor(operands, op->getAttrDictionary()); - return std::make_pair(adaptor.getLhs(), adaptor.getRhs()); - } - Operation *createTargetOperation(Location loc, Operation *op, Type resultType, - ArrayRef operands, - BroadcastValues broadcastValues, - OpBuilder &builder) override { - chlo::BroadcastCompareOpAdaptor adaptor(operands, op->getAttrDictionary()); - std::optional chloCmpType = adaptor.getCompareType(); - mhlo::ComparisonTypeAttr mhloCmpType; - if (chloCmpType) - mhloCmpType = mhlo::ComparisonTypeAttr::get( - builder.getContext(), *chlo::mhloComparisonType(*chloCmpType)); - return builder.create( - loc, resultType, broadcastValues.first, broadcastValues.second, - *chlo::mhloComparisonDirection(adaptor.getComparisonDirection()), - mhloCmpType); - } -}; - -struct ConvertRankedBroadcastBinaryOp : public ConversionPattern { - ConvertRankedBroadcastBinaryOp(MLIRContext *context, - TypeConverter &typeConverter, - PatternBenefit benefit, - BinaryBroadcastingAdaptor &bcastAdaptor) - : ConversionPattern(typeConverter, bcastAdaptor.getFromOperationName(), - benefit, context), - bcastAdaptor(bcastAdaptor) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - // Only rewrite for statically determinable non-broadcasting cases. - auto bcastOperands = bcastAdaptor.getFromBroadcastValues(op, operands); - Value lhs = bcastOperands.first; - Value rhs = bcastOperands.second; - auto lhsType = llvm::dyn_cast(lhs.getType()); - auto rhsType = llvm::dyn_cast(rhs.getType()); - if (!lhsType || !rhsType) - return rewriter.notifyMatchFailure(op, "not ranked tensors"); - - if (failed(bcastAdaptor.verifyBroadcastCompatibility(op, operands))) { - return rewriter.notifyMatchFailure(op, "not legal broadcasting"); - } - if (!isElementTypeLegalForCodegen(lhsType.getElementType()) || - !isElementTypeLegalForCodegen(rhsType.getElementType())) { - return rewriter.notifyMatchFailure(op, - "not legal element type for codegen"); - } - - // Extract the original extents. - SmallVector lhsOrigExtents; - lhsOrigExtents.reserve(lhsType.getRank()); - appendExtents(rewriter, loc, lhsOrigExtents, lhs, lhsType); - SmallVector rhsOrigExtents; - rhsOrigExtents.reserve(rhsType.getRank()); - appendExtents(rewriter, loc, rhsOrigExtents, rhs, rhsType); - - // Left pad with 1-extents to the result rank. - int resultRank = std::max(lhsType.getRank(), rhsType.getRank()); - SmallVector lhsBcastExtents; - lhsBcastExtents.reserve(resultRank); - SmallVector rhsBcastExtents; - rhsBcastExtents.reserve(resultRank); - padExtents(lhsBcastExtents, resultRank - lhsType.getRank()); - lhsBcastExtents.append(lhsOrigExtents); - padExtents(rhsBcastExtents, resultRank - rhsType.getRank()); - rhsBcastExtents.append(rhsOrigExtents); - - // Compute the result extents. - SmallVector resultExtents(resultRank); - SmallVector isLhsExpansion(resultRank); - SmallVector isRhsExpansion(resultRank); - bool lhsNeedsBroadcast = resultRank != lhsType.getRank(); - bool rhsNeedsBroadcast = resultRank != rhsType.getRank(); - for (int i = 0; i < resultRank; i++) { - auto resultExtent = computeBinaryResultExtent( - rewriter, loc, lhsBcastExtents[i], rhsBcastExtents[i], - isLhsExpansion[i], isRhsExpansion[i]); - if (!resultExtent) { - return rewriter.notifyMatchFailure(op, - "could not compute result extent"); - } - resultExtents[i] = *resultExtent; - if (isLhsExpansion[i]) lhsNeedsBroadcast = true; - if (isRhsExpansion[i]) rhsNeedsBroadcast = true; - } - - // Broadcast the operands. - Value lhsBcast = - lhsNeedsBroadcast - ? broadcast(rewriter, loc, lhs, resultExtents, isLhsExpansion) - : lhs; - Value rhsBcast = - rhsNeedsBroadcast - ? broadcast(rewriter, loc, rhs, resultExtents, isRhsExpansion) - : rhs; - - // TODO: Don't do this result type change. - rewriter.replaceOp(op, - bcastAdaptor - .createTargetOperation( - loc, op, op->getResult(0).getType(), operands, - std::make_pair(lhsBcast, rhsBcast), rewriter) - ->getResults()); - return success(); - } - - BinaryBroadcastingAdaptor &bcastAdaptor; -}; - -// Converts binary ops that statically are determined to not broadcast directly -// to the corresponding mhlo non-broadcasting op. -struct ConvertTrivialNonBroadcastBinaryOp : public ConversionPattern { - ConvertTrivialNonBroadcastBinaryOp(MLIRContext *context, - TypeConverter &typeConverter, - PatternBenefit benefit, - BinaryBroadcastingAdaptor &bcastAdaptor) - : ConversionPattern(typeConverter, bcastAdaptor.getFromOperationName(), - benefit, context), - bcastAdaptor(bcastAdaptor) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Only rewrite for statically determinable non-broadcasting cases. - auto bcastOperands = bcastAdaptor.getFromBroadcastValues(op, operands); - auto lhsType = - llvm::dyn_cast(bcastOperands.first.getType()); - auto rhsType = - llvm::dyn_cast(bcastOperands.second.getType()); - if (!lhsType || !rhsType) - return rewriter.notifyMatchFailure(op, "not ranked tensors"); - if (!isElementTypeLegalForCodegen(lhsType.getElementType()) || - !isElementTypeLegalForCodegen(rhsType.getElementType())) { - return rewriter.notifyMatchFailure(op, - "not legal element type for codegen"); - } - - // Requires rank broadcast. - if (lhsType.getRank() != rhsType.getRank()) - return rewriter.notifyMatchFailure(op, "not same rank"); - // Any dynamic dimension may require broadcasting and requires more - // analysis. - if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) - return rewriter.notifyMatchFailure(op, "not static shapes"); - - for (auto [lhsExtent, rhsExtent] : - llvm::zip_equal(lhsType.getShape(), rhsType.getShape())) { - if (lhsExtent != rhsExtent) { - return rewriter.notifyMatchFailure(op, "not equal extents"); - } - } - - if (failed(bcastAdaptor.verifyBroadcastCompatibility(op, operands))) { - return rewriter.notifyMatchFailure(op, "not legal broadcasting"); - } - - rewriter.replaceOp(op, bcastAdaptor - .createTargetOperation( - op->getLoc(), op, op->getResult(0).getType(), - operands, bcastOperands, rewriter) - ->getResults()); - return success(); - } - - BinaryBroadcastingAdaptor &bcastAdaptor; -}; - -// ----------------------------------------------------------------------------- -// Ternary broadcasting op conversions -// ----------------------------------------------------------------------------- - -// Sepecial case conversion for the BroadcastSelectOp into primitives. -// This follows the new convention of SelectV2, which allows a true ternary -// select (whereas the original definition only supported one broadcasting -// value). -struct ConvertSelectOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - chlo::BroadcastSelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - // Only support ranked operands. - Value pred = adaptor.getPred(); - Value thenValue = adaptor.getOnTrue(); - Value elseValue = adaptor.getOnFalse(); - auto predType = llvm::dyn_cast(pred.getType()); - auto thenType = llvm::dyn_cast(thenValue.getType()); - auto elseType = llvm::dyn_cast(elseValue.getType()); - auto resultType = - llvm::dyn_cast(op.getResult().getType()); - if (!predType || !thenType || !elseType || !resultType) { - return rewriter.notifyMatchFailure(op, "cannot convert unranked tensors"); - } - if (!isElementTypeLegalForCodegen(resultType.getElementType())) { - return rewriter.notifyMatchFailure(op, - "not legal element type for codegen"); - } - - // Short-circuit if all types are statically equal. - if (predType == thenType && predType == elseType) { - // No broadcasting. This includes the 0d -> 0d case. - rewriter.replaceOpWithNewOp(op, resultType, pred, - thenValue, elseValue); - return success(); - } - - // Full ternary broadcast. See ConvertBroadcastBinaryOp for the - // simplified version. - // Extract the original extents. - SmallVector predOrigExtents; - predOrigExtents.reserve(predType.getRank()); - appendExtents(rewriter, loc, predOrigExtents, pred, predType); - SmallVector thenOrigExtents; - thenOrigExtents.reserve(thenType.getRank()); - appendExtents(rewriter, loc, thenOrigExtents, thenValue, thenType); - SmallVector elseOrigExtents; - elseOrigExtents.reserve(elseType.getRank()); - appendExtents(rewriter, loc, elseOrigExtents, elseValue, elseType); - - // Left pad with 1-extents to the result rank. - int resultRank = std::max(std::max(predType.getRank(), thenType.getRank()), - elseType.getRank()); - SmallVector predBcastExtents; - predBcastExtents.reserve(resultRank); - padExtents(predBcastExtents, resultRank - predType.getRank()); - predBcastExtents.append(predOrigExtents); - - SmallVector thenBcastExtents; - thenBcastExtents.reserve(resultRank); - padExtents(thenBcastExtents, resultRank - thenType.getRank()); - thenBcastExtents.append(thenOrigExtents); - - SmallVector elseBcastExtents; - elseBcastExtents.reserve(resultRank); - padExtents(elseBcastExtents, resultRank - elseType.getRank()); - elseBcastExtents.append(elseOrigExtents); - - // Compute the result extents. - SmallVector resultExtents(resultRank); - SmallVector isPredExpansion(resultRank); - SmallVector isThenExpansion(resultRank); - SmallVector isElseExpansion(resultRank); - bool predNeedsBroadcast = resultRank != predType.getRank(); - bool thenNeedsBroadcast = resultRank != thenType.getRank(); - bool elseNeedsBroadcast = resultRank != elseType.getRank(); - for (int i = 0; i < resultRank; i++) { - auto resultExtent = computeTernaryResultExtent( - rewriter, loc, predBcastExtents[i], thenBcastExtents[i], - elseBcastExtents[i], isPredExpansion[i], isThenExpansion[i], - isElseExpansion[i]); - if (!resultExtent) { - return rewriter.notifyMatchFailure(op, - "could not compute result extent"); - } - resultExtents[i] = *resultExtent; - if (isPredExpansion[i]) predNeedsBroadcast = true; - if (isThenExpansion[i]) thenNeedsBroadcast = true; - if (isElseExpansion[i]) elseNeedsBroadcast = true; - } - - // Broadcast all. - Value predBcast = - predNeedsBroadcast - ? broadcast(rewriter, loc, pred, resultExtents, isPredExpansion) - : pred; - Value thenBcast = thenNeedsBroadcast - ? broadcast(rewriter, loc, thenValue, resultExtents, - isThenExpansion) - : thenValue; - Value elseBcast = elseNeedsBroadcast - ? broadcast(rewriter, loc, elseValue, resultExtents, - isElseExpansion) - : elseValue; - - rewriter.replaceOpWithNewOp(op, resultType, predBcast, - thenBcast, elseBcast); - return success(); - } -}; - -// Fallback conversion of mhlo.dynamic_reshape to flow.tensor.reshape. -// This is not the most optimal way to lower most reshapes, and higher -// benefit patterns should match more specific ops and lower them to -// Linalg expanding and contracting reshapes. -// -// Note that as a low-level op, it is assumed that invariants have been -// satisfied externally in some fashion and further checks are not inserted -// at this time. This may need to be re-evaluated as more user-driven -// reshapes are permitted. -struct ConvertDynamicReshapeOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::DynamicReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value input = adaptor.getOperand(); - Value outputShape = adaptor.getOutputShape(); - auto outputShapeType = - llvm::dyn_cast(outputShape.getType()); - auto resultType = llvm::dyn_cast_if_present( - typeConverter->convertType(op.getType())); - if (!outputShapeType || !resultType) { - return rewriter.notifyMatchFailure(op, "not ranked"); - } - SmallVector targetDims; - assert(resultType.getRank() == outputShapeType.getNumElements() && - "mismatched rank"); - for (int i = 0, e = resultType.getRank(); i < e; ++i) { - if (resultType.isDynamicDim(i)) { - Value index = rewriter.create(loc, i); - targetDims.push_back( - rewriter.create(loc, outputShape, index)); - } - } - - SmallVector castedTargetDims; - for (Value dim : targetDims) { - if (llvm::isa(dim.getType())) { - dim = rewriter.create(loc, rewriter.getIndexType(), - dim); - } - castedTargetDims.push_back(dim); - } - - rewriter.replaceOpWithNewOp( - op, resultType, input, castedTargetDims); - return success(); - } -}; - -} // namespace - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir - -void mlir::iree_compiler::MHLO::populateMHLOBroadcastingToLinalgPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet &patterns) { -#define POPULATE_SIMPLE_BCAST(ChloOp, HloOp) \ - patterns.insert( \ - context, typeConverter, 10, \ - SimpleBinaryBroadcastingAdaptor::getInstance()); \ - patterns.insert( \ - context, typeConverter, 5, \ - SimpleBinaryBroadcastingAdaptor::getInstance()); - - POPULATE_SIMPLE_BCAST(chlo::BroadcastAddOp, mhlo::AddOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastAndOp, mhlo::AndOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastAtan2Op, mhlo::Atan2Op); - POPULATE_SIMPLE_BCAST(chlo::BroadcastComplexOp, mhlo::ComplexOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastDivOp, mhlo::DivOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastMaxOp, mhlo::MaxOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastMinOp, mhlo::MinOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastMulOp, mhlo::MulOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastOrOp, mhlo::OrOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastPolygammaOp, chlo::PolygammaOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastPowOp, mhlo::PowOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastRemOp, mhlo::RemOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastShiftLeftOp, mhlo::ShiftLeftOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastShiftRightArithmeticOp, - mhlo::ShiftRightArithmeticOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastShiftRightLogicalOp, - mhlo::ShiftRightLogicalOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastSubOp, mhlo::SubtractOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastXorOp, mhlo::XorOp); - POPULATE_SIMPLE_BCAST(chlo::BroadcastZetaOp, chlo::ZetaOp); - - // Special case for Compare (not a simple signature). - patterns.insert( - context, typeConverter, 10, - CompareBinaryBroadcastingAdaptor::getInstance()); - patterns.insert( - context, typeConverter, 5, - CompareBinaryBroadcastingAdaptor::getInstance()); - - // Other ops. - // TODO: Remove the benefit after it is removed upstream. - patterns.insert(typeConverter, context, 1000); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - - // Make mixed scalar broadcasting of Clamp explicit. - // NOTE: Because we are doing a full conversion out of HLO, we do not use - // the corresponding setup legality, since that explicitly marks clamp as - // conditionally legal. - // TODO: Rename this upstream or find a better place to shove it. - mhlo::populateMaterializeBroadcastsPatterns(context, &patterns); -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt deleted file mode 100644 index faf680c02cac..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/CMakeLists.txt +++ /dev/null @@ -1,112 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# compiler/src/iree/compiler/InputConversion/MHLO/BUILD.bazel # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_tablegen_library( - NAME - PassesIncGen - TD_FILE - "Passes.td" - OUTS - --gen-pass-decls Passes.h.inc -) - -iree_cc_library( - NAME - PassHeaders - HDRS - "PassDetail.h" - "Passes.h" - "Passes.h.inc" - "Rewriters.h" - DEPS - ::PassesIncGen - MLIRPass - MLIRTransforms - PUBLIC -) - -iree_cc_library( - NAME - MHLO - HDRS - "Passes.h" - SRCS - "BroadcastingToLinalgPatterns.cpp" - "ConvertCollectiveOps.cpp" - "ConvertComplexToReal.cpp" - "ConvertMHLOToFlow.cpp" - "ConvertMHLOToFlow.h" - "ConvertMHLOToLinalgExt.cpp" - "ConvertMHLOToStableHLO.cpp" - "FlattenTuplesInCFG.cpp" - "MHLOToLinalgOnTensors.cpp" - "MHLOToMHLOPreprocessing.cpp" - "Passes.cpp" - "VerifyCompilerMHLOInputLegality.cpp" - DEPS - ::PassHeaders - ::PassesIncGen - ChloOps - ChloPasses - IREELinalgExtDialect - IREELinalgExtPasses - LLVMSupport - MLIRAffineDialect - MLIRAffineUtils - MLIRArithDialect - MLIRComplexDialect - MLIRControlFlowDialect - MLIRFuncDialect - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRMLProgramDialect - MLIRMathDialect - MLIRMemRefDialect - MLIRMhloUtils - MLIRPass - MLIRReconcileUnrealizedCasts - MLIRSCFToControlFlow - MLIRSCFTransforms - MLIRShapeDialect - MLIRShapeOpsTransforms - MLIRShapeToStandard - MLIRSupport - MLIRTensorDialect - MLIRTensorUtils - MLIRTransforms - MhloDialect - MhloPasses - MhloShapeOpsToStandard - MhloToLinalg - MhloToStablehlo - MhloToStandard - StablehloBroadcastUtils - StablehloOps - StablehloToMhlo - iree::compiler::Dialect::Flow::IR - iree::compiler::Dialect::Util::IR - iree::compiler::Dialect::Util::Transforms - iree::compiler::InputConversion::Common - iree::compiler::Utils - tensorflow::external_mhlo_includes - DEFINES - "IREE_HAVE_MHLO_INPUT" - PUBLIC -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### -# TODO: For some reason, these dependencies are not being added automatically. -add_dependencies( - iree_compiler_InputConversion_MHLO_PassHeaders - iree_compiler_InputConversion_MHLO_PassesIncGen -) diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp deleted file mode 100644 index b68e843dee09..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertCollectiveOps.cpp +++ /dev/null @@ -1,1010 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h" -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "iree/compiler/InputConversion/MHLO/Rewriters.h" -#include "iree/compiler/Utils/IndexSet.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -// Work in progress. The implementation is planned as several stages. -// -// For the first stage, a few simplifications are made to support simple models. -// -// 1. Single stream with deterministic order of execution -// 2. Single replica group for all collective ops -// 3. Only replicas without partition_id used -// -// These allow us to use a default channel for all communications, and there is -// 1:1 mapping from the replica IDs to the communication ranks. The attribute, -// use_global_device_ids, is always set in this case. -// -// The next stage is to support multiple replica groups. This needs a channel -// creation with a subset of processes, which should have another communication -// among the group. A possible strategy is to have the root process in the group -// (the first rank of the group) creates a channel and the other processes query -// the channel info from the root process. A key-value store using gRPC might be -// a good solution. -// -// Supporting partition_id comes next. This includes the support for various -// mode combinations for cross-replica and cross partition communication. See -// the stablehlo specification for more details about the different modes. - -namespace { - -static std::optional -convertToFlowCollectiveElementType(Type type) { - if (type.isF32()) { - return IREE::Flow::CollectiveElementType::Float32; - } - - if (type.isInteger(32)) { - if (type.isSignedInteger()) { - return IREE::Flow::CollectiveElementType::Sint32; - } else { - return IREE::Flow::CollectiveElementType::Uint32; - } - } - - if (type.isF16()) { - return IREE::Flow::CollectiveElementType::Float16; - } - - if (type.isInteger(8)) { - if (type.isSignedInteger()) { - return IREE::Flow::CollectiveElementType::Sint8; - } else { - return IREE::Flow::CollectiveElementType::Uint8; - } - } - - if (type.isInteger(16)) { - if (type.isSignedInteger()) { - return IREE::Flow::CollectiveElementType::Sint16; - } else { - return IREE::Flow::CollectiveElementType::Uint16; - } - } - - if (type.isBF16()) { - return IREE::Flow::CollectiveElementType::BFloat16; - } - - if (type.isF64()) { - return IREE::Flow::CollectiveElementType::Float64; - } - - if (type.isInteger(64)) { - if (type.isSignedInteger()) { - return IREE::Flow::CollectiveElementType::Sint64; - } else { - return IREE::Flow::CollectiveElementType::Uint64; - } - } - - return std::nullopt; -} - -static std::optional -convertToFlowCollectiveReductionOp(const Operation &op) { - if (isa(op)) { - return IREE::Flow::CollectiveReductionOp::ReductionSum; - } else if (isa(op)) { - return IREE::Flow::CollectiveReductionOp::ReductionProduct; - } else if (isa(op)) { - return IREE::Flow::CollectiveReductionOp::ReductionMinimum; - } else if (isa(op)) { - return IREE::Flow::CollectiveReductionOp::ReductionMaximum; - } else { - // TODO: we may be able to detect an average operation and convert it - // into IREE::Flow::CollectiveReductionOp::ReductionAverage. - return std::nullopt; - } -} - -static IREE::Flow::CollectiveElementTypeAttr getCollectiveElementTypeAttr( - MLIRContext *context, RankedTensorType type) { - std::optional collectiveElemType = - convertToFlowCollectiveElementType(type.getElementType()); - if (!collectiveElemType) { - return IREE::Flow::CollectiveElementTypeAttr(); - } - return IREE::Flow::CollectiveElementTypeAttr::get(context, - *collectiveElemType); -} - -template -static LogicalResult checkCollectiveAttrs(T op, PatternRewriter &rewriter) { - // Note that the channel handle attribute consists of two 64-bit values, - // handle and type. - int64_t handle = - op.getChannelHandle() ? op.getChannelHandleAttr().getHandle() : 0; - if (handle <= 0) { - // When the channel handle attribute is not present, it means the - // handle (a.k.a. channel_id in stablehlo) is 0. When this case is combined - // with `use_global_device_ids=false`, the communication type is - // `cross-replica`, but since there is only one replica group, it is - // effectively the same as `flatten_ids`, which is supported. - if (op.getUseGlobalDeviceIds()) { - return rewriter.notifyMatchFailure( - op, "must not set use_global_device_ids when channel_id <= 0"); - } - } - - return success(); -} - -/// Returns `color` and `key` parameter values indexed by the rank of the -/// participant in |baseChannel|. -/// -/// Examples: -/// (0),(1) => colors=[0,1], keys=[0,0] -/// (0,1),(2,3) => colors=[0,0,1,1], keys=[0,1,0,1] -static std::pair makeSplitColorAndKey(Location loc, - Value baseChannel, - DenseIntElementsAttr groups, - OpBuilder &builder) { - IndexSet indexSet(loc, builder); - Value noColor = indexSet.get(-1); - if (!groups) return std::make_pair(noColor, noColor); - - auto groupsType = llvm::cast(groups.getType()); - assert(groupsType.getRank() == 2); - int64_t rows = groupsType.getShape()[0]; - int64_t cols = groupsType.getShape()[1]; - auto values = groups.getValues(); - - // Find the max rank so we can size our tables. Today the tables are always - // dense starting from rank 0 but we could offset the rank lookup if for - // example all ranks started at some offset. - int64_t maxRank = 0; - for (int64_t rank : values) { - maxRank = std::max(maxRank, rank); - } - - // Table of pairs indexed by rank. -1 is used to indicate that - // a particular rank does not participate in any group. - SmallVector colorTable(maxRank + 1, noColor); - SmallVector keyTable(maxRank + 1, noColor); - - // Sparsely populate table with each rank getting a color/key pair. - // Rows equate to colors (groups) and columns equate to keys (local ranks). - for (int64_t i = 0; i < rows; ++i) { - for (int64_t j = 0; j < cols; ++j) { - const int64_t index = i * cols + j; - int64_t rank = values[index]; - // -1 represents a null value in a group, where the group does not - // fully occupy the space in the row, e.g., [[0,1,2,3], [4,5,-1,-1]]. - if (rank != -1) { - colorTable[rank] = indexSet.get(i); - keyTable[rank] = indexSet.get(j); - } - } - } - - // Lookup the color/key split parameters by indexing into the tables we - // generated from the static op information. - Value rank = builder.create(loc, baseChannel); - Value color = - builder.create(loc, rank, noColor, colorTable); - Value key = - builder.create(loc, rank, noColor, keyTable); - return std::make_pair(color, key); -} - -static DenseIntElementsAttr convertToRankGroupsByCrossReplica( - DenseIntElementsAttr replicaGroups, int32_t numPartitions, - OpBuilder &builder) { - if (numPartitions <= 1) { - // Treat as a single partition. - return replicaGroups; - } - - auto groupsType = llvm::cast(replicaGroups.getType()); - assert(groupsType.getRank() == 2); - int rows = groupsType.getShape()[0]; - int cols = groupsType.getShape()[1]; - auto values = replicaGroups.getValues(); - SmallVector newValues; - - // The number of groups is (rows * numPartitions). - for (int i = 0; i < rows; ++i) { - for (int p = 0; p < numPartitions; ++p) { - // Each group starts here. The group size is the same as the column size. - for (int j = 0; j < cols; ++j) { - const int index = i * cols + j; - const int64_t replicaId = values[index]; - const int64_t value = - (replicaId == -1) ? -1 : replicaId * numPartitions + p; - newValues.push_back(builder.getI64IntegerAttr(value)); - } - } - } - - auto type = - RankedTensorType::get({rows * numPartitions, cols}, builder.getI64Type()); - return DenseIntElementsAttr::get(type, newValues); -} - -static DenseIntElementsAttr convertToRankGroupsByCrossPartition( - DenseIntElementsAttr partitionGroups, int32_t numReplicas, - OpBuilder &builder) { - if (numReplicas <= 1) { - // Treat as a single replica. - return partitionGroups; - } - - auto groupsType = llvm::cast(partitionGroups.getType()); - assert(groupsType.getRank() == 2); - int rows = groupsType.getShape()[0]; - int cols = groupsType.getShape()[1]; - auto values = partitionGroups.getValues(); - SmallVector newValues; - // partitionGroups must have unique elements and cover all partition_ids, so - // numPartitions == values.size(). - int64_t numPartitions = values.size(); - - // The number of groups is (rows * numReplicas). - for (int i = 0; i < rows; ++i) { - for (int r = 0; r < numReplicas; ++r) { - // Each group starts here. The group size is the same as the column size. - for (int j = 0; j < cols; ++j) { - const int index = i * cols + j; - const int64_t partitionId = values[index]; - const int64_t value = - (partitionId == -1) ? -1 : r * numPartitions + partitionId; - - newValues.push_back(builder.getI64IntegerAttr(value)); - } - } - } - - auto type = - RankedTensorType::get({rows * numReplicas, cols}, builder.getI64Type()); - return DenseIntElementsAttr::get(type, newValues); -} - -static DenseIntElementsAttr convertToRankGroupsByCrossReplicaAndPartition( - DenseIntElementsAttr replicaGroups, int32_t numPartitions, - OpBuilder &builder) { - if (numPartitions <= 1) { - // Treat as a single partition. - return replicaGroups; - } - - auto groupsType = llvm::cast(replicaGroups.getType()); - assert(groupsType.getRank() == 2); - int rows = groupsType.getShape()[0]; - int cols = groupsType.getShape()[1]; - auto values = replicaGroups.getValues(); - SmallVector newValues; - - // The number of groups is the same as the number of rows. - for (int i = 0; i < rows; ++i) { - // Each group starts here. The group size is (numPartitions * cols). - for (int p = 0; p < numPartitions; ++p) { - for (int j = 0; j < cols; ++j) { - const int index = i * cols + j; - const int64_t replicaId = values[index]; - const int64_t value = - (replicaId == -1) ? -1 : replicaId * numPartitions + p; - newValues.push_back(builder.getI64IntegerAttr(value)); - } - } - } - auto type = - RankedTensorType::get({rows, numPartitions * cols}, builder.getI64Type()); - return DenseIntElementsAttr::get(type, newValues); -} - -// The collective group mode determines how the StableHLO process grid is split -// into independent process groups. -enum class CollectiveOpGroupMode { - // Only cross-replica communications happen within each process group. - CrossReplica, - // Only cross-partition communications happen within each process group. - CrossPartition, - // Both cross-replica and cross-partition communications may happen within - // each process group. - CrossReplicaAndPartition, - // A list of flattened process ids is used to specify the process groups. - FlattenedIds, -}; - -// clang-format off -// +--------------------+-----------+--------------------+--------------------------+ -// | Collective | channelId | useGlobalDeviceIds | Collective Group Mode | -// +--------------------+-----------+--------------------+--------------------------+ -// | all_gather | <= 0 | false | CrossReplica | -// | | > 0 | false | CrossReplicaAndPartition | -// | | > 0 | true | FlattenedIds | -// +--------------------+-----------+--------------------+--------------------------+ -// | all_reduce | <= 0 | false | CrossReplica | -// | | > 0 | false | CrossReplicaAndPartition | -// | | > 0 | true | FlattenedIds | -// +--------------------+-----------+--------------------+--------------------------+ -// | all_to_all | <= 0 | | CrossReplica | -// | | > 0 | | CrossPartition | -// +--------------------+-----------+--------------------+--------------------------+ -// | collective_permute | <= 0 | | CrossReplica | -// | | > 0 | | CrossPartition | -// +--------------------+-----------+--------------------+--------------------------+ -// | reduce_scatter | <= 0 | false | CrossReplica | -// | | > 0 | false | CrossReplicaAndPartition | -// | | > 0 | true | FlattenedIds | -// +--------------------+-----------+--------------------+--------------------------+ -// clang-format on -static CollectiveOpGroupMode getCollectiveOpGroupMode( - int64_t channelId, std::optional useGlobalDeviceIds) { - if (channelId <= 0) { - assert(!useGlobalDeviceIds.has_value() || !*useGlobalDeviceIds); - return CollectiveOpGroupMode::CrossReplica; - } else { - if (!useGlobalDeviceIds.has_value()) { - return CollectiveOpGroupMode::CrossPartition; - } else if (!*useGlobalDeviceIds) { - return CollectiveOpGroupMode::CrossReplicaAndPartition; - } else { - return CollectiveOpGroupMode::FlattenedIds; - } - } -} - -/// Creates a channel matching the given |channelHandleAttr| scoped to the -/// requested group. -static Value createChannelWithGroupInfo( - Location loc, mhlo::ChannelHandleAttr channelHandleAttr, - int32_t numReplicas, int32_t numPartitions, - DenseIntElementsAttr replicaGroups, std::optional useGlobalDeviceIds, - OpBuilder &builder) { - // Set numPartitions, numReplicas to 1 if not set by the user. - if (numPartitions == -1) numPartitions = 1; - if (numReplicas == -1) numReplicas = 1; - - // Base channel that may be split by the group info. - Value baseChannel = - builder.create(loc, /*group=*/StringAttr{}); - - // No need to split if there is a single group. - ShapedType replicaGroupType = replicaGroups.getType(); - assert(replicaGroupType.getRank() == 2); - if (numPartitions == 1 && replicaGroupType.getDimSize(0) == 1) { - return baseChannel; - } - - // Convert replica_groups into flattened IDs depending on group mode. - DenseIntElementsAttr rankGroups; - int64_t channelId = channelHandleAttr ? channelHandleAttr.getHandle() : 0; - CollectiveOpGroupMode mode = - getCollectiveOpGroupMode(channelId, useGlobalDeviceIds); - if (mode == CollectiveOpGroupMode::CrossReplica) { - rankGroups = convertToRankGroupsByCrossReplica(replicaGroups, numPartitions, - builder); - } else if (mode == CollectiveOpGroupMode::CrossPartition) { - rankGroups = convertToRankGroupsByCrossPartition(replicaGroups, numReplicas, - builder); - } else if (mode == CollectiveOpGroupMode::CrossReplicaAndPartition) { - rankGroups = convertToRankGroupsByCrossReplicaAndPartition( - replicaGroups, numPartitions, builder); - } else if (mode == CollectiveOpGroupMode::FlattenedIds) { - // already flattened. - rankGroups = replicaGroups; - } - - // Construct lookups for color and key split parameters. - // Note that `replica_groups` can be interpreted in multiple ways based on the - // other attributes. - auto [color, key] = - makeSplitColorAndKey(loc, baseChannel, rankGroups, builder); - - // Split the channel. Note that this is an expensive operation. - return builder.create(loc, baseChannel, color, - key); -} - -static Value emitTranspose(ConversionPatternRewriter &rewriter, Location loc, - Value input, int64_t srcDim, int64_t dstDim) { - // Creates a transpose op that swaps dimensions srcDim and dstDim in the - // input. - auto inputType = cast(input.getType()); - SmallVector inputShape(inputType.getShape()); - SmallVector permutation = - llvm::to_vector(llvm::seq(0, inputShape.size())); - std::swap(permutation[srcDim], permutation[dstDim]); - std::swap(inputShape[srcDim], inputShape[dstDim]); - DenseIntElementsAttr permutationAttr = rewriter.getI64VectorAttr(permutation); - return rewriter.create( - loc, RankedTensorType::get(inputShape, inputType.getElementType()), input, - permutationAttr); -} - -static int32_t getNumReplicas(ModuleOp moduleOp) { - if (!moduleOp) { - return -1; - } - if (auto numReplicasAttr = - moduleOp->getAttrOfType("mhlo.num_replicas")) { - return numReplicasAttr.getInt(); - } else { - return -1; - } -} - -static int32_t getNumPartitions(ModuleOp moduleOp) { - if (!moduleOp) { - return -1; - } - if (auto numPartitionsAttr = - moduleOp->getAttrOfType("mhlo.num_partitions")) { - return numPartitionsAttr.getInt(); - } else { - return -1; - } -} - -} // namespace - -/// Converts mhlo.partition_id to (flow.channel.rank % numPartitions) -struct PartitionIdOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::PartitionIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - // PartitionId = rank % numPartitions - auto moduleOp = op->getParentOfType(); - int32_t numPartitions = getNumPartitions(moduleOp); - Value value; - if (numPartitions <= 1) { - value = rewriter.create(loc, 0); - } else { - auto channel = rewriter.create( - loc, /*group=*/StringAttr{}); - Value rank = rewriter.create(loc, channel); - auto cst = - rewriter.create(loc, - /*value=*/numPartitions); - value = rewriter.create(loc, rank, cst); - } - auto resultType = - llvm::cast(op.getType()); // tensor - auto elemType = resultType.getElementType(); - // index -> ui32 - auto rankElem = rewriter.create(loc, elemType, value); - // tensor - auto rankTensor = rewriter.create( - loc, resultType, rankElem.getResult()); - rewriter.replaceOp(op, rankTensor.getResult()); - return success(); - } -}; - -/// Converts mhlo.replica_id to floor_div(flow.channel.rank, numPartitions) -struct ReplicaIdOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ReplicaIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto channel = rewriter.create( - loc, /*group=*/StringAttr{}); - Value rank = rewriter.create(loc, channel); - - // ReplicaId = floor_div(rank, numPartitions) - auto moduleOp = op->getParentOfType(); - int32_t numPartitions = getNumPartitions(moduleOp); - auto cst = rewriter.create(loc, - /*value=*/numPartitions); - if (numPartitions > 1) { - rank = rewriter.create(loc, rank, cst); - } - - auto resultType = - llvm::cast(op.getType()); // tensor - auto elemType = resultType.getElementType(); - // index -> ui32 - auto rankElem = rewriter.create(loc, elemType, rank); - // tensor - auto rankTensor = rewriter.create( - loc, resultType, rankElem.getResult()); - rewriter.replaceOp(op, rankTensor.getResult()); - return success(); - } -}; - -struct AllGatherOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::AllGatherOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (checkCollectiveAttrs(op, rewriter).failed()) { - return failure(); - } - - auto loc = op.getLoc(); - - auto moduleOp = op->getParentOfType(); - int32_t numReplicas = getNumReplicas(moduleOp); - int32_t numPartitions = getNumPartitions(moduleOp); - - // Create a channel. - Value channel = createChannelWithGroupInfo( - loc, op.getChannelHandleAttr(), numReplicas, numPartitions, - op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter); - - // Get the collective element type attribute. - auto resultType = llvm::cast(op.getResult().getType()); - IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = - getCollectiveElementTypeAttr(op.getContext(), resultType); - if (!elementTypeAttr) { - return rewriter.notifyMatchFailure( - op, "unsupported element type for collective op"); - } - uint64_t allGatherDim = op.getAllGatherDim(); - Value gatherInput = adaptor.getOperand(); - SmallVector gatherResultShape(resultType.getShape()); - - // When all_gather_dim != 0, we need to transpose between 0 and - // all_gather_dim before and after the flow allgather op. - const bool requiresTranspose = allGatherDim != 0; - if (requiresTranspose) { - std::swap(gatherResultShape[0], gatherResultShape[allGatherDim]); - gatherInput = emitTranspose(rewriter, loc, gatherInput, 0, allGatherDim); - } - - // Create an empty tensor for the result. - Value target = rewriter.create( - loc, gatherResultShape, - getElementTypeOrSelf(adaptor.getOperand().getType())); - Value gatherResult = rewriter.create( - op.getLoc(), elementTypeAttr, target, gatherInput, channel); - - if (requiresTranspose) { - gatherResult = - emitTranspose(rewriter, loc, gatherResult, allGatherDim, 0); - } - - rewriter.replaceOp(op, gatherResult); - return success(); - } -}; - -struct AllReduceOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::AllReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (checkCollectiveAttrs(op, rewriter).failed()) { - return failure(); - } - - // Only single elementwise op is supported. - Block &block = op.getComputation().front(); - - if (block.empty() || llvm::hasSingleElement(block) || - std::next(block.begin(), 2) != block.end()) { - return rewriter.notifyMatchFailure(op, "must have two ops in the block"); - } - - if (block.getNumArguments() != 2) { - return rewriter.notifyMatchFailure(op, "must have two block args"); - } - - Operation &op1 = block.front(); - Operation &op2 = *(++block.begin()); - - if (op1.getNumResults() != 1 || - !op1.hasTrait<::mlir::OpTrait::Elementwise>()) { - return rewriter.notifyMatchFailure(op, "must have elementwise trait"); - } - - // Convert mhlo reduction op into flow reduction op. - std::optional redOp = - convertToFlowCollectiveReductionOp(op1); - if (!redOp) { - return rewriter.notifyMatchFailure(op, "unsupported operation."); - } - - if (!op2.mightHaveTrait()) { - return rewriter.notifyMatchFailure(op, - "the second op must be a terminator"); - } - - auto loc = op.getLoc(); - - auto moduleOp = op->getParentOfType(); - int32_t numReplicas = getNumReplicas(moduleOp); - int32_t numPartitions = getNumPartitions(moduleOp); - - // Create a channel. - Value channel = createChannelWithGroupInfo( - loc, op.getChannelHandleAttr(), numReplicas, numPartitions, - op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter); - - // Convert mhlo reduction op into flow reduction op. - auto reductionOpAttr = - IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp); - - auto inputType = llvm::cast(op.getOperand().getType()); - - // Get the collective element type attribute. - IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = - getCollectiveElementTypeAttr(op.getContext(), inputType); - if (!elementTypeAttr) { - return rewriter.notifyMatchFailure(op, "unsupported input type"); - } - - // Create an empty tensor for the result. - ArrayRef inputShape = inputType.getShape(); - Value target = rewriter.create( - loc, inputShape, getElementTypeOrSelf(adaptor.getOperand().getType())); - auto allReduceOp = rewriter.create( - op.getLoc(), reductionOpAttr, elementTypeAttr, target, - adaptor.getOperand(), channel); - rewriter.replaceOp(op, allReduceOp.getResult()); - return success(); - } -}; - -static Value splitAndConcatForAllToAll(ConversionPatternRewriter &rewriter, - Location loc, Value input, - uint64_t splitDim, uint64_t concatDim, - uint64_t splitCount) { - // Helper function to rearrange data after all-to-all. - auto inputType = llvm::cast(input.getType()); - ArrayRef inputShape = inputType.getShape(); - - // Reshape - const int64_t rank = inputShape.size(); - llvm::SmallVector newShape; - for (int64_t i = 0; i < rank; ++i) { - if (i != splitDim) { - newShape.push_back(inputShape[i]); - continue; - } - newShape.push_back(splitCount); - newShape.push_back(inputShape[i] / splitCount); - } - Value result = rewriter.create( - loc, RankedTensorType::get(newShape, inputType.getElementType()), input); - - // Transpose - SmallVector permutation; - permutation.reserve(rank + 1); - for (int64_t i = 0; i < rank; ++i) { - int64_t dimAfterReshape = i >= splitDim ? i + 1 : i; - if (i == concatDim) { - permutation.push_back(splitDim); - } - permutation.push_back(dimAfterReshape); - } - SmallVector transposeResultShape; - transposeResultShape.reserve(rank + 1); - for (int64_t i = 0; i < rank + 1; ++i) - transposeResultShape.push_back(newShape[permutation[i]]); - result = rewriter.create( - loc, - RankedTensorType::get(transposeResultShape, inputType.getElementType()), - result, rewriter.getI64VectorAttr(permutation)); - - // Reshape - llvm::SmallVector finalShape(inputShape); - finalShape[concatDim] *= splitCount; - finalShape[splitDim] /= splitCount; - return rewriter.create( - loc, RankedTensorType::get(finalShape, inputType.getElementType()), - result); -} - -struct AllToAllOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::AllToAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto moduleOp = op->getParentOfType(); - int32_t numReplicas = getNumReplicas(moduleOp); - int32_t numPartitions = getNumPartitions(moduleOp); - - // Create a channel. - Value channel = createChannelWithGroupInfo( - loc, op.getChannelHandleAttr(), numReplicas, numPartitions, - op.getReplicaGroups(), /*useGlobalDeviceIds=*/std::nullopt, rewriter); - - // Get the collective element type attribute. - auto resultType = llvm::cast(op.getResult(0).getType()); - IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = - getCollectiveElementTypeAttr(op.getContext(), resultType); - if (!elementTypeAttr) { - return rewriter.notifyMatchFailure( - op, "unsupported element type for collective op"); - } - if (op.getNumOperands() != 1) { - return rewriter.notifyMatchFailure(op, - "tuple all-to-all is not supported"); - } - if (!op.getSplitDimension() || !op.getConcatDimension() || - !op.getSplitCount()) { - return rewriter.notifyMatchFailure( - op, - "split_dimension, concat_dimension, and split_count must be present " - "for array all-to-all"); - } - - uint64_t splitDim = *op.getSplitDimension(); - uint64_t concatDim = *op.getConcatDimension(); - uint64_t splitCount = *op.getSplitCount(); - Value allToAllInput = adaptor.getOperand().front(); - - // When splitDim != 0, we need to transpose splitDim to 0 before and after - // the all-to-all. - const bool requiresTranspose = splitDim != 0; - // When the concatDim != splitDim, we need to rearrange the data after the - // all-to-all. - const bool requiresSplitAndConcat = concatDim != splitDim; - if (requiresTranspose) { - allToAllInput = emitTranspose(rewriter, loc, allToAllInput, 0, splitDim); - } - - // Create an empty tensor for the result. - Value target = rewriter.create( - loc, cast(allToAllInput.getType()).getShape(), - getElementTypeOrSelf(allToAllInput.getType())); - // Create all-to-all. - Value allToAllResult = rewriter.create( - op.getLoc(), elementTypeAttr, target, allToAllInput, channel); - - if (requiresTranspose) { - allToAllResult = - emitTranspose(rewriter, loc, allToAllResult, splitDim, 0); - } - if (requiresSplitAndConcat) { - allToAllResult = splitAndConcatForAllToAll( - rewriter, loc, allToAllResult, splitDim, concatDim, splitCount); - } - - rewriter.replaceOp(op, allToAllResult); - return success(); - } -}; - -struct ReduceScatterOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ReduceScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (checkCollectiveAttrs(op, rewriter).failed()) { - return failure(); - } - - // Only single elementwise op is supported. - Block &block = op.getComputation().front(); - - if (block.empty() || llvm::hasSingleElement(block) || - std::next(block.begin(), 2) != block.end()) { - return rewriter.notifyMatchFailure(op, "must have two ops in the block"); - } - - if (block.getNumArguments() != 2) { - return rewriter.notifyMatchFailure(op, "must have two block args"); - } - - Operation &op1 = block.front(); - Operation &op2 = *(++block.begin()); - - if (op1.getNumResults() != 1 || - !op1.hasTrait<::mlir::OpTrait::Elementwise>()) { - return rewriter.notifyMatchFailure(op, "must have elementwise trait"); - } - - // Convert mhlo reduction op into flow reduction op. - std::optional redOp = - convertToFlowCollectiveReductionOp(op1); - if (!redOp) { - return rewriter.notifyMatchFailure(op, "unsupported operation."); - } - - if (!op2.mightHaveTrait()) { - return rewriter.notifyMatchFailure(op, - "the second op must be a terminator"); - } - - // Convert mhlo reduction op into flow reduction op. - auto reductionOpAttr = - IREE::Flow::CollectiveReductionOpAttr::get(op.getContext(), *redOp); - - auto loc = op.getLoc(); - - auto moduleOp = op->getParentOfType(); - int32_t numReplicas = getNumReplicas(moduleOp); - int32_t numPartitions = getNumPartitions(moduleOp); - - // Create a channel. - Value channel = createChannelWithGroupInfo( - loc, op.getChannelHandleAttr(), numReplicas, numPartitions, - op.getReplicaGroups(), op.getUseGlobalDeviceIds(), rewriter); - - // Get the collective element type attribute. - auto resultType = llvm::cast(op.getResult().getType()); - IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = - getCollectiveElementTypeAttr(op.getContext(), resultType); - if (!elementTypeAttr) { - return rewriter.notifyMatchFailure(op, "unsupported input type"); - } - - // When scatter_dimension != 0, we need to transpose between 0 and - // scatter_dimension before and after the flow reduce_scatter op. - uint64_t scatterDim = op.getScatterDimension(); - auto inputType = llvm::cast(op.getOperand().getType()); - SmallVector reduceInputShape(inputType.getShape()); - Value reduceInput = adaptor.getOperand(); - DenseIntElementsAttr permutationAttr; - - SmallVector scatterResultShape(resultType.getShape()); - auto elemType = getElementTypeOrSelf(reduceInput.getType()); - - if (scatterDim != 0) { - SmallVector permutation = - llvm::to_vector(llvm::seq(0, scatterResultShape.size())); - std::swap(permutation[0], permutation[scatterDim]); - permutationAttr = rewriter.getI64VectorAttr(permutation); - std::swap(reduceInputShape[0], reduceInputShape[scatterDim]); - std::swap(scatterResultShape[0], scatterResultShape[scatterDim]); - // Transpose the input. - reduceInput = rewriter.create( - loc, RankedTensorType::get(reduceInputShape, elemType), reduceInput, - permutationAttr); - } - - // Create an empty tensor for the result. - Value target = - rewriter.create(loc, scatterResultShape, elemType); - Value scatterResult = - rewriter.create( - op.getLoc(), reductionOpAttr, elementTypeAttr, target, reduceInput, - channel); - - if (scatterDim != 0) { - scatterResult = rewriter.create( - loc, resultType, scatterResult, permutationAttr); - } - - rewriter.replaceOp(op, scatterResult); - return success(); - } -}; - -struct CollectivePermuteOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::CollectivePermuteOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto moduleOp = op->getParentOfType(); - int32_t numReplicas = getNumReplicas(moduleOp); - int32_t numPartitions = getNumPartitions(moduleOp); - - // Replica group consists of all partitions or all replicas depending on the - // mode. If numPartitions is not set, a single group will result in the base - // channel being used. - int64_t channelId = - op.getChannelHandleAttr() ? op.getChannelHandleAttr().getHandle() : 0; - auto mode = getCollectiveOpGroupMode(channelId, - /*useGlobalDeviceIds=*/std::nullopt); - int64_t numParticipants = mode == CollectiveOpGroupMode::CrossReplica - ? numReplicas - : numPartitions; - if (numParticipants == -1) numParticipants = 1; - SmallVector replicaGroups; - for (int64_t i = 0; i < numParticipants; ++i) { - replicaGroups.push_back(rewriter.getI64IntegerAttr(i)); - } - auto type = - RankedTensorType::get({1, numParticipants}, rewriter.getI64Type()); - auto replicaGroupsAttr = DenseIntElementsAttr::get(type, replicaGroups); - - // Create a channel. - Value channel = createChannelWithGroupInfo( - loc, op.getChannelHandleAttr(), numReplicas, numPartitions, - replicaGroupsAttr, /*useGlobalDeviceIds=*/std::nullopt, rewriter); - - auto inputType = llvm::cast(op.getOperand().getType()); - - // Get the collective element type attribute. - IREE::Flow::CollectiveElementTypeAttr elementTypeAttr = - getCollectiveElementTypeAttr(op.getContext(), inputType); - if (!elementTypeAttr) { - return rewriter.notifyMatchFailure(op, "unsupported input type"); - } - - // Convert source target pairs into a constant table that can be indexed by - // rank to find which ids that rank should send to and recv from, or -1 for - // no send/recv. - DenseIntElementsAttr sourceTargetPairs = op.getSourceTargetPairs(); - llvm::DenseMap sendMap, recvMap; - auto values = sourceTargetPairs.getValues(); - // Find the max rank so we can size our tables. - int64_t maxRank = 0; - for (auto rank : values) { - if (rank > std::numeric_limits::max()) { - return rewriter.notifyMatchFailure( - op, "source or target id exceeds maximum value of 16-bit integer"); - } - maxRank = std::max(maxRank, rank); - } - // Create tables. -1 is used to indicate no send or recv. - IndexSet indexSet(loc, rewriter); - Value noSendOrRecv = indexSet.get(-1); - SmallVector sendTable(maxRank + 1, noSendOrRecv); - SmallVector recvTable(maxRank + 1, noSendOrRecv); - for (auto i = values.begin(); i != values.end(); ++i) { - int64_t source = (*i); - int64_t target = (*++i); - sendTable[source] = indexSet.get(target); - recvTable[target] = indexSet.get(source); - } - // Look up the local send/recv values using rank. - Value rank = - rewriter.create(loc, channel).getResult(); - Value send = rewriter.create(loc, rank, noSendOrRecv, - sendTable); - Value recv = rewriter.create(loc, rank, noSendOrRecv, - recvTable); - - // Create an empty tensor for the result. - auto input = adaptor.getOperand(); - ArrayRef inputShape = inputType.getShape(); - Value target = rewriter.create( - loc, inputShape, getElementTypeOrSelf(input.getType())); - auto collectiveSendRecvOp = - rewriter.create( - op.getLoc(), elementTypeAttr, target, input, channel, send, recv); - - rewriter.replaceOp(op, collectiveSendRecvOp.getResult()); - return success(); - } -}; - -void populateMHLOCollectiveOpsConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp deleted file mode 100644 index 5052338cba1b..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp +++ /dev/null @@ -1,535 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "iree/compiler/InputConversion/MHLO/Rewriters.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -inline std::optional chloComparisonDirection( - mhlo::ComparisonDirection value) { - switch (value) { - case mhlo::ComparisonDirection::EQ: - return chlo::ComparisonDirection::EQ; - case mhlo::ComparisonDirection::NE: - return chlo::ComparisonDirection::NE; - case mhlo::ComparisonDirection::GE: - return chlo::ComparisonDirection::GE; - case mhlo::ComparisonDirection::GT: - return chlo::ComparisonDirection::GT; - case mhlo::ComparisonDirection::LE: - return chlo::ComparisonDirection::LE; - case mhlo::ComparisonDirection::LT: - return chlo::ComparisonDirection::LT; - default: - return {}; - } -} - -inline std::optional chloComparisonType( - mhlo::ComparisonType value) { - switch (value) { - case mhlo::ComparisonType::NOTYPE: - return chlo::ComparisonType::NOTYPE; - case mhlo::ComparisonType::FLOAT: - return chlo::ComparisonType::FLOAT; - case mhlo::ComparisonType::TOTALORDER: - return chlo::ComparisonType::TOTALORDER; - case mhlo::ComparisonType::SIGNED: - return chlo::ComparisonType::SIGNED; - case mhlo::ComparisonType::UNSIGNED: - return chlo::ComparisonType::UNSIGNED; - default: - return {}; - } -} - -bool isComplexTensor(Value v) { - if (auto tt = llvm::dyn_cast(v.getType())) { - return llvm::isa(tt.getElementType()); - } - return false; -} - -Type convertComplexTensorTypeToReal(Type complexTensorType) { - auto newElementType = - llvm::cast( - complexTensorType.cast().getElementType()) - .getElementType(); - if (auto tt = llvm::dyn_cast(complexTensorType)) { - return RankedTensorType::get(tt.getShape(), newElementType, - tt.getEncoding()); - } else if (auto tt = llvm::dyn_cast(complexTensorType)) { - return UnrankedTensorType::get(newElementType); - } - assert(false && "unknown TensorType subclass"); - return Type(); -} - -// Add and subtraction are elementwise and can be distributed across the real -// and imaginary components. -template -struct ConvertAddSubOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static Value createOp(OpBuilder &b, mhlo::AddOp op, Value lhs, Value rhs) { - return b.create(op.getLoc(), lhs, rhs); - } - static Value createOp(OpBuilder &b, mhlo::SubtractOp op, Value lhs, - Value rhs) { - return b.create(op.getLoc(), lhs, rhs); - } - static Value createOp(OpBuilder &b, chlo::BroadcastAddOp op, Value lhs, - Value rhs) { - return b.create(op.getLoc(), lhs, rhs, nullptr); - } - static Value createOp(OpBuilder &b, chlo::BroadcastSubOp op, Value lhs, - Value rhs) { - return b.create(op.getLoc(), lhs, rhs, nullptr); - } - - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - if (!isComplexTensor(adaptor.getLhs()) || - !isComplexTensor(adaptor.getRhs())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - - Value real = - createOp(rewriter, op, - rewriter.createOrFold(loc, adaptor.getLhs()), - rewriter.createOrFold(loc, adaptor.getRhs())); - Value imag = - createOp(rewriter, op, - rewriter.createOrFold(loc, adaptor.getLhs()), - rewriter.createOrFold(loc, adaptor.getRhs())); - Value result = rewriter.create(loc, real, imag); - rewriter.replaceOp(op, result); - return success(); - } -}; - -// Complex multiplication results in a cross product multiplication between the -// real and imaginary components such that: -// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag -// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag -template -struct ConvertMulOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - MulOpTy op, typename MulOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getLhs()) || - !isComplexTensor(adaptor.getRhs())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - - auto lhsReal = rewriter.createOrFold(loc, adaptor.getLhs()); - auto lhsImag = rewriter.createOrFold(loc, adaptor.getLhs()); - auto rhsReal = rewriter.createOrFold(loc, adaptor.getRhs()); - auto rhsImag = rewriter.createOrFold(loc, adaptor.getRhs()); - - auto realComponent = rewriter.create( - loc, - rewriter.create(loc, lhsReal, rhsReal, - /*broadcast_dimensions=*/nullptr), - rewriter.create( - loc, lhsImag, rhsImag, /*broadcast_dimensions=*/nullptr)); - auto imagComponent = rewriter.create( - loc, - rewriter.create(loc, lhsReal, rhsImag, - /*broadcast_dimensions=*/nullptr), - rewriter.create( - loc, lhsImag, rhsReal, /*broadcast_dimensions=*/nullptr)); - Value result = rewriter.createOrFold(loc, realComponent, - imagComponent); - rewriter.replaceOp(op, result); - return success(); - } -}; - -// Division is performed by normalizing the denominator by multiplying by the -// conjugate of the rhs. -// numerator = lhs * conj(rhs) -// denominator = rhs * conj(rhs) -template -struct ConvertDivOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - DivOpTy op, typename DivOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getLhs()) || - !isComplexTensor(adaptor.getRhs())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - auto rhsReal = rewriter.createOrFold(loc, rhs); - auto rhsImag = rewriter.createOrFold(loc, rhs); - - Value conj = rewriter.createOrFold( - loc, rhsReal, rewriter.create(loc, rhsImag)); - Value complexNumerator = rewriter.create( - loc, lhs, conj, /*broadcast_dimensions=*/nullptr); - Value denominator = rewriter.create( - loc, rewriter.create(loc, rhsReal, rhsReal), - rewriter.create(loc, rhsImag, rhsImag)); - - Value realComponent = rewriter.create( - loc, rewriter.create(loc, complexNumerator), denominator, - /*broadcast_dimensions=*/nullptr); - Value imagComponent = rewriter.create( - loc, rewriter.create(loc, complexNumerator), denominator, - /*broadcast_dimensions=*/nullptr); - - Value result = rewriter.createOrFold(loc, realComponent, - imagComponent); - rewriter.replaceOp(op, result); - return success(); - } -}; - -// Absolute value is evaluated as: -// result = sqrt(val.real * val.real + val.imag * val.imag) -struct ConvertAbsOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::AbsOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getOperand())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - - auto operandReal = - rewriter.createOrFold(loc, adaptor.getOperand()); - auto operandImag = - rewriter.createOrFold(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp( - op, - rewriter.create( - loc, rewriter.create(loc, operandReal, operandReal), - rewriter.create(loc, operandImag, operandImag))); - return success(); - } -}; - -// Exponential can be lowered to an exponential on the real component and a -// sum of sinusoids of the imaginary component, which equates to a normal -// exponential operator multiplied by Euler's formula. -// -// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b)) -struct ConvertExpOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::ExpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getOperand())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - - auto operandReal = rewriter.create(loc, adaptor.getOperand()); - auto operandImag = rewriter.create(loc, adaptor.getOperand()); - - Value expReal = rewriter.create(loc, operandReal); - Value result = rewriter.createOrFold( - loc, - rewriter.create( - loc, rewriter.create(loc, operandImag), expReal), - rewriter.create( - loc, rewriter.create(loc, operandImag), expReal)); - rewriter.replaceOp(op, result); - return success(); - } -}; - -template -struct ConvertCHLOCompareOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - ConvertCHLOCompareOp(TypeConverter &typeConverter, MLIRContext *context, - chlo::ComparisonDirection direction) - : OpConversionPattern(typeConverter, context), - direction(direction) {} - - LogicalResult matchAndRewrite( - CompareOpTy op, typename CompareOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getLhs()) || - !isComplexTensor(adaptor.getRhs())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - if (direction != op.getComparisonDirection()) { - return rewriter.notifyMatchFailure(op, "not matching direction"); - } - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - auto lhsReal = rewriter.createOrFold(loc, lhs); - auto lhsImag = rewriter.createOrFold(loc, lhs); - auto rhsReal = rewriter.createOrFold(loc, rhs); - auto rhsImag = rewriter.createOrFold(loc, rhs); - - rewriter.replaceOpWithNewOp( - op, - rewriter.create( - loc, lhsReal, rhsReal, - /*broadcast_dimensions=*/nullptr, - adaptor.getComparisonDirectionAttr(), adaptor.getCompareTypeAttr()), - rewriter.create( - loc, lhsImag, rhsImag, - /*broadcast_dimensions=*/nullptr, - adaptor.getComparisonDirectionAttr(), - adaptor.getCompareTypeAttr())); - - return success(); - } - - chlo::ComparisonDirection direction; -}; - -template -struct ConvertMHLOCompareOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - ConvertMHLOCompareOp(TypeConverter &typeConverter, MLIRContext *context, - mhlo::ComparisonDirection direction) - : OpConversionPattern(typeConverter, context), - direction(direction) {} - - LogicalResult matchAndRewrite( - CompareOpTy op, typename CompareOpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - if (!isComplexTensor(adaptor.getLhs()) || - !isComplexTensor(adaptor.getRhs())) { - return rewriter.notifyMatchFailure(op, "not complex tensor"); - } - if (direction != op.getComparisonDirection()) { - return rewriter.notifyMatchFailure(op, "not matching direction"); - } - - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - auto lhsReal = rewriter.createOrFold(loc, lhs); - auto lhsImag = rewriter.createOrFold(loc, lhs); - auto rhsReal = rewriter.createOrFold(loc, rhs); - auto rhsImag = rewriter.createOrFold(loc, rhs); - - // If the input op is an mhlo op, we need to convert the attributes to the - // corresponding chlo one.. - chlo::ComparisonDirection chloCmpDirection = - *chloComparisonDirection(adaptor.getComparisonDirection()); - - std::optional mhloCmpType = adaptor.getCompareType(); - chlo::ComparisonTypeAttr chloCmpType; - if (mhloCmpType) - chloCmpType = chlo::ComparisonTypeAttr::get( - rewriter.getContext(), *chloComparisonType(*mhloCmpType)); - - rewriter.replaceOpWithNewOp( - op, - rewriter.create( - loc, lhsReal, rhsReal, - /*broadcast_dimensions=*/nullptr, chloCmpDirection, chloCmpType), - rewriter.create( - loc, lhsImag, rhsImag, - /*broadcast_dimensions=*/nullptr, chloCmpDirection, chloCmpType)); - - return success(); - } - - mhlo::ComparisonDirection direction; -}; - -struct ElideComplexPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::ComplexOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; - -struct ElideRealPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::RealOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto complexProducer = - adaptor.getOperands()[0].getDefiningOp(); - if (complexProducer) { - rewriter.replaceOp(op, complexProducer.getLhs()); - return success(); - } - return failure(); - } -}; - -struct ElideImagPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::ImagOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto complexProducer = - adaptor.getOperands()[0].getDefiningOp(); - if (complexProducer) { - rewriter.replaceOp(op, complexProducer.getRhs()); - return success(); - } - return failure(); - } -}; - -} // namespace - -void populateMHLOComplexToRealPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns) { - // Add an subtract patterns. - patterns.insert>(typeConverter, context); - patterns.insert>(typeConverter, context); - patterns.insert>(typeConverter, - context); - patterns.insert>(typeConverter, - context); - - // Mul patterns. - patterns.insert>(typeConverter, context); - patterns.insert>(typeConverter, context); - - // Div patterns. - patterns.insert>(typeConverter, context); - patterns.insert>(typeConverter, context); - - // Unary ops. - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); - - // Compare ops. - patterns.insert>( - typeConverter, context, mhlo::ComparisonDirection::NE); - patterns.insert>( - typeConverter, context, mhlo::ComparisonDirection::EQ); - patterns.insert>( - typeConverter, context, chlo::ComparisonDirection::NE); - patterns.insert>( - typeConverter, context, chlo::ComparisonDirection::EQ); - - // Complex/Real/Imag conversions should fold away. - // Note that this is an opinion taken because these patterns are targeted - // at full conversion scenarios and we would rather know eagerly if - // conversion is not possible. A more lax conversion would not include the - // ElideComplexPattern. - // Doing it this way makes error messages nice because a failure will report - // which remaining live op is keeping it from being erased. - patterns.insert(typeConverter, context, 0); - patterns.insert(typeConverter, context); - patterns.insert(typeConverter, context); -} - -namespace { - -struct TestMHLOConvertComplexToRealPass - : public TestMHLOConvertComplexToRealBase< - TestMHLOConvertComplexToRealPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - MLIRContext *context = &getContext(); - TypeConverter typeConverter; - typeConverter.addConversion([](Type t) { return t; }); - - populateMHLOComplexToRealPatterns(context, typeConverter, patterns); - - ConversionTarget target(*context); - auto hasNoComplexTypes = [](Operation *op) { - for (Value operand : op->getOperands()) { - if (auto st = llvm::dyn_cast(operand.getType())) { - if (llvm::isa(st.getElementType())) { - return false; - } - } - } - for (Value result : op->getResults()) { - if (auto st = llvm::dyn_cast(result.getType())) { - if (llvm::isa(st.getElementType())) { - return false; - } - } - } - return true; - }; - - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - - // For the test, require that casts fully convert. - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - - // Binary elementwise. - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - - // Unary. - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - - // Compare. - target.addDynamicallyLegalOp(hasNoComplexTypes); - target.addDynamicallyLegalOp(hasNoComplexTypes); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createTestMHLOConvertComplexToRealPass() { - return std::make_unique(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp deleted file mode 100644 index 483b7ad01e2a..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.cpp +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h" - -#include - -#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -struct ConstOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::ConstantOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getValue()); - return success(); - } -}; - -} // namespace - -void setupDirectMHLOToFlowLegality(MLIRContext *context, - ConversionTarget &conversionTarget) { - conversionTarget.addIllegalOp(); -} - -void populateMHLOToFlowPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - patterns.insert(context); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h deleted file mode 100644 index d81d5b2f0176..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_INPUTCONVERSION_MHLO_CONVERTMHLOTOFLOW_H_ -#define IREE_COMPILER_INPUTCONVERSION_MHLO_CONVERTMHLOTOFLOW_H_ - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -// Setup the |conversionTarget| op legality for early-phase direct-to-flow -// conversion from the MHLO dialect. This will make certain ops illegal that we -// know we have good patterns for such that we can be sure we catch them before -// they are outlined into dispatch regions. -void setupDirectMHLOToFlowLegality(MLIRContext *context, - ConversionTarget &conversionTarget); - -// Appends all patterns for converting MHLO ops to flow ops. -void populateMHLOToFlowPatterns(MLIRContext *context, - RewritePatternSet &patterns); - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_CONVERTMHLOTOFLOW_H_ diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp deleted file mode 100644 index a1508218d121..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToLinalgExt.cpp +++ /dev/null @@ -1,617 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" -#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "iree/compiler/InputConversion/MHLO/Rewriters.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -static Type convertIntegerToSignless(IntegerType intType) { - return IntegerType::get(intType.getContext(), - intType.getIntOrFloatBitWidth()); -} - -static std::optional convertRank0TensorToScalar( - RankedTensorType tensorType) { - if (tensorType.getRank() != 0) return std::nullopt; - Type elementType = tensorType.getElementType(); - if (auto intType = llvm::dyn_cast(elementType)) { - elementType = convertIntegerToSignless(intType); - } - return elementType; -} - -static Type convertShapedToSignless(ShapedType shapedType) { - if (auto intType = llvm::dyn_cast(shapedType.getElementType())) - return shapedType.clone(convertIntegerToSignless(intType)); - return shapedType; -} - -static std::optional materializeCast(OpBuilder &builder, Type toType, - ValueRange inputs, Location loc) { - assert(inputs.size() == 1 && "too many inputs to type conversion"); - Value fromValue = inputs[0]; - auto fromType = llvm::dyn_cast(fromValue.getType()); - if (!fromType) return std::nullopt; - - if (auto intFromType = - llvm::dyn_cast(fromType.getElementType())) { - Type castType = getElementTypeOrSelf(toType); - if (auto shapedType = llvm::dyn_cast(fromType)) - castType = shapedType.clone(castType); - - if (castType != fromType) - fromValue = builder.create(loc, castType, fromValue) - ->getResult(0); - } - - if (fromType.getRank() != 0) return fromValue; - - Type extractType = getElementTypeOrSelf(toType); - return builder.createOrFold(loc, extractType, fromValue); -} - -/// Note: only designed to work for casts involving rank-0 tensors and scalars -/// implicitly captured within op regions. -class MhloToStdTypeConverter : public TypeConverter { - public: - MhloToStdTypeConverter() { - addConversion([](Type type) { return type; }); - - addConversion(convertShapedToSignless); - addConversion(convertRank0TensorToScalar); - addConversion(convertIntegerToSignless); - - addArgumentMaterialization(materializeCast); - addSourceMaterialization(materializeCast); - addTargetMaterialization(materializeCast); - } -}; - -//===----------------------------------------------------------------------===// -// Utils -//===----------------------------------------------------------------------===// - -static bool isInBodyOfLinalgExtOps(Operation *op) { - auto parent_op = op->getParentRegion()->getParentOp(); - return parent_op->getDialect() == - parent_op->getContext() - ->getLoadedDialect(); -} - -//===----------------------------------------------------------------------===// -// Region operations lowering. -//===----------------------------------------------------------------------===// - -template -struct LinalgExtRegionHLOOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (!isInBodyOfLinalgExtOps(op)) return failure(); - TensorType origRetType = llvm::dyn_cast(op.getType()); - if (!origRetType) return failure(); - SmallVector scalarArgs; - Type newRetType = getElementTypeOrSelf( - this->typeConverter->convertType(origRetType.getElementType())); - Value result = mhlo::MhloOpToStdScalarOp::mapOp( - op, newRetType, adaptor.getOperands(), &rewriter); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct LinalgExtRegionReturnOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - mhlo::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (!isInBodyOfLinalgExtOps(op)) return failure(); - rewriter.replaceOpWithNewOp( - op, adaptor.getOperands()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// SortOp -//===----------------------------------------------------------------------===// - -struct SortOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::SortOp mhloSortOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - Location loc = mhloSortOp.getLoc(); - - llvm::SmallVector resultTypes; - if (this->typeConverter - ->convertTypes(mhloSortOp.getResultTypes(), resultTypes) - .failed()) { - return failure(); - }; - auto sortOp = rewriter.create( - loc, resultTypes, - /*inputs=*/ValueRange{}, adaptor.getOperands(), - mhloSortOp.getDimensionAttr()); - rewriter.inlineRegionBefore(mhloSortOp.getComparator(), sortOp.getRegion(), - sortOp.getRegion().begin()); - Region ®ion = sortOp.getRegion(); - Block &block = region.front(); - TypeConverter::SignatureConversion signature_converter( - block.getNumArguments()); - for (auto en : llvm::enumerate(block.getArguments())) { - signature_converter.addInputs( - en.index(), this->typeConverter->convertType( - getElementTypeOrSelf(en.value().getType()))); - } - rewriter.applySignatureConversion(®ion, signature_converter); - - rewriter.replaceOp(mhloSortOp, sortOp->getResults()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ScatterOp -//===----------------------------------------------------------------------===// - -struct ScatterOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - /// Returns true if the `dimensionNumbers` from the mhlo.scatter op follows a - /// canonical form: - /// - /// * The rank of indices is greater than or equal to two. - /// * The index_vector_dim is the last dim of indices. - /// * Scatter dims to operand dims order: (0, ... , n) - /// * Inserted window dims order: (0, ... , d) - /// * Update window dims order: (d + 1, ... , m) - static bool hasCanonicalDimensionNumbers(mhlo::ScatterOp op) { - auto dimNumbers = op.getScatterDimensionNumbers(); - auto indicesType = llvm::cast(op.getScatterIndices().getType()); - auto indicesRank = indicesType.getRank(); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - auto indexDepth = indicesType.getShape().back(); - auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims(); - - if (indicesRank != 2) return false; - if (indexVectorDim != indicesRank - 1) return false; - if (scatterDimsToOperandDims.size() != indexDepth) return false; - - auto insertedWindowDims = dimNumbers.getInsertedWindowDims(); - for (auto en : llvm::enumerate(insertedWindowDims)) { - if (en.index() != en.value()) return false; - } - - // Check that there is only one batch dimension in the updates. - for (auto en : llvm::enumerate(dimNumbers.getUpdateWindowDims())) { - if (en.index() + 1 != en.value()) return false; - } - - return true; - } - - LogicalResult matchAndRewrite( - mhlo::ScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (!hasCanonicalDimensionNumbers(op)) return failure(); - if (llvm::size(op.getInputs()) != 1) - return op.emitError("NYI variadic operands scatter"); - if (llvm::size(op.getUpdates()) != 1) - return op.emitError("NYI variadic updates scatter"); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - Value original = adaptor.getInputs().front(); - Value indices = adaptor.getScatterIndices(); - Value updates = adaptor.getUpdates().front(); - - llvm::SmallVector scatterDimMap; - for (auto dim : - op.getScatterDimensionNumbers().getScatterDimsToOperandDims()) { - scatterDimMap.push_back(dim); - } - - auto scatterOp = rewriter.create( - op.getLoc(), op->getResultTypes(), ValueRange{updates, indices}, - ValueRange{original}, scatterDimMap, op.getUniqueIndices()); - - rewriter.inlineRegionBefore(op.getUpdateComputation(), - scatterOp.getRegion(), - scatterOp.getRegion().begin()); - Region ®ion = scatterOp.getRegion(); - TypeConverter::SignatureConversion signatureConverter(2); - Type argType = getElementTypeOrSelf(original.getType()); - // mhlo.scatter ops takes: - // output[O] = update_computation(output[O], updates[U]) - // where output[O] maps to block args #1 in linalg_ext.scatter ops. - signatureConverter.addInputs(1, argType); - signatureConverter.addInputs(0, argType); - rewriter.applySignatureConversion(®ion, signatureConverter); - - rewriter.replaceOp(op, scatterOp->getResults()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// FftOp -//===----------------------------------------------------------------------===// - -struct FftOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - static Value getBitReversalBuffer(ImplicitLocOpBuilder &b, int fftLength) { - SmallVector values; - int logn = std::log(fftLength) / std::log(2); - for (int i = 0; i < fftLength; ++i) { - int r = 0; - for (int j = 0; j < logn; ++j) { - r |= ((i >> j) & 1) << (logn - j - 1); - } - values.push_back(b.getI32IntegerAttr(r)); - } - auto type = RankedTensorType::get({fftLength}, b.getI32Type()); - return b.create(type, - DenseIntElementsAttr::get(type, values)); - } - - static SmallVector getBitReversalOrder(ImplicitLocOpBuilder &b, - Value real, int fftLength) { - auto realType = llvm::cast(real.getType()); - auto rank = realType.getRank(); - - SmallVector mixedSizes = - tensor::createDimValues(b, b.getLoc(), real); - Value emptyTensor = - b.create(mixedSizes, realType.getElementType()); - - SmallVector maps; - maps.push_back( - AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext())); - maps.push_back(b.getMultiDimIdentityMap(rank)); - SmallVector iterTypes(rank, - utils::IteratorType::parallel); - - Value indices = getBitReversalBuffer(b, fftLength); - auto genericOp = b.create( - TypeRange{realType}, indices, emptyTensor, maps, iterTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector ivs; - for (auto i : llvm::seq(0, rank - 1)) { - ivs.push_back(b.create(loc, i)); - } - ivs.push_back( - b.create(loc, b.getIndexType(), args[0])); - b.create( - loc, b.create(loc, real, ivs).getResult()); - }); - return {genericOp.getResult(0), - b.create( - realType, - DenseFPElementsAttr::get( - realType, llvm::cast(b.getF32FloatAttr(0.0))))}; - } - - static SmallVector getCoeffConstants(ImplicitLocOpBuilder &b, - int stage) { - constexpr std::complex kI(0, 1); - int m = 1 << stage; - int mh = m >> 1; - SmallVector real, imag; - for (auto i : llvm::seq(0, mh)) { - auto v = std::exp(-2 * M_PI * i / m * kI); - real.push_back(b.getF32FloatAttr(v.real())); - imag.push_back(b.getF32FloatAttr(v.imag())); - } - auto type = RankedTensorType::get({mh}, b.getF32Type()); - return { - b.create(type, DenseFPElementsAttr::get(type, real)), - b.create(type, - DenseFPElementsAttr::get(type, imag))}; - } - - LogicalResult matchAndRewrite( - mhlo::FftOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // Only handle 2^n fft length. - auto operandType = - llvm::dyn_cast(adaptor.getOperand().getType()); - if (!operandType || !operandType.hasStaticShape()) { - return failure(); - } - int fftLength = op.getFftLength().getSplatValue().getInt(); - if (fftLength & (fftLength - 1)) { - return rewriter.notifyMatchFailure( - op, "expected FFT length to be a power of two"); - } - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector results = - getBitReversalOrder(b, adaptor.getOperand(), fftLength); - int lognPlus1 = std::log(fftLength) / std::log(2) + 1; - for (auto s : llvm::seq(1, lognPlus1)) { - SmallVector inputs; - inputs.push_back(b.create(s)); - inputs.append(getCoeffConstants(b, s)); - auto fft = b.create( - TypeRange{results[0].getType(), results[1].getType()}, inputs, - results); - results = fft.getResults(); - } - - SmallVector shape(operandType.getShape().begin(), - operandType.getShape().end()); - shape.back() = fftLength / 2 + 1; - auto ty = RankedTensorType::get(shape, operandType.getElementType()); - SmallVector offsets(ty.getRank(), b.getIndexAttr(0)); - SmallVector strides(ty.getRank(), b.getIndexAttr(1)); - SmallVector sizes = - tensor::createDimValues(b, b.getLoc(), adaptor.getOperand()); - sizes.back() = b.getIndexAttr(shape.back()); - auto real = b.create(ty, results[0], offsets, sizes, - strides); - auto imag = b.create(ty, results[1], offsets, sizes, - strides); - rewriter.replaceOpWithNewOp(op, op.getType(), real, imag); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -struct ReverseOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ReverseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - auto ty = - llvm::dyn_cast(adaptor.getOperands()[0].getType()); - if (!ty) return failure(); - - Location loc = op.getLoc(); - SmallVector mixedSizes = - tensor::createDimValues(rewriter, loc, adaptor.getOperands()[0]); - Value emptyTensor = - rewriter.create(loc, mixedSizes, ty.getElementType()); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), adaptor.getOperands(), - emptyTensor, op.getDimensions()); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// TopkOp -//===----------------------------------------------------------------------===// - -struct TopkOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - chlo::TopKOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - Location loc = op.getLoc(); - Value operand = adaptor.getOperand(); - - auto inputValuesType = llvm::dyn_cast(operand.getType()); - auto outputValuesType = - llvm::dyn_cast(op.getValues().getType()); - auto outputIndicesType = - llvm::dyn_cast(op.getIndices().getType()); - if (!inputValuesType || !outputValuesType || !outputIndicesType) { - return rewriter.notifyMatchFailure( - op, "Input and output must be of ShapedType"); - } - - Type valueElementType = outputValuesType.getElementType(); - Type indicesElementType = outputIndicesType.getElementType(); - // Only handle integer types for indicies. Index type is not supported. - if (!llvm::isa(indicesElementType)) { - return rewriter.notifyMatchFailure( - op, "Output indices must be of integer type."); - } - - // Create and initialize output tensors for LinalgExt TopK results - // Define the output types based on the results of CHLO TopK - SmallVector mixedSizes = - tensor::createDimValues(rewriter, loc, adaptor.getOperand()); - mixedSizes.back() = rewriter.getIndexAttr(adaptor.getK()); - Value emptyTensorOutputValues = rewriter.create( - loc, mixedSizes, valueElementType); - Value emptyTensorOutputIndices = rewriter.create( - loc, mixedSizes, indicesElementType); - // Initialize indices to 0 and values to negative infinity - TypedAttr negInfAttr; - if (auto intType = llvm::dyn_cast(valueElementType)) { - negInfAttr = rewriter.getIntegerAttr( - intType, APInt::getSignedMinValue(intType.getWidth())); - } else { - auto negApFloat = APFloat::getInf( - llvm::cast(valueElementType).getFloatSemantics(), - /*Negative=*/true); - negInfAttr = rewriter.getFloatAttr(valueElementType, negApFloat); - } - Value negInf = rewriter.create(loc, negInfAttr); - TypedAttr posInfAttr = rewriter.getIntegerAttr( - indicesElementType, APInt::getSignedMaxValue(32)); - Value posInf = rewriter.create(loc, posInfAttr); - Value negInfTensor = - rewriter.create(loc, negInf, emptyTensorOutputValues) - .result(); - Value posInfTensor = - rewriter.create(loc, posInf, emptyTensorOutputIndices) - .result(); - - // Replace the CHLO TopK with LinalgExt TopK - uint64_t kDim = inputValuesType.getRank() - 1; - auto topkOp = rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), ValueRange{operand}, - ValueRange{negInfTensor, posInfTensor}, kDim); - - // Define the region of TopK with a GT comparison - SmallVector types(2, valueElementType); - SmallVector locations(2, loc); - Block *block = - rewriter.createBlock(&topkOp.getRegion(), {}, types, locations); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(block); - Value lhs = block->getArgument(0); - Value rhs = block->getArgument(1); - Value condition; - if (llvm::isa(valueElementType)) { - condition = rewriter.create( - loc, arith::CmpIPredicate::sge, lhs, rhs); - } else { - condition = rewriter.create( - loc, arith::CmpFPredicate::OGT, lhs, rhs); - } - rewriter.create(loc, condition); - } - - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// Pass -//===----------------------------------------------------------------------===// - -struct ConvertMHLOToLinalgExtPass - : public ConvertMHLOToLinalgExtBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - MLIRContext *context = &getContext(); - - MhloToStdTypeConverter typeConverter; - patterns.insert(typeConverter, - context); - // FIXME: It shouldn't be necessary to list every matching MHLO op here, - // especially since they're already listed in - // populateHLOToLinalgConversionPattern and in HloOpToStdScalarOp. These - // lists are all the same. Can we leverage SFINAE here? - patterns - .insert, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionHLOOpConversion, - LinalgExtRegionReturnOpConversion>(typeConverter, context); - - ConversionTarget target(getContext()); - target.addLegalDialect(); - // TODO: Scatter is not marked as illegal to allow falling back to the - // generic LinAlg lowering, the generic lowering is not always performant - // and even though only used in fallback here, may hide performance - // issues and we'd rather know when the optimized lowering fails. - target.addIllegalOp(); - // FFT conversion creates complex ops which will be converted by the normal - // MHLO lowering, but these should still be converted if present inside - // other linalg_ext op regions. - target.addDynamicallyLegalOp( - [](mhlo::ComplexOp complexOp) { - return !isInBodyOfLinalgExtOps(complexOp); - }); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> -createConvertMHLOToLinalgExtPass() { - return std::make_unique(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToStableHLO.cpp deleted file mode 100644 index c64b983ae302..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/ConvertMHLOToStableHLO.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "mhlo/transforms/passes.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/DialectRegistry.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir::iree_compiler::MHLO { -namespace { -struct ConvertMHLOToStableHLOPass final - : ConvertMHLOToStableHLOPassBase { - void runOnOperation() override { - OpPassManager pm(ModuleOp::getOperationName(), - OpPassManager::Nesting::Explicit); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - - if (failed(runPipeline(pm, getOperation()))) { - signalPassFailure(); - } - } - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; -} // namespace - -std::unique_ptr> createConvertMHLOToStableHLOPass() { - return std::make_unique(); -} -} // namespace mlir::iree_compiler::MHLO diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp deleted file mode 100644 index df2bae368b6a..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/FlattenTuplesInCFG.cpp +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/iterator_range.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -// Given a set of types, unpack to a list of a types, removing all tuples. -void untupleTypes(TypeRange types, llvm::SmallVectorImpl &newTypes) { - for (Type type : types) { - if (llvm::isa(type)) { - untupleTypes(llvm::dyn_cast(type).getTypes(), newTypes); - } else { - newTypes.push_back(type); - } - } -} - -Value processTuple(Type type, Location loc, Block *block, OpBuilder &builder) { - if (!llvm::isa(type)) { - return block->addArgument(type, loc); - } - - auto tupleType = llvm::dyn_cast(type); - llvm::SmallVector values; - values.reserve(tupleType.size()); - for (auto subtype : tupleType.getTypes()) { - values.push_back(processTuple(subtype, loc, block, builder)); - } - - return builder.create(loc, tupleType, values); -} - -void copyOperationAttrs(Operation *oldOp, Operation *newOp) { - for (const auto &oldAttr : oldOp->getAttrs()) { - // Don't copy segment attributes as these correspond to the number operands, - // which may be different. - if (oldAttr.getName() == "operand_segment_sizes" || - oldAttr.getName() == "result_segment_sizes") - continue; - - newOp->setAttr(oldAttr.getName(), oldAttr.getValue()); - } -} - -bool recursiveUntuple(Value value, Location loc, OpBuilder &builder, - IRMapping *mapping, - llvm::SmallVectorImpl *newValues) { - Type type = value.getType(); - // We can return the value as is. - if (!llvm::isa(type)) { - newValues->push_back(value); - return false; - } - - TupleType tupleType = llvm::dyn_cast(type); - for (int i = 0; i < tupleType.size(); i++) { - auto subType = tupleType.getType(i); - - auto elementOp = builder.create( - loc, subType, value, builder.getI32IntegerAttr(i)); - recursiveUntuple(elementOp.getResult(), loc, builder, mapping, newValues); - } - - return false; -} - -Value recursiveRetuple(Type oldType, Operation::result_range *values, - OpBuilder &builder, Location loc) { - if (!llvm::isa(oldType)) { - Value returnValue = *values->begin(); - *values = {values->begin() + 1, values->end()}; - return returnValue; - } - - TupleType tupleType = llvm::dyn_cast(oldType); - llvm::SmallVector subValues; - for (auto subtype : tupleType.getTypes()) { - subValues.push_back(recursiveRetuple(subtype, values, builder, loc)); - } - - return builder.create(loc, tupleType, subValues).getResult(); -} - -template -bool untupleAndLookupValues(T values, llvm::SmallVectorImpl *newValues, - OpBuilder &builder, Location loc, - IRMapping *mapping) { - for (auto operand : values) { - auto newValue = mapping->lookupOrNull(operand); - if (!newValue) { - return true; - } - - recursiveUntuple(newValue, loc, builder, mapping, newValues); - } - - return false; -} - -bool convertReturnOp(mlir::func::ReturnOp *op, OpBuilder &builder, - IRMapping *mapping) { - llvm::SmallVector newOperands; - if (untupleAndLookupValues(op->getOperands(), &newOperands, builder, - op->getLoc(), mapping)) { - return true; - } - - builder.create(op->getLoc(), newOperands); - return false; -} - -bool convertCallOp(func::CallOp *oldOp, OpBuilder &builder, - IRMapping *mapping) { - llvm::SmallVector newArgs; - if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, - oldOp->getLoc(), mapping)) { - return true; - } - - SmallVector resultTypes; - untupleTypes(oldOp->getOperation()->getResultTypes(), resultTypes); - auto newOp = builder.create(oldOp->getLoc(), oldOp->getCallee(), - resultTypes, newArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - auto newResults = newOp.getResults(); - for (auto oldResult : oldOp->getResults()) { - llvm::SmallVector subValues; - auto newResult = recursiveRetuple(oldResult.getType(), &newResults, builder, - oldOp->getLoc()); - mapping->map(oldResult, newResult); - } - - return false; -} - -bool convertIndirectCallOp(func::CallIndirectOp *oldOp, OpBuilder &builder, - IRMapping *mapping) { - llvm::SmallVector newArgs; - if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, - oldOp->getLoc(), mapping)) { - return true; - } - - auto newOp = builder.create( - oldOp->getLoc(), oldOp->getCallee(), newArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - for (int i = 0; i < newOp.getNumResults(); ++i) { - auto oldResult = oldOp->getResult(i); - auto newResult = newOp.getResult(i); - mapping->map(oldResult, newResult); - } - - return false; -} - -bool convertBranchOp(cf::BranchOp *oldOp, OpBuilder &builder, - IRMapping *mapping) { - llvm::SmallVector newArgs; - if (untupleAndLookupValues(oldOp->getOperands(), &newArgs, builder, - oldOp->getLoc(), mapping)) { - return true; - } - - auto newOp = builder.create( - oldOp->getLoc(), mapping->lookupOrNull(oldOp->getDest()), newArgs); - - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - return false; -} - -bool convertCondBranchOp(cf::CondBranchOp *oldOp, OpBuilder &builder, - IRMapping *mapping) { - llvm::SmallVector trueArgs; - if (untupleAndLookupValues(oldOp->getTrueOperands(), &trueArgs, builder, - oldOp->getLoc(), mapping)) { - return true; - } - - llvm::SmallVector falseArgs; - if (untupleAndLookupValues(oldOp->getFalseOperands(), &falseArgs, builder, - oldOp->getLoc(), mapping)) { - return true; - } - - auto newOp = builder.create( - oldOp->getLoc(), mapping->lookupOrNull(oldOp->getCondition()), - mapping->lookupOrNull(oldOp->getTrueDest()), trueArgs, - mapping->lookupOrNull(oldOp->getFalseDest()), falseArgs); - - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - return false; -} - -bool convertOperation(Operation *op, OpBuilder &builder, IRMapping *mapping) { - if (auto returnOp = dyn_cast(op)) { - return convertReturnOp(&returnOp, builder, mapping); - } else if (auto callOp = dyn_cast(op)) { - return convertCallOp(&callOp, builder, mapping); - } else if (auto callIndirectOp = dyn_cast(op)) { - return convertIndirectCallOp(&callIndirectOp, builder, mapping); - } else if (auto branchOp = dyn_cast(op)) { - return convertBranchOp(&branchOp, builder, mapping); - } else if (auto condBranchOp = dyn_cast(op)) { - return convertCondBranchOp(&condBranchOp, builder, mapping); - } - - builder.clone(*op, *mapping); - return false; -} - -bool convertFunction(func::FuncOp oldFunction, func::FuncOp newFunction) { - OpBuilder builder(newFunction.getBody()); - IRMapping mapping; - - // Check whether has tuple in signature. - bool hasTupleSig = (oldFunction.getArgumentTypes().size() != - newFunction.getArgumentTypes().size()) || - (oldFunction.getResultTypes().size() != - newFunction.getResultTypes().size()); - - // Cache unused XLA ABI marker names. - auto xlaAbiParam = StringAttr::get(newFunction.getContext(), - "xla_entry_computation_parameter_layouts"), - xlaAbiLayout = StringAttr::get(newFunction.getContext(), - "xla_entry_computation_result_layout"); - - for (auto attr : oldFunction->getAttrs()) { - if (attr.getName() == oldFunction.getFunctionTypeAttrName() || - // Currently skipping all arg, result and XLA specific ABI attributes. - attr.getName() == xlaAbiParam || attr.getName() == xlaAbiLayout) - continue; - // If it has tuples in sig, then skip arg and res attrs. None of the - // existing ones along path that produces tuples are used further, so just - // remove instead of flattening. - if (hasTupleSig && (attr.getName() == oldFunction.getArgAttrsAttrName() || - attr.getName() == oldFunction.getResAttrsAttrName())) - continue; - newFunction->setAttr(attr.getName(), attr.getValue()); - } - - newFunction.getBlocks().clear(); - for (auto &oldBlock : oldFunction.getBlocks()) { - auto *newBlock = builder.createBlock(&newFunction.getBody()); - for (auto oldArg : oldBlock.getArguments()) { - llvm::SmallVector newTypes; - untupleTypes(oldArg.getType(), newTypes); - - Value newTuple = processTuple(oldArg.getType(), oldFunction.getLoc(), - newBlock, builder); - if (!newTuple) { - return true; - } - - mapping.map(oldArg, newTuple); - } - mapping.map(&oldBlock, newBlock); - } - - // Convert all ops in the blocks. - for (auto &oldBlock : oldFunction.getBlocks()) { - builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); - for (auto &oldOp : oldBlock.getOperations()) { - if (convertOperation(&oldOp, builder, &mapping)) { - return true; - } - } - } - - return false; -} - -class FlattenTuplesInCFGPass - : public FlattenTuplesInCFGBase { - public: - void runOnOperation() override { - auto module = getOperation(); - Builder builder(module.getContext()); - - // Build a list of (oldFunction, newFunction) for all functions we need to - // replace. This will ensure that when we go to convert function bodies we - // have only new functions defined. - std::vector> convertedFunctions; - - for (auto oldFunction : module.getOps()) { - auto oldFunctionType = oldFunction.getFunctionType(); - - llvm::SmallVector newInputTypes; - untupleTypes(oldFunctionType.getInputs(), newInputTypes); - - llvm::SmallVector newResultTypes; - untupleTypes(oldFunctionType.getResults(), newResultTypes); - - auto newFunctionType = - builder.getFunctionType(newInputTypes, newResultTypes); - auto newFunction = - func::FuncOp::create(oldFunction.getLoc(), oldFunction.getName(), - newFunctionType, oldFunction->getDialectAttrs()); - convertedFunctions.push_back({oldFunction, newFunction}); - - // Perform the actual body conversion now that we have proper signatures. - if (convertFunction(oldFunction, newFunction)) { - return signalPassFailure(); - } - } - - // Replace functions in the module. - for (auto &pair : convertedFunctions) { - pair.first.erase(); - module.push_back(pair.second); - } - } -}; - -} // namespace - -std::unique_ptr> createFlattenTuplesInCFGPass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp deleted file mode 100644 index 75d0de53ab80..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToLinalgOnTensors.cpp +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- XLAToLinalgOnTensors.cpp - Pass to convert XLA to Linalg on tensors-===// -// -// Pass to convert from XLA to linalg on tensers. Uses the patterns from -// tensorflow/compiler/mlir/xla/transforms/legalize_to_linalg.cc along with -// some IREE specific patterns. -// -//===----------------------------------------------------------------------===// -#include - -#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" -#include "iree/compiler/Dialect/Util/IR/UtilOps.h" -#include "iree/compiler/InputConversion/MHLO/ConvertMHLOToFlow.h" -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "iree/compiler/InputConversion/MHLO/Rewriters.h" -#include "iree/compiler/Utils/ConversionUtils.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/rewriters.h" -#include "mhlo/utils/legalize_to_linalg_utils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -//===----------------------------------------------------------------------===// -// mhlo.concatenate conversion patterns. -//===----------------------------------------------------------------------===// - -namespace { -/// Converts mhlo.concatenate operation to extract_slice ops + insert_slice ops. -struct ConcatenateOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::ConcatenateOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto resultType = llvm::dyn_cast( - this->typeConverter->convertType(op.getResult().getType())); - if (!resultType || !resultType.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, - "expected static shape for output"); - } - - Location loc = op.getLoc(); - int dim = op.getDimension(); - int rank = resultType.getRank(); - SmallVector offsets, sizes, strides; - for (int i = 0; i < rank; ++i) { - offsets.push_back(rewriter.create(loc, 0)); - sizes.push_back(rewriter.createOrFold( - loc, adaptor.getOperands()[0], i)); - strides.push_back(rewriter.create(loc, 1)); - } - Value resultDimSize = rewriter.create(loc, 0); - for (auto arg : adaptor.getOperands()) { - auto size = rewriter.createOrFold(loc, arg, dim); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); - } - sizes[dim] = resultDimSize; - Value result = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) return v; - return op.getValue(); - }; - - Value accBound = rewriter.create(loc, 0); - for (auto arg : adaptor.getOperands()) { - offsets[dim] = accBound; - sizes[dim] = rewriter.createOrFold(loc, arg, dim); - result = rewriter.create( - loc, arg, result, llvm::map_to_vector(offsets, toOpFoldResult), - llvm::map_to_vector(sizes, toOpFoldResult), - llvm::map_to_vector(strides, toOpFoldResult)); - accBound = rewriter.create(loc, accBound, sizes[dim]); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// mhlo.fft conversion patterns. -//===----------------------------------------------------------------------===// - -/// Creats coefficients based on DFT definition, see -/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform -Value getDFTMatmulCoeff(OpBuilder b, Location loc, RankedTensorType matrixType, - bool isRealPart) { - // scale = 2 * pi / N - double scale = 2 * M_PI / matrixType.getDimSize(0); - - SmallVector values; - assert(matrixType.getRank() == 2 && "expected 2D matrix"); - for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { - for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { - double v = scale * i * j; - if (isRealPart) { - v = cos(v); - } else { - v = -sin(v); - } - values.push_back(b.getF32FloatAttr(v)); - } - } - return b.create( - loc, matrixType, DenseFPElementsAttr::get(matrixType, values)); -} - -Value createLinalgMatmulOnTensors(OpBuilder b, Location loc, - RankedTensorType resultType, Value lhs, - Value rhs) { - Value zero = b.create( - loc, b.getZeroAttr(resultType.getElementType())); - Value emptyTensor = b.create( - loc, resultType.getShape(), resultType.getElementType(), - /*dyn_size=*/ValueRange{}); - Value zeroTensor = - b.create(loc, zero, emptyTensor).getResult(0); - - switch (llvm::cast(lhs.getType()).getRank()) { - case 1: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, - ValueRange{zeroTensor}) - .getResult(0); - case 2: - return b - .create(loc, TypeRange{resultType}, - ValueRange{lhs, rhs}, - ValueRange{zeroTensor}) - .getResult(0); - default: - assert(false && "unhandled matmul type"); - return Value(); - } -} - -/// Converts mhlo.fft operation to Linalg ops. -struct FftOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::FftOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.getFftType() != mhlo::FftType::RFFT) { - return rewriter.notifyMatchFailure(op, - "non RFFT types are supported yet"); - } - - auto inputType = - llvm::dyn_cast(adaptor.getOperand().getType()); - if (!inputType || !inputType.hasStaticShape() || inputType.getRank() > 2) { - return rewriter.notifyMatchFailure(op, "only static 1D or 2D dft ops"); - } - - int rank = inputType.getRank(); - int n = inputType.getDimSize(rank - 1); - int fftLength = - op.getFftLength().getSplatValue().getInt() / 2 + 1; - - Location loc = op.getLoc(); - auto matrixType = - RankedTensorType::get({n, fftLength}, inputType.getElementType()); - auto resultType = RankedTensorType::get( - llvm::cast(op.getType()).getShape(), - inputType.getElementType()); - - auto realMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/true); - auto real = createLinalgMatmulOnTensors(rewriter, loc, resultType, - adaptor.getOperand(), realMatrix); - - auto imagMatrix = - getDFTMatmulCoeff(rewriter, loc, matrixType, /*isRealPart=*/false); - auto imag = createLinalgMatmulOnTensors(rewriter, loc, resultType, - adaptor.getOperand(), imagMatrix); - - // Pack the results back to mhlo::ComplexOp. - rewriter.replaceOpWithNewOp(op, op.getType(), real, imag); - return success(); - } -}; - -struct OptimizationBarrierOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::OptimizationBarrierOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector outputs; - for (auto operand : adaptor.getOperands()) { - outputs.push_back( - rewriter - .create(op.getLoc(), operand) - .getResult(0)); - } - rewriter.replaceOp(op, outputs); - return success(); - } -}; - -// Returns true if all attributes in the given dictionary are valid for IREE -// input dialects. -static bool isValidFuncAttr(DictionaryAttr attrs) { - // TODO: switch to using a dialect-based exclusion list or some other way that - // is not a big string table. - for (auto attr : attrs) { - if (attr.getName() == "tf.aliasing_output") return false; - } - return true; -} - -// Adds iree.abi.encoding attributes for arguments and results when they have -// had their type changed during conversion. -static void setFuncEncodings(func::FuncOp funcOp, FunctionType oldFuncType, - FunctionType newFuncType) { - auto encodingName = StringAttr::get(funcOp.getContext(), "iree.abi.encoding"); - for (auto [i, oldType, newType] : - llvm::enumerate(oldFuncType.getInputs(), newFuncType.getInputs())) { - if (oldType != newType) - funcOp.setArgAttr(i, encodingName, TypeAttr::get(oldType)); - } - for (auto [i, oldType, newType] : - llvm::enumerate(oldFuncType.getResults(), newFuncType.getResults())) { - if (oldType != newType) - funcOp.setResultAttr(i, encodingName, TypeAttr::get(oldType)); - } -} - -// Rewrites attributes on the function from ones coming from HLO-based frontends -// to the IREE supported versions. -static void rewriteFuncAttrs(func::FuncOp funcOp) { - auto *context = funcOp.getContext(); - auto indexType = IndexType::get(context); - auto abiOutputName = StringAttr::get(context, "iree.abi.output"); - auto aliasingOutputName = StringAttr::get(context, "tf.aliasing_output"); - auto rewriteAttrs = [&](DictionaryAttr &allAttrs) { - SmallVector newAttrs; - newAttrs.reserve(allAttrs.size()); - for (auto attr : allAttrs) { - if (attr.getName() == aliasingOutputName) { - newAttrs.push_back({ - abiOutputName, - IntegerAttr::get(indexType, - llvm::cast(attr.getValue()).getInt()), - }); - } else { - newAttrs.push_back(attr); - } - } - allAttrs = DictionaryAttr::get(context, newAttrs); - }; - SmallVector argAttrs; - funcOp.getAllArgAttrs(argAttrs); - llvm::for_each(argAttrs, rewriteAttrs); - funcOp.setAllArgAttrs(argAttrs); - SmallVector resultAttrs; - funcOp.getAllResultAttrs(resultAttrs); - llvm::for_each(resultAttrs, rewriteAttrs); - funcOp.setAllResultAttrs(resultAttrs); -} - -// We need to convert func ops in order to convert types. -class BuiltinFuncOpPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - func::FuncOp srcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - FunctionType srcFuncType = srcOp.getFunctionType(); - TypeConverter::SignatureConversion signatureConversion( - srcOp.getNumArguments()); - - // Convert function arguments. - for (unsigned i = 0, e = srcFuncType.getNumInputs(); i < e; ++i) { - if (failed(getTypeConverter()->convertSignatureArg( - i, srcFuncType.getInput(i), signatureConversion))) { - return rewriter.notifyMatchFailure(srcOp, "argument failed to convert"); - } - } - - // Convert function results. - SmallVector convertedResultTypes; - if (failed(getTypeConverter()->convertTypes(srcFuncType.getResults(), - convertedResultTypes))) { - return rewriter.notifyMatchFailure(srcOp, "results failed to convert"); - } - - // Create new function with converted argument and result types. - auto oldFuncType = srcOp.getFunctionType(); - auto newFuncType = mlir::FunctionType::get( - srcOp.getContext(), signatureConversion.getConvertedTypes(), - convertedResultTypes); - - // Update the function in place. - rewriter.startRootUpdate(srcOp); - srcOp.setType(newFuncType); - rewriteFuncAttrs(srcOp); - setFuncEncodings(srcOp, oldFuncType, newFuncType); - - // Tell the rewriter to convert the region signature. - TypeConverter &typeConverter = *getTypeConverter(); - if (failed(rewriter.convertRegionTypes(&srcOp.getBody(), typeConverter, - &signatureConversion))) { - return failure(); - } - - rewriter.finalizeRootUpdate(srcOp); - return success(); - } -}; - -class GlobalOpPattern : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - ml_program::GlobalOp globalOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto oldType = globalOp.getType(); - auto newType = getTypeConverter()->convertType(oldType); - if (newType == oldType) return failure(); - if (!newType) { - return rewriter.notifyMatchFailure(globalOp, - "result type conversion failed"); - } - rewriter.updateRootInPlace(globalOp, [&]() { - globalOp.setType(newType); - if (auto oldValue = globalOp.getValueAttr()) { - globalOp.setValueAttr( - convertAttribute(globalOp.getLoc(), oldValue, *getTypeConverter())); - } - }); - return success(); - } -}; - -class GenericTypeConvert : public ConversionPattern { - public: - GenericTypeConvert(StringRef rootName, TypeConverter &converter, - MLIRContext *context, PatternBenefit benefit = 0) - : ConversionPattern(converter, rootName, benefit, context) {} - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector newAttr; - llvm::append_range(newAttr, op->getAttrs()); - llvm::SmallVector newResults; - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newResults))) { - return rewriter.notifyMatchFailure(op, "result type conversion failed"); - } - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - newResults, newAttr, op->getSuccessors()); - for (Region &r : op->getRegions()) { - Region *newRegion = state.addRegion(); - rewriter.inlineRegionBefore(r, *newRegion, newRegion->begin()); - TypeConverter::SignatureConversion result(newRegion->getNumArguments()); - if (failed(getTypeConverter()->convertSignatureArgs( - newRegion->getArgumentTypes(), result))) { - return rewriter.notifyMatchFailure(op, - "argument type conversion failed"); - } - rewriter.applySignatureConversion(newRegion, result); - } - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; - -std::optional scalarToTensor(OpBuilder &builder, Type /*type*/, - ValueRange inputs, Location loc) { - assert(inputs.size() == 1); - if (llvm::isa(inputs.front().getType())) { - return std::nullopt; - } - return builder - .create( - loc, RankedTensorType::get({}, inputs.front().getType()), - inputs.front()) - .getResult(); -} - -std::optional materializeCastFromIllegal(OpBuilder &builder, Type type, - ValueRange inputs, - Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || - !toType.isSignlessInteger()) - return std::nullopt; - // Use bitcast to do signless->signful conversions. - return builder.create(loc, type, inputs[0])->getResult(0); -} - -std::optional materializeCastToIllegal(OpBuilder &builder, Type type, - ValueRange inputs, Location loc) { - Type fromType = getElementTypeOrSelf(inputs[0].getType()); - Type toType = getElementTypeOrSelf(type); - if (!fromType.isSignlessInteger() || - (!toType.isSignedInteger() && !toType.isUnsignedInteger())) - return std::nullopt; - // Use bitcast to do signless->signful conversions. - return builder.create(loc, type, inputs[0])->getResult(0); -} - -struct ConvertMHLOToLinalgOnTensorsPass - : public ConvertMHLOToLinalgOnTensorsBase< - ConvertMHLOToLinalgOnTensorsPass> { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert< - IREE::Flow::FlowDialect, IREE::Util::UtilDialect, linalg::LinalgDialect, - mhlo::MhloDialect, shape::ShapeDialect, tensor::TensorDialect, - math::MathDialect, memref::MemRefDialect, complex::ComplexDialect>(); - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - MLIRContext *context = &getContext(); - - auto typeConverter = mhlo::createHloToLinalgTypeConverter(); - typeConverter->addArgumentMaterialization(scalarToTensor); - typeConverter->addArgumentMaterialization(materializeCastFromIllegal); - typeConverter->addTargetMaterialization(materializeCastFromIllegal); - typeConverter->addSourceMaterialization(materializeCastToIllegal); - // NOTE: not using corresponding setupMHLOToFlowPatterns because the entire - // MHLO dialects are marked illegal by this pass. - // TODO: Collapse/rework all of these patterns once the consolidation - // lands. There is little reason to have these so spread out. - populateMHLOToFlowPatterns(context, patterns); - - chlo::populateDecomposeChloPatterns(context, &patterns); - populateMHLOBroadcastingToLinalgPatterns(context, *typeConverter, patterns); - mhlo::populateScalarHloToArithmeticConversionPatterns( - context, *typeConverter, &patterns, - [](Operation *op) { return mhlo::isInBodyOfLinalgOps(op); }); - populateMHLOToLinalgOnTensorsConversionPatterns(context, *typeConverter, - patterns); - populateMHLOComplexToRealPatterns(context, *typeConverter, patterns); - - populateMHLOCollectiveOpsConversionPatterns(context, *typeConverter, - patterns); - // TODO(*): expose patterns that do this much better from - // iree/compiler/Dialect/Util/Transforms/ConvertPrimitiveType.cpp - - // Structural patterns (functions, cfg, terminators). - patterns.insert(*typeConverter, context); - patterns.insert(func::ReturnOp::getOperationName(), - *typeConverter, context); - patterns.insert(func::CallOp::getOperationName(), - *typeConverter, context); - patterns.insert(cf::CondBranchOp::getOperationName(), - *typeConverter, context); - patterns.insert(cf::BranchOp::getOperationName(), - *typeConverter, context); - patterns.insert(*typeConverter, context); - patterns.insert( - ml_program::GlobalLoadOp::getOperationName(), *typeConverter, context); - patterns.insert( - ml_program::GlobalLoadConstOp::getOperationName(), *typeConverter, - context); - patterns.insert( - ml_program::GlobalStoreOp::getOperationName(), *typeConverter, context); - // This is needed when converting mhlo::ReplicaIDOp. - patterns.insert( - tensor::FromElementsOp::getOperationName(), *typeConverter, context); - patterns.insert( - arith::IndexCastUIOp::getOperationName(), *typeConverter, context); - ConversionTarget target(getContext()); - - auto isIllegalType = [&](Type t) { return !typeConverter->isLegal(t); }; - auto isLegallyTypedOp = [&](Operation *op) -> bool { - for (Type type : op->getResultTypes()) { - if (isIllegalType(type)) return false; - } - for (Type type : op->getOperandTypes()) { - if (isIllegalType(type)) return false; - } - return true; - }; - - target.addIllegalDialect(); - target.addIllegalDialect(); - - // Functions must have legal types. - target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { - if (auto attrs = funcOp.getAllArgAttrs()) { - if (!llvm::all_of(attrs.getAsRange(), - isValidFuncAttr)) { - return false; - } - } - if (auto attrs = funcOp.getAllResultAttrs()) { - if (!llvm::all_of(attrs.getAsRange(), - isValidFuncAttr)) { - return false; - } - } - for (Type type : funcOp.getFunctionType().getInputs()) { - if (isIllegalType(type)) return false; - } - for (Type type : funcOp.getFunctionType().getResults()) { - if (isIllegalType(type)) return false; - } - for (Block &block : funcOp.getFunctionBody()) { - for (Type type : block.getArgumentTypes()) { - if (isIllegalType(type)) return false; - } - } - return true; - }); - target.addDynamicallyLegalOp([&](func::ReturnOp op) { - return llvm::all_of(op.getOperandTypes(), - [&](Type type) { return !isIllegalType(type); }); - }); - target.addDynamicallyLegalOp([&](func::CallOp op) { - return llvm::all_of(op.getOperandTypes(), - [&](Type type) { return !isIllegalType(type); }); - }); - target.addDynamicallyLegalOp([&](cf::CondBranchOp op) { - return llvm::all_of(op.getOperandTypes(), - [&](Type type) { return !isIllegalType(type); }); - }); - target.addDynamicallyLegalOp([&](cf::BranchOp op) { - return llvm::all_of(op.getOperandTypes(), - [&](Type type) { return !isIllegalType(type); }); - }); - target.addDynamicallyLegalOp( - [&](ml_program::GlobalOp op) { - return typeConverter->isLegal(op.getType()); - }); - - // Let the rest fall through. - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalOp(); - target.markUnknownOpDynamicallyLegal(isLegallyTypedOp); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - return signalPassFailure(); - } - - { - // Apply the patterns to remove unused operands and results. - RewritePatternSet removeUnusedOperandsResultsPatterns(&getContext()); - linalg::populateEraseUnusedOperandsAndResultsPatterns( - removeUnusedOperandsResultsPatterns); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), - std::move(removeUnusedOperandsResultsPatterns)))) { - return signalPassFailure(); - } - } - } -}; - -} // namespace - -void populateMHLOToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet &patterns) { - mhlo::populateHloToLinalgConversionPattern(context, typeConverter, &patterns); - // TODO(#5809): Drop ConcatenateOp lowering in favor of the upstream version - // then remove the PatternBenefit here - patterns.insert(typeConverter, context, - PatternBenefit(1000)); -} - -std::unique_ptr> createMHLOToLinalgOnTensorsPass() { - return std::make_unique(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp deleted file mode 100644 index c627f2513963..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/MHLOToMHLOPreprocessing.cpp +++ /dev/null @@ -1,1535 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include -#include - -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/rewriters.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -namespace { - -static bool isIota(ArrayRef array) { - for (auto it : llvm::enumerate(array)) { - if (it.index() != it.value()) { - return false; - } - } - return true; -} - -static DenseIntElementsAttr make1DElementsAttr(OpBuilder &b, - ArrayRef integers) { - auto type = RankedTensorType::get({static_cast(integers.size())}, - b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, integers); -} - -static DenseIntElementsAttr make1DElementsAttr(OpBuilder &b, int64_t start, - int64_t num) { - return make1DElementsAttr( - b, llvm::to_vector<4>(llvm::seq(start, start + num))); -} - -static Value getF32Const(ImplicitLocOpBuilder b, ArrayRef shapes, - ArrayRef values) { - RankedTensorType ty = RankedTensorType::get(shapes, b.getF32Type()); - return b.create(DenseFPElementsAttr::get(ty, values)) - .getResult(); -} - -// Guarantee that the input dimensions are ordered batch, spatial_dims, feature -// dim. -class ReorderConvOpInputDimensions - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, - PatternRewriter &rewriter) const override { - auto lhsType = llvm::cast(op.getLhs().getType()); - auto lhsShape = lhsType.getShape(); - if (!lhsType.hasRank()) { - return failure(); - } - - auto dimensionNumbers = op.getDimensionNumbers(); - auto spatialDims = dimensionNumbers.getInputSpatialDimensions(); - - // Compute the permutation required to create a standard order. - llvm::SmallVector permutations; - permutations.push_back(dimensionNumbers.getInputBatchDimension()); - permutations.append(spatialDims.begin(), spatialDims.end()); - permutations.push_back(dimensionNumbers.getInputFeatureDimension()); - - // If the permutation is iota then no reordering is required. - if (isIota(permutations)) { - return failure(); - } - - llvm::SmallVector transposeShape; - for (auto p : permutations) { - transposeShape.push_back(lhsShape[p]); - } - - auto transposed = rewriter.create( - op.getLoc(), - RankedTensorType::get(transposeShape, lhsType.getElementType()), - op.getLhs(), rewriter.getI64TensorAttr(permutations)); - - llvm::SmallVector newSpatialDimensions(spatialDims.size()); - std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1); - - auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( - op.getContext(), - /*input_batch_dimension=*/0, - /*input_feature_dimension=*/newSpatialDimensions.size() + 1, - /*input_spatial_dimensions=*/newSpatialDimensions, - dimensionNumbers.getKernelInputFeatureDimension(), - dimensionNumbers.getKernelOutputFeatureDimension(), - dimensionNumbers.getKernelSpatialDimensions(), - dimensionNumbers.getOutputBatchDimension(), - dimensionNumbers.getOutputFeatureDimension(), - dimensionNumbers.getOutputSpatialDimensions()); - - SmallVector operands = {transposed, op.getRhs()}; - auto newConv = rewriter.create( - op.getLoc(), op.getType(), operands, op->getAttrs()); - newConv.setDimensionNumbersAttr(newDimensionNumbers); - rewriter.replaceOp(op, newConv.getResult()); - - return success(); - } -}; - -struct ReorderConvOpKernelDimensions - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, - PatternRewriter &rewriter) const override { - auto kernel = op.getRhs(); - auto kernelType = llvm::cast(kernel.getType()); - if (!kernelType.hasRank()) return failure(); - auto kernelShape = kernelType.getShape(); - - auto dimensionNumbers = op.getDimensionNumbers(); - - auto spatialDims = dimensionNumbers.getKernelSpatialDimensions(); - - auto inputFeatureDimension = - dimensionNumbers.getKernelInputFeatureDimension(); - auto outputFeatureDimension = - dimensionNumbers.getKernelOutputFeatureDimension(); - - // Compute the permutation for the transpose. - llvm::SmallVector permutation(spatialDims.begin(), - spatialDims.end()); - permutation.push_back(inputFeatureDimension); - permutation.push_back(outputFeatureDimension); - - // If the permutation is iota, then no transpose is required. - if (isIota(permutation)) return failure(); - - llvm::SmallVector transposeShape; - for (auto perm : permutation) { - transposeShape.push_back(kernelShape[perm]); - } - - llvm::SmallVector newSpatialDimensions(spatialDims.size()); - std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 0); - - auto transposeKernel = rewriter.create( - op.getLoc(), - RankedTensorType::get(transposeShape, kernelType.getElementType()), - kernel, rewriter.getI64TensorAttr(permutation)); - - auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( - op.getContext(), dimensionNumbers.getInputBatchDimension(), - dimensionNumbers.getInputFeatureDimension(), - dimensionNumbers.getInputSpatialDimensions(), - /*kernel_input_feature_dimension=*/ - newSpatialDimensions.size(), - /*kernel_output_feature_dimension=*/ - newSpatialDimensions.size() + 1, newSpatialDimensions, - dimensionNumbers.getOutputBatchDimension(), - dimensionNumbers.getOutputFeatureDimension(), - dimensionNumbers.getOutputSpatialDimensions()); - - SmallVector operands = {op.getLhs(), transposeKernel}; - mhlo::ConvolutionOp newConv = rewriter.create( - op.getLoc(), op.getType(), operands, op->getAttrs()); - newConv.setDimensionNumbersAttr(newDimensionNumbers); - - rewriter.replaceOp(op, {newConv.getResult()}); - return success(); - } -}; - -// Guarantee that the output dimensions are ordered batch, spatial_dims, feature -// dim. -class ReorderConvOpOutputDimensions - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, - PatternRewriter &rewriter) const override { - auto resultType = llvm::cast(op.getType()); - auto resultShape = resultType.getShape(); - if (!resultType.hasRank()) { - return failure(); - } - - auto dimensionNumbers = op.getDimensionNumbers(); - auto spatialDims = dimensionNumbers.getOutputSpatialDimensions(); - - // Compute the permutation to transpose to an ordered output. - llvm::SmallVector permutation; - permutation.push_back(dimensionNumbers.getOutputBatchDimension()); - permutation.append(spatialDims.begin(), spatialDims.end()); - permutation.push_back(dimensionNumbers.getOutputFeatureDimension()); - - // If the permutation is iota then no reordering is required. - if (isIota(permutation)) { - return failure(); - } - - // Compute what the new conv shape should be. - llvm::SmallVector convShape; - for (auto p : permutation) { - convShape.push_back(resultShape[p]); - } - - // Compute the inverse transpose to unordered and ordered output. - llvm::SmallVector invertPermutation(permutation.size()); - for (auto it : llvm::enumerate(permutation)) { - invertPermutation[it.value()] = it.index(); - } - - llvm::SmallVector newSpatialDimensions(spatialDims.size()); - std::iota(newSpatialDimensions.begin(), newSpatialDimensions.end(), 1); - - auto newDimensionNumbers = mhlo::ConvDimensionNumbersAttr::get( - op.getContext(), dimensionNumbers.getInputBatchDimension(), - dimensionNumbers.getInputFeatureDimension(), - dimensionNumbers.getInputSpatialDimensions(), - dimensionNumbers.getKernelInputFeatureDimension(), - dimensionNumbers.getKernelOutputFeatureDimension(), - dimensionNumbers.getKernelSpatialDimensions(), - /*output_batch_dimension=*/0, - /*output_feature_dimension=*/newSpatialDimensions.size() + 1, - /*output_spatial_dimensions=*/newSpatialDimensions); - - SmallVector operands = {op.getLhs(), op.getRhs()}; - auto newConv = rewriter.create( - op.getLoc(), - RankedTensorType::get(convShape, resultType.getElementType()), operands, - op->getAttrs()); - newConv.setDimensionNumbersAttr(newDimensionNumbers); - - auto transposed = rewriter.create( - op.getLoc(), resultType, newConv, - rewriter.getI64TensorAttr(invertPermutation)); - - rewriter.replaceOp(op, transposed.getResult()); - return success(); - } -}; - -bool isConsecutive(ArrayRef array) { - for (int i = 1; i < array.size(); ++i) { - if (array[i] - array[i - 1] != 1) return false; - } - return true; -} - -// Rewrites mhlo.dot_general so lhs contraction dimensions are innermost and rhs -// contraction dimensions are dims right after batch dimension. The pattern -// inserts transposes so the dot_general always has the form: -// {batch_dims, parallel_dims, contraction_dims}. -// {batch_dims, contraction_dims, parallel_dims} -// After that, batch_dims, contraction_dims, parallel_dims are -// in consecutive order and not spliting the domain. This pattern inserts -// reshapes to collapse consecutive reduction and parallel dims to always -// generate a rank-3 dot_general op. -class TransposeReshapeGenericDotGeneral - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - Value TransposeIfNonConsecutive(OpBuilder &b, Location loc, Value src, - ArrayRef targetOrder) const { - if (isConsecutive(targetOrder)) return src; - auto type = llvm::cast(src.getType()); - SmallVector transposeShape; - for (auto i : targetOrder) { - transposeShape.push_back(type.getDimSize(i)); - } - return b.create( - loc, RankedTensorType::get(transposeShape, type.getElementType()), src, - b.getI64TensorAttr(targetOrder)); - } - - Value ReshapeIfNonStandard(OpBuilder &b, Location loc, Value src, - size_t dimsBorder0, size_t dimsBorder1) const { - auto type = llvm::cast(src.getType()); - auto shape = type.getShape(); - if (dimsBorder0 <= 1 && dimsBorder1 - dimsBorder0 <= 1 && - shape.size() - dimsBorder1 <= 1) - return src; - SmallVector result_shape = { - std::accumulate(shape.begin(), shape.begin() + dimsBorder0, 1, - std::multiplies()), - std::accumulate(shape.begin() + dimsBorder0, - shape.begin() + dimsBorder1, 1, - std::multiplies()), - std::accumulate(shape.begin() + dimsBorder1, shape.end(), 1, - std::multiplies())}; - return b.create( - loc, RankedTensorType::get(result_shape, type.getElementType()), src); - } - - LogicalResult matchAndRewrite(mhlo::DotGeneralOp op, - PatternRewriter &rewriter) const override { - auto lhsShapeType = llvm::dyn_cast(op.getLhs().getType()); - auto rhsShapeType = llvm::dyn_cast(op.getRhs().getType()); - auto resultType = - llvm::dyn_cast(op.getResult().getType()); - if (!lhsShapeType || !rhsShapeType || !resultType) return failure(); - - // TODO(jpienaar): This pattern is not safe for dynamic shapes and seems to - // be (now) redundant with later pass that does handle them. To decouple - // fixing and verifying redundant, this just limits to static shapes and - // then will remove this in follow up. - if (!lhsShapeType.hasStaticShape() || !rhsShapeType.hasStaticShape()) - return failure(); - - SmallVector lhsTargetOrder, rhsTargetOrder; - mhlo::DotDimensionNumbersAttr dimNumbers = op.getDotDimensionNumbers(); - auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - auto lhsContractingDims = dimNumbers.getLhsContractingDimensions(); - auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - auto rhsContractingDims = dimNumbers.getRhsContractingDimensions(); - - // No contraction dims means this can be represented as a mul. - if (lhsContractingDims.size() == 0 || rhsContractingDims.size() == 0) - return rewriter.notifyMatchFailure(op, "can be represented as mhlo.mul"); - - // No batching dimensions means this can be represented a dot. - if (lhsBatchingDims.size() == 0 || rhsBatchingDims.size() == 0) - return rewriter.notifyMatchFailure(op, "can be represented as mhlo.dot"); - - SmallVector isLhsParallel(lhsShapeType.getRank(), true); - for (auto i : lhsBatchingDims) { - lhsTargetOrder.push_back(i); - isLhsParallel[i] = false; - } - for (auto i : lhsContractingDims) { - isLhsParallel[i] = false; - } - for (int64_t i = 0, e = lhsShapeType.getRank(); i < e; ++i) { - if (isLhsParallel[i]) { - lhsTargetOrder.push_back(i); - } - } - for (auto i : lhsContractingDims) { - lhsTargetOrder.push_back(i); - } - - SmallVector isRhsParallel(rhsShapeType.getRank(), true); - - for (auto i : rhsBatchingDims) { - rhsTargetOrder.push_back(i); - isRhsParallel[i] = false; - } - for (auto i : rhsContractingDims) { - rhsTargetOrder.push_back(i); - isRhsParallel[i] = false; - } - for (int64_t i = 0, e = rhsShapeType.getRank(); i < e; ++i) { - if (isRhsParallel[i]) { - rhsTargetOrder.push_back(i); - } - } - - Value lhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.getLhs(), - lhsTargetOrder); - Value rhs = TransposeIfNonConsecutive(rewriter, op.getLoc(), op.getRhs(), - rhsTargetOrder); - - // The dimensions of this will always be transposed into {batch_dims, - // parallel_dims, contraction_dims}, and the - // following logic is based on this assumption. - // TODO(#7443): If we consider transpose performance, the above assumptions - // may not be true. - int64_t numLhsContractionDims = lhsContractingDims.size(); - int64_t lhsContractionBase = lhsShapeType.getRank() - numLhsContractionDims; - int64_t rhsContractionBase = rhsBatchingDims.size(); - int64_t numRhsContractionDims = - rhsContractionBase + rhsContractingDims.size(); - - lhs = ReshapeIfNonStandard(rewriter, op.getLoc(), lhs, - lhsBatchingDims.size(), lhsContractionBase); - rhs = ReshapeIfNonStandard(rewriter, op.getLoc(), rhs, - rhsBatchingDims.size(), numRhsContractionDims); - - if (lhs == op.getLhs() && rhs == op.getRhs()) - return rewriter.notifyMatchFailure(op, "already in canonical form"); - - auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get( - rewriter.getContext(), /*lhsBatchingDimensions=*/0, - /*rhsBatchingDimensions=*/0, - /*lhsContractingDimensions=*/ - llvm::cast(lhs.getType()).getRank() - 1, - /*rhsContractingDimensions=*/1); - auto lhsNewType = llvm::cast(lhs.getType()); - auto rhsNewType = llvm::cast(rhs.getType()); - - // if lhs's shape or rhs's shape has collapsed, we need reshape the result - bool needReshapeResult = lhsNewType.getRank() < lhsShapeType.getRank() || - rhsNewType.getRank() < rhsShapeType.getRank(); - // batching、lhs parallel、rhs parallel this order is a convension - SmallVector newShape = {lhsNewType.getShape()[0], - lhsNewType.getShape()[1]}; - if (rhsNewType.getRank() > 2) newShape.push_back(rhsNewType.getDimSize(2)); - - auto newResultType = - needReshapeResult - ? RankedTensorType::get(newShape, resultType.getElementType()) - : op.getType(); - - auto newOp = rewriter.create( - op.getLoc(), newResultType, lhs, rhs, dimensionNumbers, - op.getPrecisionConfigAttr()); - - // Copy over unknown attributes as we currently rely on it to let user tune - // lowering parameters. - ArrayRef odsAttrs = op.getAttributeNames(); - for (NamedAttribute kv : op->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { - newOp->setAttr(kv.getName(), kv.getValue()); - } - } - - Value result = newOp.getResult(); - if (needReshapeResult) { - result = - rewriter.create(op.getLoc(), resultType, result); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct ScatterInt64Indices : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - auto indices = op.getScatterIndices(); - auto indicesTy = indices.getType(); - auto indicesETy = indicesTy.getElementType(); - if (indicesETy.isInteger(32)) - return rewriter.notifyMatchFailure(op, "already has i32 index type"); - - if (!indicesTy.hasStaticShape()) - return rewriter.notifyMatchFailure(op, "cannot validate legal size"); - - uint64_t maxSize = std::numeric_limits::max(); - if (indicesETy.getIntOrFloatBitWidth() > 32) { - for (int i = 0, s = indicesTy.getRank(); i < s; ++i) { - if (indicesTy.getDimSize(i) > maxSize) { - return rewriter.notifyMatchFailure(op, "index may exceed i32 max"); - } - } - } - - indices = rewriter.create( - op.getLoc(), indicesTy.clone(rewriter.getI32Type()), indices); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), indices, - op.getUpdates(), op.getScatterDimensionNumbers(), - op.getIndicesAreSorted(), op.getUniqueIndices()); - - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - - return success(); - } -}; - -// If the indices tensor has an implicit index vector dim we expand and make it -// an explicit dim. -struct ScatterImplicitIndex : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - auto dimNumbers = op.getScatterDimensionNumbers(); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - Value indices = op.getScatterIndices(); - auto indicesTy = llvm::cast(indices.getType()); - - // Check indices vector has an implicit dim. - if (indexVectorDim != indicesTy.getRank()) { - return rewriter.notifyMatchFailure(op, "no implicit index dim"); - } - - // Materialize the implicit indices dim. - SmallVector reassociationMap; - reassociationMap.resize(indicesTy.getRank()); - SmallVector newShape; - for (int i = 0, s = indicesTy.getRank(); i < s; i++) { - reassociationMap[i].push_back(rewriter.getAffineDimExpr(i)); - newShape.push_back(indicesTy.getDimSize(i)); - } - if (!reassociationMap.empty()) { - reassociationMap.back().push_back( - rewriter.getAffineDimExpr(indicesTy.getRank())); - } - newShape.push_back(1); - indicesTy = RankedTensorType::get(newShape, indicesTy.getElementType()); - indices = rewriter.create(op.getLoc(), indicesTy, - indices, reassociationMap); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), indices, - op.getUpdates(), dimNumbers, op.getIndicesAreSorted(), - op.getUniqueIndices()); - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - return success(); - } -}; - -struct ScatterImplicitBatch : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static Value addUnitBatchDim(Location loc, Value value, - PatternRewriter &rewriter) { - ShapedType valueTy = llvm::cast(value.getType()); - if (!valueTy.hasRank()) return nullptr; - - // Materialize the implicit indices dim. - SmallVector reassociationMap(valueTy.getRank()); - if (!reassociationMap.empty()) { - reassociationMap.front().push_back(rewriter.getAffineDimExpr(0)); - } - - SmallVector newShape = {1}; - for (int i = 0, s = valueTy.getRank(); i < s; i++) { - reassociationMap[i].push_back(rewriter.getAffineDimExpr(i + 1)); - newShape.push_back(valueTy.getDimSize(i)); - } - - valueTy = RankedTensorType::get(newShape, valueTy.getElementType()); - return rewriter.create(loc, valueTy, value, - reassociationMap); - } - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - auto dimNumbers = op.getScatterDimensionNumbers(); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - auto indices = llvm::cast(op.getScatterIndices()); - auto indicesTy = llvm::dyn_cast(indices.getType()); - - // Check whether indices has no batch dimension. - if (!indicesTy) return failure(); - if (indicesTy.getRank() != 1 || indexVectorDim != 0) { - return rewriter.notifyMatchFailure(op, - "no implicit batch dimension to add."); - } - - indices = addUnitBatchDim(op.getLoc(), indices, rewriter); - if (!indices) { - return rewriter.notifyMatchFailure( - op, "Unable to add implicit batch dim to indice."); - } - - llvm::SmallVector newUpdateWindowDims; - for (auto dim : dimNumbers.getUpdateWindowDims()) { - // Batch dimension is inserted at the start so window dimensions are shift - // forwards. - newUpdateWindowDims.push_back(dim + 1); - } - - llvm::SmallVector updates; - for (Value update : op.getUpdates()) { - update = addUnitBatchDim(op.getLoc(), update, rewriter); - if (!update) { - return rewriter.notifyMatchFailure( - op, "Unable to add implicit batch dim to update."); - } - updates.push_back(update); - } - - auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get( - op.getContext(), newUpdateWindowDims, - dimNumbers.getInsertedWindowDims(), - dimNumbers.getScatterDimsToOperandDims(), - dimNumbers.getIndexVectorDim() + 1); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), indices, updates, - newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices()); - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - return success(); - } -}; - -struct ScatterCollapseBatch : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - static Value collapseBatchDims(Location loc, Value value, int64_t batchCount, - PatternRewriter &rewriter) { - auto valueTy = llvm::dyn_cast(value.getType()); - if (!valueTy) return nullptr; - - SmallVector reassociationMap(1); - reassociationMap.reserve(valueTy.getRank() - batchCount + 1); - int64_t batchSize = 1; - for (int i = 0, s = batchCount; i < s; i++) { - reassociationMap.front().push_back(rewriter.getAffineDimExpr(i)); - bool isDynamic = - valueTy.isDynamicDim(i) || batchSize == ShapedType::kDynamic; - batchSize = - isDynamic ? ShapedType::kDynamic : valueTy.getDimSize(i) * batchSize; - } - - SmallVector newShape = {batchSize}; - for (int i = batchCount, s = valueTy.getRank(); i < s; i++) { - reassociationMap.push_back({rewriter.getAffineDimExpr(i)}); - newShape.push_back(valueTy.getDimSize(i)); - } - - valueTy = RankedTensorType::get(newShape, valueTy.getElementType()); - return rewriter.create(loc, valueTy, value, - reassociationMap); - } - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - auto dimNumbers = op.getScatterDimensionNumbers(); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - auto indices = llvm::cast(op.getScatterIndices()); - auto indicesTy = llvm::cast(indices.getType()); - auto updatedWindowDims = dimNumbers.getUpdateWindowDims(); - - if (!indicesTy.hasRank()) { - return rewriter.notifyMatchFailure(op, "indices has unknown rank"); - } - - // Check for an explicit indice dimension. - if (indexVectorDim != indicesTy.getRank() - 1) { - return rewriter.notifyMatchFailure(op, "no explicit indices dimension"); - } - - // Check that there are multiple batch dimensions. - if (indicesTy.getRank() < 3) { - return rewriter.notifyMatchFailure(op, "no multiple batch dimensions"); - } - - const int64_t batchCount = indicesTy.getRank() - 1; - for (auto it : llvm::enumerate(updatedWindowDims)) { - if (it.index() != it.value() - batchCount) { - return rewriter.notifyMatchFailure( - op, "update windows should be at the end."); - } - } - - indices = collapseBatchDims(op.getLoc(), indices, batchCount, rewriter); - if (!indices) { - return rewriter.notifyMatchFailure(op, - "cannot collapse indices batch dims"); - } - - llvm::SmallVector updates; - for (Value update : op.getUpdates()) { - update = collapseBatchDims(op.getLoc(), update, batchCount, rewriter); - if (!update) { - return rewriter.notifyMatchFailure(op, - "cannot collapse update batch dims"); - } - updates.push_back(update); - } - - llvm::SmallVector newUpdatedWindowDims; - for (auto dim : updatedWindowDims) { - newUpdatedWindowDims.push_back(dim - batchCount + 1); - } - - auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get( - op.getContext(), newUpdatedWindowDims, - dimNumbers.getInsertedWindowDims(), - dimNumbers.getScatterDimsToOperandDims(), - /*indexVectorDim=*/1); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), indices, updates, - newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices()); - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - return success(); - } -}; - -// Ensure the batch dimensions of both the indices and updates are the first -// dimensions. If they are not, transpose them to the start. -struct ScatterBatchFirst : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - auto dimNumbers = op.getScatterDimensionNumbers(); - - // If the index vector dim is not implicitly or explicitly at the end - // we need to transpose the batch dimensions to the start. - Value indices = op.getScatterIndices(); - auto indicesTy = llvm::cast(indices.getType()); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - if (indexVectorDim < indicesTy.getRank() - 1) { - llvm::SmallVector perm; - perm.reserve(indicesTy.getRank()); - for (int i = 0, s = indicesTy.getRank(); i < s; ++i) - if (i != indexVectorDim) perm.push_back(i); - - if (perm.size() < indicesTy.getRank()) perm.push_back(indexVectorDim); - - llvm::SmallVector newShape; - for (int i = 0, s = perm.size(); i < s; ++i) - newShape.push_back(indicesTy.getDimSize(perm[i])); - - indices = builder.create( - indicesTy.clone(newShape), indices, builder.getI64TensorAttr(perm)); - indicesTy = llvm::cast(indices.getType()); - indexVectorDim = indicesTy.getRank() - 1; - } - - // Compute the permutation require to transpose the batch dimensions to - // the beginning. - auto updates = op.getUpdates(); - auto updates0 = updates.front(); - auto updates0Ty = llvm::cast(updates0.getType()); - auto updatedWindowDims = dimNumbers.getUpdateWindowDims(); - - // Determine which dimensions are batch dimensions. - llvm::SmallVector isBatch(updates0Ty.getRank(), true); - for (int i = 0, s = updatedWindowDims.size(); i < s; ++i) - isBatch[updatedWindowDims[i]] = false; - - // Permute batch dimensions to the start of the update tensor. - llvm::SmallVector updatePerm; - updatePerm.reserve(updates0Ty.getRank()); - for (int i = 0, s = isBatch.size(); i < s; ++i) - if (isBatch[i]) updatePerm.push_back(i); - updatePerm.append(updatedWindowDims.begin(), updatedWindowDims.end()); - - llvm::SmallVector newUpdatedWindowDims; - int64_t batchCount = updates0Ty.getRank() - updatedWindowDims.size(); - for (int i = batchCount, s = updates0Ty.getRank(); i < s; i++) - newUpdatedWindowDims.push_back(i); - - bool indicesChanged = indices != op.getScatterIndices(); - bool updatesChanged = - llvm::any_of(llvm::enumerate(updatePerm), - [](auto it) { return it.index() != it.value(); }); - llvm::SmallVector newUpdates(updates.begin(), updates.end()); - if (updatesChanged) { - for (Value &update : newUpdates) { - auto updateTy = llvm::cast(update.getType()); - llvm::SmallVector newShape; - newShape.reserve(updateTy.getRank()); - for (int i = 0, s = updatePerm.size(); i < s; i++) - newShape.push_back(updateTy.getDimSize(updatePerm[i])); - update = builder.create( - updateTy.clone(newShape), update, - builder.getI64TensorAttr(updatePerm)); - } - } - - if (!indicesChanged && !updatesChanged) - return rewriter.notifyMatchFailure( - op, "batch dimensions are already leading"); - - auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get( - op.getContext(), newUpdatedWindowDims, - dimNumbers.getInsertedWindowDims(), - dimNumbers.getScatterDimsToOperandDims(), - /*indexVectorDim=*/indexVectorDim); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), indices, newUpdates, - newDimNumbers, op.getIndicesAreSorted(), op.getUniqueIndices()); - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - return success(); - } -}; - -// mhlo.scatter can materialize a unit dimension at both indexed dimensions or -// at unary dimensions in the destination matrix. linalg_ext.scatter only -// allows unit dimensions at indexed dimensions. This pattern inserts all -// unary dimensions that are not index dimensions to be compatible with -// linalg_ext.scatter. -// -// If converts an mhlo.scatter as below: -// %result = "mhlo.scatter"(...) ({ -// indices_are_sorted = true, -// scatter_dimension_numbers = #mhlo.scatter< -// update_window_dims = [1], -// inserted_window_dims = [0, 2], -// scatter_dims_to_operand_dims = [0], -// index_vector_dim = 1>, -// unique_indices = true} : -// (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>) -// -// To: -// %result = "mhlo.scatter"(...) ({ -// indices_are_sorted = true, -// scatter_dimension_numbers = #mhlo.scatter< -// update_window_dims = [1, 2], -// inserted_window_dims = [0], -// scatter_dims_to_operand_dims = [0], -// index_vector_dim = 1>, -// unique_indices = true} : -// (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4x1xi32>) -// return %0 : tensor<5x4x1xi32> -struct ScatterMaterializeInsertedDim - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::ScatterOp op, - PatternRewriter &rewriter) const final { - auto indices = op.getScatterIndices(); - auto operand = op.getInputs().front(); - auto indicesTy = llvm::cast(indices.getType()); - auto operandTy = llvm::cast(operand.getType()); - - if (!operandTy.hasRank() || !indicesTy.hasRank()) { - return rewriter.notifyMatchFailure(op, "operand/indices have no rank"); - } - - auto dimNumbers = op.getScatterDimensionNumbers(); - auto updateDims = dimNumbers.getUpdateWindowDims(); - - if (indicesTy.getRank() != 2 || dimNumbers.getIndexVectorDim() != 1) { - return rewriter.notifyMatchFailure( - op, "indices is not of shape [batch, indices]"); - } - - if (!updateDims.empty() && updateDims.front() == 0) { - return rewriter.notifyMatchFailure( - op, "updates is not of shape [batch, ...]"); - } - - auto scatterDimsToOperandDims = dimNumbers.getScatterDimsToOperandDims(); - llvm::SmallVector isIndexDim(operandTy.getRank(), false); - for (auto val : scatterDimsToOperandDims) { - isIndexDim[val] = true; - } - - int64_t firstNonIndex = 0; - for (int64_t s = scatterDimsToOperandDims.size(); firstNonIndex < s; - ++firstNonIndex) { - if (!isIndexDim[firstNonIndex]) break; - } - - llvm::SmallVector isInsertDims(operandTy.getRank(), false); - for (auto val : dimNumbers.getInsertedWindowDims()) { - isInsertDims[val] = true; - } - - int64_t frontInsertedDims = 0; - for (; frontInsertedDims < firstNonIndex; ++frontInsertedDims) { - if (!isInsertDims[frontInsertedDims]) { - break; - } - } - - llvm::ArrayRef toInsertDims = - llvm::ArrayRef(isInsertDims).drop_front(frontInsertedDims); - if (!llvm::any_of(toInsertDims, [](auto d) { return d; })) { - return rewriter.notifyMatchFailure(op, "no dimensions to insert"); - } - - // Create a reassociation map that starts with the batch dims. - SmallVector reassociationMap; - reassociationMap.push_back({rewriter.getAffineDimExpr(0)}); - - for (auto it : llvm::enumerate(llvm::ArrayRef(toInsertDims))) { - if (!it.value()) reassociationMap.push_back({}); - reassociationMap.back().push_back( - rewriter.getAffineDimExpr(it.index() + 1)); - } - - llvm::SmallVector expandedUpdates; - for (auto update : op.getUpdates()) { - auto updatesTy = llvm::cast(update.getType()); - - llvm::SmallVector newShape; - for (int i = 0, s = reassociationMap.size(); i < s; ++i) { - newShape.push_back(updatesTy.getDimSize(i)); - for (int j = 1, s = reassociationMap[i].size(); j < s; ++j) { - newShape.push_back(1); - } - } - - Value expandUpdate = rewriter.create( - op.getLoc(), - RankedTensorType::get(newShape, updatesTy.getElementType()), update, - reassociationMap); - expandedUpdates.push_back(expandUpdate); - } - - llvm::SmallVector newUpdatedWindowDims(toInsertDims.size()); - llvm::SmallVector newInsertedWindowDims(frontInsertedDims); - std::iota(newUpdatedWindowDims.begin(), newUpdatedWindowDims.end(), 1); - std::iota(newInsertedWindowDims.begin(), newInsertedWindowDims.end(), 0); - - auto newDimNumbers = mhlo::ScatterDimensionNumbersAttr::get( - op.getContext(), newUpdatedWindowDims, newInsertedWindowDims, - dimNumbers.getScatterDimsToOperandDims(), - /*indexVectorDim=*/1); - - auto newScatter = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getInputs(), - op.getScatterIndices(), expandedUpdates, newDimNumbers, - op.getIndicesAreSorted(), op.getUniqueIndices()); - Region ®ion = newScatter.getUpdateComputation(); - rewriter.cloneRegionBefore(op.getUpdateComputation(), region, region.end()); - rewriter.replaceOp(op, newScatter.getResults()); - return success(); - } -}; - -// Traverse upward past common operations to see if the value came from a -// boolean tensor. -bool isFromBool(Value val) { - while (true) { - Operation *op = val.getDefiningOp(); - if (!op) return false; - - if (auto convertOp = dyn_cast(op)) { - auto inTy = llvm::cast(convertOp.getOperand().getType()); - if (inTy.getElementType().isInteger(1)) { - return true; - } - val = convertOp.getOperand(); - continue; - } - - if (isa(op) || - isa(op) || isa(op)) { - val = op->getOperand(0); - continue; - } - - return false; - } -} - -// Mhlo of non-finite values (e.g. NaN, inf) and 0.0 produce 0.0 for XLA. For -// linalg we need to conver these to select operations. -class MulCastOfBool : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::MulOp op, - PatternRewriter &rewriter) const override { - auto resultTy = llvm::cast(op.getType()); - if (!llvm::isa(resultTy.getElementType())) return failure(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - bool lhsIsBool = isFromBool(lhs); - bool rhsIsBool = isFromBool(rhs); - - if (lhsIsBool == rhsIsBool) return failure(); - if (rhsIsBool) std::swap(lhs, rhs); - - Type eType = resultTy.getElementType(); - auto lhsTy = llvm::cast(lhs.getType()); - Value lhsBool = rewriter.create( - op.getLoc(), lhsTy.clone(rewriter.getIntegerType(1)), lhs); - Value zero = rewriter.create( - op.getLoc(), DenseElementsAttr::get(RankedTensorType::get({}, eType), - rewriter.getZeroAttr(eType))); - - auto lhsShape = rewriter.create( - op.getLoc(), - RankedTensorType::get({lhsTy.getRank()}, rewriter.getIndexType()), lhs); - - int64_t resultRank = resultTy.getRank(); - auto broadcast = [&](Value value) -> Value { - auto valueTy = llvm::cast(value.getType()); - auto newTy = - RankedTensorType::get(resultTy.getShape(), valueTy.getElementType()); - if (valueTy == newTy) return value; - auto dimensions = llvm::to_vector<4>( - llvm::seq(resultRank - valueTy.getRank(), resultRank)); - return rewriter.create( - op.getLoc(), newTy, value, lhsShape, - rewriter.getI64TensorAttr(dimensions)); - }; - - zero = broadcast(zero); - - rewriter.replaceOpWithNewOp(op, resultTy, lhsBool, rhs, - zero); - return success(); - } -}; - -// Generates Gaussian noise with uniform random generator based on Box-Muller -// transform. -class ExpandRngNormal : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::RngOp op, - PatternRewriter &rewriter) const override { - if (op.getRngDistribution() != mhlo::RngDistribution::NORMAL) - return failure(); - - auto resTy = llvm::dyn_cast(op.getType()); - // We can support static shapes, but it's easier to implement Box-Muller - // transform if we know the number of elements. - if (!resTy || !resTy.hasStaticShape()) return failure(); - - // The algorithm requires even numbers and will generate pairs. - auto numElems = resTy.getNumElements(); - if (numElems & 1) numElems++; - auto halfNumElems = numElems / 2; - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // Explicitly set the seed to 0, so we have stateless generator. This is not - // a hard limit. Random generator is still a new topic, and we start with - // stateless random generator. - std::mt19937 rng{0}; - std::uniform_real_distribution<> runif(0.0, 1.0); - SmallVector sqrtValues(halfNumElems), cosValues(halfNumElems), - sinValues(halfNumElems); - for (auto i : llvm::seq(0, numElems / 2)) { - constexpr float kEpsilon = std::numeric_limits::epsilon(); - constexpr float kTwoPi = static_cast(2.0 * M_PI); - float u1, u2; - do { - u1 = runif(rng); - u2 = runif(rng); - } while (u1 <= kEpsilon); - sqrtValues[i] = -2.0 * log(u1); - cosValues[i] = cos(kTwoPi * u2); - sinValues[i] = sin(kTwoPi * u2); - } - - // mag = sigma * sqrt(-2.0 * log(u1)); - Value mag = getF32Const(b, /*shapes=*/{halfNumElems}, sqrtValues); - Value sigma = b.create( - mag.getType(), op.getB(), make1DElementsAttr(b, halfNumElems)); - mag = b.create(sigma, b.create(mag)); - - // z0 = mag * cos(two_pi * u2) + mu; - // z1 = mag * sin(two_pi * u2) + mu; - Value mu = b.create(mag.getType(), op.getA(), - make1DElementsAttr(b, halfNumElems)); - Value z0 = getF32Const(b, /*shapes=*/{halfNumElems}, cosValues); - z0 = b.create(mag, z0); - z0 = b.create(z0, mu); - Value z1 = getF32Const(b, /*shapes=*/{halfNumElems}, sinValues); - z1 = b.create(mag, z1); - z1 = b.create(z1, mu); - - Value res = b.create(ValueRange{z0, z1}, - b.getI64IntegerAttr(0)); - if (numElems != resTy.getNumElements()) { - OpFoldResult zero = b.getIndexAttr(0); - OpFoldResult one = b.getIndexAttr(1); - OpFoldResult size = b.getIndexAttr(resTy.getNumElements()); - res = b.create(res, zero, size, one); - } - if (resTy.getRank() != 1) { - res = b.create(resTy, res); - } - rewriter.replaceOp(op, res); - return success(); - } -}; - -// clang-format off -// -// Reorder BroadcastInDimOp and N-ary elementwise op. -// -// Rewrites the following pattern (take binary elementwise op as example) -// -// %bcastx = "mhlo.broadcast_in_dim"(%x) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] -// %bcasty = "mhlo.broadcast_in_dim"(%y) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] -// %result = "BinaryElementwiseOpT"(%bcastx, %bcasty) : (%[[SHAPE_AFTER_BCAST]], %[[SHAPE_AFTER_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] -// -// into -// -// %z = "BinaryElementwiseOpT"(%x, %y) : (%[[SHAPE_BEFORE_BCAST]], %[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_BEFORE_BCAST]] -// %result = "mhlo.broadcast_in_dim"(%z) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] -// -// clang-format on -template -class ReorderBroadcastInDimOpAndElementwiseOp - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ElementwiseOpT op, - PatternRewriter &rewriter) const override { - Operation *operation = op.getOperation(); - assert(operation->getNumOperands() >= 1 && operation->getNumResults() == 1); - - // Verify if all operands are from BroadcastInDimOp and its - // broadcast_dimensions is the same. - llvm::SmallVector bcastOps; - for (auto operand : operation->getOperands()) { - if (auto bcastOp = operand.getDefiningOp()) { - bcastOps.push_back(bcastOp); - } else { - return failure(); - } - } - - if (llvm::any_of(bcastOps, [&bcastOps](mhlo::BroadcastInDimOp bcastOp) { - return bcastOp.getBroadcastDimensions() != - bcastOps[0].getBroadcastDimensions(); - })) { - return failure(); - } - - // Verify if all operands of BroadcastInDimOp are of same type and have - // static shape. - auto bcastOperandType = - llvm::dyn_cast(bcastOps[0].getOperand().getType()); - llvm::SmallVector bcastOperands; - for (auto bcastOp : bcastOps) { - auto bcastOperand = bcastOp.getOperand(); - auto type = llvm::dyn_cast(bcastOperand.getType()); - if (!type || !type.hasStaticShape() || type != bcastOperandType) { - return failure(); - } - bcastOperands.push_back(bcastOperand); - } - - // Some elementwise ops, mhlo::RealOp for example, do not have - // SameOperandsAndResultType trait, so resultType might be different - // from bcastOperandType. - auto elementType = getElementTypeOrSelf(op.getResult()); - auto resultShape = bcastOperandType.getShape(); - auto resultType = RankedTensorType::get(resultShape, elementType); - - Value result = - rewriter.create(op.getLoc(), resultType, bcastOperands); - rewriter.replaceOpWithNewOp( - op, op.getType(), result, bcastOps[0].getBroadcastDimensions()); - - for (auto bcastOp : bcastOps) { - if (bcastOp.getOperation()->use_empty()) { - rewriter.eraseOp(bcastOp); - } - } - - return success(); - } -}; - -// Identifies cases where a dense operation has inputs that come from widening -// operations. For instance, a dot product widening from FP16 to FP32 is better -// to have the casting operation fused into the dot operation. This decreases -// the loading required during a dense computation. -template -struct FuseWidenOperands : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(Op op, - PatternRewriter &rewriter) const override { - llvm::SmallVector operands; - for (Value operand : op->getOperands()) { - auto convertOp = - dyn_cast_or_null(operand.getDefiningOp()); - if (convertOp) { - auto inputType = getElementTypeOrSelf(convertOp.getOperand().getType()); - auto castedType = getElementTypeOrSelf(convertOp.getResult().getType()); - if (inputType.getIntOrFloatBitWidth() < - castedType.getIntOrFloatBitWidth()) { - operands.push_back(convertOp.getOperand()); - continue; - } - } - operands.push_back(operand); - } - - if (llvm::all_of( - llvm::zip_equal(operands, op->getOperands()), - [](auto pair) { return std::get<0>(pair) == std::get<1>(pair); })) - return failure(); - - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), operands, - op->getAttrs()); - return success(); - } -}; - -struct DotToMul : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::DotOp op, - PatternRewriter &rewriter) const override { - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto lhsTy = llvm::dyn_cast(lhs.getType()); - auto rhsTy = llvm::dyn_cast(rhs.getType()); - auto resultTy = llvm::cast(op.getType()); - - if (!lhsTy || !rhsTy) { - return rewriter.notifyMatchFailure(op, "lhs and rhs must be ranked"); - } - - if (lhsTy.getRank() != 2 || rhsTy.getRank() != 2) { - return rewriter.notifyMatchFailure(op, "lhs and rhs must be rank-2"); - } - - if (lhsTy.getDimSize(1) != 1) return failure(); - - // Dynamically compute the shape of the result of the DotOp by querying - // the 0-th dimensions, of the left, and the 1st dimension of the right. - // Concatenating them togething to make the final shape. - Value batchSize = rewriter.create( - op.getLoc(), lhs, rewriter.getI64IntegerAttr(0)); - Value batchSize1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()), - batchSize); - - Value featureSize = rewriter.create( - op.getLoc(), rhs, rewriter.getI64IntegerAttr(1)); - Value featureSize1 = rewriter.create( - op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()), - featureSize); - - Value outSize = rewriter.create( - op.getLoc(), RankedTensorType::get({2}, rewriter.getI32Type()), - ValueRange{batchSize1, featureSize1}, rewriter.getI64IntegerAttr(0)); - - lhs = rewriter.create( - op.getLoc(), resultTy.clone(lhsTy.getElementType()), lhs, outSize, - rewriter.getI64TensorAttr({0, 1})); - - rhs = rewriter.create( - op.getLoc(), resultTy.clone(rhsTy.getElementType()), rhs, outSize, - rewriter.getI64TensorAttr({0, 1})); - - auto computeETy = lhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < rhsTy.getElementTypeBitWidth()) - computeETy = rhsTy.getElementType(); - if (computeETy.getIntOrFloatBitWidth() < resultTy.getElementTypeBitWidth()) - computeETy = resultTy.getElementType(); - - auto computeTy = resultTy.clone(computeETy); - - rhs = rewriter.create(op.getLoc(), computeTy, rhs); - lhs = rewriter.create(op.getLoc(), computeTy, lhs); - - auto result = rewriter.create( - op.getLoc(), resultTy.clone(computeETy), lhs, rhs); - rewriter.replaceOpWithNewOp(op, resultTy, result); - return success(); - } -}; - -// Similar to DotIsMul, this finds the case where a dot general -// can be represented using a mul operation. This includes possibly making -// an implicit cast explicit prior the mul. -struct DotGeneralIsMul : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mhlo::DotGeneralOp op, - PatternRewriter &rewriter) const override { - auto lhs = llvm::cast(op.getLhs()); - auto rhs = llvm::cast(op.getRhs()); - auto lhsTy = llvm::dyn_cast(lhs.getType()); - auto rhsTy = llvm::dyn_cast(rhs.getType()); - auto resultTy = llvm::dyn_cast(op.getType()); - ImplicitLocOpBuilder builder(op.getLoc(), rewriter); - - if (!lhsTy || !rhsTy || !resultTy) return failure(); - - auto dNums = op.getDotDimensionNumbers(); - auto batchDimsL = dNums.getLhsBatchingDimensions(); - auto batchDimsR = dNums.getRhsBatchingDimensions(); - auto contractDimsL = dNums.getLhsContractingDimensions(); - auto contractDimsR = dNums.getRhsContractingDimensions(); - - llvm::SmallVector isLhsParallelDim(lhsTy.getRank(), true); - llvm::SmallVector isRhsParallelDim(rhsTy.getRank(), true); - - for (auto dim : batchDimsL) isLhsParallelDim[dim] = false; - for (auto dim : batchDimsR) isRhsParallelDim[dim] = false; - for (auto dim : contractDimsL) isLhsParallelDim[dim] = false; - for (auto dim : contractDimsR) isRhsParallelDim[dim] = false; - - for (auto dim : contractDimsL) { - if (lhsTy.getDimSize(dim) != 1) { - return rewriter.notifyMatchFailure(op, "Non unit contract dimensions"); - } - } - - // Generate the permutation matrix to order BatchDims, ParallelDims, - // ContractDims. - llvm::SmallVector permLhs; - llvm::SmallVector permRhs; - permLhs.append(batchDimsL.begin(), batchDimsL.end()); - permRhs.append(batchDimsR.begin(), batchDimsR.end()); - - for (auto it : llvm::enumerate(isLhsParallelDim)) { - if (it.value()) permLhs.push_back(it.index()); - } - - for (auto it : llvm::enumerate(isRhsParallelDim)) { - if (it.value()) permRhs.push_back(it.index()); - } - - permLhs.append(contractDimsL.begin(), contractDimsL.end()); - permRhs.append(contractDimsR.begin(), contractDimsR.end()); - - // Determine the transpose shape based on the generate permutations. - llvm::SmallVector lhsTransposeShape; - llvm::SmallVector rhsTransposeShape; - for (auto dim : permLhs) lhsTransposeShape.push_back(lhsTy.getDimSize(dim)); - for (auto dim : permRhs) rhsTransposeShape.push_back(rhsTy.getDimSize(dim)); - - // Transpose the left hand side and the right hand side. - lhs = builder.create( - RankedTensorType::get(lhsTransposeShape, lhsTy.getElementType()), lhs, - builder.getI64TensorAttr(permLhs)); - lhsTy = llvm::cast(lhs.getType()); - - rhs = builder.create( - RankedTensorType::get(rhsTransposeShape, rhsTy.getElementType()), rhs, - builder.getI64TensorAttr(permRhs)); - rhsTy = llvm::cast(rhs.getType()); - - auto dimI32Ty = RankedTensorType::get({1}, builder.getI32Type()); - - // Drop all of the non-concat dimensions from the lhs. - llvm::SmallVector lhsReshapeDims; - for (int i = 0, s = lhsTy.getRank() - contractDimsL.size(); i < s; i++) { - Value dim = builder.create(lhs, i); - lhsReshapeDims.push_back(builder.create(dimI32Ty, dim)); - } - Value lhsDynShape = builder.create( - RankedTensorType::get({static_cast(lhsReshapeDims.size())}, - builder.getI32Type()), - lhsReshapeDims, 0); - lhsTy = - RankedTensorType::get(lhsTy.getShape().drop_back(contractDimsL.size()), - lhsTy.getElementType()); - lhs = builder.create(lhsTy, lhs, lhsDynShape); - - // Drop all of the non concat dimensions from the rhs. - llvm::SmallVector rhsReshapeDims; - for (int i = 0, s = rhsTy.getRank() - contractDimsR.size(); i < s; i++) { - Value dim = builder.create(rhs, i); - rhsReshapeDims.push_back(builder.create(dimI32Ty, dim)); - } - Value rhsDynShape = builder.create( - RankedTensorType::get({static_cast(rhsReshapeDims.size())}, - builder.getI32Type()), - rhsReshapeDims, 0); - rhsTy = - RankedTensorType::get(rhsTy.getShape().drop_back(contractDimsR.size()), - rhsTy.getElementType()); - rhs = builder.create(rhsTy, rhs, rhsDynShape); - - // Compute the size of the output shape with dynamic shape support using the - // lhs and rhs dimensions. - llvm::SmallVector outputDims; - outputDims.append(lhsReshapeDims); - outputDims.append(rhsReshapeDims.begin() + batchDimsR.size(), - rhsReshapeDims.end()); - Value outputShape = builder.create( - RankedTensorType::get({resultTy.getRank()}, builder.getI32Type()), - outputDims, 0); - - // Broadcast the left hand side to match the expect output shape. - llvm::SmallVector lhsDimMapping(lhsTy.getRank()); - std::iota(lhsDimMapping.begin(), lhsDimMapping.end(), 0); - auto lhsBroadcastTy = - RankedTensorType::get(resultTy.getShape(), lhsTy.getElementType()); - lhs = builder.createOrFold( - lhsBroadcastTy, lhs, outputShape, - rewriter.getI64TensorAttr(lhsDimMapping)); - - // Broadcast the right hand side to match the expected output shape. - llvm::SmallVector rhsDimMapping(rhsTy.getRank()); - std::iota(rhsDimMapping.begin(), rhsDimMapping.begin() + batchDimsR.size(), - 0); - std::iota(rhsDimMapping.begin() + batchDimsR.size(), rhsDimMapping.end(), - lhsTy.getRank()); - auto rhsBroadcastTy = - RankedTensorType::get(resultTy.getShape(), rhsTy.getElementType()); - rhs = builder.createOrFold( - rhsBroadcastTy, rhs, outputShape, - rewriter.getI64TensorAttr(rhsDimMapping)); - - lhs = builder.createOrFold(resultTy, lhs); - rhs = builder.createOrFold(resultTy, rhs); - rewriter.replaceOpWithNewOp(op, resultTy, lhs, rhs); - return success(); - } -}; - -struct MHLOToMHLOPreprocessingPass - : public MHLOToMHLOPreprocessingBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionTarget conversionTarget(*context); - RewritePatternSet conversionPatterns(&getContext()); - // Note that various input modalities may do their own legalization of - // CHLO. Converting here allows IREE to accept CHLO dialect regardless of - // whether it was legalized away at a higher level. - // chlo::PopulateLegalizeChloToHloPatterns(context, &conversionPatterns); - conversionTarget.addLegalDialect< - shape::ShapeDialect, chlo::ChloDialect, mhlo::MhloDialect, - math::MathDialect, mlir::func::FuncDialect, mlir::arith::ArithDialect, - mlir::tensor::TensorDialect>(); - // conversionTarget.addIllegalDialect(); - if (failed(applyPartialConversion(getOperation(), conversionTarget, - std::move(conversionPatterns)))) { - return signalPassFailure(); - } - - RewritePatternSet patterns(&getContext()); - // TODO: Remove once we have a general contraction to matmul pass. - mhlo::populateEinsumToDotGeneralPatterns(context, &patterns); - mhlo::populateUnfuseBatchNormPatterns(context, &patterns); - mhlo::populateComplexLoweringPatterns(context, &patterns); - mhlo::populateGatherToTorchIndexSelectPatterns(context, &patterns); - patterns.insert(context); - - // scatter canonicalization patterns - patterns.insert(context); - - // dot_general canoncalization patterns. - mhlo::populateGeneralDotOpLoweringPatterns(&patterns, context); - // TODO(jpienaar): This may be redundant with lower_general_dot. Remove if - // so. - patterns.insert(context, - /*benefit=*/200); - patterns.insert(context, /*benefit=*/300); - - // Fusion operations. - patterns.insert, - FuseWidenOperands, - FuseWidenOperands>(context, - /*benefit=*/400); - - // Additional canonicalizers that simplify to computationally - // less-complex operations. - patterns.insert(context); - - // Unary elementwise op. - patterns.insert< - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp>(context); - // Binary elementwise op. - patterns.insert< - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp, - ReorderBroadcastInDimOpAndElementwiseOp>(context); - if (orderConvFeatures) { - patterns.insert(context); - patterns.insert(context); - patterns.insert(context); - } - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -createMHLOToMHLOPreprocessingPass() { - return std::make_unique(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/PassDetail.h b/compiler/src/iree/compiler/InputConversion/MHLO/PassDetail.h deleted file mode 100644 index a1320a8b809e..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/PassDetail.h +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_INPUTCONVERSION_MHLO_PASSDETAIL_H_ -#define IREE_COMPILER_INPUTCONVERSION_MHLO_PASSDETAIL_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -#define GEN_PASS_CLASSES -#include "iree/compiler/InputConversion/MHLO/Passes.h.inc" - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_PASSDETAIL_H_ diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/Passes.cpp deleted file mode 100644 index edd267509384..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.cpp +++ /dev/null @@ -1,148 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/Passes.h" - -#include "iree/compiler/Dialect/Util/Transforms/Passes.h" -#include "iree/compiler/InputConversion/Common/Passes.h" -#include "mhlo/transforms/passes.h" -#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Dialect/Shape/Transforms/Passes.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassOptions.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -// TODO(#8745): remove these flags when the -iree-flow-demote-* flags can be -// used without tripping upstream verifier issues. -static llvm::cl::opt clDemoteI64ToI32( - "iree-mhlo-demote-i64-to-i32", - llvm::cl::desc( - "Converts all MHLO i64 ops and values into i32 counterparts."), - llvm::cl::init(true)); -static llvm::cl::opt clDemoteF64ToF32( - "iree-mhlo-demote-f64-to-f32", - llvm::cl::desc( - "Converts all MHLO f64 ops and values into f32 counterparts."), - llvm::cl::init(true)); -static llvm::cl::opt clPromoteBF16ToF32( - "iree-mhlo-promote-bf16-to-f32", - llvm::cl::desc( - "Converts all MHLO bf16 ops and values into f32 counterparts."), - llvm::cl::init(false)); - -void registerMHLOConversionPassPipeline() { - PassPipelineRegistration<> mhlo( - "iree-mhlo-input-transformation-pipeline", - "Runs the MHLO IREE flow dialect transformation pipeline", - [](OpPassManager &passManager) { - buildMHLOInputConversionPassPipeline(passManager); - }); - PassPipelineRegistration<> xla( - "iree-xla-input-transformation-pipeline", - "Runs the XLA IREE flow dialect transformation pipeline", - [](OpPassManager &passManager) { - buildXLAInputConversionPassPipeline(passManager); - }); -} - -// Prepare HLO for use as an input to the Flow dialect. -static void buildMHLOInputConversionPassPipelineImpl(OpPassManager &passManager, - bool detuple) { - passManager.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - passManager.addNestedPass(mlir::createCanonicalizerPass()); - passManager.addNestedPass( - mhlo::createLegalizeControlFlowPass()); - - // Currently we don't handle SCF ops well and have to convert them all to CFG. - // In the future it would be nice if we could have all of flow be both scf - // and cfg compatible. - passManager.addNestedPass(createTopLevelSCFToCFGPass()); - if (detuple) passManager.addPass(createFlattenTuplesInCFGPass()); - - passManager.addNestedPass(createMHLOToMHLOPreprocessingPass()); - passManager.addNestedPass(mlir::createCanonicalizerPass()); - - // Various shape functions may have been materialized in the `shape.shape_of` - // style of treating shapes as tensors. We prefer to legalize these to - // scalar ops as early as possible to avoid having them persist as tensor - // computations. - passManager.addNestedPass(createShapeToShapeLowering()); - passManager.addPass(createConvertShapeToStandardPass()); - passManager.addNestedPass(mlir::createCanonicalizerPass()); - - // We also don't handle calls well on the old codepath; until we remove the - // use of the CFG we can continue inlining. - passManager.addPass(mlir::createInlinerPass()); - - // Hacky type conversion to work around lack of type support lower in the - // stack. This is often required because of implicit i64 insertion by JAX/HLO - // that we don't want forcing 32-bit embedded devices to support. - // TODO(#8745): remove these and prefer the flow pipeline options instead. - if (clDemoteI64ToI32) { - passManager.addPass(IREE::Util::createDemoteI64ToI32Pass()); - } - if (clDemoteF64ToF32) { - passManager.addPass(IREE::Util::createDemoteF64ToF32Pass()); - } - if (clPromoteBF16ToF32) { - passManager.addPass(IREE::Util::createPromoteBF16ToF32Pass()); - } - - // Perform initial cleanup. createLegalizeInputTypes could rewrite types. In - // this context, some operations could be folded away. - passManager.addNestedPass(mlir::createCanonicalizerPass()); - passManager.addNestedPass(mlir::createCSEPass()); - - // Convert to Linalg. After this point, MHLO will be eliminated. - passManager.addNestedPass( - mhlo::createLegalizeShapeComputationsPass()); - passManager.addNestedPass(createConvertMHLOToLinalgExtPass()); - passManager.addPass(createMHLOToLinalgOnTensorsPass()); - // Ensure conversion completed. - passManager.addPass(createReconcileUnrealizedCastsPass()); - - // Note that some MHLO ops are left by the above and must resolve via - // canonicalization. See comments in the above pass and find a better way. - passManager.addNestedPass(mlir::createCanonicalizerPass()); - - //---------------------------------------------------------------------------- - // Entry dialect cleanup - //---------------------------------------------------------------------------- - passManager.addPass(createVerifyCompilerMHLOInputLegality()); -} - -void buildMHLOInputConversionPassPipeline(OpPassManager &passManager) { - buildMHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/false); -} - -void buildXLAInputConversionPassPipeline(OpPassManager &passManager) { - buildMHLOInputConversionPassPipelineImpl(passManager, /*detuple=*/true); -} - -namespace { -#define GEN_PASS_REGISTRATION -#include "iree/compiler/InputConversion/MHLO/Passes.h.inc" // IWYU pragma: export -} // namespace - -void registerMHLOConversionPasses() { - // Generated. - registerPasses(); - - // Pipelines. - registerMHLOConversionPassPipeline(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.h b/compiler/src/iree/compiler/InputConversion/MHLO/Passes.h deleted file mode 100644 index 15c945468e0b..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES_H_ -#define IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES_H_ - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -//===----------------------------------------------------------------------===// -// Pipelines -//===----------------------------------------------------------------------===// - -// Performs input legalization for specific combination of input dialects. -void buildMHLOInputConversionPassPipeline(OpPassManager &passManager); - -// Performs input legalization on programs that may have originated from an XLA -// import (or made to interop with it). -void buildXLAInputConversionPassPipeline(OpPassManager &passManager); - -void registerMHLOConversionPassPipelines(); - -//------------------------------------------------------------------------------ -// Cleanup passes -//------------------------------------------------------------------------------ - -// Flattens tuples in functions and CFG control flow. This is a common -// form of MHLO as produced by XLA based systems. -std::unique_ptr> createFlattenTuplesInCFGPass(); - -//------------------------------------------------------------------------------ -// Conversions into Linalg -//------------------------------------------------------------------------------ - -/// Creates XLA-HLO to Linalg on tensors transformation pass. -std::unique_ptr> createMHLOToLinalgOnTensorsPass(); - -/// Creates XLA-HLO to LinalgExt pass. -std::unique_ptr> createConvertMHLOToLinalgExtPass(); - -/// Creates XLA-HLO preprocessing transformation pass. In this pass we should -/// have all mhlo -> mhlo transformations that are shared between all -/// backends. -std::unique_ptr> -createMHLOToMHLOPreprocessingPass(); - -// Verifies a module being input to the core compiler pipeline only contains -// IR structures that are supported at that level. -std::unique_ptr> -createVerifyCompilerMHLOInputLegality(); - -//------------------------------------------------------------------------------ -// Passes to aid in the MHLO to StableHLO transition -//------------------------------------------------------------------------------ - -std::unique_ptr> createConvertMHLOToStableHLOPass(); - -//------------------------------------------------------------------------------ -// Test passes -//------------------------------------------------------------------------------ - -std::unique_ptr> -createTestMHLOConvertComplexToRealPass(); - -//===----------------------------------------------------------------------===// -// Register all Passes -//===----------------------------------------------------------------------===// - -void registerMHLOConversionPasses(); - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES_H_ diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td b/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td deleted file mode 100644 index 85d5ed252c1d..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/Passes.td +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES -#define IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES - -include "mlir/Pass/PassBase.td" - -def ConvertMHLOToLinalgOnTensors : - Pass<"iree-mhlo-to-linalg-on-tensors", "ModuleOp"> { - let summary = "Convert from XLA-HLO ops to Linalg ops on tensors"; - let constructor = "mlir::iree_compiler::MHLO::createMHLOToLinalgOnTensorsPass()"; -} - -def ConvertMHLOToLinalgExt - : Pass<"iree-mhlo-to-linalg-ext", "func::FuncOp"> { - let summary = - "Convert from XLA-HLO ops to LinalgExt ops and distribute to Flow ops"; - let constructor = - "mlir::iree_compiler::MHLO::createConvertMHLOToLinalgExtPass()"; -} - -def FlattenTuplesInCFG : - Pass<"iree-mhlo-flatten-tuples-in-cfg", "ModuleOp"> { - let summary = "Flattens tuples in a CFG form of MHLO"; - let constructor = "mlir::iree_compiler::MHLO::createFlattenTuplesInCFGPass()"; -} - -def MHLOToMHLOPreprocessing : - Pass<"iree-mhlo-to-mhlo-preprocessing", "func::FuncOp"> { - let summary = "Apply mhlo to mhlo transformations for some mhlo ops"; - let constructor = "mlir::iree_compiler::MHLO::createMHLOToMHLOPreprocessingPass()"; - let options = [ - Option<"orderConvFeatures", "order-conv-features", "bool", /*default=*/"true", - "Guarantees input/output features ordered from conv kernel"> - ]; -} - -def VerifyCompilerMHLOInputLegality : - Pass<"iree-mhlo-verify-compiler-input-legality", "ModuleOp"> { - let summary = "Verifies that only supported IR constructs are passed to the compiler."; - let constructor = "mlir::iree_compiler::MHLO::createVerifyCompilerMHLOInputLegality()"; -} - -def ConvertMHLOToStableHLOPass : Pass<"iree-convert-mhlo-to-stablehlo", "ModuleOp"> { - let summary = "Convert MHLO to StableHLO, to aid in transition to StableHLO."; - let constructor = "mlir::iree_compiler::MHLO::createConvertMHLOToStableHLOPass()"; -} - -//------------------------------------------------------------------------------ -// Test passes -//------------------------------------------------------------------------------ - -def TestMHLOConvertComplexToReal : - Pass<"iree-test-mhlo-convert-complex-to-real", "func::FuncOp"> { - let summary = "Test pass that does an MHLO->MHLO conversion of just complex arithmetic ops."; - let constructor = "mlir::iree_compiler::MHLO::createTestMHLOConvertComplexToRealPass()"; -} - -#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_PASSES diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h b/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h deleted file mode 100644 index e839f38c843c..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/Rewriters.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_INPUTCONVERSION_MHLO_REWRITER_H_ -#define IREE_COMPILER_INPUTCONVERSION_MHLO_REWRITER_H_ - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -/// Populates the patterns that convert from MHLO to Linalg on tensors. Imports -/// patterns from XLA, as well as some IREE specific modifications. -void populateMHLOToLinalgOnTensorsConversionPatterns( - MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet &patterns); - -/// Populates IREE specific patterns to convert HLO broadcasting ops to Linalg. -/// These are being maintained separately because they are a standalone unit -/// that is both intricate and possible to upstream, should there be alignment -/// to do so. -void populateMHLOBroadcastingToLinalgPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); - -/// Populates patterns to convert MHLO collective ops to Stream ops. -void populateMHLOCollectiveOpsConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); - -/// Populates patterns to convert MHLO/CHLO arithmetic on complex tensors to -/// equivalent HLO level real arithmetic. -void populateMHLOComplexToRealPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet &patterns); - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_INPUTCONVERSION_MHLO_REWRITER_H_ diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp b/compiler/src/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp deleted file mode 100644 index bde4fa51ad1e..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/VerifyCompilerMHLOInputLegality.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/InputConversion/MHLO/PassDetail.h" -#include "iree/compiler/InputConversion/MHLO/Passes.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/ChloOps.h" - -namespace mlir { -namespace iree_compiler { -namespace MHLO { - -struct VerifyCompilerMHLOInputLegalityPass - : public VerifyCompilerMHLOInputLegalityBase< - VerifyCompilerMHLOInputLegalityPass> { - void runOnOperation() override { - auto *context = &getContext(); - ConversionTarget conversionTarget(*context); - RewritePatternSet conversionPatterns(&getContext()); - - // Note that we would prefer allow-lists of what we positively support. - // However, it is so common to sneak input-level ops into the pipeline - // that we explicitly deny the dialects we know about. - conversionTarget.addIllegalDialect(); - conversionTarget.addIllegalDialect(); - conversionTarget.addIllegalDialect(); - - // NOTE: It is not fully illegal to tunnel input dialect ops through to - // backends that expect them. When such situations arise, the container - // op should be marked recursively legal here. - SmallVector failures; - { - ScopedDiagnosticHandler diag(context, - [&](Diagnostic &d) -> LogicalResult { - failures.push_back(std::move(d)); - return success(); - }); - if (succeeded(applyPartialConversion(getOperation(), conversionTarget, - std::move(conversionPatterns)))) { - return; - } - } - - // Error fall-through. Attach all reported issues as notes. - InFlightDiagnostic errorDiag = - emitError(getOperation().getLoc()) - << "one or more illegal operations were found in the compiler input " - "(are you missing an --iree-input-type= flag, or did you mean to " - "pre-process through an IREE importer frontend?)"; - for (auto &failureDiag : failures) { - Diagnostic ¬e = errorDiag.attachNote(failureDiag.getLocation()); - for (auto &arg : failureDiag.getArguments()) { - note.append(arg); - } - } - - signalPassFailure(); - } -}; - -std::unique_ptr> -createVerifyCompilerMHLOInputLegality() { - return std::make_unique(); -} - -} // namespace MHLO -} // namespace iree_compiler -} // namespace mlir diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD.bazel deleted file mode 100644 index a9ff945d5f73..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD.bazel +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Tests for common transforms. - -load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") -load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") - -package( - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = enforce_glob( - [ - "broadcasting.mlir", - "convert_mhlo_to_linalg_ext.mlir", - "convert_mhlo_to_stablehlo.mlir", - "convert_collective_ops.mlir", - "convert_complex_to_real.mlir", - "convert_structural_types.mlir", - "dynamic_shape.mlir", - "fft.mlir", - "flatten_tuples_in_cfg.mlir", - "mhlo_to_linalg.mlir", - "mhlo_to_mhlo_preprocessing.mlir", - "mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir", - "mhlo_to_mhlo_scatter.mlir", - "missing_legalizations.mlir", - "transformation_pipeline.mlir", - "verify_compiler_mhlo_input_legality.mlir", - ], - include = ["*.mlir"], - ), - cfg = "//compiler:lit.cfg.py", - tools = [ - "//tools:iree-opt", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt deleted file mode 100644 index 679cebd889d2..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# compiler/src/iree/compiler/InputConversion/MHLO/test/BUILD.bazel # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_lit_test_suite( - NAME - lit - SRCS - "broadcasting.mlir" - "convert_collective_ops.mlir" - "convert_complex_to_real.mlir" - "convert_mhlo_to_linalg_ext.mlir" - "convert_mhlo_to_stablehlo.mlir" - "convert_structural_types.mlir" - "dynamic_shape.mlir" - "fft.mlir" - "flatten_tuples_in_cfg.mlir" - "mhlo_to_linalg.mlir" - "mhlo_to_mhlo_preprocessing.mlir" - "mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir" - "mhlo_to_mhlo_scatter.mlir" - "missing_legalizations.mlir" - "transformation_pipeline.mlir" - "verify_compiler_mhlo_input_legality.mlir" - TOOLS - FileCheck - iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir deleted file mode 100644 index 39da1653db3a..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/broadcasting.mlir +++ /dev/null @@ -1,442 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors %s | FileCheck %s - -// Check the non-broadcast case for each registered op, then just check a -// representative op for detailed broadcast semantics. Since the broadcasting -// implementation lowers through mhlo ops, we are primarily checking broadcast -// semantics and not exhaustively checking that the non broadcasting ops lower -// to the right linalg sequences. - -// CHECK-LABEL: @addWithoutBroadcast -func.func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: linalg.generic - // CHECK-SAME: outs(%0 : tensor<4xf32> - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1) -> (d1)> -// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: @dynamicBroadcast -func.func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { - // Should broadcast %arg0 -> %arg1 and cf.assert on dynamic expansion. - - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[ARG0_D0:.*]] = tensor.dim %arg0, %[[C0]] - // CHECK-DAG: %[[ARG1_D0:.*]] = tensor.dim %arg1, %[[C0]] : tensor - // CHECK-DAG: %[[ARG1_D1:.*]] = tensor.dim %arg1, %[[C1]] : tensor - // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[ARG0_D0]], %[[ARG1_D1]] : index - // CHECK: cf.assert %[[EQ]], "mismatched dynamic broadcast extents" - - // CHECK: %[[INIT_0:.*]] = tensor.empty(%[[ARG1_D0]], %[[ARG0_D0]]) : tensor - // CHECK: %[[BCAST_ARG0:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%arg0 : tensor) outs(%[[INIT_0]] : tensor) - - // CHECK: %[[RESULT:.*]] = linalg.generic - // CHECK-SAME: ins(%[[BCAST_ARG0]], %arg1 : tensor, tensor) - - // CHECK-NOT: mhlo.add - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions -func.func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- -// Verifies that broadcast_dimensions validity checks are valid. -// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions -func.func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastComplex -func.func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK-NOT: mhlo.complex - // CHECK-NOT: chlo.broadcast_complex - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> - - %1 = "mhlo.real"(%0) : (tensor>) -> tensor - %2 = "mhlo.imag"(%0) : (tensor>) -> tensor - - return %1, %2 : tensor, tensor -} - -// ----- -// CHECK-LABEL: @dynamicBroadcastCompare -func.func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { - // NOTE: compare is unique because of the element type switch. The pattern - // will fail or the verifier will catch it if wrong. - // CHECK-NOT: mhlo.compare - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor, tensor) -> tensor - return %0 : tensor -} - -// ----- -// CHECK-LABEL: func.func @selectv2 -func.func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // All same type: should just short-circtuit to one mhlo.select / one generic. - // CHECK: linalg.generic - // CHECK: %[[BODY:.*]] = arith.select - // CHECK-NOT: linalg.generic - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0) -> ()> -// CHECK: #map1 = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @selectv2_pred_scalar -func.func @selectv2_pred_scalar(%arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: %[[INIT_0:.*]] = tensor.empty() : tensor<2xi1> - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%arg0 : tensor) outs(%[[INIT_0]] : tensor<2xi1>) - // CHECK: %[[INIT_1:.*]] = tensor.empty() : tensor<2xi32> - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %arg1, %arg2 : tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) outs(%[[INIT_1]] : tensor<2xi32>) - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1, d2) -> ()> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d1, 0)> -// CHECK-LABEL: func.func @selectv2_broadcast_then -func.func @selectv2_broadcast_then(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor) - // CHECK: %[[BCAST_THEN:.*]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8x1xi32>) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %arg2 : tensor<2x8x8xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) - // CHECK: arith.select - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - return %0: tensor<2x8x8xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1, d2) -> ()> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d1, 0)> -// CHECK-LABEL: func.func @selectv2_broadcast_else -func.func @selectv2_broadcast_else(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor) - // CHECK: %[[BCAST_ELSE:.*]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<8x1xi32>) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %arg1, %[[BCAST_ELSE]] : tensor<2x8x8xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) - // CHECK: arith.select - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> - return %0: tensor<2x8x8xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1, d2) -> (0)> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func.func @selectv2_broadcast_pred -func.func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1xi1>) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %arg1, %arg2 : tensor<2x8x8xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) - // CHECK: arith.select - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> - return %0: tensor<2x8x8xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, 0, 0)> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (0, d1, 0)> -// CHECK: #map3 = affine_map<(d0, d1, d2) -> (0, 0, d2)> -// CHECK-LABEL: func.func @selectv2_broadcast_all -func.func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<8x1x1xi1>) - // CHECK: %[[BCAST_THEN:.*]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<1x8x1xi32>) - // CHECK: %[[BCAST_ELSE:.*]] = linalg.generic {indexing_maps = [#map3, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1x1x8xi32>) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %[[BCAST_ELSE]] : tensor<8x8x8xi1>, tensor<8x8x8xi32>, tensor<8x8x8xi32>) - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> - return %0: tensor<8x8x8xi32> -} - -// ----- -// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, 0, 0)> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (0, d1, 0)> -// CHECK: #map3 = affine_map<(d0, d1, d2) -> (0, 0, d2)> -// CHECK-LABEL: func.func @selectv2_broadcast_dyn_pred -func.func @selectv2_broadcast_dyn_pred(%arg0: tensor, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[DIM_PRED_0:.*]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[INIT_PRED:.*]] = tensor.empty(%[[DIM_PRED_0]]) - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map1] - // CHECK-SAME: ins(%arg0 : tensor) outs(%[[INIT_PRED]] : tensor) - // CHECK: %[[INIT_THEN:.*]] = tensor.empty(%[[DIM_PRED_0]]) - // CHECK: %[[BCAST_THEN:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map2, #map1] - // CHECK-SAME: ins(%arg1 : tensor<1x8x1xi32>) outs(%[[INIT_THEN]] : tensor) - // CHECK: %[[INIT_ELSE:.*]] = tensor.empty(%[[DIM_PRED_0]]) - // CHECK: %[[BCAST_ELSE:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map3, #map1] - // CHECK-SAME: ins(%arg2 : tensor<1x1x8xi32>) outs(%[[INIT_ELSE]] : tensor) - // CHECK: %[[SHAPE_BCAST_THEN:.*]] = shape.shape_of %[[BCAST_THEN]] - // CHECK: %[[DIM_BCAST_THEN_0:.*]] = tensor.extract %[[SHAPE_BCAST_THEN]][%[[C0]]] - // CHECK: %[[INIT_RESULT:.*]] = tensor.empty(%[[DIM_BCAST_THEN_0]]) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %[[BCAST_ELSE]] : tensor, tensor, tensor) outs(%[[INIT_RESULT]] : tensor) - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor - return %0: tensor -} - -// ----- -// CHECK-LABEL: func.func @selectv2_broadcast_dyn_then -func.func @selectv2_broadcast_dyn_then(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x?x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x?x8xi32> { - // CHECK: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[DIM_THEN_1:.*]] = tensor.dim %arg1, %[[C1]] - // CHECK: %[[INIT_PRED:.*]] = tensor.empty(%[[DIM_THEN_1]]) - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map1] - // CHECK-SAME: ins(%arg0 : tensor<8x1x1xi1>) outs(%[[INIT_PRED]] : tensor<8x?x8xi1>) - // CHECK: %[[INIT_THEN:.*]] = tensor.empty(%[[DIM_THEN_1]]) - // CHECK: %[[BCAST_THEN:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map2, #map1] - // CHECK-SAME: ins(%arg1 : tensor<1x?x1xi32>) outs(%[[INIT_THEN]] : tensor<8x?x8xi32>) - // CHECK: %[[INIT_ELSE:.*]] = tensor.empty(%[[DIM_THEN_1]]) - // CHECK: %[[BCAST_ELSE:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map3, #map1] - // CHECK-SAME: ins(%arg2 : tensor<1x1x8xi32>) outs(%[[INIT_ELSE]] : tensor<8x?x8xi32>) - // CHECK: %[[SHAPE_BCAST_THEN:.*]] = shape.shape_of %[[BCAST_THEN]] - // CHECK: %[[DIM_BCAST_THEN_1:.*]] = tensor.extract %[[SHAPE_BCAST_THEN]][%[[C1]]] - // CHECK: %[[INIT_RESULT:.*]] = tensor.empty(%[[DIM_BCAST_THEN_1]]) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %[[BCAST_ELSE]] : tensor<8x?x8xi1>, tensor<8x?x8xi32>, tensor<8x?x8xi32>) outs(%[[INIT_RESULT]] : tensor<8x?x8xi32>) - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x?x1xi32>, tensor<1x1x8xi32>) -> tensor<8x?x8xi32> - return %0: tensor<8x?x8xi32> -} - -// ----- -// CHECK-LABEL: func.func @selectv2_broadcast_dyn_else -func.func @selectv2_broadcast_dyn_else(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x?xi32>) -> tensor<8x8x?xi32> { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[DIM_ELSE_2:.*]] = tensor.dim %arg2, %[[C2]] - // CHECK: %[[INIT_PRED:.*]] = tensor.empty(%[[DIM_ELSE_2]]) - // CHECK: %[[BCAST_PRED:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map1] - // CHECK-SAME: ins(%arg0 : tensor<8x1x1xi1>) outs(%[[INIT_PRED]] : tensor<8x8x?xi1>) - - // CHECK: %[[INIT_THEN:.*]] = tensor.empty(%[[DIM_ELSE_2]]) - // CHECK: %[[BCAST_THEN:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map2, #map1] - // CHECK-SAME: ins(%arg1 : tensor<1x8x1xi32>) outs(%[[INIT_THEN]] : tensor<8x8x?xi32>) - // CHECK: %[[INIT_ELSE:.*]] = tensor.empty(%[[DIM_ELSE_2]]) - // CHECK: %[[BCAST_ELSE:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#map3, #map1] - // CHECK-SAME: ins(%arg2 : tensor<1x1x?xi32>) outs(%[[INIT_ELSE]] : tensor<8x8x?xi32>) - // CHECK: %[[SHAPE_BCAST_THEN:.*]] = shape.shape_of %[[BCAST_THEN]] - // CHECK: %[[DIM_BCAST_THEN_1:.*]] = tensor.extract %[[SHAPE_BCAST_THEN]][%[[C2]]] - // CHECK: %[[INIT_RESULT:.*]] = tensor.empty(%[[DIM_BCAST_THEN_1]]) - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[BCAST_PRED]], %[[BCAST_THEN]], %[[BCAST_ELSE]] : tensor<8x8x?xi1>, tensor<8x8x?xi32>, tensor<8x8x?xi32>) outs(%[[INIT_RESULT]] : tensor<8x8x?xi32>) - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x?xi32>) -> tensor<8x8x?xi32> - return %0: tensor<8x8x?xi32> -} - -// ----- -// CHECK-LABEL: func.func @selectv2_broadcast_dyn_all -func.func @selectv2_broadcast_dyn_all(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[PRED_D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor - // CHECK-DAG: %[[THEN_D0:.*]] = tensor.dim %arg1, %[[C0]] : tensor - // CHECK-DAG: %[[ELSE_D0:.*]] = tensor.dim %arg2, %[[C0]] : tensor - // CHECK-DAG: %[[ELSE_D2:.*]] = tensor.dim %arg2, %[[C2]] : tensor - // CHECK: %[[CMP_0:.*]] = arith.cmpi eq, %[[PRED_D0]], %[[THEN_D0]] : index - // CHECK: cf.assert %[[CMP_0]], "mismatched dynamic broadcast extents" - // CHECK: %[[CMP_1:.*]] = arith.cmpi eq, %[[PRED_D0]], %[[ELSE_D0]] : index - // CHECK: cf.assert %[[CMP_1]], "mismatched dynamic broadcast extents" - // Only two cf.asserts are needed. The rest are statically verified. - // CHECK-NOT: cf.assert - %0 = "chlo.broadcast_select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - return %0: tensor -} - -// ----- -// Note that broadcast_add is used as a proxy for all of the template -// expansions. Tests below merely verify that the op has an expansion. -// CHECK-LABEL: @andWithoutBroadcast -func.func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK-NOT: mhlo.and - %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @atan2WithoutBroadcast -func.func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.atan2 - %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @compareWithoutBroadcast -func.func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { - // CHECK-NOT: mhlo.compare - %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = #chlo} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @complexWithoutBroadcast -func.func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - // CHECK-NOT: mhlo.complex - // CHECK-NOT: chlo.broadcast_complex - %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - - %1 = "mhlo.real"(%0) : (tensor<4xcomplex>) -> tensor<4xf32> - %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> tensor<4xf32> - - return %1, %2 : tensor<4xf32>, tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @divideWithoutBroadcast -func.func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.divide - %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @maximumWithoutBroadcast -func.func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.maximum - %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @minimumWithoutBroadcast -func.func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.minimum - %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @multiplyWithoutBroadcast -func.func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.multiply - %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @orWithoutBroadcast -func.func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK-NOT: mhlo.or - %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @powerWithoutBroadcast -func.func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.power - %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @remainderWithoutBroadcast -func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.remainder - %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @subWithoutBroadcast -func.func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NOT: mhlo.subtract - %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @xorWithoutBroadcast -func.func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { - // CHECK-NOT: mhlo.xor - %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - return %0 : tensor<4xi1> -} - -// ----- -// CHECK-LABEL: @ZetaWithoutBroadcast -func.func @ZetaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // This is a composition: it should lower completely. - // CHECK-NOT: mhlo. - %0 = chlo.broadcast_zeta %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @PolygammaWithoutBroadcast -// CHECK-SAME: (%[[LHS:.*]]: tensor<4xf32>, %[[RHS:.*]]: tensor<4xf32>) -func.func @PolygammaWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) - -> tensor<4xf32> { - // This is a composition: it should lower completely. - // CHECK-NOT: mhlo. - %0 = chlo.broadcast_polygamma %arg0, %arg1 - : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- -// CHECK-LABEL: @fallbackDynamicReshape -func.func @fallbackDynamicReshape(%arg0 : tensor<4x?x3x?xui32>, %arg1 : tensor<5xindex>) -> tensor<12x?x?x1x?xui32> { - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[RESULT_D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xindex> - // CHECK-DAG: %[[RESULT_D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xindex> - // CHECK-DAG: %[[RESULT_D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xindex> - // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32> - // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32> - // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]} - %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xindex>) -> tensor<12x?x?x1x?xui32> - // CHECK: return %[[RESULT]] - return %0 : tensor<12x?x?x1x?xui32> -} - -// ----- -// CHECK-LABEL: @fallbackDynamicReshapeInt -func.func @fallbackDynamicReshapeInt(%arg0 : tensor<4x?x3x?xui32>, %arg1 : tensor<5xi32>) -> tensor<12x?x?x1x?xui32> { - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[D1:.*]] = tensor.extract %arg1[%[[C1]]] : tensor<5xi32> - // CHECK-DAG: %[[D2:.*]] = tensor.extract %arg1[%[[C2]]] : tensor<5xi32> - // CHECK-DAG: %[[D4:.*]] = tensor.extract %arg1[%[[C4]]] : tensor<5xi32> - // CHECK-DAG: %[[RESULT_D1:.*]] = arith.index_cast %[[D1]] : i32 to index - // CHECK-DAG: %[[RESULT_D2:.*]] = arith.index_cast %[[D2]] : i32 to index - // CHECK-DAG: %[[RESULT_D4:.*]] = arith.index_cast %[[D4]] : i32 to index - // CHECK-DAG: %[[ARG_D1:.*]] = tensor.dim %arg0, %[[C1]] : tensor<4x?x3x?xi32> - // CHECK-DAG: %[[ARG_D3:.*]] = tensor.dim %arg0, %[[C3]] : tensor<4x?x3x?xi32> - // CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<4x?x3x?xi32>{%[[ARG_D1]], %[[ARG_D3]]} -> tensor<12x?x?x1x?xi32>{%[[RESULT_D1]], %[[RESULT_D2]], %[[RESULT_D4]]} - %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<4x?x3x?xui32>, tensor<5xi32>) -> tensor<12x?x?x1x?xui32> - // CHECK: return %[[RESULT]] - return %0 : tensor<12x?x?x1x?xui32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir deleted file mode 100644 index 14cd3078d101..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_collective_ops.mlir +++ /dev/null @@ -1,658 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors --canonicalize -cse %s | FileCheck %s - -// CHECK-LABEL: @replica_id -func.func @replica_id() -> tensor { - // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index - // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[RANK]] : index to i32 - // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor - // CHECK-DAG: return %[[TENSOR]] : tensor - %id = mhlo.replica_id : tensor - return %id : tensor -} - -// ----- - -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @replica_id_with_partitions - func.func @replica_id_with_partitions() -> tensor { - // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index - // CHECK-DAG: %[[DIV2:.+]] = arith.divui %[[RANK]], %c2 : index - // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[DIV2]] : index to i32 - // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor - // CHECK-DAG: return %[[TENSOR]] : tensor - %id = mhlo.replica_id : tensor - return %id : tensor - } -} - -// ----- - -// Returns 0 since num_partitions is not set. - -// CHECK-LABEL: @partition_id -func.func @partition_id() -> tensor { - // CHECK-DAG: %[[CST0:.+]] = arith.constant dense<0> : tensor - // CHECK-DAG: return %[[CST0]] : tensor - %id = mhlo.partition_id : tensor - return %id : tensor -} - -// ----- - -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @partition_id_with_partitions - func.func @partition_id_with_partitions() -> tensor { - // CHECK-DAG: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK-DAG: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index - // CHECK-DAG: %[[REM2:.+]] = arith.remui %[[RANK]], %c2 : index - // CHECK-DAG: %[[CAST:.+]] = arith.index_castui %[[REM2]] : index to i32 - // CHECK-DAG: %[[TENSOR:.+]] = tensor.from_elements %[[CAST]] : tensor - // CHECK-DAG: return %[[TENSOR]] : tensor - %id = mhlo.partition_id : tensor - return %id : tensor - } -} - -// ----- - -// CHECK-LABEL: @all_reduce_sum -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>) -func.func @all_reduce_sum(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xf32> - // CHECK: return %[[OP]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_sum_uint -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xi32> -func.func @all_reduce_sum_uint(%input : tensor<2304xui32>) -> tensor<2304xui32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xi32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xi32>, tensor<2304xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xi32> - // CHECK: return %[[OP]] : tensor<2304xi32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xui32>) -> tensor<2304xui32> - return %out : tensor<2304xui32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_product -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>) -func.func @all_reduce_product(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce product, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xf32> - // CHECK: return %[[OP]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %mul = mhlo.multiply %arg0, %arg1 : tensor - mhlo.return %mul : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_minimum -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>) -func.func @all_reduce_minimum(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce minimum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xf32> - // CHECK: return %[[OP]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %mul = mhlo.minimum %arg0, %arg1 : tensor - mhlo.return %mul : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_maximum -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>) -func.func @all_reduce_maximum(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce maximum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xf32> - // CHECK: return %[[OP]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %mul = mhlo.maximum %arg0, %arg1 : tensor - mhlo.return %mul : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_maximum_optional_attrs -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2304xf32>) -func.func @all_reduce_maximum_optional_attrs(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce maximum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2304xf32> - // CHECK: return %[[OP]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %mul = mhlo.maximum %arg0, %arg1 : tensor - mhlo.return %mul : tensor - }) {replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> -} - -// ----- - -// CHECK-LABEL: @all_reduce_sum_with_groups -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xi32>) -func.func @all_reduce_sum_with_groups(%input : tensor<2x4xi32>) -> tensor<2x4xi32> { - // CHECK: %[[BASE_CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[BASE_RANK:.+]] = flow.channel.rank %[[BASE_CHANNEL]] - // CHECK: %[[SPLIT_COLOR:.+]] = util.switch index from [%c0, %c1] at %[[BASE_RANK]] else %c-1 - // CHECK: %[[SPLIT_KEY:.+]] = util.switch index from [%c0, %c0] at %[[BASE_RANK]] else %c-1 - // CHECK: %[[SPLIT_CHANNEL:.+]] = flow.channel.split %[[BASE_CHANNEL]], %[[SPLIT_COLOR]], %[[SPLIT_KEY]] : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x4xi32> - // CHECK: %[[OP:.+]] = flow.collective.all_reduce sum, ui32, %[[EMPTY]], %[[ARG0]], %[[SPLIT_CHANNEL]] : (tensor<2x4xi32>, tensor<2x4xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2x4xi32> - // CHECK: return %[[OP]] : tensor<2x4xi32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, - use_global_device_ids} : (tensor<2x4xi32>) -> tensor<2x4xi32> - return %out : tensor<2x4xi32> -} - -// ----- - -// CHECK-LABEL: @all_gather_dim_0 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<512xf32>) -> tensor<1024xf32> -func.func @all_gather_dim_0(%input : tensor<512xf32>) -> tensor<1024xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1024xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_gather f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<1024xf32>, tensor<512xf32>, !flow.channel) -> %[[EMPTY]] as tensor<1024xf32> - // CHECK: return %[[OP]] : tensor<1024xf32> - %out = "mhlo.all_gather"(%input) {all_gather_dim = 0 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<512xf32>) -> tensor<1024xf32> - return %out : tensor<1024xf32> -} - -// ----- - -// CHECK-LABEL: @all_gather_dim_0_uint -// CHECK-SAME: (%[[ARG0:.+]]: tensor<512xi32> -func.func @all_gather_dim_0_uint(%input : tensor<512xui32>) -> tensor<1024xui32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1024xi32> - // CHECK: %[[OP:.+]] = flow.collective.all_gather ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<1024xi32>, tensor<512xi32>, !flow.channel) -> %[[EMPTY]] as tensor<1024xi32> - // CHECK: return %[[OP]] : tensor<1024xi32> - %out = "mhlo.all_gather"(%input) {all_gather_dim = 0 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<512xui32>) -> tensor<1024xui32> - return %out : tensor<1024xui32> -} - -// ----- - -// CHECK-LABEL: @all_gather_dim_1 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x2xf32>) -> tensor<2x4xf32> -func.func @all_gather_dim_1(%input : tensor<2x2xf32>) -> tensor<2x4xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: tensor.empty() : tensor<2x2xf32> - // CHECK: %[[TRANSPOSE_ARG:.+]] = linalg.generic - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_gather f32, %[[EMPTY]], %[[TRANSPOSE_ARG]], %[[CHANNEL]] : (tensor<4x2xf32>, tensor<2x2xf32>, !flow.channel) -> %[[EMPTY]] as tensor<4x2xf32> - // CHECK: tensor.empty() : tensor<2x4xf32> - // CHECK: %[[TRANSPOSE_OUT:.+]] = linalg.generic - // CHECK: return %[[TRANSPOSE_OUT]] : tensor<2x4xf32> - %out = "mhlo.all_gather"(%input) {all_gather_dim = 1 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<2x2xf32>) -> tensor<2x4xf32> - return %out : tensor<2x4xf32> -} - -// ----- - -// CHECK-LABEL: @all_gather_dim_0_optional_attrs -// CHECK-SAME: (%[[ARG0:.+]]: tensor<512xf32>) -> tensor<1024xf32> -func.func @all_gather_dim_0_optional_attrs(%input : tensor<512xf32>) -> tensor<1024xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1024xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_gather f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<1024xf32>, tensor<512xf32>, !flow.channel) -> %[[EMPTY]] as tensor<1024xf32> - // CHECK: return %[[OP]] : tensor<1024xf32> - %out = "mhlo.all_gather"(%input) {all_gather_dim = 0 : i64, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<512xf32>) -> tensor<1024xf32> - return %out : tensor<1024xf32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_split_concat_same -// CHECK-SAME: (%[[ARG0:.+]]: tensor<1024xf32>) -> tensor<1024xf32> -func.func @all_to_all_split_concat_same(%input : tensor<1024xf32>) -> tensor<1024xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1024xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_to_all f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<1024xf32>, tensor<1024xf32>, !flow.channel) -> %[[EMPTY]] as tensor<1024xf32> - // CHECK: return %[[OP]] : tensor<1024xf32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 0 : i64, - concat_dimension = 0 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<1024xf32>) -> tensor<1024xf32> - return %out : tensor<1024xf32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_split_concat_same_uint -// CHECK-SAME: (%[[ARG0:.+]]: tensor<1024xi32> -func.func @all_to_all_split_concat_same_uint(%input : tensor<1024xui32>) -> tensor<1024xui32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1024xi32> - // CHECK: %[[OP:.+]] = flow.collective.all_to_all ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<1024xi32>, tensor<1024xi32>, !flow.channel) -> %[[EMPTY]] as tensor<1024xi32> - // CHECK: return %[[OP]] : tensor<1024xi32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 0 : i64, - concat_dimension = 0 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<1024xui32>) -> tensor<1024xui32> - return %out : tensor<1024xui32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_split_concat_same_dim_1 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xf32>) -> tensor<2x4xf32> -func.func @all_to_all_split_concat_same_dim_1(%input : tensor<2x4xf32>) -> tensor<2x4xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x2xf32> - // CHECK: %[[TRANSPOSE_ARG:.+]] = linalg.generic - // CHECK: %[[OP:.+]] = flow.collective.all_to_all f32, %[[EMPTY]], %[[TRANSPOSE_ARG]], %[[CHANNEL]] : (tensor<4x2xf32>, tensor<4x2xf32>, !flow.channel) -> %[[EMPTY]] as tensor<4x2xf32> - // CHECK: %[[TRANSPOSE_OUT:.+]] = linalg.generic - // CHECK: return %[[TRANSPOSE_OUT]] : tensor<2x4xf32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 1 : i64, - concat_dimension = 1 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32> - return %out : tensor<2x4xf32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_split_dim_0 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x4xf32>) -> tensor<2x8xf32> -func.func @all_to_all_split_dim_0(%input : tensor<4x4xf32>) -> tensor<2x8xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4xf32> - // CHECK: %[[OP:.+]] = flow.collective.all_to_all f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<4x4xf32>, tensor<4x4xf32>, !flow.channel) -> %[[EMPTY]] as tensor<4x4xf32> - // CHECK: %[[REARRANGE_RESHAPE:.+]] = tensor.expand_shape %[[OP]] {{\[}}[0, 1], [2]] : tensor<4x4xf32> into tensor<2x2x4xf32> - // CHECK: %[[REARRANGE_TRANSPOSE:.+]] = linalg.generic - // CHECK: %[[RESHAPE_OUT:.+]] = tensor.collapse_shape %[[REARRANGE_TRANSPOSE]] {{\[}}[0], [1, 2]] : tensor<2x2x4xf32> into tensor<2x8xf32> - // CHECK: return %[[RESHAPE_OUT]] : tensor<2x8xf32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 0 : i64, - concat_dimension = 1 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x4xf32>) -> tensor<2x8xf32> - return %out : tensor<2x8xf32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_split_dim_1 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x4xf32>) -> tensor<8x2xf32> -func.func @all_to_all_split_dim_1(%input : tensor<4x4xf32>) -> tensor<8x2xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4xf32> - // CHECK: %[[TRANSPOSE_ARG:.+]] = linalg.generic - // CHECK: %[[OP:.+]] = flow.collective.all_to_all f32, %[[EMPTY]], %[[TRANSPOSE_ARG]], %[[CHANNEL]] : (tensor<4x4xf32>, tensor<4x4xf32>, !flow.channel) -> %[[EMPTY]] as tensor<4x4xf32> - // CHECK: %[[TRANSPOSE_OUT:.+]] = linalg.generic - // CHECK: %[[REARRANGE_RESHAPE1:.+]] = tensor.expand_shape %[[TRANSPOSE_OUT]] {{\[}}[0], [1, 2]] : tensor<4x4xf32> into tensor<4x2x2xf32> - // CHECK: %[[EMPTY2:.+]] = tensor.empty() : tensor<2x4x2xf32> - // CHECK: %[[REARRANGE_TRANSPOSE:.+]] = linalg.generic - // CHECK: %[[REARRANGE_RESHAPE2:.+]] = tensor.collapse_shape %[[REARRANGE_TRANSPOSE]] {{\[}}[0, 1], [2]] : tensor<2x4x2xf32> into tensor<8x2xf32> - // CHECK: return %[[REARRANGE_RESHAPE2]] : tensor<8x2xf32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 1 : i64, - concat_dimension = 0 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x4xf32>) -> tensor<8x2xf32> - return %out : tensor<8x2xf32> -} - -// ----- - -// CHECK-LABEL: @all_to_all_3d_split_dim_1 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x4x4xf32>) -> tensor<4x2x8xf32> -func.func @all_to_all_3d_split_dim_1(%input : tensor<4x4x4xf32>) -> tensor<4x2x8xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x4xf32> - // CHECK: %[[TRANSPOSE_ARG:.+]] = linalg.generic - // CHECK: %[[OP:.+]] = flow.collective.all_to_all f32, %[[EMPTY]], %[[TRANSPOSE_ARG]], %[[CHANNEL]] : (tensor<4x4x4xf32>, tensor<4x4x4xf32>, !flow.channel) -> %[[EMPTY]] as tensor<4x4x4xf32> - // CHECK: %[[TRANSPOSE_OUT:.+]] = linalg.generic - // CHECK: %[[REARRANGE_RESHAPE1:.+]] = tensor.expand_shape %[[TRANSPOSE_OUT]] {{\[}}[0], [1, 2], [3]] : tensor<4x4x4xf32> into tensor<4x2x2x4xf32> - // CHECK: %[[EMPTY_1:.+]] = tensor.empty() : tensor<4x2x2x4xf32> - // CHECK: %[[REARRANGE_TRANSPOSE:.+]] = linalg.generic - // CHECK: %[[REARRANGE_RESHAPE2:.+]] = tensor.collapse_shape %[[REARRANGE_TRANSPOSE]] {{\[}}[0], [1], [2, 3]] : tensor<4x2x2x4xf32> into tensor<4x2x8xf32> - // CHECK: return %[[REARRANGE_RESHAPE2]] : tensor<4x2x8xf32> - %out = "mhlo.all_to_all"(%input) { - split_dimension = 1 : i64, - concat_dimension = 2 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x4x4xf32>) -> tensor<4x2x8xf32> - return %out : tensor<4x2x8xf32> -} - -// ----- - -// CHECK-LABEL: @reduce_scatter_dim_0 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x2xf32>) -> tensor<2x2xf32> -func.func @reduce_scatter_dim_0(%input : tensor<4x2xf32>) -> tensor<2x2xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2xf32> - // CHECK: %[[OP:.+]] = flow.collective.reduce_scatter sum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2x2xf32> - // CHECK: return %[[OP]] : tensor<2x2xf32> - %out = "mhlo.reduce_scatter"(%input) ({ - ^bb0(%arg0: tensor , %arg1: tensor) : - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {scatter_dimension = 0 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<4x2xf32>) -> tensor<2x2xf32> - return %out : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: @reduce_scatter_dim_0_uint -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x2xi32> -func.func @reduce_scatter_dim_0_uint(%input : tensor<4x2xui32>) -> tensor<2x2xui32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2xi32> - // CHECK: %[[OP:.+]] = flow.collective.reduce_scatter sum, ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2x2xi32>, tensor<4x2xi32>, !flow.channel) -> %[[EMPTY]] as tensor<2x2xi32> - // CHECK: return %[[OP]] : tensor<2x2xi32> - %out = "mhlo.reduce_scatter"(%input) ({ - ^bb0(%arg0: tensor , %arg1: tensor) : - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {scatter_dimension = 0 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<4x2xui32>) -> tensor<2x2xui32> - return %out : tensor<2x2xui32> -} - -// ----- - -// CHECK-LABEL: @reduce_scatter_dim_1 -// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xf32>) -> tensor<2x2xf32> -func.func @reduce_scatter_dim_1(%input : tensor<2x4xf32>) -> tensor<2x2xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: tensor.empty() : tensor<4x2xf32> - // CHECK: %[[TRANSPOSE_ARG:.+]] = linalg.generic - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2xf32> - // CHECK: %[[OP:.+]] = flow.collective.reduce_scatter sum, f32, %[[EMPTY]], %[[TRANSPOSE_ARG]], %[[CHANNEL]] : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2x2xf32> - // CHECK: %[[TRANSPOSE_OUT:.+]] = linalg.generic - // CHECK: return %[[TRANSPOSE_OUT]] : tensor<2x2xf32> - %out = "mhlo.reduce_scatter"(%input) ({ - ^bb0(%arg0: tensor , %arg1: tensor) : - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {scatter_dimension = 1 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, - use_global_device_ids} : (tensor<2x4xf32>) -> tensor<2x2xf32> - return %out : tensor<2x2xf32> -} - -// ----- - -// CHECK-LABEL: @reduce_scatter_dim_0_optional_attrs -// CHECK-SAME: (%[[ARG0:.+]]: tensor<4x2xf32>) -> tensor<2x2xf32> -func.func @reduce_scatter_dim_0_optional_attrs(%input : tensor<4x2xf32>) -> tensor<2x2xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2xf32> - // CHECK: %[[OP:.+]] = flow.collective.reduce_scatter sum, f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]] : (tensor<2x2xf32>, tensor<4x2xf32>, !flow.channel) -> %[[EMPTY]] as tensor<2x2xf32> - // CHECK: return %[[OP]] : tensor<2x2xf32> - %out = "mhlo.reduce_scatter"(%input) ({ - ^bb0(%arg0: tensor , %arg1: tensor) : - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {scatter_dimension = 0 : i64, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<4x2xf32>) -> tensor<2x2xf32> - return %out : tensor<2x2xf32> -} - -// ----- - -// flattened_ids: channel_id > 0 && use_global_device_ids = true -module @jit_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 8 : i32 } { - // CHECK-LABEL: @flattened_ids - // CHECK-SAME: ([[ARG0:%.+]]: tensor<2304xf32>) - func.func @flattened_ids(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // CHECK: [[CHANNEL:%.+]] = flow.channel.default : !flow.channel - // CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<2304xf32> - // CHECK: [[ALLREDUCE:%.+]] = flow.collective.all_reduce sum, f32, [[EMPTY]], [[ARG0]], [[CHANNEL]] : (tensor<2304xf32>, tensor<2304xf32>, !flow.channel) -> [[EMPTY]] as tensor<2304xf32> - // CHECK: return [[ALLREDUCE]] : tensor<2304xf32> - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, - use_global_device_ids} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> - } -} - -// ----- - -// cross-replica: channel_id <= 0 && use_global_device_ids = false -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @cross_replica - func.func @cross_replica(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // Cross replica should form groups (0,2,4,6),(1,3,5,7), where each number represents a cell below. - // +---+---+ - // | 0 | 1 | - // | 2 | 3 | - // | 4 | 5 | - // | 6 | 7 | - // +---+---+ - // rank: 0 1 2 3 4 5 6 7 - // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index - // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> - } : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> - } -} - -// ----- - -// cross_replica_and_partition: channel_id > 0 && use_global_device_ids = false -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @cross_replica_and_partition - func.func @cross_replica_and_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // Cross replica_and_partition should form groups (0,2,1,3),(4,6,5,7), where each number represents a cell below. - // Note that the rank is assigned in a partiton first, e.g., rank 0 and 1 are assigned to cell 0 and 2, respectively. - // +---+---+ - // | 0 1 | - // | 2 3 | - // |---+---| - // | 4 5 | - // | 6 7 | - // +---+---+ - // rank: 0 1 2 3 4 5 6 7 - // CHECK: util.switch index from [%c0, %c0, %c0, %c0, %c1, %c1, %c1, %c1] at %channel_rank else %c-1 : index - // CHECK: util.switch index from [%c0, %c2, %c1, %c3, %c0, %c2, %c1, %c3] at %channel_rank else %c-1 : index - %out = "mhlo.all_reduce"(%input) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %sum = mhlo.add %arg0, %arg1 : tensor - mhlo.return %sum : tensor - }) {channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64> - } : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> - } -} - -// ----- - -// cross_partition: channel_id > 0 -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @cross_partition - func.func @cross_partition(%input : tensor<2304xf32>) -> tensor<2304xf32> { - // Cross partition should form groups (0,1),(2,3),(4,5),(6,7) where each number represents a cell below. - // +---+---+ - // | 0 1 | - // +---+---+ - // | 2 3 | - // +---+---+ - // | 4 5 | - // +---+---+ - // | 6 7 | - // +---+---+ - // rank: 0 1 2 3 4 5 6 7 - // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index - // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index - %out = "mhlo.all_to_all"(%input) { - split_dimension = 0 : i64, - concat_dimension = 0 : i64, - split_count = 2 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>} : (tensor<2304xf32>) -> tensor<2304xf32> - return %out : tensor<2304xf32> - } -} - -// ----- - -// CHECK-LABEL: @collective_permute -// CHECK-SAME: (%[[ARG0:.+]]: tensor<8xf32>) -> tensor<8xf32> -func.func @collective_permute(%input : tensor<8xf32>) -> tensor<8xf32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index - // CHECK: %[[SEND:.+]] = util.switch index from [%c1, %c2, %c3, %c0] at %[[RANK]] else %c-1 - // CHECK: %[[RECV:.+]] = util.switch index from [%c3, %c0, %c1, %c2] at %[[RANK]] else %c-1 - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> - // CHECK: %[[OP:.+]] = flow.collective.send_recv f32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]], %[[SEND]], %[[RECV]] : (tensor<8xf32>, tensor<8xf32>, !flow.channel, index, index) -> %[[EMPTY]] as tensor<8xf32> - // CHECK: return %[[OP]] : tensor<8xf32> - %out = "mhlo.collective_permute"(%input) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0]]> : tensor<4x2xi64>, - channel_handle = #mhlo.channel_handle} : (tensor<8xf32>) -> tensor<8xf32> - return %out : tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: @collective_permute_uint -// CHECK-SAME: (%[[ARG0:.+]]: tensor<8xi32> -func.func @collective_permute_uint(%input : tensor<8xui32>) -> tensor<8xui32> { - // CHECK: %[[CHANNEL:.+]] = flow.channel.default : !flow.channel - // CHECK: %[[RANK:.+]] = flow.channel.rank %[[CHANNEL]] : index - // CHECK: %[[SEND:.+]] = util.switch index from [%c1, %c2, %c3, %c0] at %[[RANK]] else %c-1 - // CHECK: %[[RECV:.+]] = util.switch index from [%c3, %c0, %c1, %c2] at %[[RANK]] else %c-1 - // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xi32> - // CHECK: %[[OP:.+]] = flow.collective.send_recv ui32, %[[EMPTY]], %[[ARG0]], %[[CHANNEL]], %[[SEND]], %[[RECV]] : (tensor<8xi32>, tensor<8xi32>, !flow.channel, index, index) -> %[[EMPTY]] as tensor<8xi32> - // CHECK: return %[[OP]] : tensor<8xi32> - %out = "mhlo.collective_permute"(%input) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0]]> : tensor<4x2xi64>, - channel_handle = #mhlo.channel_handle} : (tensor<8xui32>) -> tensor<8xui32> - return %out : tensor<8xui32> -} - -// ----- - -// collective_permute cross_replica: channel_id <= 0 -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @collective_permute_cross_replica - func.func @collective_permute_cross_replica(%input : tensor<8xf32>) -> tensor<8xf32> { - // Cross replica should form groups (0,2,4,6),(1,3,5,7) where each number represents a cell below. - // +---+---+ - // | 0 | 1 | - // | | | - // | 2 | 3 | - // | | | - // | 4 | 5 | - // | | | - // | 6 | 7 | - // +---+---+ - // rank: 0 1 2 3 4 5 6 7 - // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index - // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index - %out = "mhlo.collective_permute"(%input) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3], [3, 0]]> : tensor<4x2xi64>, - channel_handle = #mhlo.channel_handle} : (tensor<8xf32>) -> tensor<8xf32> - return %out : tensor<8xf32> - } -} - -// ----- - -// collective_permute cross_partition: channel_id > 0 -module @jit_fn attributes {mhlo.num_partitions = 2 : i32, mhlo.num_replicas = 4 : i32 } { - // CHECK-LABEL: @collective_permute_cross_partition - func.func @collective_permute_cross_partition(%input : tensor<8xf32>) -> tensor<8xf32> { - // Cross partition should form groups (0,1),(2,3),(4,5),(6,7) where each number represents a cell below. - // +---+---+ - // | 0 1 | - // +---+---+ - // | 2 3 | - // |---+---| - // | 4 5 | - // +---+---+ - // | 6 7 | - // +---+---+ - // rank: 0 1 2 3 4 5 6 7 - // CHECK: util.switch index from [%c0, %c0, %c1, %c1, %c2, %c2, %c3, %c3] at %channel_rank else %c-1 : index - // CHECK: util.switch index from [%c0, %c1, %c0, %c1, %c0, %c1, %c0, %c1] at %channel_rank else %c-1 : index - %out = "mhlo.collective_permute"(%input) { - source_target_pairs = dense<[[0, 1]]> : tensor<1x2xi64>, - channel_handle = #mhlo.channel_handle} : (tensor<8xf32>) -> tensor<8xf32> - return %out : tensor<8xf32> - } -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir deleted file mode 100644 index 7c32f4706ccb..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_complex_to_real.mlir +++ /dev/null @@ -1,147 +0,0 @@ -// RUN: iree-opt --iree-test-mhlo-convert-complex-to-real %s | FileCheck %s - -// CHECK-LABEL: @add -func.func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3 - %4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = mhlo.real %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = mhlo.imag %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @sub -func.func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2 - // CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3 - %4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = mhlo.real %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = mhlo.imag %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @mul -func.func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: %[[VAL0:.+]] = chlo.broadcast_multiply %arg0, %arg2 - // CHECK-DAG: %[[VAL1:.+]] = chlo.broadcast_multiply %arg1, %arg3 - // CHECK-DAG: %[[VAL2:.+]] = mhlo.subtract %[[VAL0]], %[[VAL1]] - // CHECK-DAG: %[[VAL3:.+]] = chlo.broadcast_multiply %arg0, %arg3 - // CHECK-DAG: %[[VAL4:.+]] = chlo.broadcast_multiply %arg1, %arg2 - // CHECK-DAG: %[[VAL5:.+]] = mhlo.add %[[VAL3]], %[[VAL4]] - %4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - %5 = mhlo.real %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = mhlo.imag %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @div -func.func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: %[[VAL0:.+]] = mhlo.negate %arg3 - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: %[[VAL1:.+]] = chlo.broadcast_multiply %arg0, %arg2 - // CHECK-DAG: %[[VAL2:.+]] = chlo.broadcast_multiply %arg1, %[[VAL0]] - // CHECK-DAG: %[[VAL3:.+]] = mhlo.subtract %[[VAL1]], %[[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: %[[VAL4:.+]] = mhlo.multiply %arg2, %arg2 - // CHECK-DAG: %[[VAL5:.+]] = mhlo.multiply %arg3, %arg3 - // CHECK-DAG: %[[VAL6:.+]] = mhlo.add %[[VAL4]], %[[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: %[[VAL7:.+]] = chlo.broadcast_multiply %arg1, %arg2 - // CHECK-DAG: %[[VAL8:.+]] = chlo.broadcast_multiply %arg0, %[[VAL0]] - // CHECK-DAG: %[[VAL9:.+]] = mhlo.add %[[VAL8]], %[[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: %[[VAL10:.+]] = chlo.broadcast_divide %[[VAL3]], %[[VAL6]] - // CHECK-DAG: %[[VAL11:.+]] = chlo.broadcast_divide %[[VAL9]], %[[VAL6]] - %4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) - - %5 = mhlo.real %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - %6 = mhlo.imag %4 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return %[[VAL10]], %[[VAL11]] - return %5, %6 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @abs -func.func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { - %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: %[[VAL0:.+]] = mhlo.multiply %arg0, %arg0 - // CHECK-DAG: %[[VAL1:.+]] = mhlo.multiply %arg1, %arg1 - // CHECK-DAG: %[[VAL2:.+]] = mhlo.add %[[VAL0]], %[[VAL1]] - // CHECK-DAG: %[[VAL3:.+]] = mhlo.sqrt %[[VAL2]] - %1 = mhlo.abs %0 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: return %[[VAL3]] - return %1 : tensor<2xf32> -} - -// CHECK-LABEL: @exp -func.func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { - %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - - // CHECK-DAG: %[[EXP:.+]] = mhlo.exponential %arg0 - // CHECK-DAG: %[[COS:.+]] = mhlo.cosine %arg1 - // CHECK-DAG: %[[SIN:.+]] = mhlo.sine %arg1 - // CHECK-DAG: %[[OUTR:.+]] = mhlo.multiply %[[COS]], %[[EXP]] - // CHECK-DAG: %[[OUTI:.+]] = mhlo.multiply %[[SIN]], %[[EXP]] - %1 = mhlo.exponential %0 : tensor<2xcomplex> - - %2 = mhlo.real %1 : (tensor<2xcomplex>) -> (tensor<2xf32>) - %3 = mhlo.imag %1 : (tensor<2xcomplex>) -> (tensor<2xf32>) - - // CHECK: %[[OUTR]], %[[OUTI]] - return %2, %3 : tensor<2xf32>, tensor<2xf32> -} - -// CHECK-LABEL: @compare_eq -func.func @compare_eq(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, - %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xi1>) { - %lhs = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %rhs = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: %[[OUTR:.+]] = chlo.broadcast_compare %arg0, %arg2 {comparison_direction = #chlo} - // CHECK-DAG: %[[OUTI:.+]] = chlo.broadcast_compare %arg1, %arg3 {comparison_direction = #chlo} - // CHECK-DAG: %[[OUT:.+]] = mhlo.and %[[OUTR]], %[[OUTI]] - %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xi1> - - // CHECK: return %[[OUT]] - return %0 : tensor<2xi1> -} - -// CHECK-LABEL: @compare_ne -func.func @compare_ne(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, - %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xi1>) { - %lhs = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - %rhs = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) - // CHECK-DAG: %[[OUTR:.+]] = chlo.broadcast_compare %arg0, %arg2 {comparison_direction = #chlo} - // CHECK-DAG: %[[OUTI:.+]] = chlo.broadcast_compare %arg1, %arg3 {comparison_direction = #chlo} - // CHECK-DAG: %[[OUT:.+]] = mhlo.or %[[OUTR]], %[[OUTI]] - %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<2xcomplex>, tensor<2xcomplex>) -> tensor<2xi1> - - // CHECK: return %[[OUT]] - return %0 : tensor<2xi1> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir deleted file mode 100644 index 7d00f42df2eb..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_linalg_ext.mlir +++ /dev/null @@ -1,564 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-ext %s | FileCheck %s -// Also ensure that full lowering to linalg doesn't error. -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-ext --iree-mhlo-to-linalg-on-tensors --reconcile-unrealized-casts %s - -func.func @sort_1d(%arg0: tensor<128xi32>) -> (tensor<128xi32>) { - %0 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>) -> (tensor<128xi32>) - return %0 : tensor<128xi32> -} -// CHECK-LABEL: func.func @sort_1d( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[ARG0]] : tensor<128xi32>) -// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -// CHECK: %[[CMP:.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG2]] -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]] - -// ----- - -func.func @sort_1d_ui(%arg0: tensor<128xui32>) -> (tensor<128xui32>) { - %0 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<128xui32>) -> (tensor<128xui32>) - return %0 : tensor<128xui32> -} -// CHECK-LABEL: func.func @sort_1d_ui( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[CAST:.+]] = tensor.bitcast %[[ARG0]] : tensor<128xui32> to tensor<128xi32> -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[CAST]] : tensor<128xi32>) -// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -// CHECK: %[[CMP:.+]] = arith.cmpi ugt, %[[ARG1]], %[[ARG2]] -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: %[[RESULT:.+]] = tensor.bitcast %[[SORT]] : tensor<128xi32> to tensor<128xui32> -// CHECK: return %[[RESULT]] - -// ----- - -func.func @sort_cst_capture(%arg0: tensor<1x10xi32>) -> tensor<1x10xi32> { - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg1: tensor, %arg3: tensor): - %2 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<1x10xi32>) -> tensor<1x10xi32> - return %1 : tensor<1x10xi32> -} - -// CHECK-LABEL: func.func @sort_cst_capture( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SCALAR:.+]] = arith.constant 0 : i32 -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>) { -// CHECK: ^bb0(%[[ARG1:.+]]: i32, %{{.*}}: i32) -// CHECK: %[[RES:.+]] = arith.cmpi slt, %[[ARG1]], %[[SCALAR]] : i32 -// CHECK: iree_linalg_ext.yield %[[RES]] : i1 -// CHECK: } -> tensor<1x10xi32> -// CHECK: return %[[SORT]] - -// ----- - -func.func @sort_argument_capture(%arg0: tensor<1x10xi32>, %arg1 : tensor) -> tensor<1x10xi32> { - %1 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %2 = "mhlo.compare"(%arg2, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%2) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<1x10xi32>) -> tensor<1x10xi32> - return %1 : tensor<1x10xi32> -} - -// CHECK-LABEL: func.func @sort_argument_capture( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SCALAR:.+]] = tensor.extract %[[ARG1]][] : tensor -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort dimension(1) outs(%[[ARG0]] : tensor<1x10xi32>) { -// CHECK: ^bb0(%[[ARG2:.+]]: i32, %{{.*}}: i32) -// CHECK: %[[RES:.+]] = arith.cmpi slt, %[[ARG2]], %[[SCALAR]] : i32 -// CHECK: iree_linalg_ext.yield %[[RES]] : i1 -// CHECK: } -> tensor<1x10xi32> -// CHECK: return %[[SORT]] - -// ----- - -func.func @sort_2d(%arg0: tensor<16x32xi32>) -> (tensor<16x32xi32>) { - %0 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<16x32xi32>) -> (tensor<16x32xi32>) - return %0 : tensor<16x32xi32> -} -// CHECK-LABEL: func.func @sort_2d( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[ARG0]] : tensor<16x32xi32>) -// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -// CHECK: %[[CMP:.+]] = arith.cmpi sgt, %[[ARG1]], %[[ARG2]] -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]] - -// ----- - -func.func @sort_unsigned(%arg0: tensor<1x5xf32>) -> tensor<1x5xf32> { - %1 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg1: tensor, %arg2: tensor): - %2 = "mhlo.bitcast_convert"(%arg1) : (tensor) -> tensor - %3 = "mhlo.bitcast_convert"(%arg2) : (tensor) -> tensor - %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%4) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32> - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: func.func @sort_unsigned( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(1) -// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>) -// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) -// CHECK: %[[CAST1:.+]] = arith.bitcast %[[ARG1]] : f32 to i32 -// CHECK: %[[CAST2:.+]] = arith.bitcast %[[ARG2]] : f32 to i32 -// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[CAST1]], %[[CAST2]] : i32 -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]] - -// ----- - -func.func @sort_unsigned_cst_capture(%arg0: tensor<1x5xf32>) -> tensor<1x5xf32> { - %ui32 = mhlo.constant dense<2> : tensor - %1 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg1: tensor, %arg2: tensor): - %2 = "mhlo.bitcast_convert"(%arg1) : (tensor) -> tensor - %3 = "mhlo.compare"(%2, %ui32) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32> - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: func.func @sort_unsigned_cst_capture( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[UI32:.+]] = mhlo.constant dense<2> : tensor -// CHECK: %[[CONVERSION_CAST_CST:.+]] = tensor.bitcast %[[UI32]] : tensor to tensor -// CHECK: %[[EXTRACT_CST:.+]] = tensor.extract %[[CONVERSION_CAST_CST]][] : tensor -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(1) -// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>) -// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) -// CHECK: %[[CAST1:.+]] = arith.bitcast %[[ARG1]] : f32 to i32 -// CHECK: %[[CMP:.+]] = arith.cmpi ult, %[[CAST1]], %[[EXTRACT_CST]] : i32 -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]] - -// ----- - -// For testing that complex within an iree_linalg_ext.op gets lowered -func.func @sort_complex(%arg0: tensor<1x5xf32>, %arg1 : tensor>) -> tensor<1x5xf32> { - %ui32 = mhlo.constant dense<2> : tensor - %1 = "mhlo.sort"(%arg0) ( { - ^bb0(%arg2: tensor, %arg3: tensor): - %2 = "mhlo.complex"(%arg2, %arg3) : (tensor, tensor) -> tensor> - %3 = mhlo.add %2, %arg1 : tensor> - %4 = "mhlo.real"(%3) : (tensor>) -> tensor - %5 = "mhlo.imag"(%3) : (tensor>) -> tensor - %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%6) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<1x5xf32>) -> tensor<1x5xf32> - return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: func.func @sort_complex( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: ) -// CHECK: %[[SORT:.+]] = iree_linalg_ext.sort -// CHECK-SAME: dimension(1) -// CHECK-SAME: outs(%[[ARG0]] : tensor<1x5xf32>) -// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) -// CHECK-NOT: mhlo.complex -// CHECK: %[[CMP:.+]] = arith.cmpf olt, %{{.+}}, %{{.+}} : f32 -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]] - -// ----- - -func.func @topk(%arg0: tensor<128xi32>, %arg1: tensor<128xi32>) -> (tensor<128xi32>) { - %0:2 = "mhlo.sort"(%arg0, %arg1) ( { - ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): - %1 = "mhlo.compare"(%arg2, %arg3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<128xi32>, tensor<128xi32>) -> (tensor<128xi32>, tensor<128xi32>) - return %0#0 : tensor<128xi32> -} -// CHECK-LABEL: func.func @topk -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[SORT:.+]]:2 = iree_linalg_ext.sort -// CHECK-SAME: dimension(0) -// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]] : tensor<128xi32>, tensor<128xi32>) -// CHECK: ^bb0(%[[ARG2:.+]]: i32, %[[ARG3:.+]]: i32, %{{.*}}: i32, %{{.*}}: i32) -// CHECK: %[[CMP:.+]] = arith.cmpi sgt, %[[ARG2]], %[[ARG3]] -// CHECK: iree_linalg_ext.yield %[[CMP]] -// CHECK: return %[[SORT]]#0 - -// ----- - -func.func @scatter_update_scalar_1D(%arg0: tensor<8xi32>, %arg1: tensor<4x1xi32>, - %arg2: tensor<4xi32>) -> tensor<8xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32> - return %0 : tensor<8xi32> -} -// CHECK-LABEL: func.func @scatter_update_scalar_1D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<4xi32>, tensor<4x1xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<8xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): -// CHECK: iree_linalg_ext.yield %[[V1]] -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_update_scalar_2D(%arg0: tensor<4x3xi32>, %arg1: tensor<3x2xi32>, - %arg2: tensor<3xi32>) -> tensor<4x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0, 1], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32> - return %0 : tensor<4x3xi32> -} -// CHECK-LABEL: func.func @scatter_update_scalar_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<3xi32>, tensor<3x2xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<4x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): -// CHECK: iree_linalg_ext.yield %[[V1]] -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_update_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, - %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - return %0 : tensor<6x3xi32> -} -// CHECK-LABEL: func.func @scatter_update_slice_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(true) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): -// CHECK: iree_linalg_ext.yield %[[V1]] -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, - %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - return %0 : tensor<6x3xi32> -} -// CHECK-LABEL: func.func @scatter_add_slice_2D -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(false) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<2x3xi32>, tensor<2x1xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) -// CHECK: ^bb0(%[[V1:.+]]: i32, %[[V2:.+]]: i32): -// -// The order is reverse. -// CHECK: %[[V3:.+]] = arith.addi %[[V2]], %[[V1]] -// CHECK: iree_linalg_ext.yield %[[V3]] -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_partial(%arg0: tensor<10x5xf32>, %arg1: tensor<3x1xi32>, %arg2: tensor<3x3xf32>) -> tensor<10x5xf32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<10x5xf32>, tensor<3x1xi32>, tensor<3x3xf32>) -> tensor<10x5xf32> - return %0 : tensor<10x5xf32> -} - -// CHECK-LABEL: func.func @scatter_partial -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(false) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<3x3xf32>, tensor<3x1xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<10x5xf32>) -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @rfft_1d(%input: tensor<8xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - %0 = "mhlo.fft"(%input) { - fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<8xf32>) -> tensor<5xcomplex> - %1 = "mhlo.real"(%0) : (tensor<5xcomplex>) -> tensor<5xf32> - %2 = "mhlo.imag"(%0) : (tensor<5xcomplex>) -> tensor<5xf32> - return %1, %2 : tensor<5xf32>, tensor<5xf32> -} -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)> -// CHECK: func.func @rfft_1d -// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[INDICES:.+]] = arith.constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xi32> -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> -// CHECK: %[[REORDERED:.+]] = linalg.generic -// CHECK-SAME: {indexing_maps = [#[[MAP]], #[[MAP]]] -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[INDICES]] -// CHECK-SAME: outs(%[[EMPTY]] -// CHECK: ^bb0(%[[IDX:.+]]: i32, %{{.+}}: f32): -// CHECK: %[[IDXVAL:.+]] = arith.index_cast %[[IDX]] : i32 to index -// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[IDXVAL]]] : tensor<8xf32> -// CHECK: linalg.yield %[[LOAD]] : f32 -// CHECK-DAG: %[[IMAG:.+]] = arith.constant dense<0.000000e+00> : tensor<8xf32> -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<1xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<1xf32> -// CHECK: %[[R1:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C1]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[REORDERED]], %[[IMAG]] -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<2xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<2xf32> -// CHECK: %[[R2:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C2]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[R1]]#0, %[[R1]]#1 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<4xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<4xf32> -// CHECK: %[[R3:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C3]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[R2]]#0, %[[R2]]#1 -// CHECK: %[[RES_REAL:.+]] = tensor.extract_slice %[[R3]]#0[0] [5] [1] : tensor<8xf32> to tensor<5xf32> -// CHECK: %[[RES_IMAG:.+]] = tensor.extract_slice %[[R3]]#1[0] [5] [1] : tensor<8xf32> to tensor<5xf32> -// CHECK: %{{.+}} = mhlo.complex %[[RES_REAL]], %[[RES_IMAG]] - -// ----- - -func.func @rfft_2d(%input: tensor<4x8xf32>) -> (tensor<4x5xf32>, tensor<4x5xf32>) { - %0 = "mhlo.fft"(%input) { - fft_length = dense<8> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<4x8xf32>) -> tensor<4x5xcomplex> - %1 = "mhlo.real"(%0) : (tensor<4x5xcomplex>) -> tensor<4x5xf32> - %2 = "mhlo.imag"(%0) : (tensor<4x5xcomplex>) -> tensor<4x5xf32> - return %1, %2 : tensor<4x5xf32>, tensor<4x5xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func.func @rfft_2d -// CHECK-SAME: %[[REAL:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[INDICES:.+]] = arith.constant dense<[0, 4, 2, 6, 1, 5, 3, 7]> : tensor<8xi32> -// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4x8xf32> -// CHECK: %[[REORDERED:.+]] = linalg.generic -// CHECK-SAME: {indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[INDICES]] -// CHECK-SAME: outs(%[[EMPTY]] -// CHECK: ^bb0(%[[IDX:.+]]: i32, %{{.+}}: f32): -// CHECK: %[[I:.+]] = linalg.index 0 -// CHECK: %[[IDXVAL:.+]] = arith.index_cast %[[IDX]] : i32 to index -// CHECK: %[[LOAD:.+]] = tensor.extract %[[REAL]][%[[I]], %[[IDXVAL]]] : tensor<4x8xf32> -// CHECK: linalg.yield %[[LOAD]] : f32 -// CHECK-DAG: %[[IMAG:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8xf32> -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<1xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<1xf32> -// CHECK: %[[R1:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C1]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[REORDERED]], %[[IMAG]] -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<2xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<2xf32> -// CHECK: %[[R2:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C2]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[R1]]#0, %[[R1]]#1 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[COEF_REAL:.+]] = arith.constant dense<{{.+}}> : tensor<4xf32> -// CHECK-DAG: %[[COEF_IMAG:.+]] = arith.constant dense<{{.+}}> : tensor<4xf32> -// CHECK: %[[R3:.+]]:2 = iree_linalg_ext.fft -// CHECK-SAME: ins(%[[C3]], %[[COEF_REAL]], %[[COEF_IMAG]] -// CHECK-SAME: outs(%[[R2]]#0, %[[R2]]#1 -// CHECK: %[[RES_REAL:.+]] = tensor.extract_slice %[[R3]]#0[0, 0] [4, 5] [1, 1] : tensor<4x8xf32> to tensor<4x5xf32> -// CHECK: %[[RES_IMAG:.+]] = tensor.extract_slice %[[R3]]#1[0, 0] [4, 5] [1, 1] : tensor<4x8xf32> to tensor<4x5xf32> -// CHECK: %{{.+}} = mhlo.complex %[[RES_REAL]], %[[RES_IMAG]] - -// ----- - -func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %0 = "mhlo.reverse"(%arg0) { - dimensions = dense<1> : tensor<1xi64> - } : (tensor<3x5xi32>) -> tensor<3x5xi32> - return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: func.func @reverse_dim1 -// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32> -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[IN]] : tensor<3x5xi32>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32> -// CHECK: return %[[REV]] - -// ----- - -func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { - %0 = "mhlo.reverse"(%arg0) { - dimensions = dense<1> : tensor<1xi64> - } : (tensor<3x5xui32>) -> tensor<3x5xui32> - return %0 : tensor<3x5xui32> -} -// CHECK-LABEL: func.func @reverse_unsigned -// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK: %[[BITCAST:.+]] = tensor.bitcast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32> -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>) -// CHECK-SAME: ins(%[[BITCAST]] : tensor<3x5xi32>) -// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32> -// CHECK: %[[BITCAST:.+]] = tensor.bitcast %[[REV]] : tensor<3x5xi32> to tensor<3x5xui32> -// CHECK: return %[[BITCAST]] - -// ----- - -func.func @reverse_multi_dim(%arg0: tensor) -> tensor { - %0 = "mhlo.reverse"(%arg0) { - dimensions = dense<[0, 1]> : tensor<2xi64> - } : (tensor) -> tensor - return %0 : tensor -} -// CHECK-LABEL: func.func @reverse_multi_dim -// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor -// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse -// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>) -// CHECK-SAME: ins(%[[IN]] : tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) : tensor -// CHECK: return %[[REV]] - -// ----- - -func.func @chlo_top_k_int(%arg : tensor<16x16xi32>) -> (tensor<16x8xi32>, tensor<16x8xi32>) { - %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xi32> -> (tensor<16x8xi32>, tensor<16x8xi32>) - return %1#0, %1#1 : tensor<16x8xi32>, tensor<16x8xi32> -} - -// CHECK: func.func @chlo_top_k_int -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[D2:.+]] = tensor.empty() : tensor<16x8xi32> -// CHECK: %[[D3:.+]] = tensor.empty() : tensor<16x8xi32> -// CHECK-DAG: %[[CNEG:.+]] = arith.constant -2147483648 : i32 -// CHECK-DAG: %[[CPOS:.+]] = arith.constant 2147483647 : i32 -// CHECK-DAG: %[[D4:.+]] = linalg.fill ins(%[[CNEG]] : i32) outs(%[[D2]] -// CHECK-DAG: %[[D5:.+]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D3]] -// CHECK: %[[D6:.+]]:2 = iree_linalg_ext.topk -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[D4]], %[[D5]] -// CHECK: ^bb0(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) -// CHECK: %[[D7:.+]] = arith.cmpi sge, %[[ARG1]], %[[ARG2]] : i32 -// CHECK: iree_linalg_ext.yield %[[D7]] : i1 -// CHECK: return %[[D6]]#0, %[[D6]]#1 - -// ----- - -func.func @chlo_top_k_float(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) - return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> -} - -// CHECK: func.func @chlo_top_k_float -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[D2:.+]] = tensor.empty() : tensor<16x8xf32> -// CHECK: %[[D3:.+]] = tensor.empty() : tensor<16x8xi32> -// CHECK-DAG: %[[CNEG:.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: %[[CPOS:.+]] = arith.constant 2147483647 : i32 -// CHECK-DAG: %[[D4:.+]] = linalg.fill ins(%[[CNEG]] : f32) outs(%[[D2]] -// CHECK-DAG: %[[D5:.+]] = linalg.fill ins(%[[CPOS]] : i32) outs(%[[D3]] -// CHECK: %[[D6:.+]]:2 = iree_linalg_ext.topk -// CHECK-SAME: dimension(1) -// CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[D4]], %[[D5]] -// CHECK: ^bb0(%[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) -// CHECK: %[[D7:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ARG2]] : f32 -// CHECK: iree_linalg_ext.yield %[[D7]] : i1 -// CHECK: return %[[D6]]#0, %[[D6]]#1 diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_stablehlo.mlir deleted file mode 100644 index a40972cf7518..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_mhlo_to_stablehlo.mlir +++ /dev/null @@ -1,13 +0,0 @@ -// RUN: iree-opt --iree-convert-mhlo-to-stablehlo %s | FileCheck %s - -// CHECK-LABEL: func.func @add -// CHECK-NEXT: stablehlo.add -// CHECK-NEXT: chlo.broadcast_add -// CHECK-NEXT: stablehlo.add -// CHECK-NEXT: return -func.func @add(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> - %1 = chlo.broadcast_add %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %2 = stablehlo.add %1, %arg1 : tensor<4xf32> - return %2 : tensor<4xf32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir deleted file mode 100644 index 509c079ec265..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/convert_structural_types.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors %s | FileCheck %s - -// CHECK-LABEL: @func_cfg_conversion -module @func_cfg_conversion { - // CHECK: func.func @caller(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32> - func.func @caller(%arg0: tensor<2xi32>, %arg1 : i1) -> tensor<2xi32> { - // CHECK: %[[RESULT:.*]] = call @callee(%arg0, %arg1) : (tensor<2xi32>, i1) -> tensor<2xi32> - %1 = call @callee(%arg0, %arg1) : (tensor<2xi32>, i1) -> tensor<2xi32> - // CHECK: return %[[RESULT]] : tensor<2xi32> - return %1 : tensor<2xi32> - } - - // CHECK: func.func @callee(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32> - func.func @callee(%arg0: tensor<2xi32>, %arg1: i1) -> tensor<2xi32> { - // CHECK: cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xi32>), ^bb2(%arg0 : tensor<2xi32>) - cf.cond_br %arg1, ^bb1(%arg0 : tensor<2xi32>), ^bb2(%arg0 : tensor<2xi32>) - // CHECK: ^bb1(%[[BB1_PHI:.*]]: tensor<2xi32>) - ^bb1(%phi0 : tensor<2xi32>) : - // CHECK: %[[BB1_PHI_ADD:.*]] = linalg.generic - // CHECK: cf.br ^bb2(%[[BB1_PHI_ADD]] : tensor<2xi32>) - %0 = "mhlo.add"(%phi0, %phi0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - cf.br ^bb2(%0 : tensor<2xi32>) - // CHECK: ^bb2(%[[BB2_PHI:.*]]: tensor<2xi32>) - ^bb2(%phi1 : tensor<2xi32>): - // CHECK: %[[BB2_PHI_ADD:.*]] = linalg.generic - // CHECK: return %[[BB2_PHI_ADD]] : tensor<2xi32> - %1 = "mhlo.add"(%phi1, %phi1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %1 : tensor<2xi32> - } -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir deleted file mode 100644 index dcc27db9bcb5..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/dynamic_shape.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors %s | FileCheck %s - -func.func @dynamic_shape(%operand: tensor) -> (tensor) -attributes {iree.dispatch_fn_name = ""} { - %result = "mhlo.exponential"(%operand) : (tensor) -> tensor - return %result : tensor -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: func.func @dynamic_shape -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[SHAPE:.+]] = shape.shape_of %[[ARG0]] -// CHECK: %[[T0:.+]] = tensor.extract %[[SHAPE]][%[[C0]]] -// CHECK: %[[T1:.+]] = tensor.extract %[[SHAPE]][%[[C1]]] -// CHECK: %[[T2:.+]] = tensor.empty(%[[T0]], %[[T1]]) -// CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs(%[[T2]] : tensor) -// CHECK-NEXT: ^{{.+}}(%[[OPERAND_IN:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32): -// CHECK-NEXT: %[[RESULT:.+]] = math.exp %[[OPERAND_IN]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 -// CHECK: return %[[T3]] diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/fft.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/fft.mlir deleted file mode 100644 index 354ce7e147db..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/fft.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors --canonicalize %s | FileCheck %s - -func.func @rfft_1d(%input: tensor<32xf32>) -> (tensor<17xf32>, tensor<17xf32>) { - %0 = "mhlo.fft"(%input) { - fft_length = dense<32> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<32xf32>) -> tensor<17xcomplex> - %1 = "mhlo.real"(%0) : (tensor<17xcomplex>) -> tensor<17xf32> - %2 = "mhlo.imag"(%0) : (tensor<17xcomplex>) -> tensor<17xf32> - return %1, %2 : tensor<17xf32>, tensor<17xf32> -} -// CHECK: func.func @rfft_1d -// CHECK-SAME: %[[Arg0:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[RealMatrix:.+]] = arith.constant dense<"0x0000803F{{.*}}"> : tensor<32x17xf32> -// CHECK-DAG: %[[ImagMatrix:.+]] = arith.constant dense<"0x00000080{{.*}}"> : tensor<32x17xf32> -// CHECK-DAG: %[[Zero:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RealInit:.+]] = tensor.empty() : tensor<17xf32> -// CHECK: %[[RealFill:.+]] = linalg.fill -// CHECK-SAME: ins(%[[Zero]] : -// CHECK-SAME: outs(%[[RealInit]] : -// CHECK: %[[RealRes:.+]] = linalg.vecmat -// CHECK-SAME: ins(%[[Arg0]], %[[RealMatrix]] : tensor<32xf32>, tensor<32x17xf32>) -// CHECK-SAME: outs(%[[RealFill]] : tensor<17xf32>) -> tensor<17xf32> -// CHECK: %[[ImagInit:.+]] = tensor.empty() : tensor<17xf32> -// CHECK: %[[ImagFill:.+]] = linalg.fill -// CHECK-SAME: ins(%[[Zero]] : -// CHECK-SAME: outs(%[[ImagInit]] : -// CHECK: %[[ImagRes:.+]] = linalg.vecmat -// CHECK-SAME: ins(%[[Arg0]], %[[ImagMatrix]] : tensor<32xf32>, tensor<32x17xf32>) -// CHECK-SAME: outs(%[[ImagFill]] : tensor<17xf32>) -> tensor<17xf32> -// CHECK: %[[ComplexRes:.*]] = linalg.generic -// CHECK: %[[ReRes:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ComplexRes]] -// CHECK: %[[ImRes:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ComplexRes]] -// CHECK: return %[[ReRes]], %[[ImRes]] : tensor<17xf32>, tensor<17xf32> - -// ----- - -func.func @rfft_2d(%input: tensor<1x32xf32>) -> (tensor<1x17xf32>, tensor<1x17xf32>) { - %0 = "mhlo.fft"(%input) { - fft_length = dense<32> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<1x32xf32>) -> tensor<1x17xcomplex> - %1 = "mhlo.real"(%0) : (tensor<1x17xcomplex>) -> tensor<1x17xf32> - %2 = "mhlo.imag"(%0) : (tensor<1x17xcomplex>) -> tensor<1x17xf32> - return %1, %2 : tensor<1x17xf32>, tensor<1x17xf32> -} -// CHECK: func.func @rfft_2d -// CHECK-SAME: %[[Arg0:[a-zA-Z0-9_]*]] -// CHECK-DAG: %[[RealMatrix:.+]] = arith.constant dense<"0x0000803F{{.*}}"> : tensor<32x17xf32> -// CHECK-DAG: %[[ImagMatrix:.+]] = arith.constant dense<"0x00000080{{.*}}"> : tensor<32x17xf32> -// CHECK-DAG: %[[Zero:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[RealInit:.+]] = tensor.empty() : tensor<1x17xf32> -// CHECK: %[[RealFill:.+]] = linalg.fill -// CHECK-SAME: ins(%[[Zero]] : -// CHECK-SAME: outs(%[[RealInit]] : -// CHECK: %[[RealRes:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[Arg0]], %[[RealMatrix]] : tensor<1x32xf32>, tensor<32x17xf32>) -// CHECK-SAME: outs(%[[RealFill]] : tensor<1x17xf32>) -> tensor<1x17xf32> -// CHECK: %[[ImagInit:.+]] = tensor.empty() : tensor<1x17xf32> -// CHECK: %[[ImagFill:.+]] = linalg.fill -// CHECK-SAME: ins(%[[Zero]] : -// CHECK-SAME: outs(%[[ImagInit]] : -// CHECK: %[[ImagRes:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[Arg0]], %[[ImagMatrix]] : tensor<1x32xf32>, tensor<32x17xf32>) -// CHECK-SAME: outs(%[[ImagFill]] : tensor<1x17xf32>) -> tensor<1x17xf32> -// CHECK: %[[ComplexRes:.*]] = linalg.generic -// CHECK: %[[ReRes:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ComplexRes]] -// CHECK: %[[ImRes:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ComplexRes]] -// CHECK: return %[[ReRes]], %[[ImRes]] : tensor<1x17xf32>, tensor<1x17xf32> diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir deleted file mode 100644 index 107ea14e0d30..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/flatten_tuples_in_cfg.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-flatten-tuples-in-cfg --canonicalize %s | FileCheck %s -// We rely on canonicalization to cancel out tuple/get_element operations, so -// we test this followed by the canonicalizer rather than just the pass in -// isolation. -// TODO: It would be better if the pass was standalone. - -// CHECK-LABEL: @flatten_func -module @flatten_func { - // CHECK: func.func @caller(%arg0: i1, %arg1: tensor) -> tensor - func.func @caller(%arg0 : i1, %arg1: tensor) -> tensor { - // CHECK: %[[RESULT:.*]]:2 = call @callee(%arg0, %arg1, %arg1, %arg1) : (i1, tensor, tensor, tensor) -> (tensor, tensor) - %0 = "mhlo.tuple"(%arg1, %arg1) : (tensor, tensor) -> tuple, tensor> - %1 = "mhlo.tuple"(%arg1, %0) : (tensor, tuple, tensor>) -> tuple, tuple, tensor>> - %2 = call @callee(%arg0, %1) : (i1, tuple, tuple, tensor>>) -> tuple, tensor> - %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple, tensor>) -> tensor - // CHECK: return %[[RESULT]]#0 : tensor - return %3 : tensor - } - - // CHECK: func.func private @callee(%arg0: i1, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) - func.func private @callee(%arg0: i1, %arg1: tuple, tuple, tensor>>) -> tuple, tensor> { - // CHECK-DAG: %[[RESULT0:.*]] = arith.select %arg0, %arg2, %arg1 : tensor - // CHECK-DAG: %[[RESULT1:.*]] = arith.select %arg0, %arg3, %arg1 : tensor - // CHECK: return %[[RESULT0]], %[[RESULT1]] : tensor, tensor - %0 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple, tuple, tensor>>) -> tensor - %1 = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32} : (tuple, tuple, tensor>>) -> tuple, tensor> - cf.cond_br %arg0, ^bb1(%1 : tuple, tensor>), ^bb2(%0 : tensor) - ^bb1(%phi0 : tuple, tensor>): - return %phi0 : tuple, tensor> - ^bb2(%phi1 : tensor): - %2 = "mhlo.tuple"(%phi1, %phi1) : (tensor, tensor) -> tuple, tensor> - cf.br ^bb1(%2 : tuple, tensor>) - } -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir deleted file mode 100644 index 0c4ae295ee9c..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_linalg.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors --canonicalize -cse %s | FileCheck %s - -func.func @concatenate(%arg0: tensor<2x2xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x9xi32> { - %cst = mhlo.constant dense<514> : tensor<2x3xi32> - %0 = "mhlo.concatenate"(%arg0, %cst, %arg1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x4xi32>) -> tensor<2x9xi32> - return %0 : tensor<2x9xi32> -} -// CHECK: func.func @concatenate -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]] -// CHECK: %[[CST:.+]] = arith.constant dense<514> : tensor<2x3xi32> -// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<2x9xi32> -// CHECK: %[[T0:.+]] = tensor.insert_slice %[[ARG0]] into %[[INIT]][0, 0] [2, 2] [1, 1] -// CHECK: %[[T1:.+]] = tensor.insert_slice %[[CST]] into %[[T0]][0, 2] [2, 3] [1, 1] -// CHECK: %[[T2:.+]] = tensor.insert_slice %[[ARG1]] into %[[T1]][0, 5] [2, 4] [1, 1] -// CHECK: return %[[T2]] - -// ----- - -// CHECK: ml_program.global private mutable @variable(dense<0> : tensor<2xi32>) : tensor<2xi32> -ml_program.global private mutable @variable(dense<0> : tensor<2xui32>) : tensor<2xui32> -// CHECK: func.func @global_types() -> (tensor<2xi32> {iree.abi.encoding = tensor<2xui32>}) -func.func @global_types() -> tensor<2xui32> { - // CHECK-NEXT: %[[VALUE:.+]] = ml_program.global_load @variable : tensor<2xi32> - %0 = ml_program.global_load @variable : tensor<2xui32> - // CHECK: return %[[VALUE]] : tensor<2xi32> - return %0 : tensor<2xui32> -} - -// ----- - -// CHECK: func.func @optimization_barrier -// CHECK: %[[RESULT1:.+]] = util.optimization_barrier %arg0 : tensor<3x4xf32 -// CHECK: %[[RESULT2:.+]] = util.optimization_barrier %arg1 : tensor<4xi32> -// CHECK: return %[[RESULT1]], %[[RESULT2]] -func.func @optimization_barrier(%arg0: tensor<3x4xf32>, %arg1: tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>) { - %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) : (tensor<3x4xf32>, tensor<4xi32>) -> (tensor<3x4xf32>, tensor<4xi32>) - return %0, %1 : tensor<3x4xf32>, tensor<4xi32> -} - -// ----- - -// CHECK: @unsigned_integer_input_output(%[[ARG0:.*]]: tensor<2x2xi32> {iree.abi.encoding = tensor<2x2xui32>}, %[[ARG1:.*]]: tensor<2x2xi32> {iree.abi.encoding = tensor<2x2xui32>}) -> (tensor<2x2xi32> {iree.abi.encoding = tensor<2x2xui32>}) -func.func @unsigned_integer_input_output(%arg0: tensor<2x2xui32>, %arg1: tensor<2x2xui32>) -> tensor<2x2xui32> { - // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<2x2xi32> - // CHECK: %[[RESULT:.*]] = linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x2xi32>, tensor<2x2xi32> - // CHECK-SAME: outs(%[[INIT]] : tensor<2x2xi32>) - // CHECK: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %out: i32): - // CHECK: %[[ADD:.*]] = arith.addi %[[IN0]], %[[IN1]] : i32 - // CHECK: linalg.yield %[[ADD:.*]] : i32 - %0 = "mhlo.add"(%arg0, %arg1) : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xui32> - // CHECK: return %[[RESULT]] : tensor<2x2xi32> - return %0 : tensor<2x2xui32> -} - -// ----- - -// CHECK: func.func @aliasing_output -// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<3x4xf32> {iree.abi.output = 1 : index} -// CHECK-SAME: %[[ARG1:[^:]+]]: tensor<4xi32> {iree.abi.encoding = tensor<4xui32>} -func.func @aliasing_output(%arg0: tensor<3x4xf32> {tf.aliasing_output = 1 : i32}, %arg1: tensor<4xui32>) -> (tensor<4xui32>, tensor<3x4xf32>) { - return %arg1, %arg0 : tensor<4xui32>, tensor<3x4xf32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir deleted file mode 100644 index 5a0ec4708574..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing.mlir +++ /dev/null @@ -1,354 +0,0 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s - -// CHECK-LABEL: @batch_norm_inference -// CHECK-SAME: %[[X:[^:[:space:]]+]] -// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] -// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] -// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] -// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] -func.func @batch_norm_inference( - %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, - %mean: tensor<256xf32>, %variance: tensor<256xf32>) - -> (tensor<4x256xf32>) { - // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> - // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> - // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> - %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) - {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : - (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, - tensor<256xf32>) -> tensor<4x256xf32> - // CHECK-DAG: return %[[RESULT]] - return %0 : tensor<4x256xf32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_scalar_binary(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) -func.func @reorder_broadcast_in_dim_scalar_binary(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>) { - // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[ADD]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[ATAN2:.*]] = mhlo.atan2 %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[ATAN2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[DIV:.*]] = mhlo.divide %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[DIV]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[MAX:.*]] = mhlo.maximum %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[MAX]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[MIN]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[MUL:.*]] = mhlo.multiply %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[MUL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[POW:.*]] = mhlo.power %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[POW]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[REM:.*]] = mhlo.remainder %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[REM]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[SL:.*]] = mhlo.shift_left %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - // CHECK: %[[SRA:.*]] = mhlo.shift_right_arithmetic %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SRA]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - // CHECK: %[[SRL:.*]] = mhlo.shift_right_logical %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SRL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - // CHECK: %[[SUB:.*]] = mhlo.subtract %[[ARG0]], %[[ARG1]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SUB]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[AND:.*]] = mhlo.and %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[AND]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - // CHECK: %[[OR:.*]] = mhlo.or %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[OR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - // CHECK: %[[XOR:.*]] = mhlo.xor %[[ARG2]], %[[ARG3]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[XOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %2 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - %3 = "mhlo.broadcast_in_dim"(%arg3) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xi32> - %4 = mhlo.add %0, %1 : tensor<1x8x8x64xf32> - %5 = mhlo.atan2 %0, %1 : tensor<1x8x8x64xf32> - %6 = mhlo.divide %0, %1 : tensor<1x8x8x64xf32> - %7 = mhlo.maximum %0, %1 : tensor<1x8x8x64xf32> - %8 = mhlo.minimum %0, %1 : tensor<1x8x8x64xf32> - %9 = mhlo.multiply %0, %1 : tensor<1x8x8x64xf32> - %10 = mhlo.power %0, %1 : tensor<1x8x8x64xf32> - %11 = mhlo.remainder %0, %1 : tensor<1x8x8x64xf32> - %12 = mhlo.shift_left %2, %3 : tensor<1x8x8x64xi32> - %13 = mhlo.shift_right_arithmetic %2, %3 : tensor<1x8x8x64xi32> - %14 = mhlo.shift_right_logical %2, %3 : tensor<1x8x8x64xi32> - %15 = mhlo.subtract %0, %1 : tensor<1x8x8x64xf32> - %16 = mhlo.and %2, %3 : tensor<1x8x8x64xi32> - %17 = mhlo.or %2, %3 : tensor<1x8x8x64xi32> - %18 = mhlo.xor %2, %3 : tensor<1x8x8x64xi32> - return %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18 : tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32>, tensor<1x8x8x64xi32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_scalar_binary_diff_type(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor<1x8x8x64xcomplex> -func.func @reorder_broadcast_in_dim_scalar_binary_diff_type(%arg0: tensor, %arg1: tensor) -> tensor<1x8x8x64xcomplex> { - // CHECK: %0 = mhlo.complex %[[ARG0]], %[[ARG1]] : tensor> - // CHECK: "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>) -> tensor<1x8x8x64xcomplex> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %2 = "mhlo.complex"(%0, %1) : (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xcomplex> - return %2 : tensor<1x8x8x64xcomplex> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_1d_binary(%[[ARG0:.*]]: tensor<3xf32>, %[[ARG1:.*]]: tensor<3xf32>) -> tensor<4x3xf32> -func.func @reorder_broadcast_in_dim_1d_binary(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<4x3xf32> { - // CHECK: %[[ATAN2:.*]] = mhlo.atan2 %[[ARG0]], %[[ARG1]] : tensor<3xf32> - // CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[ATAN2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32> - %1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32> - %2 = mhlo.atan2 %0, %1 : tensor<4x3xf32> - // CHECK: return %[[BCAST]] - return %2 : tensor<4x3xf32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_2d_binary(%[[ARG0:.*]]: tensor<2x4xi32>, %[[ARG1:.*]]: tensor<2x4xi32>) -> tensor<3x2x4xi32> -func.func @reorder_broadcast_in_dim_2d_binary(%arg0: tensor<2x4xi32>, %arg1: tensor<2x4xi32>) -> tensor<3x2x4xi32> { - // CHECK: %[[POWER:.*]] = mhlo.power %[[ARG0]], %[[ARG1]] : tensor<2x4xi32> - // CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[POWER]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> - %1 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> - %2 = mhlo.power %0, %1 : tensor<3x2x4xi32> - // CHECK: return %[[BCAST]] - return %2 : tensor<3x2x4xi32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_scalar_unary(%[[ARG0:.*]]: tensor) -func.func @reorder_broadcast_in_dim_scalar_unary(%arg0: tensor) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { - // CHECK: %[[ABS:.*]] = mhlo.abs %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[ABS]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[CEIL:.*]] = mhlo.ceil %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[CEIL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[COSINE:.*]] = mhlo.cosine %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[COSINE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[EXP:.*]] = mhlo.exponential %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[EXP]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[FLOOR:.*]] = mhlo.floor %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[FLOOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[LOG:.*]] = mhlo.log %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[LOG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[NEG:.*]] = mhlo.negate %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[NEG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[ROUND:.*]] = mhlo.round_nearest_afz %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[ROUND]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[RSQRT:.*]] = mhlo.rsqrt %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[RSQRT]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[SIGN:.*]] = mhlo.sign %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SIGN]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[SINE:.*]] = mhlo.sine %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SINE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[SQRT:.*]] = mhlo.sqrt %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[SQRT]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[TANH:.*]] = mhlo.tanh %[[ARG0]] : tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[TANH]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %1 = mhlo.abs %0 : tensor<1x8x8x64xf32> - %2 = mhlo.ceil %0 : tensor<1x8x8x64xf32> - %3 = mhlo.cosine %0 : tensor<1x8x8x64xf32> - %4 = mhlo.exponential %0 : tensor<1x8x8x64xf32> - %5 = mhlo.floor %0 : tensor<1x8x8x64xf32> - %6 = mhlo.log %0 : tensor<1x8x8x64xf32> - %7 = mhlo.negate %0 : tensor<1x8x8x64xf32> - %8 = mhlo.round_nearest_afz %0 : tensor<1x8x8x64xf32> - %9 = mhlo.rsqrt %0 : tensor<1x8x8x64xf32> - %10 = mhlo.sign %0 : tensor<1x8x8x64xf32> - %11 = mhlo.sine %0 : tensor<1x8x8x64xf32> - %12 = mhlo.sqrt %0 : tensor<1x8x8x64xf32> - %13 = mhlo.tanh %0 : tensor<1x8x8x64xf32> - return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13: tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_1d_unary(%[[ARG0:.*]]: tensor<3xf32>) -> tensor<4x3xf32> -func.func @reorder_broadcast_in_dim_1d_unary(%arg0: tensor<3xf32>) -> tensor<4x3xf32> { - // CHECK: %[[COS:.*]] = mhlo.cosine %[[ARG0]] : tensor<3xf32> - // CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[COS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32> - %1 = mhlo.cosine %0 : tensor<4x3xf32> - // CHECK: return %[[BCAST]] - return %1 : tensor<4x3xf32> -} - -// ----- - -// CHECK: @reorder_in_dim_2d_unary(%[[ARG0:.*]]: tensor<2x4xf32>) -> tensor<3x2x4xf32> -func.func @reorder_in_dim_2d_unary(%arg0: tensor<2x4xf32>) -> tensor<3x2x4xf32> { - // CHECK: %[[LOG:.*]] = mhlo.log %[[ARG0]] : tensor<2x4xf32> - // CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%[[LOG]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<3x2x4xf32> - %1 = mhlo.log %0 : tensor<3x2x4xf32> - // CHECK: return %[[BCAST]] - return %1 : tensor<3x2x4xf32> -} - -// ----- - -// CHECK: @reorder_broadcast_in_dim_scalar_unary_diff_type(%[[ARG0:.*]]: tensor>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) -func.func @reorder_broadcast_in_dim_scalar_unary_diff_type(%arg0: tensor>) -> (tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32>) { - // CHECK: %[[REAL:.*]] = mhlo.real %[[ARG0]] : (tensor>) -> tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[REAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - // CHECK: %[[IMAG:.*]] = mhlo.imag %[[ARG0]] : (tensor>) -> tensor - // CHECK: "mhlo.broadcast_in_dim"(%[[IMAG]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<1x8x8x64xf32> - %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor>) -> tensor<1x8x8x64xcomplex> - %1 = mhlo.real %0 : (tensor<1x8x8x64xcomplex>) -> tensor<1x8x8x64xf32> - %2 = mhlo.imag %0 : (tensor<1x8x8x64xcomplex>) -> tensor<1x8x8x64xf32> - return %1, %2: tensor<1x8x8x64xf32>, tensor<1x8x8x64xf32> -} - -// ----- - -func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<3x5xf32> { - %shape = mhlo.constant dense<[3, 5]> : tensor<2xi64> - %0 = "mhlo.rng"(%arg0, %arg1, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor<3x5xf32> - return %0 : tensor<3x5xf32> -} -// CHECK-LABEL: func.func @rng_normal -// CHECK: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-DAG: %{{.*}} = mhlo.constant dense<{{.*}}> : tensor<8xf32> -// CHECK-DAG: %{{.*}} = mhlo.constant dense<{{.*}}> : tensor<8xf32> -// CHECK-DAG: %{{.*}} = mhlo.constant dense<{{.*}}> : tensor<8xf32> -// CHECK: %[[SIGMA:.+]] = "mhlo.broadcast"(%[[ARG1]]) {broadcast_sizes = dense<8> : tensor<1xi64>} : (tensor) -> tensor<8xf32> -// -// mag = sigma * sqrt(-2.0 * log(u1)) where sqrt values are -// constants. -// -// CHECK: %[[MAG:.+]] = mhlo.multiply %[[SIGMA]], %{{.*}} : tensor<8xf32> -// -// z0 = mag * cos(two_pi * u2) + mu; -// z1 = mag * sin(two_pi * u2) + mu; -// -// CHECK: %[[MU:.+]] = "mhlo.broadcast"(%[[ARG0]]) {broadcast_sizes = dense<8> : tensor<1xi64>} : (tensor) -> tensor<8xf32> -// CHECK: %[[T1:.+]] = mhlo.multiply %[[MAG]], %{{.*}} : tensor<8xf32> -// CHECK: %[[Z0:.+]] = mhlo.add %[[T1:.+]], %[[MU]] : tensor<8xf32> -// CHECK: %[[T2:.+]] = mhlo.multiply %[[MAG]], %{{.*}} : tensor<8xf32> -// CHECK: %[[Z1:.+]] = mhlo.add %[[T2:.+]], %[[MU]] : tensor<8xf32> -// -// Concate and reshape the output. -// CHECK: %[[CON:.+]] = "mhlo.concatenate"(%[[Z0]], %[[Z1]]) {dimension = 0 : i64} : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[CON]][0] [15] [1] : tensor<16xf32> to tensor<15xf32> -// CHECK: %[[RES:.+]] = mhlo.reshape %[[SLICE]] : (tensor<15xf32>) -> tensor<3x5xf32> -// CHECK: return %[[RES]] - -// ----- - -func.func @mul_float_bool_cast(%arg0 : tensor, %arg1 : tensor) -> tensor { - %0 = mhlo.convert %arg0 : (tensor) -> tensor - %1 = "mhlo.multiply"(%0, %arg1) : (tensor, tensor) -> tensor - return %1 : tensor -} - -// CHECK-LABEL: @mul_float_bool_cast -// CHECK: %[[ZERO:.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[BTOF:.+]] = mhlo.convert %arg0 : (tensor) -> tensor -// CHECK: %[[FTOB:.+]] = mhlo.convert %[[BTOF]] : (tensor) -> tensor -// CHECK: %[[SHP:.+]] = shape.shape_of %[[BTOF]] : tensor -> tensor<1xindex> -// CHECK: %[[BROADCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ZERO]], %[[SHP]]) {broadcast_dimensions = dense<> : tensor<0xi64>} -// CHECK: %[[SELECT:.+]] = mhlo.select %[[FTOB]], %arg1, %[[BROADCAST]] - -// ----- - -func.func @mul_float_bool_cast_broadcast(%arg0: tensor<5xi1>, %arg1: tensor<5x6xf32>) -> tensor<5x6xf32> { - %0 = mhlo.convert %arg0 : (tensor<5xi1>) -> tensor<5xf32> - %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<5x6xf32> - %2 = mhlo.multiply %1, %arg1 : tensor<5x6xf32> - return %2 : tensor<5x6xf32> -} - -// CHECK-LABEL: @mul_float_bool_cast_broadcast -// CHECK: mhlo.select - -// ----- - -func.func @mul_float_bool_cast_dyn_broadcast(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.convert %arg0 : (tensor) -> tensor - %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%0, %1) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor - %3 = mhlo.multiply %2, %arg1 : tensor - return %3 : tensor -} - -// CHECK-LABEL: @mul_float_bool_cast_dyn_broadcast -// CHECK: mhlo.select - -// ----- - -// CHECK-LABEL: @dot_general_fuse_both_with_attrs -func.func @dot_general_fuse_both_with_attrs(%arg0: tensor<16x64x128xf16>, %arg1: tensor<16x128x3072xf16>) -> tensor<16x64x3072xf32> { - %0 = mhlo.convert %arg0 : (tensor<16x64x128xf16>) -> tensor<16x64x128xf32> - %1 = mhlo.convert %arg1 : (tensor<16x128x3072xf16>) -> tensor<16x128x3072xf32> - // CHECK: "mhlo.dot_general"(%arg0, %arg1) - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot, - // CHECK-SAME: precision_config = [#mhlo, #mhlo] - // CHECK-SAME: -> tensor<16x64x3072xf32> - %2 = "mhlo.dot_general"(%0, %1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor<16x64x128xf32>, tensor<16x128x3072xf32>) -> tensor<16x64x3072xf32> - return %2 : tensor<16x64x3072xf32> -} - -// ----- - -// CHECK-LABEL: @dot_general_fuse_one -func.func @dot_general_fuse_one(%arg0: tensor<16x64x128xf64>, %arg1: tensor<16x128x3072xf16>) -> tensor<16x64x3072xf32> { - %0 = mhlo.convert %arg0 : (tensor<16x64x128xf64>) -> tensor<16x64x128xf32> - %1 = mhlo.convert%arg1 : (tensor<16x128x3072xf16>) -> tensor<16x128x3072xf32> - // CHECK: %[[CONVERT:.+]] = mhlo.convert %arg0 - // CHECK: "mhlo.dot_general"(%[[CONVERT]], %arg1) - %2 = "mhlo.dot_general"(%0, %1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor<16x64x128xf32>, tensor<16x128x3072xf32>) -> tensor<16x64x3072xf32> - return %2 : tensor<16x64x3072xf32> -} - -// ----- - -// CHECK-LABEL: @dot_basic -func.func @dot_basic(%arg0: tensor<4x4xf16>, %arg1: tensor<4x4xf16>) -> tensor<4x4xf32> { - %0 = mhlo.convert %arg0 : (tensor<4x4xf16>) -> tensor<4x4xf32> - %1 = mhlo.convert %arg1 : (tensor<4x4xf16>) -> tensor<4x4xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot"(%arg0, %arg1) - %2 = "mhlo.dot"(%0, %1) {precision_config = [#mhlo, #mhlo]} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - // CHECK: return %[[DOT]] - return %2 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: @convolution -func.func @convolution(%arg0: tensor<16x32x256xbf16>, %arg1: tensor<1x256x256xbf16>) -> tensor<16x32x256xf32> { - %cast = mhlo.convert %arg0 : (tensor<16x32x256xbf16>) -> tensor<16x32x256xf32> - // CHECK: %[[CONV:.+]] = mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: -> tensor<16x32x256xf32> - %0 = "mhlo.convolution"(%cast, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, - feature_group_count = 1 : i64, - lhs_dilation = dense<1> : tensor<1xi64>, - padding = dense<0> : tensor<1x2xi64>, - precision_config = [#mhlo, #mhlo], - rhs_dilation = dense<1> : tensor<1xi64>, - window_strides = dense<1> : tensor<1xi64> - } : (tensor<16x32x256xf32>, tensor<1x256x256xbf16>) -> tensor<16x32x256xf32> - // CHECK: return %[[CONV]] - func.return %0 : tensor<16x32x256xf32> -} - -// ----- - -// CHECK-LABEL: @dynamic_dot_general -// This verifies non-crashing, the lowering to linalg happens elsewhere. -func.func @dynamic_dot_general(%arg1: tensor, %arg2: tensor) -> tensor { - %2 = "mhlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor, tensor) -> tensor - return %2 : tensor -} - diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir deleted file mode 100644 index 248c784b9362..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_preprocessing_canonicalize_dot_general.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s - -// CHECK-LABEL: @dot_general_2d -func.func public @dot_general_2d(%arg0: tensor<4x3xf32> {mhlo.sharding = ""}, %arg1: tensor<4x3xf32> {mhlo.sharding = ""}) -> tensor<3xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<3xf32> - - // CHECK: %[[LHS:.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> - // CHECK: %[[RHS:.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-SAME: lhs_batching_dimensions = [0] - // CHECK-SAME: rhs_batching_dimensions = [0] - // CHECK-SAME: lhs_contracting_dimensions = [1] - // CHECK-SAME: rhs_contracting_dimensions = [1]> - // CHECK-SAME: precision_config = [#mhlo, #mhlo] - return %0 : tensor<3xf32> -} - -// CHECK-LABEL: @dot_general_4d -func.func public @dot_general_4d(%arg0: tensor<1x2x3xf32> {mhlo.sharding = ""}, %arg1: tensor<1x4x2x3xf32> {mhlo.sharding = ""}) -> tensor<1x2x4xf32> { - %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]} : (tensor<1x2x3xf32>, tensor<1x4x2x3xf32>) -> tensor<1x2x4xf32> - - // CHECK: %[[RHS_T:.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x2x3xf32>) -> tensor<1x2x3x4xf32> - // CHECK: %[[LHS_R:.+]] = mhlo.reshape %arg0 : (tensor<1x2x3xf32>) -> tensor<2x1x3xf32> - // CHECK: %[[RHS_R:.+]] = mhlo.reshape %[[RHS_T]] : (tensor<1x2x3x4xf32>) -> tensor<2x3x4xf32> - // CHECK: %[[DOT:.+]] = "mhlo.dot_general"(%[[LHS_R]], %[[RHS_R]]) - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-SAME: lhs_batching_dimensions = [0] - // CHECK-SAME: rhs_batching_dimensions = [0] - // CHECK-SAME: lhs_contracting_dimensions = [2] - // CHECK-SAME: rhs_contracting_dimensions = [1]> - // CHECK-SAME: precision_config = [#mhlo, #mhlo] - // CHECK: mhlo.reshape %[[DOT]] : (tensor<2x1x4xf32>) -> tensor<1x2x4xf32> - return %0 : tensor<1x2x4xf32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir deleted file mode 100644 index c7ffeadb7746..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/mhlo_to_mhlo_scatter.mlir +++ /dev/null @@ -1,295 +0,0 @@ -// RUN: iree-opt --split-input-file --verify-diagnostics --iree-mhlo-to-mhlo-preprocessing %s | FileCheck %s - -func.func @scatter_implicit_batch(%arg0: tensor<5x6x7xi32>, %arg1: tensor<2xi32>, %arg2: tensor<7xi32>) -> tensor<5x6x7xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<5x6x7xi32>, tensor<2xi32>, tensor<7xi32>) -> tensor<5x6x7xi32> - return %0 : tensor<5x6x7xi32> -} - -// CHECK-LABEL: func.func @scatter_implicit_batch -// CHECK-DAG: %[[RE_I:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<2xi32> into tensor<1x2xi32> -// CHECK-DAG: %[[RE_U:.+]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1]] : tensor<7xi32> into tensor<1x7xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%{{.*}}, %[[RE_I]], %[[RE_U]]) -// CHECK: mhlo.return %{{.*}} -// CHECK: update_window_dims = [1], -// CHECK-SAME: inserted_window_dims = [0, 1] -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - -// ----- - -func.func @scatter_implicit_indices(%arg0: tensor<17x11xf32>, - %arg1: tensor<7xi32>, %arg2: tensor<7x11xf32>) -> tensor<17x11xf32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1>, - unique_indices = false - } : (tensor<17x11xf32>, tensor<7xi32>, tensor<7x11xf32>) -> tensor<17x11xf32> - return %0 : tensor<17x11xf32> -} - -// CHECK-LABEL: func.func @scatter_implicit_indices -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg1 {{\[\[}}0, 1]] : tensor<7xi32> into tensor<7x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[EXPAND]], %arg2) ({ -// CHECK-NEXT: ^bb0(%[[A0:.+]]: tensor, %[[A1:.+]]: tensor): -// CHECK-NEXT: %[[ADD:.+]] = mhlo.add %[[A0]], %[[A1]] : tensor -// CHECK-NEXT: mhlo.return %[[ADD]] -// CHECK-NEXT: }) -// CHECK-SAME: indices_are_sorted = false, -// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1], -// CHECK-SAME: inserted_window_dims = [0], -// CHECK-SAME: scatter_dims_to_operand_dims = [0], -// CHECK-SAME: index_vector_dim = 1>, -// CHECK-SAME: unique_indices = false - -// ----- - -func.func @scatter_collapse_batch(%arg0: tensor<1x24x512xi32>, - %arg1: tensor<2x3x2xi32>, %arg2: tensor<2x3x512xi32>) -> tensor<1x24x512xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [2], - inserted_window_dims = [0, 1], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 2, - >, - unique_indices = true - } : (tensor<1x24x512xi32>, tensor<2x3x2xi32>, tensor<2x3x512xi32>) -> tensor<1x24x512xi32> - return %0 : tensor<1x24x512xi32> -} - -// CHECK-LABEL: func.func @scatter_collapse_batch -// CHECK: %[[COLLAPSE0:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<2x3x2xi32> into tensor<6x2xi32> -// CHECK: %[[COLLAPSE1:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0, 1], [2]] : tensor<2x3x512xi32> into tensor<6x512xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %[[COLLAPSE0]], %[[COLLAPSE1]]) -// CHECK: ^bb0(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor): -// CHECK: mhlo.return %[[ARG1]] -// CHECK: }) { -// CHECK: indices_are_sorted = false, -// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1] -// CHECK-SAME: inserted_window_dims = [0, 1] -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] -// CHECK-SAME: index_vector_dim = 1> -// CHECK-SAME: unique_indices = true -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_materialize_index_update(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x1x1xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = true, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1x4xi32>) -> tensor<5x1x1xi32> - return %0 : tensor<5x1x1xi32> -} - -// CHECK-LABEL: @scatter_materialize_index_update -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2, 3]] : tensor<1x4xi32> into tensor<1x4x1x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]]) -// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1, 2, 3] -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] -// CHECK-SAME: index_vector_dim = 1>, unique_indices = true - -// ----- - -func.func @scatter_materialize_one_dim(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = true, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x1x1xi32> - return %0 : tensor<5x1x1xi32> -} - -// CHECK-LABEL: @scatter_materialize_one_dim -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1]] : tensor<1xi32> into tensor<1x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]]) -// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1] -// CHECK-SAME: inserted_window_dims = [0, 1] -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] -// CHECK-SAME: index_vector_dim = 1>, unique_indices = true - -// ----- - -func.func @scatter_materialize_two_dims(%arg0: tensor<5x1x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<5x1x1xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = true, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = true} : (tensor<5x1x1xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<5x1x1xi32> - return %0 : tensor<5x1x1xi32> -} - -// CHECK-LABEL: @scatter_materialize_two_dims -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]]) -// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1, 2] -// CHECK-SAME: inserted_window_dims = [0] -// CHECK-SAME: scatter_dims_to_operand_dims = [0] -// CHECK-SAME: index_vector_dim = 1>, unique_indices = true - -// ----- - -func.func @scatter_materialize_comprehensive(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1x4xi32>) -> tensor<5x4x1xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = true, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x1xi32>, tensor<1x4xi32>) -> tensor<5x4x1xi32> - return %0 : tensor<5x4x1xi32> -} - -// CHECK-LABEL: @scatter_materialize_comprehensive -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0], [1, 2]] : tensor<1x4xi32> into tensor<1x4x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]]) -// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1, 2] -// CHECK-SAME: inserted_window_dims = [0] -// CHECK-SAME: scatter_dims_to_operand_dims = [0] -// CHECK-SAME: index_vector_dim = 1>, unique_indices = true - -// ----- - -func.func @scatter_operand_map(%arg0: tensor<5x4x1xi32>, %arg1: tensor<1x2xi32>, %arg2: tensor<1xi32>) -> tensor<5x4x1xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = true, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = true} : (tensor<5x4x1xi32>, tensor<1x2xi32>, tensor<1xi32>) -> tensor<5x4x1xi32> - return %0 : tensor<5x4x1xi32> -} - -// CHECK-LABEL: @scatter_operand_map -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %arg2 {{\[\[}}0, 1, 2]] : tensor<1xi32> into tensor<1x1x1xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%arg0, %arg1, %[[EXPAND]]) -// CHECK: indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter< -// CHECK-SAME: update_window_dims = [1, 2], -// CHECK-SAME: inserted_window_dims = [0], -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2], -// CHECK-SAME: index_vector_dim = 1>, unique_indices = true - -// ----- - -func.func @scatter_update_transpose(%a: tensor<16x17x8x384xf32>, %b: tensor<15x1xi32>, %c: tensor<16x17x15x384xf32>) -> tensor<16x17x8x384xf32> -{ - %out = "mhlo.scatter"(%a, %b, %c) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %add = mhlo.add %arg0, %arg1 : tensor - mhlo.return %add : tensor - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter, - unique_indices = false} : (tensor<16x17x8x384xf32>, tensor<15x1xi32>, tensor<16x17x15x384xf32>) -> tensor<16x17x8x384xf32> - return %out : tensor<16x17x8x384xf32> -} - -// CHECK-LABEL: @scatter_update_transpose -// CHECK-SAME: %[[ARG0:.+]]: tensor<16x17x8x384xf32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<15x1xi32> -// CHECK-SAME: %[[ARG2:.+]]: tensor<16x17x15x384xf32> -// CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[ARG2]]) {permutation = dense<[2, 0, 1, 3]> : tensor<4xi64>} -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRANSPOSE]] -// CHECK-NEXT{literal}: [[0], [1], [2, 3], [4]] : tensor<15x16x17x384xf32> into tensor<15x16x17x1x384xf32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%[[ARG0]], %[[ARG1]], %[[EXPANDED]]) ({ -// CHECK: ^bb0(%[[ARG3:.+]]: tensor, %[[ARG4:.+]]: tensor): -// CHECK: %[[ADD:.+]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[ADD]] : tensor -// CHECK: }) -// CHECK-SAME: indices_are_sorted = false -// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter -// CHECK-SAME: unique_indices = false -// CHECK: return %[[SCATTER]] - -// ----- - -func.func @scatter_transpose_indices(%arg0: tensor<1x64x32x640xf32>, %arg1: tensor<1x44xi32>, %arg2: tensor<44x1x64x640xf32>) -> tensor<1x64x32x640xf32> { - %0 = "mhlo.transpose"(%arg1) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x44xi32>) -> tensor<44x1xi32> - %expanded = tensor.expand_shape %arg2 [[0], [1], [2, 3], [4]] : tensor<44x1x64x640xf32> into tensor<44x1x64x1x640xf32> - %1 = "mhlo.scatter"(%arg0, %0, %expanded) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %2 = mhlo.add %arg3, %arg4 : tensor - mhlo.return %2 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<1x64x32x640xf32>, tensor<44x1xi32>, tensor<44x1x64x1x640xf32>) -> tensor<1x64x32x640xf32> - return %1 : tensor<1x64x32x640xf32> -} - -// CHECK-LABEL: @scatter_transpose_indices -// CHECK-SAME: %[[ARG0:.+]]: tensor<1x64x32x640xf32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<1x44xi32> -// CHECK-SAME: %[[ARG2:.+]]: tensor<44x1x64x640xf32> -// CHECK: %[[TRANSPOSE:.+]] = "mhlo.transpose"(%[[ARG1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} -// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG2]] -// CHECK-SAME{literal}: [[0], [1], [2, 3], [4]] : tensor<44x1x64x640xf32> into tensor<44x1x64x1x640xf32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%[[ARG0]], %[[TRANSPOSE]], %[[EXPANDED]]) -// CHECK: ^bb0(%arg3: tensor, %arg4: tensor): -// CHECK: %2 = mhlo.add %arg3, %arg4 : tensor -// CHECK: mhlo.return %2 : tensor -// CHECK: indices_are_sorted = false -// CHECK-SAME: scatter_dimension_numbers = #mhlo.scatter -// CHECK-SAME: unique_indices = false -// CHECK: return %[[SCATTER]] : tensor<1x64x32x640xf32> - -// ----- - -func.func @scatter_i64_indices(%arg0: tensor<5x6x7xi32>, %arg1: tensor<1x2xi64>, %arg2: tensor<1x7xi32>) -> tensor<5x6x7xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - mhlo.return %arg4 : tensor - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<5x6x7xi32>, tensor<1x2xi64>, tensor<1x7xi32>) -> tensor<5x6x7xi32> - return %0 : tensor<5x6x7xi32> -} - -// CHECK-LABEL: func.func @scatter_i64_indices -// CHECK-SAME: %[[ARG0:.+]]: tensor<5x6x7xi32> -// CHECK-SAME: %[[ARG1:.+]]: tensor<1x2xi64> -// CHECK-SAME: %[[ARG2:.+]]: tensor<1x7xi32> -// CHECK-DAG: %[[CONVERT:.+]] = mhlo.convert %[[ARG1]] : (tensor<1x2xi64>) -> tensor<1x2xi32> -// CHECK: %[[SCATTER:.+]] = "mhlo.scatter"(%[[ARG0]], %[[CONVERT]], %[[ARG2]]) -// CHECK: mhlo.return %{{.*}} -// CHECK: update_window_dims = [1], -// CHECK-SAME: inserted_window_dims = [0, 1] -// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/missing_legalizations.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/missing_legalizations.mlir deleted file mode 100644 index 349af8ee23cc..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/missing_legalizations.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-to-linalg-on-tensors --verify-diagnostics %s - -// Non-numpy compatible broadcast_dimensions are not supported. -// Note: This is by design and support is not planned. -func.func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+1 {{failed to legalize operation 'chlo.broadcast_add' that was explicitly marked illegal}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- -// Non-numpy compatible broadcast_dimensions are not supported. -// Note: This is by design and support is not planned. -func.func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // expected-error@+1 {{failed to legalize operation 'chlo.broadcast_add' that was explicitly marked illegal}} - %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir deleted file mode 100644 index 2a6a8717abc0..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/transformation_pipeline.mlir +++ /dev/null @@ -1,102 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-input-transformation-pipeline %s | FileCheck %s - -// CHECK-LABEL: @empty -func.func @empty() { - // CHECK-NEXT: return - return -} - -// ----- - -func.func @mhloElementwiseOps(%arg0 : tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> - %1 = mhlo.subtract %0, %arg0 : tensor<4xf32> - %2 = mhlo.multiply %1, %arg0 : tensor<4xf32> - return %2 : tensor<4xf32> -} - -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK-NEXT: module { -// CHECK-NEXT: func.func @mhloElementwiseOps(%arg0: tensor<4xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %0 = tensor.empty() : tensor<4xf32> -// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf32>) outs(%0 : tensor<4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %out: f32): -// CHECK-NEXT: %6 = arith.addf %[[ARG1]], %[[ARG1]] : f32 -// CHECK-NEXT: linalg.yield %6 : f32 -// CHECK-NEXT: } -> tensor<4xf32> -// CHECK-NEXT: %2 = tensor.empty() : tensor<4xf32> -// CHECK-NEXT: %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%1, %arg0 : tensor<4xf32>, tensor<4xf32>) outs(%2 : tensor<4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %out: f32): -// CHECK-NEXT: %6 = arith.subf %[[ARG1]], %[[ARG2]] : f32 -// CHECK-NEXT: linalg.yield %6 : f32 -// CHECK-NEXT: } -> tensor<4xf32> -// CHECK-NEXT: %4 = tensor.empty() : tensor<4xf32> -// CHECK-NEXT: %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%3, %arg0 : tensor<4xf32>, tensor<4xf32>) outs(%4 : tensor<4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %out: f32): -// CHECK-NEXT: %6 = arith.mulf %[[ARG1]], %[[ARG2]] : f32 -// CHECK-NEXT: linalg.yield %6 : f32 -// CHECK-NEXT: } -> tensor<4xf32> -// CHECK-NEXT: return %5 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } - -// ----- - -func.func @interleavedDot(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> { - %0 = "stablehlo.add"(%arg0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %1 = "stablehlo.dot"(%0, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %2 = "stablehlo.multiply"(%1, %arg0) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - return %2 : tensor<4x4xf32> -} - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-NEXT: module { -// CHECK-NEXT: func.func @interleavedDot(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %0 = tensor.empty() : tensor<4x4xf32> -// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4xf32>) outs(%0 : tensor<4x4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %out: f32): -// CHECK-NEXT: %7 = arith.addf %[[ARG1]], %[[ARG1]] : f32 -// CHECK-NEXT: linalg.yield %7 : f32 -// CHECK-NEXT: } -> tensor<4x4xf32> -// CHECK-NEXT: %2 = tensor.empty() : tensor<4x4xf32> -// CHECK-NEXT: %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %4 = linalg.matmul ins(%1, %arg0 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%3 : tensor<4x4xf32>) -> tensor<4x4xf32> -// CHECK-NEXT: %5 = tensor.empty() : tensor<4x4xf32> -// CHECK-NEXT: %6 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%4, %arg0 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%5 : tensor<4x4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %out: f32): -// CHECK-NEXT: %7 = arith.mulf %[[ARG1]], %[[ARG2]] : f32 -// CHECK-NEXT: linalg.yield %7 : f32 -// CHECK-NEXT: } -> tensor<4x4xf32> -// CHECK-NEXT: return %6 : tensor<4x4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } - - -// ----- - -func.func @reduction(%arg0 : tensor<4x8xf32>) -> tensor<4xf32> { - %0 = arith.constant dense<0.0> : tensor - %1 = "mhlo.reduce"(%arg0, %0) ( { - ^bb0(%arg1 : tensor, %arg2 : tensor): - %2 = mhlo.add %arg1, %arg2 : tensor - "mhlo.return"(%2) : (tensor) -> () - }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - return %1 : tensor<4xf32> -} - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-NEXT: #map1 = affine_map<(d0, d1) -> (d0)> -// CHECK-NEXT: module { -// CHECK-NEXT: func.func @reduction(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %0 = tensor.empty() : tensor<4xf32> -// CHECK-NEXT: %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> -// CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x8xf32>) outs(%1 : tensor<4xf32>) { -// CHECK-NEXT: ^bb0(%[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32): -// CHECK-NEXT: %3 = arith.addf %[[ARG2]], %[[ARG1]] : f32 -// CHECK-NEXT: linalg.yield %3 : f32 -// CHECK-NEXT: } -> tensor<4xf32> -// CHECK-NEXT: return %2 : tensor<4xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/compiler/src/iree/compiler/InputConversion/MHLO/test/verify_compiler_mhlo_input_legality.mlir b/compiler/src/iree/compiler/InputConversion/MHLO/test/verify_compiler_mhlo_input_legality.mlir deleted file mode 100644 index ffb7998765dd..000000000000 --- a/compiler/src/iree/compiler/InputConversion/MHLO/test/verify_compiler_mhlo_input_legality.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: iree-opt --split-input-file --iree-mhlo-verify-compiler-input-legality --verify-diagnostics %s - -// expected-error@+1 {{one or more illegal operations were found in the compiler input}} -module { -func.func @illegal_chlo(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // expected-note@+1 {{failed to legalize operation 'chlo.broadcast_add' that was explicitly marked illegal}} - %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} -} - -// ----- -// expected-error@+1 {{one or more illegal operations were found in the compiler input}} -module { -func.func @illegal_mhlo(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // expected-note@+1 {{failed to legalize operation 'mhlo.add' that was explicitly marked illegal}} - %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> -} -} - -// ----- -// expected-error@+1 {{one or more illegal operations were found in the compiler input}} -module { -func.func @illegal_shape(%arg0: tensor<*xf32>) -> index { - // expected-note@+1 {{failed to legalize operation 'shape.shape_of' that was explicitly marked illegal}} - %arg_shape = shape.shape_of %arg0 : tensor<*xf32> -> tensor - %rank = shape.rank %arg_shape : tensor -> index - return %rank : index -} -} diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel index 7070947b57df..ff0e92c1c407 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/BUILD.bazel @@ -129,6 +129,9 @@ iree_compiler_cc_library( hdrs = [ "Passes.h", ], + defines = [ + "IREE_HAVE_STABLEHLO_INPUT", + ], deps = [ ":PassHeaders", ":StableHLOLegalization", diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt index 26c54f4c25c4..f23c246c4c64 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/CMakeLists.txt @@ -134,6 +134,8 @@ iree_cc_library( iree::compiler::Dialect::Util::Transforms iree::compiler::InputConversion::Common iree::compiler::InputConversion::StableHLO::Preprocessing + DEFINES + "IREE_HAVE_STABLEHLO_INPUT" PUBLIC ) diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeControlFlow.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeControlFlow.cpp index d41410037c00..c9eaefc2b34d 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeControlFlow.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/LegalizeControlFlow.cpp @@ -119,7 +119,8 @@ struct WhileOpPattern final : OpConversionPattern { extractTensorValue(rewriter, bounds->step), adaptor.getOperands()); rewriter.setInsertionPointToEnd(newForOp.getBody()); - // Inline while body, and only replace the mhlo.return with an scf.yield. + // Inline while body, and only replace the stablehlo.return with an + // scf.yield. inlineStableHloRegionIntoSCFRegion(rewriter, op.getBody(), newForOp.getRegion()); BlockArgument indexArg = newForOp.getRegion().insertArgument( @@ -149,7 +150,8 @@ struct WhileOpPattern final : OpConversionPattern { rewriter.replaceOpWithNewOp( conditionReturn, i1, newWhileOp.getBeforeArguments()); - // Inline while body, and only replace the mhlo.return with an scf.yield. + // Inline while body, and only replace the stablehlo.return with an + // scf.yield. inlineStableHloRegionIntoSCFRegion(rewriter, op.getBody(), newWhileOp.getAfter()); diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h index 2a78aa12fe0f..fb25d72d1435 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/MapStableHLOToScalarOp.h @@ -897,21 +897,6 @@ struct CompareSelectOpToStdScalarOp()) { - Value isnan = b->create(loc, arith::CmpFPredicate::UNO, - args[0], args[1]); - - auto nanApfloat = APFloat::getQNaN(floatType.getFloatSemantics()); - Value nan = getConstantOrSplat(b, loc, args[0].getType(), - b->getFloatAttr(floatType, nanApfloat)); - v = b->create(loc, isnan, nan, v); - } - return v; -} - template <> inline Value mapStableHloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, @@ -1268,7 +1253,7 @@ inline Value mapStableHloOpToStdScalarOp( } // namespace impl struct StableHloOpToStdScalarOp { - // Converts mhlo 'op' to linalg and arith ops. + // Converts stablehlo 'op' to linalg and arith ops. template static Value mapOp(StableHloOpTy op, ArrayRef resultTypes, ValueRange args, OpBuilder* b) { @@ -1276,8 +1261,8 @@ struct StableHloOpToStdScalarOp { return mapOpWithArgTypes(op, resultTypes, argTypes, args, b); } - // Converts mhlo 'op' to linalg and arith ops. The types of 'args' may already - // be converted, 'argTypes' are their original types. + // Converts stablehlo 'op' to linalg and arith ops. The types of 'args' may + // already be converted, 'argTypes' are their original types. template static Value mapOpWithArgTypes(StableHloOpTy op, ArrayRef resultTypes, ArrayRef argTypes, ValueRange args, @@ -1296,7 +1281,7 @@ struct StableHloOpToStdScalarOp { resultTypes, argTypes, args, b); } - // Converts mhlo 'op' to linalg and arith ops. + // Converts stablehlo 'op' to linalg and arith ops. template static Value mapOpOfType(Location loc, ArrayRef resultTypes, ArrayRef argTypes, diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp index 65864bd0a4e4..bc61b2e3aebb 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp @@ -1105,14 +1105,14 @@ struct ExpandRngNormal final : OpRewritePattern { // // Rewrites the following pattern (take binary elementwise op as example) // -// %bcastx = "mhlo.broadcast_in_dim"(%x) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] -// %bcasty = "mhlo.broadcast_in_dim"(%y) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] +// %bcastx = "stablehlo.broadcast_in_dim"(%x) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] +// %bcasty = "stablehlo.broadcast_in_dim"(%y) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] // %result = "BinaryElementwiseOpT"(%bcastx, %bcasty) : (%[[SHAPE_AFTER_BCAST]], %[[SHAPE_AFTER_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] // // into // // %z = "BinaryElementwiseOpT"(%x, %y) : (%[[SHAPE_BEFORE_BCAST]], %[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_BEFORE_BCAST]] -// %result = "mhlo.broadcast_in_dim"(%z) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] +// %result = "stablehlo.broadcast_in_dim"(%z) {broadcast_dimensions = %[[BCAST_DIMS]]} : (%[[SHAPE_BEFORE_BCAST]]) -> %[[SHAPE_AFTER_BCAST]] // // clang-format on template diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp index 1f852a3035ac..9a03d7a1aed8 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalg.cpp @@ -243,7 +243,7 @@ SmallVector getExprFromConfig( return exprs; } -// Convert mhlo.einsum op into linalg.generic. +// Convert stablehlo.einsum op into linalg.generic. // Algorithm in general 3 steps: // Step1) Dissect entire einsum_config to different operands @@ -719,8 +719,8 @@ class BroadcastInDimOpToBroadcastConverter // broadcast and go directly to `linalg.generic`. // This also covers the important case of broadcasting a scalar. Ideally the -// pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be -// converted to a tensor dialect op similar to TF's `ConstantLikeOp`. +// pattern (`stablehlo.constant` -> `stablehlo.dynamic_broadcast_in_dim`) should +// be converted to a tensor dialect op similar to TF's `ConstantLikeOp`. class HloDynamicBroadcastInDimConverter : public OpConversionPattern { public: @@ -1069,7 +1069,7 @@ class BitcastConvertConverter } }; -// Lowers mhlo.RealDynamicSliceOp to tensor.extract_slice and other +// Lowers stablehlo.RealDynamicSliceOp to tensor.extract_slice and other // arith/tensor dialect ops. class RealDynamicSliceConverter : public OpConversionPattern { @@ -1573,7 +1573,7 @@ class DynamicSliceConverter int64_t size = std::get<1>(en.value()); sizes.push_back(rewriter.getI64IntegerAttr(size)); - // By mhlo.DynamicSlice definition: + // By stablehlo.DynamicSlice definition: // `start_indices[i] = clamp(start_indices[i], // 0, operand.dimension_size[i] - size_indices[i])` Value startIndex = extractIndexFromTensor( @@ -1640,7 +1640,7 @@ class DynamicUpdateSliceConverter SmallVector startIndices; Value zero = rewriter.create(loc, 0); for (const auto& en : llvm::enumerate(adaptor.getStartIndices())) { - // By mhlo.DynamicUpdateSlice definition: + // By stablehlo.DynamicUpdateSlice definition: // `start_indices[i] = clamp(start_indices[i], // 0, operand.dimension_size[i] - update.dimension_size[i])` Value startIndex = extractIndexFromTensor( @@ -2036,7 +2036,7 @@ struct SelectAndScatterNoOverlapConverter } // The first linalg.generic operation computes the relevant index over - // window for the defined mhlo.select_and_scatter. This involves + // window for the defined stablehlo.select_and_scatter. This involves // iterating over the window of the operand a computing the index. // Rather than storing N indices we compute the row major identifier // in the window, to specify which location should be scattered to. @@ -2116,7 +2116,8 @@ struct SelectAndScatterNoOverlapConverter rewriter.cloneRegionBefore(op.getSelect(), reduceRegion, reduceRegion.end()); - // This includes convert `mhlo` scalar-tensor regions to `linalg` scalars. + // This includes convert `stablehlo` scalar-tensor regions to `linalg` + // scalars. TypeConverter::SignatureConversion reduceSignConverter(4); reduceSignConverter.addInputs(0, srcETy); reduceSignConverter.addInputs(srcETy); @@ -2166,7 +2167,7 @@ struct SelectAndScatterNoOverlapConverter b.create(selectPred, selectInVal, selectOutVal); b.create(ValueRange{selectedValue, selectedIdx}); - // Original terminator is an mhlo.return we no longer need. + // Original terminator is an stablehlo.return we no longer need. rewriter.eraseOp(reduceTerminator); b.setInsertionPoint(op); @@ -2384,7 +2385,7 @@ struct PadOpNegativePaddingConversion } }; -/// Converts mhlo.pad operation to tensor.pad or tensor.insert_slice. +/// Converts stablehlo.pad operation to tensor.pad or tensor.insert_slice. struct PadOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp index fde8bd564a6e..d0326c53bfee 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgPointwise.cpp @@ -60,7 +60,7 @@ FailureOr checkOperandsAndResults( int64_t maxRank = getMaxRank(operands); // Apply only if all operands are scalar or have the same rank. Some ops, - // like `mhlo.select`, support implicit broadcasting of scalars. + // like `stablehlo.select`, support implicit broadcasting of scalars. if (!llvm::all_of(operands, [&](Value v) { int64_t r = getRank(v); return r == 0 || r == maxRank; diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp index cc9d7820aacc..ebb366d43fa9 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/StableHLOToLinalgReduce.cpp @@ -279,10 +279,10 @@ struct ReduceOpToReduceConverter final // apply function has a signature with tensor types, this is converted to a // function with element types. E.g. the signature "(tensor, // tensor) -> tensor" will be converted to "(f32, f32) -> f32". - // Also, we need to swap the operands of the function. The mhlo.reduce op - // expects the init values to be the first parameters of the apply function, - // while the linalg.reduction op expects the init values as the last - // parameters of the 'combiner' region apply function. + // Also, we need to swap the operands of the function. The stablehlo.reduce + // op expects the init values to be the first parameters of the apply + // function, while the linalg.reduction op expects the init values as the + // last parameters of the 'combiner' region apply function. TypeConverter::SignatureConversion signatureConverter( linalgOp.getNumDpsInputs() * 2); assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits()); diff --git a/compiler/src/iree/compiler/Pipelines/BUILD.bazel b/compiler/src/iree/compiler/Pipelines/BUILD.bazel index b003ce0adae2..fb04b06d1ac8 100644 --- a/compiler/src/iree/compiler/Pipelines/BUILD.bazel +++ b/compiler/src/iree/compiler/Pipelines/BUILD.bazel @@ -17,7 +17,6 @@ iree_compiler_cc_library( srcs = ["Options.cpp"], hdrs = ["Options.h"], deps = [ - "//compiler/src/iree/compiler/InputConversion/MHLO", "//compiler/src/iree/compiler/InputConversion/StableHLO", "//compiler/src/iree/compiler/InputConversion/TMTensor", "//compiler/src/iree/compiler/InputConversion/TOSA", @@ -49,7 +48,6 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Dialect/VM/Transforms", "//compiler/src/iree/compiler/InputConversion/Common", "//compiler/src/iree/compiler/InputConversion/Common:AutoInputConversionPipeline", - "//compiler/src/iree/compiler/InputConversion/MHLO", "//compiler/src/iree/compiler/InputConversion/StableHLO", "//compiler/src/iree/compiler/InputConversion/TMTensor", "//compiler/src/iree/compiler/InputConversion/TOSA", diff --git a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt index 525eaa1ba01c..76440ca649ca 100644 --- a/compiler/src/iree/compiler/Pipelines/CMakeLists.txt +++ b/compiler/src/iree/compiler/Pipelines/CMakeLists.txt @@ -6,8 +6,7 @@ # Enable input dialects based on options. set(IREE_INPUT_DEPS "") -if(IREE_INPUT_MHLO) - list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::MHLO) +if(IREE_INPUT_STABLEHLO) list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::StableHLO) endif() if(IREE_INPUT_TORCH) diff --git a/compiler/src/iree/compiler/Pipelines/Options.cpp b/compiler/src/iree/compiler/Pipelines/Options.cpp index 904dc31e541c..2326e98d1c48 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.cpp +++ b/compiler/src/iree/compiler/Pipelines/Options.cpp @@ -44,19 +44,14 @@ void InputDialectOptions::bindOptions(OptionsBinder &binder) { clEnumValN(InputDialectOptions::Type::auto_detect, "auto", "Analyze the input program to choose conversion.") // clang-format off -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT , clEnumValN(InputDialectOptions::Type::stablehlo, "stablehlo", "Legalize from StableHLO ops.") , clEnumValN(InputDialectOptions::Type::stablehlo_xla, "stablehlo_xla", "Legalize from StableHLO ops (with XLA cleanup preprocessing). ") - , clEnumValN(InputDialectOptions::Type::mhlo_legacy, "mhlo_legacy", - "Legalize from MHLO ops. (Deprecated.)") - , clEnumValN(InputDialectOptions::Type::xla_legacy, "xla_legacy", - "Legalize from MHLO ops (with XLA cleanup preprocessing). " - "(Deprecated.)") -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT , clEnumValN(InputDialectOptions::Type::tm_tensor, "tm_tensor", "Legalize from TMTensor ops.") @@ -69,7 +64,7 @@ void InputDialectOptions::bindOptions(OptionsBinder &binder) { // clang-format on llvm::cl::cat(category)); -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT binder.opt( "iree-input-demote-i64-to-i32", demoteI64ToI32, llvm::cl::desc("Converts all i64 ops and values into i32 counterparts."), @@ -84,7 +79,7 @@ void InputDialectOptions::bindOptions(OptionsBinder &binder) { "iree-input-promote-bf16-to-f32", promoteBF16ToF32, llvm::cl::desc("Converts all bf16 ops and values into f32 counterparts."), llvm::cl::cat(category)); -#endif +#endif // IREE_HAVE_STABLEHLO_INPUT } void HighLevelOptimizationOptions::bindOptions(OptionsBinder &binder) { diff --git a/compiler/src/iree/compiler/Pipelines/Options.h b/compiler/src/iree/compiler/Pipelines/Options.h index 40b9504f9474..82d5926453e8 100644 --- a/compiler/src/iree/compiler/Pipelines/Options.h +++ b/compiler/src/iree/compiler/Pipelines/Options.h @@ -36,18 +36,13 @@ struct InputDialectOptions { none, // Analyses the input to determine what input dialect pipeline to use. auto_detect, -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT // Legalizes input defined over StableHLO ops. stablehlo, // Special case of 'stablehlo' legalization which also performs some XLA // preprocessing, e.g., flattening of tuples. stablehlo_xla, - // Legalizes input defined over MHLO ops. (Deprecated.) - mhlo_legacy, - // Special case of 'mhlo' legalization which also performs some XLA - // cleanup activities. (Deprecated.) - xla_legacy, -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT // Legalizes input defined over TMTensor ops. tm_tensor, diff --git a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp index 2b78d77115e2..39b2eb1dbd60 100644 --- a/compiler/src/iree/compiler/Pipelines/Pipelines.cpp +++ b/compiler/src/iree/compiler/Pipelines/Pipelines.cpp @@ -19,10 +19,9 @@ #include "iree/compiler/Preprocessing/Passes.h" #include "iree/compiler/Utils/TracingUtils.h" -#ifdef IREE_HAVE_MHLO_INPUT -#include "iree/compiler/InputConversion/MHLO/Passes.h" +#ifdef IREE_HAVE_STABLEHLO_INPUT #include "iree/compiler/InputConversion/StableHLO/Passes.h" -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT #include "iree/compiler/InputConversion/TMTensor/Passes.h" #endif // IREE_HAVE_TORCH_INPUT @@ -62,19 +61,20 @@ void buildIREEVMTransformPassPipeline( } AutoInputConversionPipelineOptions autoOptions; -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT stablehlo::StableHloOptions stablehloOptions; stablehloOptions.demoteI64ToI32 = inputOptions.demoteI64ToI32; stablehloOptions.demoteF64ToF32 = inputOptions.demoteF64ToF32; stablehloOptions.promoteBF16ToF32 = inputOptions.promoteBF16ToF32; -#endif +#endif // IREE_HAVE_STABLEHLO_INPUT + switch (inputOptions.type) { case InputDialectOptions::Type::none: break; case InputDialectOptions::Type::auto_detect: passManager.addPass(createAutoInputConversionPipelinePass(autoOptions)); break; -#ifdef IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT case InputDialectOptions::Type::stablehlo: stablehlo::buildStableHLOInputConversionPassPipeline(passManager, stablehloOptions); @@ -83,13 +83,7 @@ void buildIREEVMTransformPassPipeline( stablehlo::buildStableHLOXLAInputConversionPassPipeline(passManager, stablehloOptions); break; - case InputDialectOptions::Type::mhlo_legacy: - MHLO::buildMHLOInputConversionPassPipeline(passManager); - break; - case InputDialectOptions::Type::xla_legacy: - MHLO::buildXLAInputConversionPassPipeline(passManager); - break; -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT case InputDialectOptions::Type::tm_tensor: passManager.addNestedPass( diff --git a/compiler/src/iree/compiler/Preprocessing/Passes.h b/compiler/src/iree/compiler/Preprocessing/Passes.h index 7768e82d3fde..22ed1b582e07 100644 --- a/compiler/src/iree/compiler/Preprocessing/Passes.h +++ b/compiler/src/iree/compiler/Preprocessing/Passes.h @@ -21,7 +21,7 @@ namespace IREE { /// passes specified in textual pass-pipeline format using /// `iree-preprocessing-pass-pipeline`. This allows some user control /// on the sequence of preprocessing passes to run after conversion from input -/// dialects like `mhlo`/`tosa` before running the core IREE compilation +/// dialects like `stablehlo`/`tosa` before running the core IREE compilation /// pipelines (starting with the flow pipeline). void buildPreprocessingPassPipeline( OpPassManager &passManager, const PreprocessingOptions &options, diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index ca74e5ca7856..799372ed115b 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -37,14 +37,12 @@ iree_compiler_cc_library( ], deps = [ "//compiler/src/iree/compiler/InputConversion/Common", - "//compiler/src/iree/compiler/InputConversion/MHLO", "//compiler/src/iree/compiler/InputConversion/StableHLO", "//compiler/src/iree/compiler/InputConversion/TMTensor", "//compiler/src/iree/compiler/InputConversion/TOSA", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:IR", "@llvm-project//mlir:TosaDialect", - "@mlir-hlo//:mlir_hlo", "@mlir-hlo//stablehlo:chlo_ops", "@mlir-hlo//stablehlo:stablehlo_ops", "@torch-mlir-dialects//:TorchMLIRTMTensorDialect", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index 36aa3918b318..fd1a433880cc 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -42,12 +42,9 @@ endif() # Enable input dialects based on options. set(IREE_INPUT_DEPS "") -if(IREE_INPUT_MHLO) - list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::MHLO) +if(IREE_INPUT_STABLEHLO) list(APPEND IREE_INPUT_DEPS iree::compiler::InputConversion::StableHLO) - list(APPEND IREE_INPUT_DEPS tensorflow::external_mhlo_includes) list(APPEND IREE_INPUT_DEPS ChloOps) - list(APPEND IREE_INPUT_DEPS MhloDialect) list(APPEND IREE_INPUT_DEPS StablehloOps) endif() if(IREE_INPUT_TORCH) diff --git a/compiler/src/iree/compiler/Tools/init_input_dialects.cc b/compiler/src/iree/compiler/Tools/init_input_dialects.cc index 30aae1938f15..d7a329617cfe 100644 --- a/compiler/src/iree/compiler/Tools/init_input_dialects.cc +++ b/compiler/src/iree/compiler/Tools/init_input_dialects.cc @@ -6,11 +6,10 @@ #include "iree/compiler/Tools/init_input_dialects.h" -#ifdef IREE_HAVE_MHLO_INPUT -#include "mhlo/IR/hlo_ops.h" +#ifdef IREE_HAVE_STABLEHLO_INPUT #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #endif @@ -22,10 +21,9 @@ namespace mlir { namespace iree_compiler { void registerInputDialects(DialectRegistry ®istry) { -#ifdef IREE_HAVE_MHLO_INPUT - registry.insert(); -#endif // IREE_HAVE_MHLO_INPUT +#ifdef IREE_HAVE_STABLEHLO_INPUT + registry.insert(); +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT registry.insert(); #endif // IREE_HAVE_TORCH_INPUT diff --git a/compiler/src/iree/compiler/Tools/init_input_passes.cc b/compiler/src/iree/compiler/Tools/init_input_passes.cc index fe9f41f79cfe..ab2faeca410b 100644 --- a/compiler/src/iree/compiler/Tools/init_input_passes.cc +++ b/compiler/src/iree/compiler/Tools/init_input_passes.cc @@ -8,10 +8,9 @@ #include "iree/compiler/InputConversion/Common/Passes.h" -#ifdef IREE_HAVE_MHLO_INPUT -#include "iree/compiler/InputConversion/MHLO/Passes.h" +#ifdef IREE_HAVE_STABLEHLO_INPUT #include "iree/compiler/InputConversion/StableHLO/Passes.h" -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT #include "iree/compiler/InputConversion/TMTensor/Passes.h" #endif // IREE_HAVE_TORCH_INPUT @@ -27,10 +26,9 @@ namespace iree_compiler { void registerInputPasses() { registerCommonInputConversionPasses(); -#ifdef IREE_HAVE_MHLO_INPUT - MHLO::registerMHLOConversionPasses(); +#ifdef IREE_HAVE_STABLEHLO_INPUT stablehlo::registerStableHLOConversionPasses(); -#endif // IREE_HAVE_MHLO_INPUT +#endif // IREE_HAVE_STABLEHLO_INPUT #ifdef IREE_HAVE_TORCH_INPUT TMTensor::registerTMTensorConversionPasses(); #endif diff --git a/tests/e2e/models/mnist_train_test/mnist_train_test.py b/tests/e2e/models/mnist_train_test/mnist_train_test.py index d84c40685b51..578ad090fef5 100644 --- a/tests/e2e/models/mnist_train_test/mnist_train_test.py +++ b/tests/e2e/models/mnist_train_test/mnist_train_test.py @@ -28,7 +28,7 @@ def build_module(artifacts_dir: str): compile_file(input_file=os.path.join(artifacts_dir, "mnist_train.mlirbc"), output_file=vmfb_file, target_backends=[args.target_backend], - input_type=InputType.MHLO_LEGACY) + input_type=InputType.STABLEHLO) return load_vm_flatbuffer_file(vmfb_file, driver=args.driver) diff --git a/tests/e2e/xla_ops/BUILD.bazel b/tests/e2e/xla_ops/BUILD.bazel deleted file mode 100644 index 14c3a9fe333c..000000000000 --- a/tests/e2e/xla_ops/BUILD.bazel +++ /dev/null @@ -1,496 +0,0 @@ -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Tests of end-to-end IREE support for individual ops in the XLA HLO dialect. -# Each test file should have a name matching the corresponding XLA HLO op and test only the -# functionality of that op (though may make use of other ops where necessary). Tests should be -# written using the IREE Check framework and should always pass on the reference VMVX backend. -# See https://github.com/openxla/iree/blob/main/docs/developers/developing_iree/testing_guide.md#iree-core-end-to-end-tests. - -load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") -load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite") - -package( - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_check_single_backend_test_suite( - name = "check_cuda_graph", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_bf16.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_fp16.mlir", - "exponential_minus_one.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "reverse.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - exclude = [ - "fft.mlir", # TODO(#9583) - ], - ), - compiler_flags = [ - "--iree-input-type=mhlo_legacy", - # TODO(#13984): memset emulation required for graphs. - "--iree-stream-emulate-memset", - ], - driver = "cuda", - runner_args = ["--cuda_use_streams=false"], - tags = [ - # CUDA cuInit fails with sanitizer on. - "noasan", - "nomsan", - "notsan", - "noubsan", - "requires-gpu-nvidia", - ], - target_backend = "cuda", -) - -# Run cuda tests using stream command buffer -iree_check_single_backend_test_suite( - name = "check_cuda_streams", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_bf16.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_fp16.mlir", - "exponential_minus_one.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "reverse.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - exclude = [ - "fft.mlir", # TODO(#9583) - ], - ), - compiler_flags = ["--iree-input-type=mhlo_legacy"], - driver = "cuda", - runner_args = ["--cuda_use_streams=true"], - tags = [ - # CUDA cuInit fails with sanitizer on. - "noasan", - "nomsan", - "notsan", - "noubsan", - "requires-gpu-nvidia", - ], - target_backend = "cuda", -) - -iree_check_single_backend_test_suite( - name = "check_llvm-cpu_local-task", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_bf16.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_fp16.mlir", - "exponential_minus_one.mlir", - "fft.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "reverse.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - ), - compiler_flags = ["--iree-input-type=mhlo_legacy"], - driver = "local-task", - target_backend = "llvm-cpu", -) - -iree_check_single_backend_test_suite( - name = "check_vmvx_local-task", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_minus_one.mlir", - "fft.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "reverse.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - exclude = [ - "dot_bf16.mlir", # Missing BF16 support on VMVX buffer ops - "exponential_fp16.mlir", - ], - ), - compiler_flags = ["--iree-input-type=mhlo_legacy"], - driver = "local-task", - target_backend = "vmvx", -) - -iree_check_single_backend_test_suite( - name = "check_vulkan-spirv_vulkan", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_bf16.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_minus_one.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - exclude = [ - "exponential_fp16.mlir", - "fft.mlir", # TODO(#9583) - "reverse.mlir", #TODO(#12415): disabled due to miscompilation on Pixel 6. - ], - ), - compiler_flags = ["--iree-input-type=mhlo_legacy"], - driver = "vulkan", - target_backend = "vulkan-spirv", -) - -# Check host features compilation (LLVM backend with host cpu features). -iree_check_single_backend_test_suite( - name = "check_llvm-cpu-host_local-task", - srcs = enforce_glob( - # keep sorted - [ - "abs.mlir", - "add.mlir", - "batch_norm_inference.mlir", - "bitcast_convert.mlir", - "broadcast.mlir", - "broadcast_add.mlir", - "broadcast_in_dim.mlir", - "clamp.mlir", - "compare.mlir", - "complex.mlir", - "concatenate.mlir", - "constant.mlir", - "convert.mlir", - "convolution.mlir", - "cosine.mlir", - "divide.mlir", - "dot.mlir", - "dot_bf16.mlir", - "dot_general.mlir", - "dynamic_slice.mlir", - "dynamic_update_slice.mlir", - "exponential.mlir", - "exponential_fp16.mlir", - "exponential_minus_one.mlir", - "fft.mlir", - "finite.mlir", - "floor.mlir", - "gather.mlir", - "iota.mlir", - "log.mlir", - "log_plus_one.mlir", - "maximum.mlir", - "minimum.mlir", - "multiply.mlir", - "negate.mlir", - "pad.mlir", - "pow.mlir", - "reduce.mlir", - "reduce_window.mlir", - "remainder.mlir", - "reshape.mlir", - "reverse.mlir", - "rng_normal.mlir", - "rng_uniform.mlir", - "round.mlir", - "rsqrt.mlir", - "scatter.mlir", - "scatter_dynamic.mlir", - "select.mlir", - "sine.mlir", - "slice.mlir", - "sort.mlir", - "sqrt.mlir", - "subtract.mlir", - "tanh.mlir", - "torch_index_select.mlir", - "transpose.mlir", - "while.mlir", - ], - include = ["*.mlir"], - ), - compiler_flags = [ - "--iree-input-type=mhlo_legacy", - "--iree-llvmcpu-target-cpu-features=host", - ], - driver = "local-task", - # Building and testing must be on the same architecture, which doesn't work - # with remote execution in general. - tags = [ - "hostonly", - "local", - ], - target_backend = "llvm-cpu", -) - -test_suite( - name = "check", - tests = [ - ":check_llvm-cpu-host_local-task", - ":check_llvm-cpu_local-task", - ":check_vmvx_local-task", - ":check_vulkan-spirv_vulkan", - ], -) diff --git a/tests/e2e/xla_ops/CMakeLists.txt b/tests/e2e/xla_ops/CMakeLists.txt deleted file mode 100644 index c23e05f37af9..000000000000 --- a/tests/e2e/xla_ops/CMakeLists.txt +++ /dev/null @@ -1,518 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# tests/e2e/xla_ops/BUILD.bazel # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_check_single_backend_test_suite( - NAME - check_cuda_graph - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_bf16.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_fp16.mlir" - "exponential_minus_one.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "cuda" - DRIVER - "cuda" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" - "--iree-stream-emulate-memset" - RUNNER_ARGS - "--cuda_use_streams=false" - LABELS - "noasan" - "nomsan" - "notsan" - "noubsan" - "requires-gpu-nvidia" -) - -iree_check_single_backend_test_suite( - NAME - check_cuda_streams - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_bf16.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_fp16.mlir" - "exponential_minus_one.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "cuda" - DRIVER - "cuda" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" - RUNNER_ARGS - "--cuda_use_streams=true" - LABELS - "noasan" - "nomsan" - "notsan" - "noubsan" - "requires-gpu-nvidia" -) - -iree_check_single_backend_test_suite( - NAME - check_llvm-cpu_local-task - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_bf16.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_fp16.mlir" - "exponential_minus_one.mlir" - "fft.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "llvm-cpu" - DRIVER - "local-task" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" -) - -iree_check_single_backend_test_suite( - NAME - check_vmvx_local-task - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_minus_one.mlir" - "fft.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "vmvx" - DRIVER - "local-task" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" -) - -iree_check_single_backend_test_suite( - NAME - check_vulkan-spirv_vulkan - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_bf16.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_minus_one.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "vulkan-spirv" - DRIVER - "vulkan" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" -) - -iree_check_single_backend_test_suite( - NAME - check_llvm-cpu-host_local-task - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - "clamp.mlir" - "compare.mlir" - "complex.mlir" - "concatenate.mlir" - "constant.mlir" - "convert.mlir" - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_bf16.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_fp16.mlir" - "exponential_minus_one.mlir" - "fft.mlir" - "finite.mlir" - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - "maximum.mlir" - "minimum.mlir" - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - "while.mlir" - TARGET_BACKEND - "llvm-cpu" - DRIVER - "local-task" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" - "--iree-llvmcpu-target-cpu-features=host" - LABELS - "hostonly" - "local" -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### - -iree_check_single_backend_test_suite( - NAME - check_webgpu - SRCS - "abs.mlir" - "add.mlir" - "batch_norm_inference.mlir" - "bitcast_convert.mlir" - "broadcast.mlir" - "broadcast_add.mlir" - "broadcast_in_dim.mlir" - # "clamp.mlir" # TODO(#10906): fix (i8/i16?) - # "compare.mlir" # TODO(#10906): fix (i8/i16?) - # "complex.mlir" # TODO(#11054) - "concatenate.mlir" - "constant.mlir" - # "convert.mlir" # TODO(#10906): fix (i8/i16?) - "convolution.mlir" - "cosine.mlir" - "divide.mlir" - "dot.mlir" - "dot_general.mlir" - "dynamic_slice.mlir" - "dynamic_update_slice.mlir" - "exponential.mlir" - "exponential_fp16.mlir" - "exponential_minus_one.mlir" - # "fft.mlir" # TODO(#9583): fix (fft codegen via spirv) - # "finite.mlir" # TODO(#11321): error: value cannot be represented as 'f32': inf - "floor.mlir" - "gather.mlir" - "iota.mlir" - "log.mlir" - "log_plus_one.mlir" - # "maximum.mlir" # TODO(#10906): fix (i8/i16?) - # "minimum.mlir" # TODO(#10906): fix (i8/i16?) - "multiply.mlir" - "negate.mlir" - "pad.mlir" - "pow.mlir" - "reduce.mlir" - "reduce_window.mlir" - "remainder.mlir" - "reshape.mlir" - "reverse.mlir" - "rng_normal.mlir" - "rng_uniform.mlir" - "round.mlir" - "rsqrt.mlir" - "scatter.mlir" - "scatter_dynamic.mlir" - "select.mlir" - "sine.mlir" - "slice.mlir" - "sort.mlir" - "sqrt.mlir" - "subtract.mlir" - "tanh.mlir" - "torch_index_select.mlir" - "transpose.mlir" - # "while.mlir" # TODO(#12509): WebGPU SPIR-V broken - TARGET_BACKEND - "webgpu" - # Only test compilation for now, the WebGPU driver is not stable/tested yet. - # DRIVER - # "webgpu" - COMPILER_FLAGS - "--iree-input-type=mhlo_legacy" - "--iree-codegen-gpu-native-math-precision=true" # TODO(#11321): Infer/flip default -) diff --git a/tests/e2e/xla_ops/abs.mlir b/tests/e2e/xla_ops/abs.mlir deleted file mode 100644 index cb8c66c264ed..000000000000 --- a/tests/e2e/xla_ops/abs.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[-1.0, -2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.abs"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<-4.0> : tensor - %result = "mhlo.abs"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<4.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/add.mlir b/tests/e2e/xla_ops/add.mlir deleted file mode 100644 index a69cc861290e..000000000000 --- a/tests/e2e/xla_ops/add.mlir +++ /dev/null @@ -1,28 +0,0 @@ -func.func @tensor() { - %0 = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %1 = util.unfoldable_constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf32> - %result = "mhlo.add"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[6.0, 8.0, 10.0, 12.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @tensor_4d() { - %0 = util.unfoldable_constant dense<[[[[1.0, 2.0], [3.0, 4.0]], - [[5.0, 6.0], [7.0, 8.0]]], - [[[9.0, 10.0], [11.0, 12.0]], - [[13.0, 14.0], [15.0, 16.0]]]]> : - tensor<2x2x2x2xf32> - %1 = util.unfoldable_constant dense<[[[[1.0, 2.0], [3.0, 4.0]], - [[5.0, 6.0], [7.0, 8.0]]], - [[[9.0, 10.0], [11.0, 12.0]], - [[13.0, 14.0], [15.0, 16.0]]]]> : - tensor<2x2x2x2xf32> - %result = "mhlo.add"(%0, %1) : (tensor<2x2x2x2xf32>, tensor<2x2x2x2xf32>) - -> tensor<2x2x2x2xf32> - check.expect_almost_eq_const(%result, dense<[[[[2.0, 4.0], [6.0, 8.0]], - [[10.0, 12.0], [14.0, 16.0]]], - [[[18.0, 20.0], [22.0, 24.0]], - [[26.0, 28.0], [30.0, 32.0]]]]> : - tensor<2x2x2x2xf32>) : tensor<2x2x2x2xf32> - return -} diff --git a/tests/e2e/xla_ops/batch_norm_inference.mlir b/tests/e2e/xla_ops/batch_norm_inference.mlir deleted file mode 100644 index fe569e017736..000000000000 --- a/tests/e2e/xla_ops/batch_norm_inference.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @batchnorm_inference_4x2() { - %x = util.unfoldable_constant dense<[[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]]> : tensor<2x4xf32> - %mean = util.unfoldable_constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> - %var = util.unfoldable_constant dense<[2.0, 2.0, 2.0, 2.0]> : tensor<4xf32> - %offset = util.unfoldable_constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> - %scale = util.unfoldable_constant dense<[1.0, 1.0, 1.0, 1.0]> : tensor<4xf32> - %result = "mhlo.batch_norm_inference"(%x, %mean, %var, %offset, %scale) {epsilon = 1.000000e-03 : f32, feature_index = 1 : i64} : (tensor<2x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<2x4xf32> - // TODO(gcmn): This should probably be a fuzzier check with round values. - check.expect_almost_eq_const(%result, dense<[ - [2.0, 2.9995, 3.999, 4.9985], - [5.998, 6.9975, 7.997, 8.9965]]> : tensor<2x4xf32>) : tensor<2x4xf32> - return -} diff --git a/tests/e2e/xla_ops/bitcast_convert.mlir b/tests/e2e/xla_ops/bitcast_convert.mlir deleted file mode 100644 index 58add394babd..000000000000 --- a/tests/e2e/xla_ops/bitcast_convert.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @bitcast() { - %input = util.unfoldable_constant dense<0> : tensor<4xi32> - %result = "mhlo.bitcast_convert"(%input) : (tensor<4xi32>) -> tensor<4xf32> - check.expect_eq_const(%result, dense<0.0> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/broadcast.mlir b/tests/e2e/xla_ops/broadcast.mlir deleted file mode 100644 index 1bfafe229cf3..000000000000 --- a/tests/e2e/xla_ops/broadcast.mlir +++ /dev/null @@ -1,20 +0,0 @@ -func.func @broadcast_2D_3D() { - %input = util.unfoldable_constant dense<[[1, 2, 3, 4], - [5, 6, 7, 8]]> : tensor<2x4xi32> - %result = "mhlo.broadcast"(%input) {broadcast_sizes = dense<3> : tensor<1xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> - check.expect_eq_const(%result, dense<[ - [[1, 2, 3, 4], [5, 6, 7, 8]], - [[1, 2, 3, 4], [5, 6, 7, 8]], - [[1, 2, 3, 4], [5, 6, 7, 8]]]> : tensor<3x2x4xi32>) : tensor<3x2x4xi32> - return -} - -func.func @broadcast_3D_scalar() { - %input = util.unfoldable_constant dense<42> : tensor - %result = "mhlo.broadcast"(%input) {broadcast_sizes = dense<[3, 2, 4]> : tensor<3xi64>} : (tensor) -> tensor<3x2x4xi32> - check.expect_eq_const(%result, dense<[ - [[42, 42, 42, 42], [42, 42, 42, 42]], - [[42, 42, 42, 42], [42, 42, 42, 42]], - [[42, 42, 42, 42], [42, 42, 42, 42]]]> : tensor<3x2x4xi32>) : tensor<3x2x4xi32> - return -} diff --git a/tests/e2e/xla_ops/broadcast_add.mlir b/tests/e2e/xla_ops/broadcast_add.mlir deleted file mode 100644 index 3649a21da35a..000000000000 --- a/tests/e2e/xla_ops/broadcast_add.mlir +++ /dev/null @@ -1,10 +0,0 @@ -func.func @tensor() { - %0 = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %1 = util.unfoldable_constant dense<2.0> : tensor<3x4xf32> - %result = "chlo.broadcast_add"(%0, %1) : (tensor<4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> - check.expect_almost_eq_const(%result, - dense<[[3.0, 4.0, 5.0, 6.0], - [3.0, 4.0, 5.0, 6.0], - [3.0, 4.0, 5.0, 6.0]]> : tensor<3x4xf32>) : tensor<3x4xf32> - return -} diff --git a/tests/e2e/xla_ops/broadcast_in_dim.mlir b/tests/e2e/xla_ops/broadcast_in_dim.mlir deleted file mode 100644 index 75120071696b..000000000000 --- a/tests/e2e/xla_ops/broadcast_in_dim.mlir +++ /dev/null @@ -1,17 +0,0 @@ -func.func @broadcast_in_dim_2D_3D() { - %input = util.unfoldable_constant dense<[[1, 2, 3, 4], - [5, 6, 7, 8]]> : tensor<2x4xi32> - %res = "mhlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xi32>) -> tensor<3x2x4xi32> - check.expect_eq_const(%res, dense<[ - [[1, 2, 3, 4], [5, 6, 7, 8]], - [[1, 2, 3, 4], [5, 6, 7, 8]], - [[1, 2, 3, 4], [5, 6, 7, 8]]]> : tensor<3x2x4xi32>) : tensor<3x2x4xi32> - return -} - -func.func @broadcast_in_dim_3D_scalar() { - %input = util.unfoldable_constant dense<42> : tensor - %res = "mhlo.broadcast_in_dim"(%input) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<3x2x4xi32> - check.expect_eq_const(%res, dense<42> : tensor<3x2x4xi32>) : tensor<3x2x4xi32> - return -} diff --git a/tests/e2e/xla_ops/clamp.mlir b/tests/e2e/xla_ops/clamp.mlir deleted file mode 100644 index b033b47f083e..000000000000 --- a/tests/e2e/xla_ops/clamp.mlir +++ /dev/null @@ -1,35 +0,0 @@ -func.func @i8() { - %min = util.unfoldable_constant dense<[0, 0, 0, 0]> : tensor<4xi8> - %val = util.unfoldable_constant dense<[-2, 4, 8, 12]> : tensor<4xi8> - %max = util.unfoldable_constant dense<[10, 10, 10, 10]> : tensor<4xi8> - %result = "mhlo.clamp"(%min, %val, %max) : (tensor<4xi8>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%result, dense<[0, 4, 8, 10]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @i16() { - %min = util.unfoldable_constant dense<[0, 0, 0, 0]> : tensor<4xi16> - %val = util.unfoldable_constant dense<[-2, 4, 8, 12]> : tensor<4xi16> - %max = util.unfoldable_constant dense<[10, 10, 10, 10]> : tensor<4xi16> - %result = "mhlo.clamp"(%min, %val, %max) : (tensor<4xi16>, tensor<4xi16>, tensor<4xi16>) -> tensor<4xi16> - check.expect_eq_const(%result, dense<[0, 4, 8, 10]> : tensor<4xi16>) : tensor<4xi16> - return -} - -func.func @i32() { - %min = util.unfoldable_constant dense<[0, 0, 0, 0]> : tensor<4xi32> - %val = util.unfoldable_constant dense<[-2, 4, 8, 12]> : tensor<4xi32> - %max = util.unfoldable_constant dense<[10, 10, 10, 10]> : tensor<4xi32> - %result = "mhlo.clamp"(%min, %val, %max) : (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[0, 4, 8, 10]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @f32() { - %min = util.unfoldable_constant dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf32> - %val = util.unfoldable_constant dense<[-2.0, 4.0, 8.0, 12.0]> : tensor<4xf32> - %max = util.unfoldable_constant dense<[10.0, 10.0, 10.0, 10.0]> : tensor<4xf32> - %result = "mhlo.clamp"(%min, %val, %max) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_eq_const(%result, dense<[0.0, 4.0, 8.0, 10.0]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/compare.mlir b/tests/e2e/xla_ops/compare.mlir deleted file mode 100644 index b52ad8a7d8e9..000000000000 --- a/tests/e2e/xla_ops/compare.mlir +++ /dev/null @@ -1,164 +0,0 @@ -func.func @compare_tensor() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_scalar() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<5> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_i8() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<5> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_i16() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<5> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_i32() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<5> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_i64() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<5> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_f32() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<5.0> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_f64() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<5.0> : tensor - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %output = "mhlo.select"(%result, %c1, %c0) : (tensor, tensor, tensor) -> tensor - check.expect_eq_const(%output, dense<0> : tensor) : tensor - return -} - -func.func @compare_tensor_odd_length() { - %lhs = util.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<3xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<3xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<3xi1>, tensor<3xi8>, tensor<3xi8>) -> tensor<3xi8> - check.expect_eq_const(%output, dense<[0, 1, 0]> : tensor<3xi8>) : tensor<3xi8> - return -} - -func.func @compare_eq() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[0, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_ne() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[1, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_lt() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[1, 0, 0, 0]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_le() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[1, 1, 0, 1]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_gt() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[0, 0, 1, 0]> : tensor<4xi8>) : tensor<4xi8> - return -} - -func.func @compare_ge() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.compare"(%lhs, %rhs) {comparison_direction = #mhlo} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8> - return -} diff --git a/tests/e2e/xla_ops/complex.mlir b/tests/e2e/xla_ops/complex.mlir deleted file mode 100644 index 63898c5dd0e2..000000000000 --- a/tests/e2e/xla_ops/complex.mlir +++ /dev/null @@ -1,23 +0,0 @@ -func.func @math_sin() { - %real = util.unfoldable_constant dense<[0., 1., 1., -1.]> : tensor<4xf32> - %imag = util.unfoldable_constant dense<[0., 1., -1., 1.]> : tensor<4xf32> - %complex = "mhlo.complex"(%real, %imag) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - %result = "mhlo.sine"(%complex) : (tensor<4xcomplex>) -> tensor<4xcomplex> - %result_real = "mhlo.real"(%result) : (tensor<4xcomplex>) -> tensor<4xf32> - %result_imag = "mhlo.imag"(%result) : (tensor<4xcomplex>) -> tensor<4xf32> - check.expect_almost_eq_const(%result_real, dense<[0., 1.29846, 1.29846, -1.29846]> : tensor<4xf32>) : tensor<4xf32> - check.expect_almost_eq_const(%result_imag, dense<[0., 0.634964, -0.634964, 0.634964]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @math_exp() { - %real = util.unfoldable_constant dense<[0., 1., 1., -1.]> : tensor<4xf32> - %imag = util.unfoldable_constant dense<[0., 1., -1., 1.]> : tensor<4xf32> - %complex = "mhlo.complex"(%real, %imag) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - %result = "mhlo.exponential"(%complex) : (tensor<4xcomplex>) -> tensor<4xcomplex> - %result_real = "mhlo.real"(%result) : (tensor<4xcomplex>) -> tensor<4xf32> - %result_imag = "mhlo.imag"(%result) : (tensor<4xcomplex>) -> tensor<4xf32> - check.expect_almost_eq_const(%result_real, dense<[1., 1.46869, 1.46869, 0.19876]> : tensor<4xf32>) : tensor<4xf32> - check.expect_almost_eq_const(%result_imag, dense<[0., 2.28735, -2.28735, 0.30956]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/concatenate.mlir b/tests/e2e/xla_ops/concatenate.mlir deleted file mode 100644 index 2c0b82e82adb..000000000000 --- a/tests/e2e/xla_ops/concatenate.mlir +++ /dev/null @@ -1,26 +0,0 @@ -func.func @xla_concatenate() { - %c0 = util.unfoldable_constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - %c1 = util.unfoldable_constant dense<[[5, 6, 7], [8, 9, 10]]> : tensor<2x3xi32> - %c2 = util.unfoldable_constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32> - - %0 = "mhlo.concatenate"(%c0, %c1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32> - check.expect_eq_const(%0, dense<[[1, 2, 5, 6, 7], [3, 4, 8, 9, 10]]> : tensor<2x5xi32>) : tensor<2x5xi32> - - %1 = "mhlo.concatenate"(%c1, %c0) {dimension = 1} : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32> - check.expect_eq_const(%1, dense<[[5, 6, 7, 1, 2], [8, 9, 10, 3, 4]]> : tensor<2x5xi32>) : tensor<2x5xi32> - - %2 = "mhlo.concatenate"(%c0, %c1, %c2) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x7xi32> - check.expect_eq_const(%2, dense<[[1, 2, 5, 6, 7, 11, 12], [3, 4, 8, 9, 10, 13, 14]]> : tensor<2x7xi32>) : tensor<2x7xi32> - - %3 = "mhlo.concatenate"(%c0, %c2) {dimension = 0} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32> - check.expect_eq_const(%3, dense<[[1, 2], [3, 4], [11, 12], [13, 14]]> : tensor<4x2xi32>) : tensor<4x2xi32> - return -} - -func.func @concatenate_cst() { - %c0 = util.unfoldable_constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> - %c1 = mhlo.constant dense<0> : tensor<2x3xi32> - %0 = "mhlo.concatenate"(%c0, %c1) {dimension = 1} : (tensor<2x2xi32>, tensor<2x3xi32>) -> tensor<2x5xi32> - check.expect_eq_const(%0, dense<[[1, 2, 0, 0, 0], [3, 4, 0, 0, 0]]> : tensor<2x5xi32>) : tensor<2x5xi32> - return -} diff --git a/tests/e2e/xla_ops/constant.mlir b/tests/e2e/xla_ops/constant.mlir deleted file mode 100644 index 2bea25c2895b..000000000000 --- a/tests/e2e/xla_ops/constant.mlir +++ /dev/null @@ -1,26 +0,0 @@ -func.func @high_rank () { - %dense = mhlo.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> - check.expect_eq_const(%dense, dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32> - - // %splat = mhlo.constant dense<1> : tensor<2x2x3xi32> - // check.expect_eq_const(%splat, dense<1> : tensor<2x2x3xi32>) : tensor<2x2x3xi32> - return -} - -// func.func @i8() { -// %c = mhlo.constant dense<[1, 2]> : tensor<2xi8> -// check.expect_eq_const(%c, dense<[1, 2]> : tensor<2xi8>) : tensor<2xi8> -// return -// } - -// func.func @i32 () { -// %c = mhlo.constant dense<[1, 2]> : tensor<2xi32> -// check.expect_eq_const(%c, dense<[1, 2]> : tensor<2xi32>) : tensor<2xi32> -// return -// } - -// func.func @f32 () { -// %c = mhlo.constant dense<[1.1, 2.1]> : tensor<2xf32> -// check.expect_almost_eq_const(%c, dense<[1.1, 2.1]> : tensor<2xf32>) : tensor<2xf32> -// return -// } diff --git a/tests/e2e/xla_ops/convert.mlir b/tests/e2e/xla_ops/convert.mlir deleted file mode 100644 index c75fe28c38a4..000000000000 --- a/tests/e2e/xla_ops/convert.mlir +++ /dev/null @@ -1,61 +0,0 @@ -func.func @narrow_int_i32_i8() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi32> - %res = "mhlo.convert"(%input) : (tensor<3xi32>) -> tensor<3xi8> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi8>) : tensor<3xi8> - return -} - -func.func @widen_int_i8_i32() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi8> - %res = "mhlo.convert"(%input) : (tensor<3xi8>) -> tensor<3xi32> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @narrow_int_i32_i16() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi32> - %res = "mhlo.convert"(%input) : (tensor<3xi32>) -> tensor<3xi16> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi16>) : tensor<3xi16> - return -} - -func.func @widen_int_i16_i32() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi16> - %res = "mhlo.convert"(%input) : (tensor<3xi16>) -> tensor<3xi32> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @narrow_int_i64_i32() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi64> - %res = "mhlo.convert"(%input) : (tensor<3xi64>) -> tensor<3xi32> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @widen_int_i32_i64() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi32> - %res = "mhlo.convert"(%input) : (tensor<3xi32>) -> tensor<3xi64> - check.expect_eq_const(%res, dense<[-42, 0, 42]> : tensor<3xi64>) : tensor<3xi64> - return -} - -func.func @int_to_float() { - %input = util.unfoldable_constant dense<[-42, 0, 42]> : tensor<3xi32> - %res = "mhlo.convert"(%input) : (tensor<3xi32>) -> tensor<3xf32> - check.expect_almost_eq_const(%res, dense<[-42.0, 0.0, 42.0]> : tensor<3xf32>) : tensor<3xf32> - return -} - -// TODO(#6160): XLA does not specify the rounding behavior, meaning that we -// can't test something like -10.5 as that could be -11 (roundf) or -10 (rint -// with round-to-even mode). -// -// For casting rules, see -// https://www.tensorflow.org/xla/operation_semantics#convertelementtype -// func.func @float_to_int() { -// %input = util.unfoldable_constant dense<[-10.5, -4.4, 4.4, 10.5]> : tensor<4xf32> -// %res = "mhlo.convert"(%input) : (tensor<4xf32>) -> tensor<4xi32> -// check.expect_eq_const(%res, dense<[-10, -4, 4, 10]> : tensor<4xi32>) : tensor<4xi32> -// return -// } diff --git a/tests/e2e/xla_ops/convolution.mlir b/tests/e2e/xla_ops/convolution.mlir deleted file mode 100644 index 8b294dccdee4..000000000000 --- a/tests/e2e/xla_ops/convolution.mlir +++ /dev/null @@ -1,434 +0,0 @@ -func.func @conv2d_nopadding() { - %inputs = util.unfoldable_constant dense<[[ - [[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0]], - [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0]], - [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0]], - [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0]]]]> : tensor<1x4x4x2xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0]], [[ 3.0], [ 4.0]]], - [[[ 5.0], [ 6.0]], [[ 7.0], [ 8.0]]], - [[[ 9.0], [10.0]], [[11.0], [12.0]]]]> : tensor<3x2x2x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[1310.0],[1466.0],[1622.0]], - [[2090.0],[2246.0],[2402.0]] - ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32> - return -} - -func.func @conv2d_nopadding_batch_feature() { - %inputs = util.unfoldable_constant dense<[ - [[[ 1.0], [ 3.0], [ 5.0], [ 7.0]], - [[11.0], [13.0], [15.0], [17.0]], - [[21.0], [23.0], [25.0], [27.0]], - [[31.0], [33.0], [35.0], [37.0]]], - [[[ 2.0], [ 4.0], [ 6.0], [ 8.0]], - [[12.0], [14.0], [16.0], [18.0]], - [[22.0], [24.0], [26.0], [28.0]], - [[32.0], [34.0], [36.0], [38.0]]] - ]> : tensor<2x4x4x1xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0]], [[ 3.0], [ 4.0]]], - [[[ 5.0], [ 6.0]], [[ 7.0], [ 8.0]]], - [[[ 9.0], [10.0]], [[11.0], [12.0]]]]> : tensor<3x2x2x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x4x1xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[1310.0],[1466.0],[1622.0]], - [[2090.0],[2246.0],[2402.0]] - ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32> - return -} - -func.func @conv2d_reorder_input_spatial() { - %inputs = util.unfoldable_constant dense<[[ - [[ 1.0, 2.0], [11.0, 12.0], [21.0, 22.0], [31.0, 32.0]], - [[ 3.0, 4.0], [13.0, 14.0], [23.0, 24.0], [33.0, 34.0]], - [[ 5.0, 6.0], [15.0, 16.0], [25.0, 26.0], [35.0, 36.0]], - [[ 7.0, 8.0], [17.0, 18.0], [27.0, 28.0], [37.0, 38.0]]]]> : tensor<1x4x4x2xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0]], [[ 3.0], [ 4.0]]], - [[[ 5.0], [ 6.0]], [[ 7.0], [ 8.0]]], - [[[ 9.0], [10.0]], [[11.0], [12.0]]]]> : tensor<3x2x2x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x2x3x1xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[1310.0],[1466.0],[1622.0]], - [[2090.0],[2246.0],[2402.0]] - ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32> - return -} - -func.func @conv2d_reorder_kernel() { - %inputs = util.unfoldable_constant dense<[[ - [[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0]], - [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0]], - [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0]], - [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0]]]]> : tensor<1x4x4x2xf32> - %weights = util.unfoldable_constant dense< - [[[[ 1.0, 3.0], [ 2.0, 4.0]], - [[ 5.0, 7.0], [ 6.0, 8.0]], - [[ 9.0, 11.0], [10.0, 12.0]]]]> : tensor<1x3x2x2xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<1x3x2x2xf32>) -> tensor<1x2x3x1xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[1310.0],[1466.0],[1622.0]], - [[2090.0],[2246.0],[2402.0]] - ]]> : tensor<1x2x3x1xf32>) : tensor<1x2x3x1xf32> - return -} - -func.func @conv2d_reorder_output() { - %inputs = util.unfoldable_constant dense<[[ - [[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0]], - [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0]], - [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0]], - [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0]]]]> : tensor<1x4x4x2xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0]], [[ 3.0], [ 4.0]]], - [[[ 5.0], [ 6.0]], [[ 7.0], [ 8.0]]], - [[[ 9.0], [10.0]], [[11.0], [12.0]]]]> : tensor<3x2x2x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<1x4x4x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x3x1x2xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[1310.0, 2090.0]], - [[1466.0, 2246.0]], - [[1622.0, 2402.0]] - ]]> : tensor<1x3x1x2xf32>) : tensor<1x3x1x2xf32> - return -} - -func.func @conv2d_1452x3221_same() { - %inputs = util.unfoldable_constant dense<[[ - [[ 1.0, 2.0], [ 3.0, 4.0], [ 5.0, 6.0], [ 7.0, 8.0], [ 9.0, 10.0]], - [[11.0, 12.0], [13.0, 14.0], [15.0, 16.0], [17.0, 18.0], [19.0, 20.0]], - [[21.0, 22.0], [23.0, 24.0], [25.0, 26.0], [27.0, 28.0], [29.0, 30.0]], - [[31.0, 32.0], [33.0, 34.0], [35.0, 36.0], [37.0, 38.0], [39.0, 40.0]]]]> : tensor<1x4x5x2xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0]], [[ 3.0], [ 4.0]]], - [[[ 5.0], [ 6.0]], [[ 7.0], [ 8.0]]], - [[[ 9.0], [10.0]], [[11.0], [12.0]]]]> : tensor<3x2x2x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : - (tensor<1x4x5x2xf32>, tensor<3x2x2x1xf32>) -> tensor<1x4x5x1xf32> - check.expect_almost_eq_const(%res, dense<[[ - [[ 600.0], [ 736.0], [ 872.0], [1008.0], [ 476.0]], - [[1310.0], [1466.0], [1622.0], [1778.0], [ 805.0]], - [[2090.0], [2246.0], [2402.0], [2558.0], [1135.0]], - [[1080.0], [1152.0], [1224.0], [1296.0], [ 524.0]]]]> : tensor<1x4x5x1xf32>) : tensor<1x4x5x1xf32> - return -} - -func.func @conv2d_2451x2311_same() { - %inputs = util.unfoldable_constant dense<[ - [[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0]], - [[ 6.0], [ 7.0], [ 8.0], [ 9.0], [10.0]], - [[11.0], [12.0], [13.0], [14.0], [15.0]], - [[16.0], [17.0], [18.0], [19.0], [20.0]]], - [[[21.0], [22.0], [23.0], [24.0], [25.0]], - [[26.0], [27.0], [28.0], [29.0], [30.0]], - [[31.0], [32.0], [33.0], [34.0], [35.0]], - [[36.0], [37.0], [38.0], [39.0], [40.0]]]]> : tensor <2x4x5x1xf32> - %weights = util.unfoldable_constant dense<[ - [[[1.0]], [[2.0]], [[3.0]]], - [[[4.0]], [[5.0]], [[6.0]]]]> : tensor <2x3x1x1xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<[[0, 1], [1, 1]]> : tensor<2x2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : - (tensor<2x4x5x1xf32>, tensor<2x3x1x1xf32>) -> tensor<2x4x5x1xf32> - check.expect_almost_eq_const(%res, dense<[ - [[[ 80.0], [121.0], [142.0], [163.0], [100.0]], - [[160.0], [226.0], [247.0], [268.0], [160.0]], - [[240.0], [331.0], [352.0], [373.0], [220.0]], - [[ 83.0], [104.0], [110.0], [116.0], [ 59.0]]], - [[[400.0], [541.0], [562.0], [583.0], [340.0]], - [[480.0], [646.0], [667.0], [688.0], [400.0]], - [[560.0], [751.0], [772.0], [793.0], [460.0]], - [[183.0], [224.0], [230.0], [236.0], [119.0]]]]> : tensor<2x4x5x1xf32>) : tensor<2x4x5x1xf32> - return -} - -func.func @conv2d_no_padding2() { - %inputs = util.unfoldable_constant dense<[ - [[[ 1.0, 2.0, 3.0], - [ 4.0, 5.0, 6.0], - [ 7.0, 8.0, 9.0], - [ 10.0, 11.0, 12.0], - [ 13.0, 14.0, 15.0]], - [[ 16.0, 17.0, 18.0], - [ 19.0, 20.0, 21.0], - [ 22.0, 23.0, 24.0], - [ 25.0, 26.0, 27.0], - [ 28.0, 29.0, 30.0]], - [[ 31.0, 32.0, 33.0], - [ 34.0, 35.0, 36.0], - [ 37.0, 38.0, 39.0], - [ 40.0, 41.0, 42.0], - [ 43.0, 44.0, 45.0]], - [[ 46.0, 47.0, 48.0], - [ 49.0, 50.0, 51.0], - [ 52.0, 53.0, 54.0], - [ 55.0, 56.0, 57.0], - [ 58.0, 59.0, 60.0]]], - [[[ 61.0, 62.0, 63.0], - [ 64.0, 65.0, 66.0], - [ 67.0, 68.0, 69.0], - [ 70.0, 71.0, 72.0], - [ 73.0, 74.0, 75.0]], - [[ 76.0, 77.0, 78.0], - [ 79.0, 80.0, 81.0], - [ 82.0, 83.0, 84.0], - [ 85.0, 86.0, 87.0], - [ 88.0, 89.0, 90.0]], - [[ 91.0, 92.0, 93.0], - [ 94.0, 95.0, 96.0], - [ 97.0, 98.0, 99.0], - [100.0, 101.0, 102.0], - [103.0, 104.0, 105.0]], - [[106.0, 107.0, 108.0], - [109.0, 110.0, 111.0], - [112.0, 113.0, 114.0], - [115.0, 116.0, 117.0], - [118.0, 119.0, 120.0]]]]> : tensor<2x4x5x3xf32> - %weights = util.unfoldable_constant dense<[ - [[[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - [ 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], - [ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0]], - [[ 19.0, 20.0, 21.0, 22.0, 23.0, 24.0], - [ 25.0, 26.0, 27.0, 28.0, 29.0, 30.0], - [ 31.0, 32.0, 33.0, 34.0, 35.0, 36.0]], - [[ 37.0, 38.0, 39.0, 40.0, 41.0, 42.0], - [ 43.0, 44.0, 45.0, 46.0, 47.0, 48.0], - [ 49.0, 50.0, 51.0, 52.0, 53.0, 54.0]]], - [[[ 55.0, 56.0, 57.0, 58.0, 59.0, 60.0], - [ 61.0, 62.0, 63.0, 64.0, 65.0, 66.0], - [ 67.0, 68.0, 69.0, 70.0, 71.0, 72.0]], - [[ 73.0, 74.0, 75.0, 76.0, 77.0, 78.0], - [ 79.0, 80.0, 81.0, 82.0, 83.0, 84.0], - [ 85.0, 86.0, 87.0, 88.0, 89.0, 90.0]], - [[ 91.0, 92.0, 93.0, 94.0, 95.0, 96.0], - [ 97.0, 98.0, 99.0, 100.0, 101.0, 102.0], - [103.0, 104.0, 105.0, 106.0, 107.0, 108.0]]]]> : tensor<2x3x3x6xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : - (tensor<2x4x5x3xf32>, tensor<2x3x3x6xf32>) -> tensor<2x3x3x6xf32> - check.expect_almost_eq_const(%res, dense<[ - [[[16065.0, 16290.0, 16515.0, 16740.0, 16965.0, 17190.0], - [18873.0, 19152.0, 19431.0, 19710.0, 19989.0, 20268.0], - [21681.0, 22014.0, 22347.0, 22680.0, 23013.0, 23346.0]], - [[30105.0, 30600.0, 31095.0, 31590.0, 32085.0, 32580.0], - [32913.0, 33462.0, 34011.0, 34560.0, 35109.0, 35658.0], - [35721.0, 36324.0, 36927.0, 37530.0, 38133.0, 38736.0]], - [[44145.0, 44910.0, 45675.0, 46440.0, 47205.0, 47970.0], - [46953.0, 47772.0, 48591.0, 49410.0, 50229.0, 51048.0], - [49761.0, 50634.0, 51507.0, 52380.0, 53253.0, 54126.0]]], - [[[72225.0, 73530.0, 74835.0, 76140.0, 77445.0, 78750.0], - [75033.0, 76392.0, 77751.0, 79110.0, 80469.0, 81828.0], - [77841.0, 79254.0, 80667.0, 82080.0, 83493.0, 84906.0]], - [[86265.0, 87840.0, 89415.0, 90990.0, 92565.0, 94140.0], - [89073.0, 90702.0, 92331.0, 93960.0, 95589.0, 97218.0], - [91881.0, 93564.0, 95247.0, 96930.0, 98613.0, 100296.0]], - [[100305.0, 102150.0, 103995.0, 105840.0, 107685.0, 109530.0], - [103113.0, 105012.0, 106911.0, 108810.0, 110709.0, 112608.0], - [105921.0, 107874.0, 109827.0, 111780.0, 113733.0, 115686.0]]]]> : tensor<2x3x3x6xf32>) : tensor<2x3x3x6xf32> - return -} - -func.func @conv2d_1452x2223_dilated_valid() { - %inputs = util.unfoldable_constant dense< - [[[[0.09762701, 0.43037874], - [ 0.20552675, 0.08976637], - [-0.1526904, 0.29178822], - [-0.12482557, 0.78354603], - [ 0.92732555, -0.23311697]], - [[ 0.5834501, 0.05778984], - [ 0.13608912, 0.85119325], - [-0.85792786, -0.8257414 ], - [-0.9595632, 0.6652397 ], - [ 0.5563135, 0.74002427]], - [[ 0.9572367, 0.59831715], - [-0.07704128, 0.56105834], - [-0.76345116, 0.27984205], - [-0.71329343, 0.88933784], - [ 0.04369664, -0.17067613]], - [[-0.47088876, 0.5484674 ], - [-0.08769934, 0.1368679 ], - [-0.9624204, 0.23527099], - [ 0.22419144, 0.23386799], - [ 0.8874962, 0.3636406 ]]]]> : tensor<1x4x5x2xf32> - %weights = util.unfoldable_constant dense< - [[[[-0.2809842, -0.12593609, 0.3952624 ], - [-0.8795491, 0.33353344, 0.34127575]], - [[-0.5792349, -0.7421474, -0.3691433 ], - [-0.27257845, 0.14039354, -0.12279698]]], - [[[ 0.9767477, -0.79591036, -0.5822465 ], - [-0.677381, 0.30621666, -0.4934168 ]], - [[-0.06737845, -0.5111488, -0.68206084], - [-0.7792497, 0.31265917, -0.7236341 ]]]]> : tensor<2x2x2x3xf32> - %res = "mhlo.convolution"(%inputs, %weights) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 1 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = dense<[2, 1]> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32> - check.expect_almost_eq_const(%res, dense< - [[[[-0.45181108, -0.37253797, -1.1074474 ], - [-0.74972206, 0.8691965, 0.21864426], - [-1.9352274, 1.6551838, 0.13848126], - [-2.296763, 0.32046723, -0.02542188]], - [[-1.4578199, 0.59465677, 0.0599021 ], - [-0.3617443, 1.4647548, 1.2320882 ], - [ 0.04506956, 1.4347346, -0.22625303], - [-1.122044, -0.41301775, -1.5628793 ]]]]> : tensor<1x2x4x3xf32>) : tensor<1x2x4x3xf32> - return -} - -func.func @depthwise_conv_non_1_channel_multiplier() { - %arg0 = util.unfoldable_constant dense<1.0> : tensor<2x4x5x2xf32> - %arg1 = util.unfoldable_constant dense<1.0> : tensor<2x2x1x6xf32> - %res = "mhlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, - feature_group_count = 2 : i64, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} : (tensor<2x4x5x2xf32>, tensor<2x2x1x6xf32>) -> tensor<2x3x4x6xf32> - check.expect_almost_eq_const(%res, dense<4.0> : tensor<2x3x4x6xf32>) : tensor<2x3x4x6xf32> - return -} diff --git a/tests/e2e/xla_ops/cosine.mlir b/tests/e2e/xla_ops/cosine.mlir deleted file mode 100644 index 16684e1f5fcb..000000000000 --- a/tests/e2e/xla_ops/cosine.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[0.0, 1.0, 1.5, 2.0]> : tensor<4xf32> - %result = "mhlo.cosine"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 0.5403, 0.0707, -0.4161]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<3.0> : tensor - %result = "mhlo.cosine"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<-0.99> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/divide.mlir b/tests/e2e/xla_ops/divide.mlir deleted file mode 100644 index 3a99d869b1d8..000000000000 --- a/tests/e2e/xla_ops/divide.mlir +++ /dev/null @@ -1,15 +0,0 @@ -func.func @i32() { - %0 = util.unfoldable_constant dense<[5, 6, 7, 8]> : tensor<4xi32> - %1 = util.unfoldable_constant dense<[1, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.divide"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[5, 3, 2, 2]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @f32() { - %0 = util.unfoldable_constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf32> - %1 = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.divide"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[5.0, 3.0, 2.333333, 2.0]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/dot.mlir b/tests/e2e/xla_ops/dot.mlir deleted file mode 100644 index 11bdd5c1af03..000000000000 --- a/tests/e2e/xla_ops/dot.mlir +++ /dev/null @@ -1,68 +0,0 @@ -func.func @f32() { - %lhs = util.unfoldable_constant dense<[ - [15.0, 14.0, 13.0], - [12.0, 11.0, 10.0], - [09.0, 08.0, 07.0], - [06.0, 05.0, 04.0], - [03.0, 02.0, 01.0]]> : tensor<5x3xf32> - %rhs = util.unfoldable_constant dense<[ - [15.0, 14.0, 13.0, 12.0, 11.0], - [10.0, 09.0, 08.0, 07.0, 06.0], - [05.0, 04.0, 03.0, 02.0, 01.0]]> : tensor<3x5xf32> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<5x3xf32>, tensor<3x5xf32>) -> tensor<5x5xf32> - check.expect_almost_eq_const(%res, dense<[ - [430.0, 388.0, 346.0, 304.0, 262.0], - [340.0, 307.0, 274.0, 241.0, 208.0], - [250.0, 226.0, 202.0, 178.0, 154.0], - [160.0, 145.0, 130.0, 115.0, 100.0], - [70.0, 64.0, 58.0, 52.0, 46.0]]> : tensor<5x5xf32>) : tensor<5x5xf32> - return -} - -func.func @i32i32.i32() { - %lhs = util.unfoldable_constant dense<3> : tensor<2x4xi32> - %rhs = util.unfoldable_constant dense<2> : tensor<4x2xi32> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<2x4xi32>, tensor<4x2xi32>) -> tensor<2x2xi32> - check.expect_eq_const(%res, dense<24> : tensor<2x2xi32>) : tensor<2x2xi32> - return -} - -func.func @i8i8.i32() { - %lhs = util.unfoldable_constant dense<3> : tensor<2x4xi8> - %rhs = util.unfoldable_constant dense<2> : tensor<4x2xi8> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<2x4xi8>, tensor<4x2xi8>) -> tensor<2x2xi32> - check.expect_eq_const(%res, dense<24> : tensor<2x2xi32>) : tensor<2x2xi32> - return -} - -func.func @i16i16.i32() { - %lhs = util.unfoldable_constant dense<3> : tensor<2x4xi16> - %rhs = util.unfoldable_constant dense<2> : tensor<4x2xi16> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<2x4xi16>, tensor<4x2xi16>) -> tensor<2x2xi32> - check.expect_eq_const(%res, dense<24> : tensor<2x2xi32>) : tensor<2x2xi32> - return -} - -func.func @large() { - %lhs = util.unfoldable_constant dense<1.0> : tensor<15x16xf32> - %rhs = util.unfoldable_constant dense<0.4> : tensor<16x17xf32> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<15x16xf32>, tensor<16x17xf32>) -> tensor<15x17xf32> - check.expect_almost_eq_const(%res, dense<6.4> : tensor<15x17xf32>) : tensor<15x17xf32> - return -} - -func.func @matvec() { - %lhs = util.unfoldable_constant dense<1.0> : tensor<15x32xf32> - %rhs = util.unfoldable_constant dense<0.5> : tensor<32xf32> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<15x32xf32>, tensor<32xf32>) -> tensor<15xf32> - check.expect_almost_eq_const(%res, dense<16.0> : tensor<15xf32>) : tensor<15xf32> - return -} - -func.func @dot() { - %lhs = util.unfoldable_constant dense<1.0> : tensor<1024xf32> - %rhs = util.unfoldable_constant dense<0.5> : tensor<1024xf32> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor - check.expect_almost_eq_const(%res, dense<512.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/dot_bf16.mlir b/tests/e2e/xla_ops/dot_bf16.mlir deleted file mode 100644 index aa5a5121364b..000000000000 --- a/tests/e2e/xla_ops/dot_bf16.mlir +++ /dev/null @@ -1,20 +0,0 @@ -func.func @bf16() { - %lhs = util.unfoldable_constant dense<[ - [15.0, 14.0, 13.0], - [12.0, 11.0, 10.0], - [09.0, 08.0, 07.0], - [06.0, 05.0, 04.0], - [03.0, 02.0, 01.0]]> : tensor<5x3xbf16> - %rhs = util.unfoldable_constant dense<[ - [15.0, 14.0, 13.0, 12.0, 11.0], - [10.0, 09.0, 08.0, 07.0, 06.0], - [05.0, 04.0, 03.0, 02.0, 01.0]]> : tensor<3x5xbf16> - %res = "mhlo.dot"(%lhs, %rhs) : (tensor<5x3xbf16>, tensor<3x5xbf16>) -> tensor<5x5xf32> - check.expect_almost_eq_const(%res, dense<[ - [430.0, 388.0, 346.0, 304.0, 262.0], - [340.0, 307.0, 274.0, 241.0, 208.0], - [250.0, 226.0, 202.0, 178.0, 154.0], - [160.0, 145.0, 130.0, 115.0, 100.0], - [70.0, 64.0, 58.0, 52.0, 46.0]]> : tensor<5x5xf32>) : tensor<5x5xf32> - return -} diff --git a/tests/e2e/xla_ops/dot_general.mlir b/tests/e2e/xla_ops/dot_general.mlir deleted file mode 100644 index ccfad7c9d68e..000000000000 --- a/tests/e2e/xla_ops/dot_general.mlir +++ /dev/null @@ -1,157 +0,0 @@ -func.func @dot_general_lower() { - %lhs = util.unfoldable_constant dense<[[[0.3, 0.5]]]> : tensor<1x1x2xf32> - %rhs = util.unfoldable_constant dense<[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]> : tensor<2x3xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [], - rhs_contracting_dimensions = [0], - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> - check.expect_almost_eq_const(%res, dense<[[[0.23, 0.31, 0.39]]]> : tensor<1x1x3xf32>) : tensor<1x1x3xf32> - return -} - -func.func @dot_general_lower_swapped() { - %lhs = util.unfoldable_constant dense<[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]> : tensor<2x3xf32> - %rhs = util.unfoldable_constant dense<[[[0.3, 0.5]]]> : tensor<1x1x2xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [], - lhs_contracting_dimensions = [0], - rhs_batching_dimensions = [], - rhs_contracting_dimensions = [2], - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> - check.expect_almost_eq_const(%res, dense<[[[0.23]],[[0.31]],[[0.39]]]> : tensor<3x1x1xf32>) : tensor<3x1x1xf32> - return -} - -func.func @dot_general_trivial_batching_dimension() { - %lhs = util.unfoldable_constant dense<[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]> : tensor<1x2x3xf32> - %rhs = util.unfoldable_constant dense<[[ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0]]]> : tensor<1x3x4xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [0], - rhs_contracting_dimensions = [1], - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<1x2x3xf32>, tensor<1x3x4xf32>) -> tensor<1x2x4xf32> - check.expect_almost_eq_const(%res, dense<[[[0.6, 1.2, 1.8, 2.4],[1.5, 3.0, 4.5, 6.0]]]> : tensor<1x2x4xf32>) : tensor<1x2x4xf32> - return -} - -func.func @dot_general_matmul() { - %lhs = util.unfoldable_constant dense<3.0> : tensor<2x4xf32> - %rhs = util.unfoldable_constant dense<2.0> : tensor<4x2xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_batching_dimensions = [], - rhs_contracting_dimensions = [0], - > - } : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> - check.expect_eq_const(%res, dense<24.0> : tensor<2x2xf32>) : tensor<2x2xf32> - return -} - -func.func @dot_general_matmul_i32.i32.i32() { - %lhs = util.unfoldable_constant dense<3> : tensor<2x4xi32> - %rhs = util.unfoldable_constant dense<2> : tensor<4x2xi32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [], - lhs_contracting_dimensions = [1], - rhs_batching_dimensions = [], - rhs_contracting_dimensions = [0], - > - } : (tensor<2x4xi32>, tensor<4x2xi32>) -> tensor<2x2xi32> - check.expect_eq_const(%res, dense<24> : tensor<2x2xi32>) : tensor<2x2xi32> - return -} - -func.func @dot_general_nontrivial_batching_dimension() { - %lhs = util.unfoldable_constant dense<[ - [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<2x2x3xf32> - %rhs = util.unfoldable_constant dense<[[ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0] - ], [ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0]]]> : tensor<2x3x4xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [0], - rhs_contracting_dimensions = [1], - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<2x2x3xf32>, tensor<2x3x4xf32>) -> tensor<2x2x4xf32> - check.expect_almost_eq_const(%res, dense<[ - [ - [0.6, 1.2, 1.8, 2.4], - [1.5, 3.0, 4.5, 6.0] - ], [ - [6.0, 12.0, 18.0, 24.0], - [15.0, 30.0, 45.0, 60.0]]]> : tensor<2x2x4xf32>) : tensor<2x2x4xf32> - return -} - -func.func @large_dot_general() { - %lhs = util.unfoldable_constant dense<1.0> : tensor<4x8x128xf32> - %rhs = util.unfoldable_constant dense<0.4> : tensor<4x128x16xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [0], - lhs_contracting_dimensions = [2], - rhs_batching_dimensions = [0], - rhs_contracting_dimensions = [1], - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<4x8x128xf32>, tensor<4x128x16xf32>) -> tensor<4x8x16xf32> - check.expect_almost_eq_const(%res, dense<51.2> : tensor<4x8x16xf32>) : tensor<4x8x16xf32> - return -} - -func.func @dot_general_nontrivial_batching_mutliple_parallel_dimension() { - %lhs = util.unfoldable_constant dense<[ - [[[0.0], [1.0]], [[2.0], [3.0]], [[ 4.0], [ 5.0]]], - [[[6.0], [7.0]], [[8.0], [9.0]], [[10.0], [11.0]]] - ]> : tensor<2x3x2x1xf32> - %rhs = util.unfoldable_constant dense<[ - [[0.0], [1.0]], [[2.0], [3.0]] - ]> : tensor<2x2x1xf32> - %res = "mhlo.dot_general"(%lhs, %rhs) { - dot_dimension_numbers = #mhlo.dot< - lhs_batching_dimensions = [2], - rhs_batching_dimensions = [1], - lhs_contracting_dimensions = [3], - rhs_contracting_dimensions = [2] - >, - precision_config = [#mhlo, #mhlo] - } : (tensor<2x3x2x1xf32>, tensor<2x2x1xf32>) -> tensor<2x2x3x2xf32> - check.expect_almost_eq_const(%res, dense<[ - [ - [[0.0, 0.0], [0.0, 4.0], [0.0, 8.0]], - [[0.0, 12.0], [0.0, 16.0], [0.0, 20.0]] - ], - [ - [[1.0, 3.0], [3.0, 9.0], [ 5.0, 15.0]], - [[7.0, 21.0], [9.0, 27.0], [11.0, 33.0]] - ] - ]> : tensor<2x2x3x2xf32>) : tensor<2x2x3x2xf32> - return -} diff --git a/tests/e2e/xla_ops/dynamic_slice.mlir b/tests/e2e/xla_ops/dynamic_slice.mlir deleted file mode 100644 index 86aaaa1797f3..000000000000 --- a/tests/e2e/xla_ops/dynamic_slice.mlir +++ /dev/null @@ -1,40 +0,0 @@ -func.func @dynamic_slice() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %start1 = util.unfoldable_constant dense<1> : tensor - %start2 = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.dynamic_slice"(%input, %start1, %start2) { - slice_sizes = dense<[2, 2]> : tensor<2xi64> - } : (tensor<3x4xi32>, tensor, tensor) -> tensor<2x2xi32> - check.expect_eq_const(%result, dense<[ - [7, 8], - [11, 12]]> : tensor<2x2xi32>) : tensor<2x2xi32> - return -} - -func.func @dynamic_unit_slice() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %start1 = util.unfoldable_constant dense<1> : tensor - %start2 = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.dynamic_slice"(%input, %start1, %start2) { - slice_sizes = dense<[1, 2]> : tensor<2xi64> - } : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x2xi32> - check.expect_eq_const(%result, dense<[ - [7, 8]]> : tensor<1x2xi32>) : tensor<1x2xi32> - return -} - -func.func @dynamic_1d_slice() { - %input = util.unfoldable_constant dense<[1, 2, 3, 4]> : tensor<4xi32> - %start1 = util.unfoldable_constant dense<1> : tensor - %result = "mhlo.dynamic_slice"(%input, %start1) { - slice_sizes = dense<[2]> : tensor<1xi64> - } : (tensor<4xi32>, tensor) -> tensor<2xi32> - check.expect_eq_const(%result, dense<[2, 3]> : tensor<2xi32>) : tensor<2xi32> - return -} diff --git a/tests/e2e/xla_ops/dynamic_update_slice.mlir b/tests/e2e/xla_ops/dynamic_update_slice.mlir deleted file mode 100644 index 7031373ddceb..000000000000 --- a/tests/e2e/xla_ops/dynamic_update_slice.mlir +++ /dev/null @@ -1,35 +0,0 @@ -func.func @dynamic_update_slice_2x2() { - %target = util.unfoldable_constant dense<2> : tensor<3x3xi32> - %update = util.unfoldable_constant dense<1> : tensor<2x2xi32> - %c0 = util.unfoldable_constant dense<0> : tensor - %result = "mhlo.dynamic_update_slice"(%target, %update, %c0, %c0) - : (tensor<3x3xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<3x3xi32> - check.expect_eq_const(%result, dense<[ - [1, 1, 2], - [1, 1, 2], - [2, 2, 2]]> : tensor<3x3xi32>) : tensor<3x3xi32> - return -} - -func.func @dynamic_update_slice_1x3() { - %target = util.unfoldable_constant dense<2> : tensor<3x3xi32> - %update = util.unfoldable_constant dense<1> : tensor<1x3xi32> - %c0 = util.unfoldable_constant dense<0> : tensor - %c1 = util.unfoldable_constant dense<1> : tensor - %result = "mhlo.dynamic_update_slice"(%target, %update, %c1, %c0) - : (tensor<3x3xi32>, tensor<1x3xi32>, tensor, tensor) -> tensor<3x3xi32> - check.expect_eq_const(%result, dense<[ - [2, 2, 2], - [1, 1, 1], - [2, 2, 2]]> : tensor<3x3xi32>) : tensor<3x3xi32> - return -} - -func.func @into_constant() { - %update = util.unfoldable_constant dense<2> : tensor<1xi32> - %target = mhlo.constant dense<1> : tensor<4xi32> - %index = mhlo.constant dense<0> : tensor - %result = "mhlo.dynamic_update_slice"(%target, %update, %index) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[2, 1, 1, 1]> : tensor<4xi32>) : tensor<4xi32> - return -} diff --git a/tests/e2e/xla_ops/exponential.mlir b/tests/e2e/xla_ops/exponential.mlir deleted file mode 100644 index d18bbd6d84bc..000000000000 --- a/tests/e2e/xla_ops/exponential.mlir +++ /dev/null @@ -1,27 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32> - %result = "mhlo.exponential"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func.func @double() { - %input = util.unfoldable_constant dense<1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.7183> : tensor) : tensor - return -} - -func.func @negative() { - %input = util.unfoldable_constant dense<-1.0> : tensor - %result = "mhlo.exponential"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<0.367879> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/exponential_fp16.mlir b/tests/e2e/xla_ops/exponential_fp16.mlir deleted file mode 100644 index 9a21293c1823..000000000000 --- a/tests/e2e/xla_ops/exponential_fp16.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @tensor_fp16() { - %input = util.unfoldable_constant dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf16> - %result = "mhlo.exponential"(%input) : (tensor<4xf16>) -> tensor<4xf16> - check.expect_almost_eq_const(%result, dense<[1.0, 2.7183, 7.3891, 54.5981]> : tensor<4xf16>) : tensor<4xf16> - return -} diff --git a/tests/e2e/xla_ops/exponential_minus_one.mlir b/tests/e2e/xla_ops/exponential_minus_one.mlir deleted file mode 100644 index 1c3cb38a59a5..000000000000 --- a/tests/e2e/xla_ops/exponential_minus_one.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @exponential_minus_one() { - %input = util.unfoldable_constant dense<[0.0, 0.5, 1.0, -1.0]> : tensor<4xf32> - %result = "mhlo.exponential_minus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[0.0, 0.6487213, 1.7182818, -0.6321205]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/fft.mlir b/tests/e2e/xla_ops/fft.mlir deleted file mode 100644 index e347822c1e06..000000000000 --- a/tests/e2e/xla_ops/fft.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// TODO(hanchung): Add other types of fft tests, e.g. fft, ifft, irfft. - -func.func @rfft_1d() { - %input = util.unfoldable_constant dense<[ - 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7, - 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, - -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]> : tensor<32xf32> - %0 = "mhlo.fft"(%input) { - fft_length = dense<32> : tensor, fft_type = #mhlo - } : (tensor<32xf32>) -> tensor<17xcomplex> - %1 = "mhlo.real"(%0) : (tensor<17xcomplex>) -> tensor<17xf32> - %2 = "mhlo.imag"(%0) : (tensor<17xcomplex>) -> tensor<17xf32> - check.expect_almost_eq_const(%1, dense<[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]> : tensor<17xf32>) : tensor<17xf32> - check.expect_almost_eq_const(%2, dense<[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]> : tensor<17xf32>) : tensor<17xf32> - return -} - -func.func @rfft_2d() { - %input = util.unfoldable_constant dense<[[ - 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7, - 3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, - -0.777, 2.0, 1.7, 3.5, -4.5, 0.0]]> : tensor<1x32xf32> - %0 = "mhlo.fft"(%input) { - fft_length = dense<32> : tensor<1xi64>, fft_type = #mhlo - } : (tensor<1x32xf32>) -> tensor<1x17xcomplex> - %1 = "mhlo.real"(%0) : (tensor<1x17xcomplex>) -> tensor<1x17xf32> - %2 = "mhlo.imag"(%0) : (tensor<1x17xcomplex>) -> tensor<1x17xf32> - check.expect_almost_eq_const(%1, dense<[[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]]> : tensor<1x17xf32>) : tensor<1x17xf32> - check.expect_almost_eq_const(%2, dense<[[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]]> : tensor<1x17xf32>) : tensor<1x17xf32> - return -} diff --git a/tests/e2e/xla_ops/finite.mlir b/tests/e2e/xla_ops/finite.mlir deleted file mode 100644 index 3179b18b8212..000000000000 --- a/tests/e2e/xla_ops/finite.mlir +++ /dev/null @@ -1,11 +0,0 @@ -func.func @f32() { - %0 = util.unfoldable_constant dense<[1.0, 6.0, -6.0, 0.0]> : tensor<4xf32> - %1 = util.unfoldable_constant dense<[0.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %2 = "mhlo.divide"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %result = "mhlo.is_finite"(%2) : (tensor<4xf32>) -> tensor<4xi1> - %c0 = util.unfoldable_constant dense<0> : tensor<4xi8> - %c1 = util.unfoldable_constant dense<1> : tensor<4xi8> - %output = "mhlo.select"(%result, %c1, %c0) : (tensor<4xi1>, tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> - check.expect_eq_const(%output, dense<[0, 1, 1, 1]> : tensor<4xi8>) : tensor<4xi8> - return -} diff --git a/tests/e2e/xla_ops/floor.mlir b/tests/e2e/xla_ops/floor.mlir deleted file mode 100644 index da10a87fbf8f..000000000000 --- a/tests/e2e/xla_ops/floor.mlir +++ /dev/null @@ -1,20 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[0.0, 1.1, 2.5, 4.9]> : tensor<4xf32> - %result = "mhlo.floor"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[0.0, 1.0, 2.0, 4.0]> : tensor<4xf32>): tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<101.3> : tensor - %result = "mhlo.floor"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<101.0> : tensor): tensor - return -} - -func.func @negative() { - %input = util.unfoldable_constant dense<-1.1> : tensor - %result = "mhlo.floor"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<-2.0> : tensor): tensor - return -} diff --git a/tests/e2e/xla_ops/gather.mlir b/tests/e2e/xla_ops/gather.mlir deleted file mode 100644 index d563af1f9e83..000000000000 --- a/tests/e2e/xla_ops/gather.mlir +++ /dev/null @@ -1,169 +0,0 @@ -func.func @foo() { - %input = util.unfoldable_constant dense<[ - [[01, 02, 03, 04, 05]], - [[06, 07, 08, 09, 10]], - [[11, 12, 13, 14, 15]], - [[16, 17, 18, 19, 20]], - [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32> - %start_indices = util.unfoldable_constant dense<2> : tensor - %res = "mhlo.gather"(%input, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0], - index_vector_dim = 0, - offset_dims = [0, 1], - start_index_map = [0], - >, - slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> - } : (tensor<5x1x5xi32>, tensor) -> tensor<1x5xi32> - check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32> - return -} - -func.func @via_torch_index_select() { - %input = util.unfoldable_constant dense<[ - [[01, 02, 03, 04, 05]], - [[06, 07, 08, 09, 10]], - [[11, 12, 13, 14, 15]], - [[16, 17, 18, 19, 20]], - [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32> - %start_indices = util.unfoldable_constant dense<2> : tensor - %res = "mhlo.gather"(%input, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0], - index_vector_dim = 0, - offset_dims = [0, 1], - start_index_map = [0], - >, - slice_sizes = dense<[1, 1, 5]> : tensor<3xi64> - } : (tensor<5x1x5xi32>, tensor) -> tensor<1x5xi32> - check.expect_eq_const(%res, dense<[[11, 12, 13, 14, 15]]> : tensor<1x5xi32>) : tensor<1x5xi32> - return -} - - -func.func @general_but_just_index_select() { - %operand = util.unfoldable_constant dense<[[ - [ 0, 1, 2, 3, 4, 5, 6, 7], - [ 8, 9, 10, 11, 12, 13, 14, 15], - [16, 17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32> - %start_indices = util.unfoldable_constant dense<[[ - [0, 1], - [0, 2], - [0, 3], - [0, 0], - [0, 0], - [0, 1], - [0, 2], - [0, 3]]]> : tensor<1x8x2xi32> - %result = "mhlo.gather"(%operand, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 1, 8]> : tensor<3xi64> - } : (tensor<1x4x8xi32>, tensor<1x8x2xi32>) -> tensor<1x8x8xi32> - check.expect_eq_const(%result, dense<[[ - [ 8, 9, 10, 11, 12, 13, 14, 15], - [16, 17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30, 31], - [ 0, 1, 2, 3, 4, 5, 6, 7], - [ 0, 1, 2, 3, 4, 5, 6, 7], - [ 8, 9, 10, 11, 12, 13, 14, 15], - [16, 17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x8x8xi32>) : tensor<1x8x8xi32> - return -} - -func.func @small_slices() { - %operand = util.unfoldable_constant dense<[[ - [ 0, 1, 2, 3, 4, 5, 6, 7], - [ 8, 9, 10, 11, 12, 13, 14, 15], - [16, 17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32> - %start_indices = util.unfoldable_constant dense<[[ - [0, 1], - [0, 2], - [0, 3], - [0, 0]]]> : tensor<1x4x2xi32> - %result = "mhlo.gather"(%operand, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0, 1], - index_vector_dim = 2, - offset_dims = [2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 1, 3]> : tensor<3xi64> - } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x4x3xi32> - check.expect_eq_const(%result, dense<[[ - [ 8, 9, 10], - [16, 17, 18], - [24, 25, 26], - [ 0, 1, 2]]]> : tensor<1x4x3xi32>) : tensor<1x4x3xi32> - return -} - -func.func @nonstandard_offset_dims() { - %operand = util.unfoldable_constant dense<[[ - [ 0, 1, 2, 3, 4, 5, 6, 7], - [ 8, 9, 10, 11, 12, 13, 14, 15], - [16, 17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30, 31]]]> : tensor<1x4x8xi32> - %start_indices = util.unfoldable_constant dense<[[ - [0, 1], - [0, 2], - [0, 2], - [0, 0]]]> : tensor<1x4x2xi32> - %result = "mhlo.gather"(%operand, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0], - index_vector_dim = 2, - offset_dims = [1, 2], - start_index_map = [0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 2, 3]> : tensor<3xi64> - } : (tensor<1x4x8xi32>, tensor<1x4x2xi32>) -> tensor<1x2x3x4xi32> - check.expect_eq_const(%result, dense<[[ - [[ 8, 16, 16, 0], - [ 9, 17, 17, 1], - [10, 18, 18, 2]], - [[16, 24, 24, 8], - [17, 25, 25, 9], - [18, 26, 26, 10]]]]> : tensor<1x2x3x4xi32>) : tensor<1x2x3x4xi32> - return -} - -func.func @reordered_start_index() { - %operand = util.unfoldable_constant dense<[[ - [[ 0, 1, 2, 3], - [ 4, 5, 6, 7]], - [[ 8, 9, 10, 11], - [12, 13, 14, 15]], - [[16, 17, 18, 19], - [20, 21, 22, 23]]]]> : tensor<1x3x2x4xi32> - %start_indices = util.unfoldable_constant dense<[ - [0, 1, 0, 0], - [1, 0, 0, 0]]> : tensor<2x4xi32> - %result = "mhlo.gather"(%operand, %start_indices) { - dimension_numbers = #mhlo.gather< - collapsed_slice_dims = [0, 2], - index_vector_dim = 1, - offset_dims = [1, 2], - start_index_map = [3, 2, 0, 1] - >, - indices_are_sorted = false, - slice_sizes = dense<[1, 2, 1, 3]> : tensor<4xi64> - } : (tensor<1x3x2x4xi32>, tensor<2x4xi32>) -> tensor<2x2x3xi32> - - check.expect_eq_const(%result, dense<[ - [[ 4, 5, 6], - [12, 13, 14]], - [[ 1, 2, 3], - [ 9, 10, 11]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32> - return -} diff --git a/tests/e2e/xla_ops/iota.mlir b/tests/e2e/xla_ops/iota.mlir deleted file mode 100644 index fba0f93e0659..000000000000 --- a/tests/e2e/xla_ops/iota.mlir +++ /dev/null @@ -1,16 +0,0 @@ -func.func @iota_dim0() { - %result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3xf32> - check.expect_almost_eq_const(%result, dense<[ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0]]> : tensor<2x3xf32>) : tensor<2x3xf32> - return -} - - -func.func @iota_dim1() { - %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3xf32> - check.expect_almost_eq_const(%result, dense<[ - [0.0, 1.0, 2.0], - [0.0, 1.0, 2.0]]> : tensor<2x3xf32>) : tensor<2x3xf32> - return -} diff --git a/tests/e2e/xla_ops/log.mlir b/tests/e2e/xla_ops/log.mlir deleted file mode 100644 index d0cf00fc6b26..000000000000 --- a/tests/e2e/xla_ops/log.mlir +++ /dev/null @@ -1,20 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.log"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[0.0, 0.693147, 1.09861, 1.38629]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<4.0> : tensor - %result = "mhlo.log"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.3863> : tensor) : tensor - return -} - -func.func @double() { - %input = util.unfoldable_constant dense<4.0> : tensor - %result = "mhlo.log"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.3863> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/log_plus_one.mlir b/tests/e2e/xla_ops/log_plus_one.mlir deleted file mode 100644 index acdd62664e10..000000000000 --- a/tests/e2e/xla_ops/log_plus_one.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @log_plus_one() { - %input = util.unfoldable_constant dense<[0.0, 0.5, 1.0, 5.0]> : tensor<4xf32> - %result = "mhlo.log_plus_one"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[0.0, 0.4054651, 0.6931472, 1.7917595]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/maximum.mlir b/tests/e2e/xla_ops/maximum.mlir deleted file mode 100644 index 906ab9841f0d..000000000000 --- a/tests/e2e/xla_ops/maximum.mlir +++ /dev/null @@ -1,87 +0,0 @@ -func.func @tensor_i32() { - %lhs = util.unfoldable_constant dense<[1, 6, 7, 8]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 6, 3, 8]> : tensor<4xi32> - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[5, 6, 7, 8]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @tensor_odd_dim() { - %lhs = util.unfoldable_constant dense<[1, 6, 7]> : tensor<3xi32> - %rhs = util.unfoldable_constant dense<[5, 6, 3]> : tensor<3xi32> - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> - check.expect_eq_const(%result, dense<[5, 6,7]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @scalar_i32() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @negative_i32() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<-2> : tensor - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<1> : tensor) : tensor - return -} - -func.func @i8() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @i16() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @i64() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.maximum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @tensor_f32() { - %lhs = util.unfoldable_constant dense<[1.0, 2.0, 7.0, 4.0]> : tensor<4xf32> - %rhs = util.unfoldable_constant dense<[5.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar_f32() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.0> : tensor) : tensor - return -} - -func.func @double() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.0> : tensor) : tensor - return -} - -func.func @negative_f32() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<-2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<-2.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/minimum.mlir b/tests/e2e/xla_ops/minimum.mlir deleted file mode 100644 index ceb9159b0e78..000000000000 --- a/tests/e2e/xla_ops/minimum.mlir +++ /dev/null @@ -1,87 +0,0 @@ -func.func @tensor_i32() { - %lhs = util.unfoldable_constant dense<[1, 2, 7, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3, 4]> : tensor<4xi32> - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @tensor_odd_dim() { - %lhs = util.unfoldable_constant dense<[1, 2, 7]> : tensor<3xi32> - %rhs = util.unfoldable_constant dense<[5, 2, 3]> : tensor<3xi32> - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> - check.expect_eq_const(%result, dense<[1, 2, 3]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @scalar_i32() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<1> : tensor) : tensor - return -} - -func.func @negative_i32() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<-2> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<-2> : tensor) : tensor - return -} - -func.func @i8() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<1> : tensor) : tensor - return -} - -func.func @i16() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<1> : tensor) : tensor - return -} - -func.func @i64() { - %lhs = util.unfoldable_constant dense<1> : tensor - %rhs = util.unfoldable_constant dense<2> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<1> : tensor) : tensor - return -} - -func.func @tensor_f32() { - %lhs = util.unfoldable_constant dense<[1.0, 2.0, 7.0, 4.0]> : tensor<4xf32> - %rhs = util.unfoldable_constant dense<[5.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar_f32() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.0> : tensor) : tensor - return -} - -func.func @double() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<1.0> : tensor) : tensor - return -} - -func.func @negative_f32() { - %lhs = util.unfoldable_constant dense<1.0> : tensor - %rhs = util.unfoldable_constant dense<-2.0> : tensor - %result = "mhlo.minimum"(%lhs, %rhs) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<-2.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/multiply.mlir b/tests/e2e/xla_ops/multiply.mlir deleted file mode 100644 index bb31176456a5..000000000000 --- a/tests/e2e/xla_ops/multiply.mlir +++ /dev/null @@ -1,6 +0,0 @@ -func.func @multiply () { - %c2 = util.unfoldable_constant dense<2.0> : tensor - %res = "mhlo.multiply"(%c2, %c2) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%res, dense<4.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/negate.mlir b/tests/e2e/xla_ops/negate.mlir deleted file mode 100644 index 9a6ebd39e630..000000000000 --- a/tests/e2e/xla_ops/negate.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[-1.0, -2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.negate"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 2.0, -3.0, -4.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<-4.0> : tensor - %result = "mhlo.negate"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<4.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/pad.mlir b/tests/e2e/xla_ops/pad.mlir deleted file mode 100644 index a47c63671a1a..000000000000 --- a/tests/e2e/xla_ops/pad.mlir +++ /dev/null @@ -1,22 +0,0 @@ -func.func @pad_test() { - %input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - %c0 = arith.constant dense<0> : tensor - %res = "mhlo.pad"(%input, %c0) { - edge_padding_low = dense<[0, 1]> : tensor<2xi64>, - edge_padding_high = dense<[1, 5]> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> - } : (tensor<2x3xi32>, tensor) -> tensor<3x9xi32> - check.expect_eq_const(%res, dense<[ - [0, 1, 2, 3, 0, 0, 0, 0, 0], - [0, 4, 5, 6, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0]]> : tensor<3x9xi32>) : tensor<3x9xi32> - return -} - -func.func @pad_no_op() { - %input = util.unfoldable_constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - %c0 = arith.constant dense<0> : tensor - %res = "mhlo.pad"(%input, %c0) {edge_padding_high = dense<[0, 0]> : tensor<2xi64>, edge_padding_low = dense<[0, 0]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<2x3xi32>, tensor) -> tensor<2x3xi32> - check.expect_eq(%res, %input) : tensor<2x3xi32> - return -} diff --git a/tests/e2e/xla_ops/pow.mlir b/tests/e2e/xla_ops/pow.mlir deleted file mode 100644 index 376ba1cfae01..000000000000 --- a/tests/e2e/xla_ops/pow.mlir +++ /dev/null @@ -1,15 +0,0 @@ -func.func @tensor() { - %cst = mhlo.constant dense<3.0e+00> : tensor<4xf32> - %input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.power"(%input, %cst) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 8.0, 27.0, 64.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %cst = mhlo.constant dense<2.0e+00> : tensor - %input = util.unfoldable_constant dense<16.0> : tensor - %result = "mhlo.power"(%input, %cst) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<256.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/reduce.mlir b/tests/e2e/xla_ops/reduce.mlir deleted file mode 100644 index a78b595c485f..000000000000 --- a/tests/e2e/xla_ops/reduce.mlir +++ /dev/null @@ -1,360 +0,0 @@ -// Int sum values from [1, 10] -func.func @reduce_sum_1x10xi32() { - %0 = util.unfoldable_constant dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]> : tensor<1x10xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor) -> tensor<1xi32> - check.expect_eq_const(%res, dense<55> : tensor<1xi32>) : tensor<1xi32> - return -} - -// Int max values from [1, 10] -func.func @reduce_max_1x10xi32() { - %0 = util.unfoldable_constant dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]> : tensor<1x10xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor) -> tensor<1xi32> - check.expect_eq_const(%res, dense<10> : tensor<1xi32>) : tensor<1xi32> - return -} - -// Int min values, along multiple dimensions. Expected to just be a reshape in this case. -func.func @reduce_min_5x1x1xi32() { - %0 = util.unfoldable_constant dense<[[[1]],[[2]],[[3]],[[4]],[[5]]]> : tensor<5x1x1xi32> - %1 = util.unfoldable_constant dense<999> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<5x1x1xi32>, tensor) -> tensor<5xi32> - check.expect_eq_const(%res, dense<[1, 2, 3, 4, 5]> : tensor<5xi32>) : tensor<5xi32> - return -} - - -// The following cases match the examples presented at -// https://www.tensorflow.org/xla/operation_semantics#reduce - -func.func @reduce_sum_2x3xi32_dim0() { - %0 = util.unfoldable_constant dense<[ - [1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xi32>, tensor) -> tensor<3xi32> - check.expect_eq_const(%res, dense<[5, 7, 9]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @reduce_sum_2x3xi32_dim1() { - %0 = util.unfoldable_constant dense<[ - [1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor) -> tensor<2xi32> - check.expect_eq_const(%res, dense<[6, 15]> : tensor<2xi32>) : tensor<2xi32> - return -} - -func.func @reduce_sum_4x2x3xi32_dim0() { - %0 = util.unfoldable_constant dense<[ - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]]]> : tensor<4x2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x2x3xi32>, tensor) -> tensor<2x3xi32> - check.expect_eq_const(%res, dense<[[4, 8, 12],[16, 20, 24]]> : tensor<2x3xi32>) : tensor<2x3xi32> - return -} - -func.func @reduce_sum_4x2x3xi32_dim2() { - %0 = util.unfoldable_constant dense<[ - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]]]> : tensor<4x2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x2x3xi32>, tensor) -> tensor<4x2xi32> - check.expect_eq_const(%res, dense<[[6, 15],[6, 15],[6, 15],[6, 15]]> : tensor<4x2xi32>) : tensor<4x2xi32> - return -} - -func.func @reduce_sum_4x2x3xi32_dims_0_1() { - %0 = util.unfoldable_constant dense<[ - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]]]> : tensor<4x2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x2x3xi32>, tensor) -> tensor<3xi32> - check.expect_eq_const(%res, dense<[20, 28, 36]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @reduce_sum_4x2x3xi32_dims_0_1_2() { - %0 = util.unfoldable_constant dense<[ - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]], - [[1, 2, 3], [4, 5, 6]]]> : tensor<4x2x3xi32> - %1 = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x2x3xi32>, tensor) -> tensor - check.expect_eq_const(%res, dense<84> : tensor) : tensor - return -} - -// Float sum values from [1.0, 10.0] -func.func @reduce_sum_1x10xf32() { - %0 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]]> : tensor<1x10xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor) -> tensor<1xf32> - check.expect_almost_eq_const(%res, dense<55.0> : tensor<1xf32>) : tensor<1xf32> - return -} - -// Float max values from [1.0, 10.0] -func.func @reduce_max_1x10xf32() { - %0 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]]> : tensor<1x10xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) - ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) - {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor) -> tensor<1xf32> - check.expect_almost_eq_const(%res, dense<10.0> : tensor<1xf32>) : tensor<1xf32> - return -} - -// Float min values, along multiple dimensions. Expected to just be a reshape in this case. -func.func @reduce_min_5x1x1xf32() { - %0 = util.unfoldable_constant dense<[[[1.0]],[[2.0]],[[3.0]],[[4.0]],[[5.0]]]> : tensor<5x1x1xf32> - %1 = util.unfoldable_constant dense<999.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<5x1x1xf32>, tensor) -> tensor<5xf32> - check.expect_almost_eq_const(%res, dense<[1.0, 2.0, 3.0, 4.0, 5.0]> : tensor<5xf32>) : tensor<5xf32> - return -} - -// The following cases match the examples presented at -// https://www.tensorflow.org/xla/operation_semantics#reduce - -func.func @reduce_sum_2x3xf32_dim0() { - %0 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<3xf32> - check.expect_almost_eq_const(%res, dense<[5.0, 7.0, 9.0]> : tensor<3xf32>) : tensor<3xf32> - return -} - -func.func @reduce_sum_2x3xf32_dim1() { - %0 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor) -> tensor<2xf32> - check.expect_almost_eq_const(%res, dense<[6.0, 15.0]> : tensor<2xf32>) : tensor<2xf32> - return -} - -func.func @reduce_sum_4x2x3xf32_dim0() { - %0 = util.unfoldable_constant dense<[ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<4x2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor) -> tensor<2x3xf32> - check.expect_almost_eq_const(%res, dense<[[4.0, 8.0, 12.0],[16.0, 20.0, 24.0]]> : tensor<2x3xf32>) : tensor<2x3xf32> - return -} - -func.func @reduce_sum_4x2x3xf32_dim1() { - %0 = util.unfoldable_constant dense<[ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<4x2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor) -> tensor<4x3xf32> - check.expect_almost_eq_const(%res, dense<[ - [5.0, 7.0, 9.0], - [5.0, 7.0, 9.0], - [5.0, 7.0, 9.0], - [5.0, 7.0, 9.0]]> : tensor<4x3xf32>) : tensor<4x3xf32> - return -} - -func.func @reduce_sum_4x2x3xf32_dim2() { - %0 = util.unfoldable_constant dense<[ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<4x2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<2> : tensor<1xi64>} : (tensor<4x2x3xf32>, tensor) -> tensor<4x2xf32> - check.expect_almost_eq_const(%res, dense<[ - [6.0, 15.0], - [6.0, 15.0], - [6.0, 15.0], - [6.0, 15.0]]> : tensor<4x2xf32>) : tensor<4x2xf32> - return -} - -func.func @reduce_sum_4x2x3xf32_dims_0_1() { - %0 = util.unfoldable_constant dense<[ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<4x2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x2x3xf32>, tensor) -> tensor<3xf32> - check.expect_almost_eq_const(%res, dense<[20.0, 28.0, 36.0]> : tensor<3xf32>) : tensor<3xf32> - return -} - -func.func @reduce_sum_4x2x3xf32_dims_0_1_2() { - %0 = util.unfoldable_constant dense<[ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]> : tensor<4x2x3xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x2x3xf32>, tensor) -> tensor - check.expect_almost_eq_const(%res, dense<84.0> : tensor) : tensor - return -} - -func.func @reducemulti_result() { - %cst0 = mhlo.constant dense<-2147483648> : tensor - %cst1 = mhlo.constant dense<0> : tensor - %arg0 = util.unfoldable_constant dense<[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 18]]> : tensor<9x2xi32> - %arg1 = util.unfoldable_constant dense<[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15], [16, 17]]> : tensor<9x2xi32> - %res0, %res1 = "mhlo.reduce"(%arg0, %arg1, %cst0, %cst1) ( { - ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): // no predecessors - %0 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %1 = "mhlo.select"(%0, %arg2, %arg4) : (tensor, tensor, tensor) -> tensor - %2 = "mhlo.compare"(%arg2, %arg4) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %3 = mhlo.minimum %arg3, %arg5 : tensor - %4 = "mhlo.select"(%0, %arg3, %arg5) : (tensor, tensor, tensor) -> tensor - %5 = "mhlo.select"(%2, %3, %4) : (tensor, tensor, tensor) -> tensor - "mhlo.return"(%1, %5) : (tensor, tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<9x2xi32>, tensor<9x2xi32>, tensor, tensor) -> (tensor<2xi32>, tensor<2xi32>) - check.expect_eq_const(%res0, dense<[17, 18]> : tensor<2xi32>) : tensor<2xi32> - check.expect_eq_const(%res1, dense<[16, 17]> : tensor<2xi32>) : tensor<2xi32> - return -} - -func.func @reduce_dim_1() { - %0 = util.unfoldable_constant dense<[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]> : tensor<2x5xi32> - %1 = util.unfoldable_constant dense<10> : tensor - %2 = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0 : tensor, %arg1 : tensor): - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x5xi32>, tensor) -> tensor<2xi32> - check.expect_eq_const(%2, dense<[25, 50]> : tensor<2xi32>) : tensor<2xi32> - return -} - -// Constants get folded in which linalg.indexed_generic ops. Check to -// make sure this works as expected. -func.func @reduce_dim_1_const() { - %0 = util.unfoldable_constant dense<[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]> : tensor<2x5xi32> - %1 = arith.constant dense<10> : tensor - %2 = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0 : tensor, %arg1 : tensor): - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x5xi32>, tensor) -> tensor<2xi32> - check.expect_eq_const(%2, dense<[25, 50]> : tensor<2xi32>) : tensor<2xi32> - return -} - -func.func @reduce_dim_0() { - %0 = util.unfoldable_constant dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]> : tensor<1x10xi32> - %1 = util.unfoldable_constant dense<10> : tensor - %2 = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0 : tensor, %arg1 : tensor): - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xi32>, tensor) -> tensor<1xi32> - check.expect_eq_const(%2, dense<[65]> : tensor<1xi32>) : tensor<1xi32> - return -} - -func.func @reduce_to_scalar() { - %0 = util.unfoldable_constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]> : tensor<10xi32> - %1 = util.unfoldable_constant dense<10> : tensor - %2 = "mhlo.reduce"(%0, %1) ( { - ^bb0(%arg0 : tensor, %arg1 : tensor): - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<10xi32>, tensor) -> tensor - check.expect_eq_const(%2, dense<65> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/reduce_window.mlir b/tests/e2e/xla_ops/reduce_window.mlir deleted file mode 100644 index 09dd5b596833..000000000000 --- a/tests/e2e/xla_ops/reduce_window.mlir +++ /dev/null @@ -1,98 +0,0 @@ -func.func @reduce_window_nonoverlapping_1x4x6x1xf32() { - %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]], - [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]], - [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]], - [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce_window"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor) -> tensor<1x2x2x1xf32> - check.expect_eq_const(%res, dense<[[[[30.0], [48.0]],[[102.0], [120.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32> - return -} - -func.func @reduce_window_overlapping_4x6xf32() { - %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]], - [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]], - [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]], - [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce_window"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, - window_strides = dense<[1, 1, 1, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor) -> tensor<1x3x4x1xf32> - check.expect_eq_const(%res, dense<[[ - [[ 30.0], [ 36.0], [ 42.0], [ 48.0]], - [[ 66.0], [ 72.0], [ 78.0], [ 84.0]], - [[102.0], [108.0], [114.0], [120.0]]]]> : tensor<1x3x4x1xf32>) : tensor<1x3x4x1xf32> - return -} - -func.func @reduce_window_max_4x6xf32() { - %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]], - [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]], - [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]], - [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce_window"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor) -> tensor<1x2x2x1xf32> - check.expect_almost_eq_const(%res, dense<[[[[9.0], [12.0]], [[21.0], [24.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32> - return -} - -func.func @reduce_window_min_4x6xf32() { - %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]], - [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]], - [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]], - [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32> - %1 = util.unfoldable_constant dense<14.0> : tensor - %res = "mhlo.reduce_window"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>} : (tensor<1x4x6x1xf32>, tensor) -> tensor<1x2x2x1xf32> - check.expect_almost_eq_const(%res, dense<[[[[1.0], [4.0]], [[13.0], [14.0]]]]> : tensor<1x2x2x1xf32>) : tensor<1x2x2x1xf32> - return -} - -func.func @reduce_window_max_with_padding_4x6xf32() { - %0 = util.unfoldable_constant dense<[[[[ 1.0], [ 2.0], [ 3.0], [ 4.0], [ 5.0], [ 6.0]], - [[ 7.0], [ 8.0], [ 9.0], [10.0], [11.0], [12.0]], - [[13.0], [14.0], [15.0], [16.0], [17.0], [18.0]], - [[19.0], [20.0], [21.0], [22.0], [23.0], [24.0]]]]> : tensor<1x4x6x1xf32> - %1 = util.unfoldable_constant dense<0.0> : tensor - %res = "mhlo.reduce_window"(%0, %1) ( { - ^bb0(%arg0: tensor, %arg1: tensor): // no predecessors - %3 = "mhlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor - "mhlo.return"(%3) : (tensor) -> () - }) {window_dimensions = dense<[1, 2, 3, 1]> : tensor<4xi64>, - window_strides = dense<[1, 2, 3, 1]> : tensor<4xi64>, - padding = dense<[[0, 0], [1, 1], [0, 0], [0, 0]]> : tensor<4x2xi64>} : (tensor<1x4x6x1xf32>, tensor) -> tensor<1x3x2x1xf32> - check.expect_almost_eq_const(%res, dense<[[[[3.0], [6.0]], [[15.0], [18.0]], [[21.0], [24.0]]]]> : tensor<1x3x2x1xf32>) : tensor<1x3x2x1xf32> - return -} - -func.func @cumsum_f32() { - %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = util.unfoldable_constant dense<1.0> : tensor<2x2x2xf32> - %res = "mhlo.reduce_window"(%1, %0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %4 = mhlo.add %arg1, %arg2 : tensor - "mhlo.return"(%4) : (tensor) -> () - }) {padding = dense<[[1, 0], [0, 0], [0, 0]]> : tensor<3x2xi64>, - window_dimensions = dense<[2, 1, 1]> : tensor<3xi64>, - window_strides = dense<1> : tensor<3xi64> - } : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> - check.expect_almost_eq_const(%res, dense<[[[1.0, 1.0], [1.0, 1.0]], [[2.0, 2.0], [2.0, 2.0]]]> : tensor<2x2x2xf32>) : tensor<2x2x2xf32> - return -} diff --git a/tests/e2e/xla_ops/remainder.mlir b/tests/e2e/xla_ops/remainder.mlir deleted file mode 100644 index b225bc17d919..000000000000 --- a/tests/e2e/xla_ops/remainder.mlir +++ /dev/null @@ -1,63 +0,0 @@ -func.func @scalar() { - %input1 = util.unfoldable_constant dense<16.0> : tensor - %input2 = util.unfoldable_constant dense<7.0> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.0> : tensor) : tensor - return -} - -func.func @tensor() { - %input1 = util.unfoldable_constant dense<[16.0, 17.0, 18.0]> : tensor<3xf32> - %input2 = util.unfoldable_constant dense<[7.0, 8.0, 9.0]> : tensor<3xf32> - %result = "mhlo.remainder"(%input1, %input2) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> - check.expect_almost_eq_const(%result, dense<[2.0, 1.0, 0.0]> : tensor<3xf32>) : tensor<3xf32> - return -} - -func.func @negative_den() { - %input1 = util.unfoldable_constant dense<16.0> : tensor - %input2 = util.unfoldable_constant dense<-7.0> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<2.0> : tensor) : tensor - return -} - -func.func @negative_num() { - %input1 = util.unfoldable_constant dense<-16.0> : tensor - %input2 = util.unfoldable_constant dense<7.0> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_almost_eq_const(%result, dense<-2.0> : tensor) : tensor - return -} - -func.func @scalar_int() { - %input1 = util.unfoldable_constant dense<16> : tensor - %input2 = util.unfoldable_constant dense<7> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @tensor_int() { - %input1 = util.unfoldable_constant dense<[16, 17, 18]> : tensor<3xi32> - %input2 = util.unfoldable_constant dense<[7, 8, 9]> : tensor<3xi32> - %result = "mhlo.remainder"(%input1, %input2) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> - check.expect_eq_const(%result, dense<[2, 1, 0]> : tensor<3xi32>) : tensor<3xi32> - return -} - -func.func @negative_den_int() { - %input1 = util.unfoldable_constant dense<16> : tensor - %input2 = util.unfoldable_constant dense<-7> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<2> : tensor) : tensor - return -} - -func.func @negative_num_int() { - %input1 = util.unfoldable_constant dense<-16> : tensor - %input2 = util.unfoldable_constant dense<7> : tensor - %result = "mhlo.remainder"(%input1, %input2) : (tensor, tensor) -> tensor - check.expect_eq_const(%result, dense<-2> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/reshape.mlir b/tests/e2e/xla_ops/reshape.mlir deleted file mode 100644 index cf0a451eacb8..000000000000 --- a/tests/e2e/xla_ops/reshape.mlir +++ /dev/null @@ -1,32 +0,0 @@ -func.func @reshape_1D_2D() { - %input = util.unfoldable_constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> : tensor<12xi32> - %result = "mhlo.reshape"(%input) : (tensor<12xi32>) -> tensor<3x4xi32> - check.expect_eq_const(%result, dense<[ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12]]> : tensor<3x4xi32>) : tensor<3x4xi32> - return -} - -// func.func @reshape_1D_3D() { -// %input = util.unfoldable_constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> : tensor<12xi32> -// %result = "mhlo.reshape"(%input) : (tensor<12xi32>) -> tensor<2x2x3xi32> -// check.expect_eq_const(%result, dense<[ -// [[1, 2, 3], [4, 5, 6]], -// [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32> -// return -// } - -// func.func @reshape_2D_3D() { -// %input = util.unfoldable_constant dense<[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]> : tensor<2x6xi32> -// %result = "mhlo.reshape"(%input) : (tensor<2x6xi32>) -> tensor<2x1x6xi32> -// check.expect_eq_const(%result, dense<[[[1, 2, 3, 4, 5, 6]], [[7, 8, 9, 10, 11, 12]]]> : tensor<2x1x6xi32>) : tensor<2x1x6xi32> -// return -// } - -// func.func @reshape_3D_1D() { -// %input = util.unfoldable_constant dense<[[[1, 2, 3, 4, 5, 6]], [[7, 8, 9, 10, 11, 12]]]> : tensor<2x1x6xi32> -// %result = "mhlo.reshape"(%input) : (tensor<2x1x6xi32>) -> tensor<2x6xi32> -// check.expect_eq_const(%result, dense<[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]> : tensor<2x6xi32>) : tensor<2x6xi32> -// return -// } diff --git a/tests/e2e/xla_ops/reverse.mlir b/tests/e2e/xla_ops/reverse.mlir deleted file mode 100644 index 57328e1e42f3..000000000000 --- a/tests/e2e/xla_ops/reverse.mlir +++ /dev/null @@ -1,22 +0,0 @@ -func.func @xla_reverse() { - %t1 = util.unfoldable_constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> - - %dim0 = "mhlo.reverse"(%t1) {dimensions = dense<0> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> - check.expect_almost_eq_const( - %dim0, - dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32> - ) : tensor<2x3xf32> - - %dim1 = "mhlo.reverse"(%t1) {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> - check.expect_almost_eq_const( - %dim1, - dense<[[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]> : tensor<2x3xf32> - ) : tensor<2x3xf32> - - %both_dims = "mhlo.reverse"(%t1) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<2x3xf32> - check.expect_almost_eq_const( - %both_dims, - dense<[[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]> : tensor<2x3xf32> - ) : tensor<2x3xf32> - return -} diff --git a/tests/e2e/xla_ops/rng_normal.mlir b/tests/e2e/xla_ops/rng_normal.mlir deleted file mode 100644 index 3711ca79a923..000000000000 --- a/tests/e2e/xla_ops/rng_normal.mlir +++ /dev/null @@ -1,11 +0,0 @@ -func.func @rng_normal_2d() { - %mu = util.unfoldable_constant dense<0.0> : tensor - %sigma = util.unfoldable_constant dense<1.0> : tensor - %shape = util.unfoldable_constant dense<[3, 5]> : tensor<2xi64> - %res = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor<3x5xf32> - check.expect_almost_eq_const(%res, - dense<[[0.570861, 0.317593, -0.726538, 1.45925, -1.59632], - [-0.639956, 0.703875, -0.8801, -0.848389, -0.453391], - [0.645563, 0.543174, 0.2255, 0.0809385, -1.17198]]> : tensor<3x5xf32>) : tensor<3x5xf32> - return -} diff --git a/tests/e2e/xla_ops/rng_uniform.mlir b/tests/e2e/xla_ops/rng_uniform.mlir deleted file mode 100644 index 8f1bb836033d..000000000000 --- a/tests/e2e/xla_ops/rng_uniform.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// Note that they are stateless random generators, so they have fixed results. -func.func @rng_uniform_1d() { - %min = util.unfoldable_constant dense<-10.0> : tensor - %max = util.unfoldable_constant dense<10.0> : tensor - %shape = util.unfoldable_constant dense<[10]> : tensor<1xi32> - %res = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<1xi32>) -> tensor<10xf32> - check.expect_almost_eq_const(%res, dense<[ - -9.99994, -4.8613, 0.277344, 5.41599, -9.44537, -4.30673, 0.831918, 5.97056, -8.8908, -3.75215 - ]> : tensor<10xf32>) : tensor<10xf32> - return -} - -func.func @rng_uniform_2d() { - %min = util.unfoldable_constant dense<-10.0> : tensor - %max = util.unfoldable_constant dense<10.0> : tensor - %shape = util.unfoldable_constant dense<[3, 3]> : tensor<2xi32> - %res = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi32>) -> tensor<3x3xf32> - check.expect_almost_eq_const(%res, dense<[ - [6.55154, -8.30982, -3.17117], - [1.75741, 6.89606, -7.9653], - [-3.03671, 2.10193, 7.24057]]> : tensor<3x3xf32>) : tensor<3x3xf32> - return -} - -func.func @rng_uniform_3d() { - %min = util.unfoldable_constant dense<-10.0> : tensor - %max = util.unfoldable_constant dense<10.0> : tensor - %shape = util.unfoldable_constant dense<[2, 2, 2]> : tensor<3xi32> - %res = "mhlo.rng"(%min, %max, %shape) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<3xi32>) -> tensor<2x2x2xf32> - check.expect_almost_eq_const(%res, dense<[ - [[3.04814, 8.18679], [-1.74598, 3.39266]], - [[-6.91349, -1.77484], [8.29239, -6.56897]]]> : tensor<2x2x2xf32>) : tensor<2x2x2xf32> - return -} diff --git a/tests/e2e/xla_ops/round.mlir b/tests/e2e/xla_ops/round.mlir deleted file mode 100644 index 29e5bc42786c..000000000000 --- a/tests/e2e/xla_ops/round.mlir +++ /dev/null @@ -1,7 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[-0.7, -0.5, -0.2, 0.0, 0.2, 0.5, 0.7]> : tensor<7xf32> - %result = "mhlo.round_nearest_afz"(%input) : (tensor<7xf32>) -> tensor<7xf32> - check.expect_almost_eq_const(%result, dense<[-1.0, -1.0, 0.0, 0.0, 0.0, 1.0, 1.0]> : tensor<7xf32>) : tensor<7xf32> - return -} - diff --git a/tests/e2e/xla_ops/rsqrt.mlir b/tests/e2e/xla_ops/rsqrt.mlir deleted file mode 100644 index 595bfc67052d..000000000000 --- a/tests/e2e/xla_ops/rsqrt.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.rsqrt"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 0.707107, 0.57735, 0.5]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<16.0> : tensor - %result = "mhlo.rsqrt"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<0.25> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/scatter.mlir b/tests/e2e/xla_ops/scatter.mlir deleted file mode 100644 index e994f1dc6659..000000000000 --- a/tests/e2e/xla_ops/scatter.mlir +++ /dev/null @@ -1,238 +0,0 @@ -func.func @scatter_update_scalar_1D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<8xi32> - %arg1 = util.unfoldable_constant dense<[[1], [3], [4], [7]]> : tensor<4x1xi32> - %arg2 = util.unfoldable_constant dense<[9, 10, 11, 12]> : tensor<4xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32> - check.expect_eq_const(%0, dense<[0, 9, 0, 10, 11, 0, 0, 12]> : tensor<8xi32>) : tensor<8xi32> - return -} - -func.func @scatter_repeated_update_scalar_1D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<8xi32> - %arg1 = util.unfoldable_constant dense<[[1], [1], [7], [7]]> : tensor<4x1xi32> - %arg2 = util.unfoldable_constant dense<[9, 10, 11, 12]> : tensor<4xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<8xi32>, tensor<4x1xi32>, tensor<4xi32>) -> tensor<8xi32> - check.expect_eq_const(%0, dense<[0, 10, 0, 0, 0, 0, 0, 12]> : tensor<8xi32>) : tensor<8xi32> - return -} - -func.func @scatter_update_scalar_2D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<4x3xi32> - %arg1 = util.unfoldable_constant dense<[[0, 0], [1, 1], [2, 2]]> : tensor<3x2xi32> - %arg2 = util.unfoldable_constant dense<[1, 2, 3]> : tensor<3xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "mhlo.return"(%arg4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0, 1], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 1 - >, - unique_indices = true - } : (tensor<4x3xi32>, tensor<3x2xi32>, tensor<3xi32>) -> tensor<4x3xi32> - check.expect_eq_const(%0, dense<[[1, 0, 0], - [0, 2, 0], - [0, 0, 3], - [0, 0, 0]]> : tensor<4x3xi32>) : tensor<4x3xi32> - return -} - -func.func @scatter_update_slice_2D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<6x3xi32> - %arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32> - %arg2 = util.unfoldable_constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - check.expect_eq_const(%0, dense<[[0, 0, 0], - [0, 0, 0], - [1, 2, 3], - [0, 0, 0], - [4, 5, 6], - [0, 0, 0]]> : tensor<6x3xi32>) : tensor<6x3xi32> - return -} - -func.func @scatter_update_slice_partial_2D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<6x3xi32> - %arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32> - %arg2 = util.unfoldable_constant dense<[[1, 2], - [4, 5]]> : tensor<2x2xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x2xi32>) -> tensor<6x3xi32> - check.expect_eq_const(%0, dense<[[0, 0, 0], - [0, 0, 0], - [1, 2, 0], - [0, 0, 0], - [4, 5, 0], - [0, 0, 0]]> : tensor<6x3xi32>) : tensor<6x3xi32> - return -} - -func.func @scatter_add_slice_2D() { - %arg0 = util.unfoldable_constant dense<1> : tensor<6x3xi32> - %arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32> - %arg2 = util.unfoldable_constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> - check.expect_eq_const(%0, dense<[[1, 1, 1], - [1, 1, 1], - [2, 3, 4], - [1, 1, 1], - [5, 6, 7], - [1, 1, 1]]> : tensor<6x3xi32>) : tensor<6x3xi32> - return -} - -func.func @scatter_1D_large() { - %original = util.unfoldable_constant dense<1> : tensor<1400xi32> - %update = util.unfoldable_constant dense<2> : tensor<1400xi32> - %init = tensor.empty() : tensor<1400xi32> - %indices = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - outs(%init : tensor<1400xi32>) { - ^bb0(%arg0: i32): - %0 = linalg.index 0 : index - %1 = arith.index_cast %0 : index to i32 - linalg.yield %1 : i32 - } -> tensor<1400xi32> - %indices_reshaped = tensor.expand_shape %indices [[0, 1]] : - tensor<1400xi32> into tensor<1400x1xi32> - %result = "mhlo.scatter"(%original, %indices_reshaped, %update)({ - ^bb0(%arg3 : tensor, %arg4 : tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<1400xi32>, tensor<1400x1xi32>, tensor<1400xi32>) -> tensor<1400xi32> - check.expect_eq_const(%result, dense<2> : tensor<1400xi32>) : tensor<1400xi32> - return -} - -func.func @scatter_2D_large() { - %original = util.unfoldable_constant dense<1> : tensor<200x300xi32> - %update = util.unfoldable_constant dense<2> : tensor<200x300xi32> - %init = tensor.empty() : tensor<200xi32> - %indices = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - outs(%init : tensor<200xi32>) { - ^bb0(%arg0: i32): - %0 = linalg.index 0 : index - %1 = arith.index_cast %0 : index to i32 - linalg.yield %1 : i32 - } -> tensor<200xi32> - %indices_reshaped = tensor.expand_shape %indices [[0, 1]] : - tensor<200xi32> into tensor<200x1xi32> - %result = "mhlo.scatter"(%original, %indices_reshaped, %update)({ - ^bb0(%arg3 : tensor, %arg4 : tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<200x300xi32>, tensor<200x1xi32>, tensor<200x300xi32>) -> tensor<200x300xi32> - check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32> - return -} - -func.func @scatter_2D_large_permuted() { - %original = util.unfoldable_constant dense<1> : tensor<200x300xi32> - %update = util.unfoldable_constant dense<2> : tensor<300x200xi32> - %init = tensor.empty() : tensor<300xi32> - %indices = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - outs(%init : tensor<300xi32>) { - ^bb0(%arg0: i32): - %0 = linalg.index 0 : index - %1 = arith.index_cast %0 : index to i32 - linalg.yield %1 : i32 - } -> tensor<300xi32> - %indices_reshaped = tensor.expand_shape %indices [[0, 1]] : - tensor<300xi32> into tensor<300x1xi32> - %result = "mhlo.scatter"(%original, %indices_reshaped, %update)({ - ^bb0(%arg3 : tensor, %arg4 : tensor): - "mhlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [1], - scatter_dims_to_operand_dims = [1], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<200x300xi32>, tensor<300x1xi32>, tensor<300x200xi32>) -> tensor<200x300xi32> - check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32> - return -} diff --git a/tests/e2e/xla_ops/scatter_dynamic.mlir b/tests/e2e/xla_ops/scatter_dynamic.mlir deleted file mode 100644 index a7bb913a0154..000000000000 --- a/tests/e2e/xla_ops/scatter_dynamic.mlir +++ /dev/null @@ -1,28 +0,0 @@ -func.func @scatter_add_slice_2D_dynamic_num_updates() { - %arg0 = util.unfoldable_constant dense<1> : tensor<6x3xi32> - %arg1 = flow.tensor.constant dense<[[2], [4]]> : tensor<2x1xi32> -> tensor - %arg2 = flow.tensor.constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> -> tensor - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %1 = mhlo.add %arg3, %arg4 : tensor - "mhlo.return"(%1) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<6x3xi32>, tensor, tensor) -> tensor<6x3xi32> - check.expect_eq_const(%0, dense<[[1, 1, 1], - [1, 1, 1], - [2, 3, 4], - [1, 1, 1], - [5, 6, 7], - [1, 1, 1]]> : tensor<6x3xi32>) : tensor<6x3xi32> - return -} - diff --git a/tests/e2e/xla_ops/select.mlir b/tests/e2e/xla_ops/select.mlir deleted file mode 100644 index 20570591f2fb..000000000000 --- a/tests/e2e/xla_ops/select.mlir +++ /dev/null @@ -1,10 +0,0 @@ -func.func @select() { - %input = util.unfoldable_constant dense<[1, 0, 1, 0]> : tensor<4xi1> - %zeros = util.unfoldable_constant dense<0> : tensor<4xi1> - %cond = "mhlo.compare"(%input, %zeros) {comparison_direction = #mhlo} : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> - %lhs = util.unfoldable_constant dense<[1, 2, 3, 4]> : tensor<4xi32> - %rhs = util.unfoldable_constant dense<[5, 6, 7, 8]> : tensor<4xi32> - %result = "mhlo.select"(%cond, %lhs, %rhs) : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[1,6, 3, 8]> : tensor<4xi32>) : tensor<4xi32> - return -} diff --git a/tests/e2e/xla_ops/sine.mlir b/tests/e2e/xla_ops/sine.mlir deleted file mode 100644 index 8d697405bae1..000000000000 --- a/tests/e2e/xla_ops/sine.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[0.0, 1.0, 1.5, 2.0]> : tensor<4xf32> - %result = "mhlo.sine"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[0.0, 0.8415, 0.9975, 0.9093]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<3.0> : tensor - %result = "mhlo.sine"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<0.14112> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/slice.mlir b/tests/e2e/xla_ops/slice.mlir deleted file mode 100644 index 27a6875d1ae0..000000000000 --- a/tests/e2e/xla_ops/slice.mlir +++ /dev/null @@ -1,60 +0,0 @@ -func.func @slice_whole_buffer() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %result = "mhlo.slice"(%input) { - start_indices = dense<[0, 0]> : tensor<2xi64>, - limit_indices = dense<[3, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } : (tensor<3x4xi32>) -> tensor<3x4xi32> - check.expect_eq_const(%result, dense<[ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12]]> : tensor<3x4xi32>) : tensor<3x4xi32> - return -} - -func.func @slice_whole_stride() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %result = "mhlo.slice"(%input) { - start_indices = dense<[1, 0]> : tensor<2xi64>, - limit_indices = dense<[2, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } : (tensor<3x4xi32>) -> tensor<1x4xi32> - check.expect_eq_const(%result, dense<[[5, 6, 7, 8]]> : tensor<1x4xi32>) : tensor<1x4xi32> - return -} - -func.func @slice_stride_part() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %result = "mhlo.slice"(%input) { - start_indices = dense<[1, 1]> : tensor<2xi64>, - limit_indices = dense<[2, 3]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } : (tensor<3x4xi32>) -> tensor<1x2xi32> - check.expect_eq_const(%result, dense<[[6, 7]]> : tensor<1x2xi32>) : tensor<1x2xi32> - return -} - -func.func @slice_multi_stride() { - %input = util.unfoldable_constant dense<[ - [01, 02, 03, 04], - [05, 06, 07, 08], - [09, 10, 11, 12]]> : tensor<3x4xi32> - %result = "mhlo.slice"(%input) { - start_indices = dense<[1, 0]> : tensor<2xi64>, - limit_indices = dense<[3, 4]> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64> - } : (tensor<3x4xi32>) -> tensor<2x4xi32> - check.expect_eq_const(%result, dense<[ - [5, 6, 7, 8], - [9, 10, 11, 12]]> : tensor<2x4xi32>) : tensor<2x4xi32> - return -} diff --git a/tests/e2e/xla_ops/sort.mlir b/tests/e2e/xla_ops/sort.mlir deleted file mode 100644 index 81a5e1d9d5b1..000000000000 --- a/tests/e2e/xla_ops/sort.mlir +++ /dev/null @@ -1,53 +0,0 @@ -func.func @sort1D() { - %input = util.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32> - - %sort = "mhlo.sort"(%input) ( { - ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%compare) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32> - - check.expect_eq_const(%sort, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @sort2D() { - %input = util.unfoldable_constant dense<[[1, 2, 3, 4], - [4, 3, 2, 1]]> : tensor<2x4xi32> - - %sort = "mhlo.sort"(%input) ( { - ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%compare) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = false} : (tensor<2x4xi32>) -> tensor<2x4xi32> - - check.expect_eq_const(%sort, dense<[[1, 2, 3, 4], [1, 2, 3, 4]]> : tensor<2x4xi32>) : tensor<2x4xi32> - return -} - -func.func @sort3D() { - %input = util.unfoldable_constant dense<[[[1, 2, 3, 4], - [4, 3, 2, 1]]]> : tensor<1x2x4xi32> - - %sort = "mhlo.sort"(%input) ( { - ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%compare) : (tensor) -> () - }) {dimension = 2 : i64, is_stable = false} : (tensor<1x2x4xi32>) -> tensor<1x2x4xi32> - - check.expect_eq_const(%sort, dense<[[[1, 2, 3, 4], [1, 2, 3, 4]]]> : tensor<1x2x4xi32>) : tensor<1x2x4xi32> - return -} - -func.func @sort_to_decreasing_seq() { - %input = util.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32> - - %sort = "mhlo.sort"(%input) ( { - ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors - %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%compare) : (tensor) -> () - }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32> - - check.expect_eq_const(%sort, dense<[4, 3, 2, 1]> : tensor<4xi32>) : tensor<4xi32> - return -} diff --git a/tests/e2e/xla_ops/sqrt.mlir b/tests/e2e/xla_ops/sqrt.mlir deleted file mode 100644 index 54f1256b24a2..000000000000 --- a/tests/e2e/xla_ops/sqrt.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @tensor() { - %input = util.unfoldable_constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> - %result = "mhlo.sqrt"(%input) : (tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[1.0, 1.4142, 1.7321, 2.0]> : tensor<4xf32>) : tensor<4xf32> - return -} - -func.func @scalar() { - %input = util.unfoldable_constant dense<16.0> : tensor - %result = "mhlo.sqrt"(%input) : (tensor) -> tensor - check.expect_almost_eq_const(%result, dense<4.0> : tensor) : tensor - return -} diff --git a/tests/e2e/xla_ops/subtract.mlir b/tests/e2e/xla_ops/subtract.mlir deleted file mode 100644 index 4d6aa795bae9..000000000000 --- a/tests/e2e/xla_ops/subtract.mlir +++ /dev/null @@ -1,15 +0,0 @@ -func.func @i32() { - %0 = util.unfoldable_constant dense<[5, 6, 3, 4]> : tensor<4xi32> - %1 = util.unfoldable_constant dense<[1, 4, 7, 6]> : tensor<4xi32> - %result = "mhlo.subtract"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - check.expect_eq_const(%result, dense<[4, 2, -4, -2]> : tensor<4xi32>) : tensor<4xi32> - return -} - -func.func @f32() { - %0 = util.unfoldable_constant dense<[5.0, 6.0, 3.0, 4.0]> : tensor<4xf32> - %1 = util.unfoldable_constant dense<[1.0, 4.0, 7.0, 6.0]> : tensor<4xf32> - %result = "mhlo.subtract"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - check.expect_almost_eq_const(%result, dense<[4.0, 2.0, -4.0, -2.0]> : tensor<4xf32>) : tensor<4xf32> - return -} diff --git a/tests/e2e/xla_ops/tanh.mlir b/tests/e2e/xla_ops/tanh.mlir deleted file mode 100644 index db15e8c7ffde..000000000000 --- a/tests/e2e/xla_ops/tanh.mlir +++ /dev/null @@ -1,10 +0,0 @@ -func.func @tanh() { - %input = util.unfoldable_constant dense< - [[-100.0, -5.0, -0.5, 1.0], - [ 1.2, 2.0, 3.0, 100.0]]> : tensor<2x4xf32> - %result = "mhlo.tanh"(%input) : (tensor<2x4xf32>) -> tensor<2x4xf32> - check.expect_almost_eq_const(%result, dense< - [[-1.0000, -0.9999, -0.4622, 0.7616], - [ 0.8337, 0.9640, 0.9951, 1.0000]]> : tensor<2x4xf32>) : tensor<2x4xf32> - return -} diff --git a/tests/e2e/xla_ops/torch_index_select.mlir b/tests/e2e/xla_ops/torch_index_select.mlir deleted file mode 100644 index 6284fece939d..000000000000 --- a/tests/e2e/xla_ops/torch_index_select.mlir +++ /dev/null @@ -1,45 +0,0 @@ -func.func @torch_select_index_0() { - %input = util.unfoldable_constant dense<[ - [[01, 02, 03, 04, 05]], - [[06, 07, 08, 09, 10]], - [[11, 12, 13, 14, 15]], - [[16, 17, 18, 19, 20]], - [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32> - %indices = util.unfoldable_constant dense<[0, 2]> : tensor<2xi32> - %res = "mhlo.torch_index_select"(%input, %indices) { - dim = 0 : i64, - batch_dims = 0 : i64 - } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> - check.expect_eq_const(%res, dense<[[[01, 02, 03, 04, 05]], [[11, 12, 13, 14, 15]]]> : tensor<2x1x5xi32>) : tensor<2x1x5xi32> - return -} - -func.func @torch_select_index_1() { - %input = util.unfoldable_constant dense<[ - [[ 1, 2],[ 3, 4]], - [[ 5, 6],[ 7, 8]], - [[ 9, 10],[11, 12]]]> : tensor<3x2x2xi32> - %indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32> - %res = "mhlo.torch_index_select"(%input, %indices) { - dim = 1 : i64, - batch_dims = 0 : i64 - } : (tensor<3x2x2xi32>, tensor<2xi32>) -> tensor<3x2x2xi32> - check.expect_eq_const(%res, dense<[[[1, 2], [3, 4]], [[5, 6], [7, 8]],[[9, 10], [11, 12]]]> : tensor<3x2x2xi32>) : tensor<3x2x2xi32> - return -} - -func.func @torch_select_index_2() { - %input = util.unfoldable_constant dense<[ - [[01, 02, 03, 04, 05]], - [[06, 07, 08, 09, 10]], - [[11, 12, 13, 14, 15]], - [[16, 17, 18, 19, 20]], - [[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32> - %indices = util.unfoldable_constant dense<0> : tensor - %res = "mhlo.torch_index_select"(%input, %indices) { - dim = 0 : i64, - batch_dims = 0 : i64 - } : (tensor<5x1x5xi32>, tensor) -> tensor<1x5xi32> - check.expect_eq_const(%res, dense<[[01, 02, 03, 04, 05]]> : tensor<1x5xi32>) : tensor<1x5xi32> - return -} diff --git a/tests/e2e/xla_ops/transpose.mlir b/tests/e2e/xla_ops/transpose.mlir deleted file mode 100644 index 34e769d97f9b..000000000000 --- a/tests/e2e/xla_ops/transpose.mlir +++ /dev/null @@ -1,29 +0,0 @@ -func.func @transpose_2d() { - %input = util.unfoldable_constant dense<[[1, 2, 3], - [4, 5, 6]]> : tensor<2x3xi32> - %0 = "mhlo.transpose"(%input) { - permutation = dense<[1, 0]> : tensor<2xi64> - } : (tensor<2x3xi32>) -> tensor<3x2xi32> - check.expect_eq_const(%0, dense<[[1, 4], - [2, 5], - [3, 6]]> : tensor<3x2xi32>) : tensor<3x2xi32> - return -} - -func.func @transpose_3d() { - %input = util.unfoldable_constant dense<[[[ 1, 2, 3], - [ 4, 5, 6]], - [[ 7, 8, 9], - [10, 11, 12]]]> : tensor<2x2x3xi32> - %0 = "mhlo.transpose"(%input) { - permutation = dense<[0, 2, 1]> : tensor<3xi64> - } : (tensor<2x2x3xi32>) -> tensor<2x3x2xi32> - check.expect_eq_const(%0, dense<[ - [[ 1, 4], - [ 2, 5], - [ 3, 6]], - [[ 7, 10], - [ 8, 11], - [ 9, 12]]]> : tensor<2x3x2xi32>) : tensor<2x3x2xi32> - return -} diff --git a/tests/e2e/xla_ops/while.mlir b/tests/e2e/xla_ops/while.mlir deleted file mode 100644 index c079172fa04c..000000000000 --- a/tests/e2e/xla_ops/while.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// NOTE: this has already been legalized to CFG form in the TF import tools. -func.func @while() { - %start = util.unfoldable_constant dense<1> : tensor - %bound = util.unfoldable_constant dense<3> : tensor - %cst_1 = arith.constant dense<4> : tensor - cf.br ^bb1(%start : tensor) -^bb1(%2: tensor): - %3 = "mhlo.compare"(%2, %bound) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - %4 = tensor.extract %3[] : tensor - cf.cond_br %4, ^bb2(%2 : tensor), ^bb3(%2 : tensor) -^bb2(%5: tensor): - %6 = mhlo.add %5, %5 : tensor - cf.br ^bb1(%6 : tensor) -^bb3(%7: tensor): - check.expect_eq_const(%7, dense<4> : tensor) : tensor - return -}