From a9b8ce54a528b6b7c62088c81f8a01ae382f5a88 Mon Sep 17 00:00:00 2001 From: Udaya Ranga Date: Wed, 5 Nov 2025 14:01:00 +0000 Subject: [PATCH] [mlir][tosa] Add e2e tests for matmul_t_block_scaled Added tests for dtypes fp6e3m2, fp6e2m3, fp4e2m1, mxint8 Signed-off-by: Udaya Ranga --- mlir/test/Dialect/Tosa/ops.mlir | 38 ++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 22fde3b7d28a5..80886c31cb58f 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -1276,6 +1276,42 @@ func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, return %0 : tensor<4x8x16xf32> } +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e3m2_e2e +func.func @test_matmul_t_block_scaled_fp6e3m2_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf6E3M2FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E3M2FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp6e2m3_e2e +func.func @test_matmul_t_block_scaled_fp6e2m3_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf6E2M3FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf6E2M3FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_fp4e2m1_e2e +func.func @test_matmul_t_block_scaled_fp4e2m1_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf4E2M1FN>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32xf4E2M1FN>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + +// ----- +// CHECK-LABEL: test_matmul_t_block_scaled_mxint8_e2e +func.func @test_matmul_t_block_scaled_mxint8_e2e(%arg0: tensor<6x2x32xf32>, %arg1: tensor<6x64x32xf32>) -> tensor<6x2x64xf32> { + %a, %sa = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<6x2x32xf32>) -> (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>) + %b, %sb = tosa.cast_to_block_scaled %arg1 {block_size = #tosa.block_size : i32} : (tensor<6x64x32xf32>) -> (tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) + %res = tosa.matmul_t_block_scaled %a, %sa, %b, %sb {block_size = #tosa.block_size : i32} : (tensor<6x2x32x!tosa.mxint8>, tensor<6x2x1xf8E8M0FNU>, tensor<6x64x32x!tosa.mxint8>, tensor<6x64x1xf8E8M0FNU>) -> tensor<6x2x64xf32> + return %res : tensor<6x2x64xf32> +} + // ----- // CHECK-LABEL: test_cast_from_block_scaled_static func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> { @@ -1307,7 +1343,7 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<* // ----- // CHECK-LABEL: test_cast_to_block_scaled_mxint8 func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) { - %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) + %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size : i32} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU> }