diff --git a/docs/spec.md b/docs/spec.md index b65d60a865e..e86d4c48488 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -227,7 +227,7 @@ BooleanType ::= 'i1' IntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64' FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ' - | 'bf16' | 'f16' | 'f32' | 'f64' + | 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64' ComplexType ::= 'complex' '<' ('f32' | 'f64') '>' ``` @@ -250,6 +250,8 @@ values of type `tensor`). * `f8E4M3FNUZ` and `f8E5M2FNUZ` types corresponding to the `E4M3` and `E5M2` encodings of the FP8 formats described in [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915). + * `f8E4M3B11FNUZ` type corresponding to the `E4M3` encoding of the FP8 formats described in + [Hybrid 8-bit Floating Point (HFP8) Training and Inference for Deep Neural Networks](https://proceedings.neurips.cc/paper_files/paper/2019/file/65fc9fb4897a89789352e211ca2d398f-Paper.pdf). * `bf16` type corresponding to the `bfloat16` format described in [BFloat16: The secret to high performance on Cloud TPUs](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus). * `f16`, `f32` and `f64` types corresponding to respectively diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 8934326b504..544c3eb1c2c 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -34,7 +34,7 @@ def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>; def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; def HLO_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ, - F16, F32, F64, BF16]>; + F8E4M3B11FNUZ, F16, F32, F64, BF16]>; def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>; def HLO_Complex : Complex>; diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 1715cccc2f3..e2109959096 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(0, 10, 1); } + static Version getCurrentVersion() { return Version(0, 11, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index 4e7c873b108..c19201010bb 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -324,6 +324,10 @@ enum TypeCode { /// FloatF8E5M2FNUZV1Type { /// } kFloatF8E5M2FNUZV1Type = 28, + + /// FloatF8E4M3B11FNUZV1Type { + /// } + kFloatF8E4M3B11FNUZV1Type = 29, }; } // namespace vhlo_encoding @@ -702,6 +706,7 @@ const llvm::fltSemantics &getFloatSemantics(Type type) { if (type.isa()) return APFloat::IEEEsingle(); if (type.isa()) return APFloat::IEEEdouble(); if (type.isa()) return APFloat::Float8E4M3FNUZ(); + if (type.isa()) return APFloat::Float8E4M3B11FNUZ(); if (type.isa()) return APFloat::Float8E4M3FN(); if (type.isa()) return APFloat::Float8E5M2FNUZ(); if (type.isa()) return APFloat::Float8E5M2(); @@ -975,6 +980,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const { return FloatF8E5M2FNUZV1Type::get(getContext()); case vhlo_encoding::kFloatF8E4M3FNUZV1Type: return FloatF8E4M3FNUZV1Type::get(getContext()); + case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type: + return FloatF8E4M3B11FNUZV1Type::get(getContext()); case vhlo_encoding::kFunctionV1Type: return readFunctionV1Type(reader); case vhlo_encoding::kIndexV1Type: @@ -1063,6 +1070,11 @@ LogicalResult VhloBytecodeInterface::writeType( return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNUZV1Type), success(); }) + .Case([&](FloatF8E4M3B11FNUZV1Type) { + LOG_WRITE_CALL; + return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3B11FNUZV1Type), + success(); + }) .Case([&](FloatF8E5M2FNUZV1Type) { LOG_WRITE_CALL; return writer.writeVarInt(vhlo_encoding::kFloatF8E5M2FNUZV1Type), diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index 7261c7bbb21..7b8edf34638 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -29,6 +29,7 @@ def VHLO_Dialect : Dialect { Version log: 0.9.0: Initial stability guarantees. 0.10.0: Introduce `f8E4M3FNUZ` and `f8E5M2FNUZ` types. + 0.11.0: Introduce `f8E4M3B11FNUZ` type. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.cpp b/stablehlo/dialect/VhloTypes.cpp index 36cf7e94149..35a76c9e79f 100644 --- a/stablehlo/dialect/VhloTypes.cpp +++ b/stablehlo/dialect/VhloTypes.cpp @@ -80,6 +80,9 @@ void VhloTypeConverter::addBuiltinToVhloConversions() { addConversion([&](Float8E4M3FNUZType type) { return FloatF8E4M3FNUZV1Type::get(type.getContext()); }); + addConversion([&](Float8E4M3B11FNUZType type) { + return FloatF8E4M3B11FNUZV1Type::get(type.getContext()); + }); addConversion([&](Float8E5M2FNUZType type) { return FloatF8E5M2FNUZV1Type::get(type.getContext()); }); @@ -152,6 +155,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() { addConversion([&](FloatF8E4M3FNUZV1Type type) { return Float8E4M3FNUZType::get(type.getContext()); }); + addConversion([&](FloatF8E4M3B11FNUZV1Type type) { + return Float8E4M3B11FNUZType::get(type.getContext()); + }); addConversion([&](FloatF8E5M2FNUZV1Type type) { return Float8E5M2FNUZType::get(type.getContext()); }); diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index c3f8d69c103..97c79c6aa4f 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -89,6 +89,9 @@ def VHLO_FloatF8E5M2V1 : VHLO_TypeDef<"FloatF8E5M2V1", "f8E5M2_v1", "0.9.0", "cu // Corresponds to the 'f8E4M3FNUZ' FloatType from the StableHLO spec. def VHLO_FloatF8E4M3FNUZV1 : VHLO_TypeDef<"FloatF8E4M3FNUZV1", "f8E4M3FNUZ_v1", "0.10.0", "current">; +// Corresponds to the 'f8E4M3B11FNUZ' FloatType from the StableHLO spec. +def VHLO_FloatF8E4M3B11FNUZV1 : VHLO_TypeDef<"FloatF8E4M3B11FNUZV1", "f8E4M3B11FNUZ_v1", "0.11.0", "current">; + // Corresponds to the 'f8E5M2FNUZ' FloatType from the StableHLO spec. def VHLO_FloatF8E5M2FNUZV1 : VHLO_TypeDef<"FloatF8E5M2FNUZV1", "f8E5M2FNUZ_v1", "0.10.0", "current">; diff --git a/stablehlo/reference/Tensor.cpp b/stablehlo/reference/Tensor.cpp index d2758db2a81..7f745bcdace 100644 --- a/stablehlo/reference/Tensor.cpp +++ b/stablehlo/reference/Tensor.cpp @@ -105,6 +105,11 @@ Element Tensor::get(const Index &index) const { return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FNUZ(), APInt(8, *elementData))); } + if (elementType.isFloat8E4M3B11FNUZ()) { + auto elementData = reinterpret_cast(elementPtr); + return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(), + APInt(8, *elementData))); + } if (elementType.isFloat8E5M2FNUZ()) { auto elementData = reinterpret_cast(elementPtr); return Element(elementType, APFloat(llvm::APFloatBase::Float8E5M2FNUZ(), @@ -217,7 +222,8 @@ void Tensor::set(const Index &index, const Element &element) { getSizeInBytes(elementType) * flattenIndex(getShape(), index); // Handle floating-point types. - if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2FNUZ()) { + if (elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E4M3B11FNUZ() || + elementType.isFloat8E5M2FNUZ()) { auto elementData = reinterpret_cast(elementPtr); auto value = element.getFloatValue(); *elementData = (uint8_t)value.bitcastToAPInt().getZExtValue(); @@ -371,14 +377,16 @@ Tensor makeTensor(DenseElementsAttr attr) { auto elemType = type.getElementType(); // Handle floating-point types. - if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) { + if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E4M3B11FNUZ() || + elemType.isFloat8E5M2FNUZ()) { auto floatValues = llvm::to_vector(llvm::map_range( attr.getValues(), [&](APFloat value) -> uint8_t { return value.bitcastToAPInt().getZExtValue(); })); - // For both f8E4M3FNUZ and f8E5M2FNUZ floating-point types, we use uint8_t - // as their storage type because there are no builtin types for those. + // For f8E4M3FNUZ, f8E4M3B11FNUZ, and f8E5M2FNUZ floating-point types, we + // use uint8_t as their storage type because there are no builtin types for + // those. return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign( floatValues)); } diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index b1fdc4b0be5..8324642ec9d 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -44,8 +44,9 @@ bool isSupportedIntegerType(Type type) { } bool isSupportedFloatType(Type type) { - return type.isFloat8E4M3FNUZ() || type.isFloat8E5M2FNUZ() || type.isF16() || - type.isBF16() || type.isF32() || type.isF64(); + return type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ() || + type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() || + type.isF32() || type.isF64(); } bool isSupportedComplexType(Type type) { diff --git a/stablehlo/tests/interpret_constant.mlir b/stablehlo/tests/interpret_constant.mlir index 82558065fa7..04f2e2472b0 100644 --- a/stablehlo/tests/interpret_constant.mlir +++ b/stablehlo/tests/interpret_constant.mlir @@ -88,6 +88,13 @@ func.func @constant_op_test_f8_e4m3_fnuz() { // ----- +func.func @constant_op_test_f8_e4m3b11_fnuz() { + %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3B11FNUZ> + check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.25, 30.0, -30.0, 0.0001220703125, -0.0001220703125]> : tensor<10xf8E4M3B11FNUZ> + func.return +} + +// ----- func.func @constant_op_test_f8_e5m2_fnuz() { %0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E5M2FNUZ> check.expect_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.1, 3.0, 57344.0, -57344.0, 7.62939e-06, -7.62939e-06]> : tensor<10xf8E5M2FNUZ> diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index aacef7942d9..742052475ea 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1437,7 +1437,7 @@ func.func @cholesky_invalid_rank(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- func.func @cholesky_invalid_elt(%arg0: tensor<1x2x2xi32>) -> tensor<1x2x2xi32> { - // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} %0 = "stablehlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xi32>) -> tensor<1x2x2xi32> func.return %0: tensor<1x2x2xi32> } @@ -1908,7 +1908,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1917,7 +1917,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1935,7 +1935,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -1977,7 +1977,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1987,7 +1987,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2005,7 +2005,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2611,7 +2611,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} %0 = "stablehlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -4778,7 +4778,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xi32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xi32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4810,7 +4810,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{result #1 must be 1D tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} + // expected-error@+1 {{result #1 must be 1D tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2x2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4818,7 +4818,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<*xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<*xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} %0:3 = "stablehlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<*xf32> } @@ -5575,7 +5575,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or f8E4M3FNUZ type or f8E5M2FNUZ type or f8E4M3B11FNUZ type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} %0 = "stablehlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir new file mode 100644 index 00000000000..e17406d0ba6 --- /dev/null +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir @@ -0,0 +1,2261 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=0.11.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: diff <(stablehlo-translate --deserialize %s.bc | stablehlo-opt) <(stablehlo-opt %s) + +// ============ ATTRIBUTES ============ + +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_eq" + +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_ne" + +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_ge" + +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_gt" + +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_le" + +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_direction_lt" + +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_notype" + +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_float" + +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_totalorder" + +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_signed" + +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_comparison_type_unsigned" + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" + +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_original" + +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" + +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "attr_fft_type_fft" + +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "attr_fft_type_ifft" + +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} +// CHECK-LABEL: "attr_fft_type_rfft" + +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "attr_fft_type_irfft" + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_default" + +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_high" + +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_config_highest" + +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_default" + +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_three_fry" + +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "attr_rng_algorithm_philox" + +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_rng_distribution_uniform" + +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "attr_rng_distribution_normal" + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_no_transpose" + +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_transpose" + +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "attr_transpose_adjoint" + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + +// CHECK-LABEL: "attr_type_extensions_bounds" + +// ============ DEFAULTS ============ + +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) { + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "default_all_gather" + +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "default_all_reduce" + +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} +// CHECK-LABEL: "default_all_to_all" + +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) { + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} +// CHECK-LABEL: "default_cholesky" + +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} +// CHECK-LABEL: "default_collective_permute" + +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) { + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "default_compare" + +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) { + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} +// CHECK-LABEL: "default_convolution" + +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) { + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "default_custom_call" + +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) { + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} +// CHECK-LABEL: "default_dot_general" + +func.func @default_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) { + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "default_dot" + +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) { + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" + +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} +// CHECK-LABEL: "default_dynamic_conv" + +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} +// CHECK-LABEL: "default_dynamic_gather" + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: } : () -> () + func.return %arg0 : tensor +} + +func.func @dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) { + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64> + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} +// CHECK-LABEL: "dynamic_gather" + +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) { + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: } : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "default_infeed" + +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) { + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "default_outfeed" + +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "default_recv" + +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "default_send" + +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "default_reduce_scatter" + +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} +// CHECK-LABEL: "default_reduce_window" + +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} +// CHECK-LABEL: "default_scatter" + +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} +// CHECK-LABEL: "default_select_and_scatter" + +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "default_sort" + +// ============ OPS ============ + +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_abs" + +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_add" + +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%arg0) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_after_all" + +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) { + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_all_gather" + +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_all_reduce" + +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} +// CHECK-LABEL: "op_all_to_all" + +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_and" + +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_atan2" + +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} +// CHECK-LABEL: "op_batch_norm_grad" + +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} +// CHECK-LABEL: "op_batch_norm_inference" + +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} +// CHECK-LABEL: "op_batch_norm_training" + +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_bitcast_convert" + +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%arg0) { + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_broadcast_in_dim" + +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%arg0) { + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = dense<16> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_broadcast" + +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_case" + +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cbrt" + +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_ceil" + +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) { + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} +// CHECK-LABEL: "op_cholesky" + +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_clamp" + +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_count_leading_zeros" + +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} +// CHECK-LABEL: "op_collective_permute" + +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) { + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_compare" + +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "op_complex" + +func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> + %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> + func.return %0 : tensor<1xindex> +} +// CHECK-LABEL: "op_compute_reshape_shape" + +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) { + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_concatenate" + +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() { + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: } : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_constant" + +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_convert" + +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) { + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} +// CHECK-LABEL: "op_convolution" + +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cosine" + +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_create_token" + +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%arg0) { + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_cross_replica_sum" + +func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { + // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 + %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness + func.return %0 : !shape.witness +} +// CHECK-LABEL: "op_cstr_reshapable" + +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) { + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_custom_call" + +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_divide" + +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) { + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} +// CHECK-LABEL: "op_dot_general" + +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) { + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "op_dot" + +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) { + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, + known_expanding_dimensions = dense<0> : tensor<1xi64>, + known_nonexpanding_dimensions = dense<1> : tensor<1xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" + +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} +// CHECK-LABEL: "op_dynamic_conv" + +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} +// CHECK-LABEL: "op_dynamic_gather" + +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%arg0) { + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_iota" + +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_pad" + +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_dynamic_reshape" + +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%arg0, %arg1) { + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = dense<4> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} +// CHECK-LABEL: "op_dynamic_slice" + +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_dynamic_update_slice" + +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%arg0, %arg1) { + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "op_einsum" + +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_exponential_minus_one" + +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_exponential" + +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%arg0) { + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = dense<16> : tensor<1xi64> + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} +// CHECK-LABEL: "op_fft" + +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_floor" + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: } : () -> () + func.return %arg0 : tensor +} + +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) { + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64>, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} +// CHECK-LABEL: "op_gather" + +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%arg0) { + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_get_dimension_size" + +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%arg0) { + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: } : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_get_tuple_element" + +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%arg2) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_if" + +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_imag" + +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) { + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: } : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "op_infeed" + +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() { + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_iota" + +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_is_finite" + +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_log" + +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_log_plus_one" + +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_logistic" + +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_map" + +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_maximum" + +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_minimum" + +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_multiply" + +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_negate" + +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_not" + +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_optimization_barrier" + +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_or" + +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) { + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_outfeed" + +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%arg0, %arg1) { + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = dense<4> : tensor<1xi64>, + edge_padding_low = dense<4> : tensor<1xi64>, + interior_padding = dense<0> : tensor<1xi64> + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_pad" + +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_popcnt" + +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_power" + +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%arg0, %arg1, %arg2, %arg3) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_real_dynamic_slice" + +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_real" + +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} +// CHECK-LABEL: "op_recv" + +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_reduce" + +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%arg0) { + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_reduce_precision" + +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_reduce_scatter" + +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, + base_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} +// CHECK-LABEL: "op_reduce_window" + +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_remainder" + +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_replica_id" + +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_partition_id" + +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%arg0) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} +// CHECK-LABEL: "op_reshape" + +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_return" + +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%arg0) { + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_reverse" + +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%arg0) { + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: } : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} +// CHECK-LABEL: "op_rng_bit_generator" + +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) { + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_rng" + +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_round_nearest_afz" + +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_round_nearest_even" + +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_rsqrt" + +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} +// CHECK-LABEL: "op_scatter" + +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} +// CHECK-LABEL: "op_select_and_scatter" + +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_select" + +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) { + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} +// CHECK-LABEL: "op_send" + +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%arg0, %arg1) { + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_set_dimension_size" + +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_left" + +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_right_arithmetic" + +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_shift_right_logical" + +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sign" + +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sine" + +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%arg0) { + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = dense<0> : tensor<1xi64>, + limit_indices = dense<4> : tensor<1xi64>, + strides = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} +// CHECK-LABEL: "op_slice" + +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) {compare_type = #vhlo, comparison_direction = #vhlo} : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) { + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} +// CHECK-LABEL: "op_sort" + +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_sqrt" + +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_subtract" + +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_tanh" + +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: } : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} +// CHECK-LABEL: "op_torch_index_select" + +func.func @op_trace(%arg0: tensor) { + // CHECK: "vhlo.trace_v1"(%arg0) { + // CHECK-SAME: tag = #vhlo.string_v1<"foo"> + // CHECK-SAME: } : (!vhlo.tensor_v1) -> () + "stablehlo.trace"(%arg0) { + tag = "foo" + } : (tensor) -> () + func.return +} +// CHECK-LABEL: "op_trace" + +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%arg0) { + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: } : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = dense<[1, 0]> : tensor<2xi64> + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} +// CHECK-LABEL: "op_transpose" + +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%arg0, %arg1) { + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: } : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} +// CHECK-LABEL: "op_triangular_solve" + +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} +// CHECK-LABEL: "op_tuple" + +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%arg0) { + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: } : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} +// CHECK-LABEL: "op_unary_einsum" + +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_uniform_dequantize" + +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "op_uniform_quantize" + +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} +// CHECK-LABEL: "op_while" + +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_xor" + +// ============ TYPES ============ + +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i1" + +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i4" + +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i8" + +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i16" + +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i32" + +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_i64" + +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui4" + +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui8" + +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui16" + +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui32" + +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_ui64" + +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3FN" + +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E5M2" + +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3FNUZ" + +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3B11FNUZ" + +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E5M2FNUZ" + +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_bf16" + +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f16" + +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f32" + +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f64" + +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "type_complex_f32" + +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} +// CHECK-LABEL: "type_complex_f64" + +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_dynamism_ranked" + +func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.unranked_tensor_v1) -> !vhlo.unranked_tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} +// CHECK-LABEL: "type_dynamism_unranked" + +func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_quantization" + +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" + +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%arg0) {callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" + +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} +// CHECK-LABEL: "type_tuple" diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc new file mode 100644 index 00000000000..790824d7b91 Binary files /dev/null and b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc differ