Skip to content

Commit

Permalink
[StableHLO] Add e2e fft tests (#13788)
Browse files Browse the repository at this point in the history
These require some stablehlo canonicalization patterns during lowering
to IREE dialects.

I also discovered some issues with the generic syntax
(openxla/stablehlo#1539) and updated the tests
to use the custom assembly format.

Issue: #12678
  • Loading branch information
kuhar committed May 25, 2023
1 parent a8a70fb commit 6019731
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/Util/Transforms",
"//compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing",
"//compiler/src/iree/compiler/Utils",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ iree_cc_library(
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::Util::Transforms
iree::compiler::InputConversion::StableHLO::Preprocessing
iree::compiler::Utils
PUBLIC
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "iree/compiler/InputConversion/StableHLO/LegalizeToLinalgUtils.h"
#include "iree/compiler/InputConversion/StableHLO/PassDetail.h"
#include "iree/compiler/InputConversion/StableHLO/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
#include "iree/compiler/InputConversion/StableHLO/Rewriters.h"
#include "iree/compiler/InputConversion/StableHLO/TypeConversion.h"
#include "iree/compiler/Utils/ConversionUtils.h"
Expand Down Expand Up @@ -427,6 +428,10 @@ struct ConvertStableHloToIreeInputDialects final
createStableHloToLinalgTypeConverter();
typeConverter->addArgumentMaterialization(scalarToTensor);

// Run stablehlo canonicalization patterns with a high benefit to avoid some
// expensive expansions.
populateCanonicalizationPatterns(context, &patterns, /*benefit=*/1024);

// TODO(#12678): Handle chlo lowering.

populateStableHloToLinalgOnTensorsConversionPatterns(
Expand Down
6 changes: 6 additions & 0 deletions tests/e2e/stablehlo_ops/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
"finite.mlir",
"floor.mlir",
"gather.mlir",
Expand Down Expand Up @@ -115,6 +116,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
"finite.mlir",
"floor.mlir",
"gather.mlir",
Expand Down Expand Up @@ -219,6 +221,7 @@ iree_check_single_backend_test_suite(
],
include = ["*.mlir"],
exclude = [
"fft.mlir", # TODO(#9583)
"reverse.mlir", # TODO(#12415): disabled due to miscompilation on Pixel 6.
],
),
Expand Down Expand Up @@ -255,6 +258,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
"finite.mlir",
"floor.mlir",
"gather.mlir",
Expand Down Expand Up @@ -343,6 +347,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
"finite.mlir",
"floor.mlir",
"gather.mlir",
Expand Down Expand Up @@ -422,6 +427,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir",
"exponential.mlir",
"exponential_minus_one.mlir",
"fft.mlir",
"finite.mlir",
"floor.mlir",
"gather.mlir",
Expand Down
5 changes: 5 additions & 0 deletions tests/e2e/stablehlo_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
"finite.mlir"
"floor.mlir"
"gather.mlir"
Expand Down Expand Up @@ -102,6 +103,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
"finite.mlir"
"floor.mlir"
"gather.mlir"
Expand Down Expand Up @@ -235,6 +237,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
"finite.mlir"
"floor.mlir"
"gather.mlir"
Expand Down Expand Up @@ -306,6 +309,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
"finite.mlir"
"floor.mlir"
"gather.mlir"
Expand Down Expand Up @@ -381,6 +385,7 @@ iree_check_single_backend_test_suite(
"dynamic_update_slice.mlir"
"exponential.mlir"
"exponential_minus_one.mlir"
"fft.mlir"
"finite.mlir"
"floor.mlir"
"gather.mlir"
Expand Down
27 changes: 27 additions & 0 deletions tests/e2e/stablehlo_ops/fft.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// TODO(hanchung): Add other types of fft tests, e.g. fft, ifft, irfft.

func.func @rfft_1d() {
%input = util.unfoldable_constant dense<[
9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7,
3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5,
-0.777, 2.0, 1.7, 3.5, -4.5, 0.0]> : tensor<32xf32>
%0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<32xf32>) -> tensor<17xcomplex<f32>>
%1 = stablehlo.real %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32>
%2 = stablehlo.imag %0 : (tensor<17xcomplex<f32>>) -> tensor<17xf32>
check.expect_almost_eq_const(%1, dense<[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]> : tensor<17xf32>) : tensor<17xf32>
check.expect_almost_eq_const(%2, dense<[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]> : tensor<17xf32>) : tensor<17xf32>
return
}

func.func @rfft_2d() {
%input = util.unfoldable_constant dense<[[
9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5, -0.777, 2.0, 1.7,
3.5, -4.5, 0.0, 9.0, 1.0, 4.5, -0.3, 10.0, -1.0, 5.5, 0.3, 299.0, 3.5,
-0.777, 2.0, 1.7, 3.5, -4.5, 0.0]]> : tensor<1x32xf32>
%0 = stablehlo.fft %input, type = RFFT, length = [32] : (tensor<1x32xf32>) -> tensor<1x17xcomplex<f32>>
%1 = stablehlo.real %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32>
%2 = stablehlo.imag %0 : (tensor<1x17xcomplex<f32>>) -> tensor<1x17xf32>
check.expect_almost_eq_const(%1, dense<[[666.8460, 0.0, -590.16925, 0.0, 593.4485, 0.0, -579.52875, 0.0, 629.95404, 0.0, -567.1126, 0.0, 591.75146, 0.0, -583.1894, 0.0, 630.846]]> : tensor<1x17xf32>) : tensor<1x17xf32>
check.expect_almost_eq_const(%2, dense<[[0.0, 0.0, -23.956373, 0.0, -10.254326, 0.0, -6.1443653, 0.0, -10.0, 0.0, 3.865515, 0.0, 0.63767385, 0.0, 52.453506, 0.0, 0.0]]> : tensor<1x17xf32>) : tensor<1x17xf32>
return
}

0 comments on commit 6019731

Please sign in to comment.