diff --git a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp index 04d2d34a7bbb..920679a86c1a 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp +++ b/compiler/src/iree/compiler/GlobalOptimization/SetEncoding.cpp @@ -128,6 +128,16 @@ static MatmulNarrowSizes getMatmulNarrowSizes(ShapedType outType, if (!ShapedType::isDynamic(N) && N < kNarrowThreshold) { narrow.N = N; } + + // Only pick 1 if both are present + if (narrow.M && narrow.N) { + if (*narrow.M <= *narrow.N) { + narrow.N.reset(); + } else { + narrow.M.reset(); + } + } + return narrow; } diff --git a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir index 37b97e579c2c..a63bc090d11d 100644 --- a/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir +++ b/compiler/src/iree/compiler/GlobalOptimization/test/set_encoding.mlir @@ -1021,9 +1021,9 @@ util.func public @batch_matmul_f32f32f32_narrow_MN(%arg0 : tensor<64x4x250xf32>, // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK: util.func public @batch_matmul_f32f32f32_narrow_MN( -// CHECK: iree_encoding.upper_bound_tile_size tensor<64x4x250xf32, #iree_encoding.encoding> -// CHECK: iree_encoding.upper_bound_tile_size tensor<64x250x2xf32, #iree_encoding.encoding> -// CHECK: iree_encoding.upper_bound_tile_size tensor<64x4x2xf32, #iree_encoding.encoding> +// CHECK: iree_encoding.upper_bound_tile_size tensor<64x4x250xf32, #iree_encoding.encoding> +// CHECK: iree_encoding.upper_bound_tile_size tensor<64x250x2xf32, #iree_encoding.encoding> +// CHECK: iree_encoding.upper_bound_tile_size tensor<64x4x2xf32, #iree_encoding.encoding> // CHECK: linalg.batch_matmul // -----