diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 22fde3b7d28a5..80886c31cb58f 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1276,6 +1276,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, return %0 : tensor<4x8x16xf32> } +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e +func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e +func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e +func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e +func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + // ----- // CHECK-LABEL: test_cast_from_block_scaled_static func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { @@ -1307,7 +1343,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<* // ----- // CHECK-LABEL: test_cast_to_block_scaled_mxint8 func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { - %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> }