diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp index 64e9ad3216bf..385c60ba8863 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp @@ -182,17 +182,26 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) { (out.isSignlessInteger(32) || out.isF32())) { if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) { return { - TileMxNxK{4, 8, 32}, - TileMxNxK{2, 8, 32}, - TileMxNxK{1, 8, 32}, + TileMxNxK{4, 8, 16}, // SMMLA + TileMxNxK{2, 8, 16}, // Truncations + TileMxNxK{1, 8, 16}, // Truncations + }; + } + + if (out.isSignlessInteger(32) && hasFeature(target, "+dotprod")) { + return { + TileMxNxK{8, 8, 8}, // SDOT + TileMxNxK{4, 8, 8}, // Truncations + TileMxNxK{2, 8, 8}, // Truncations + TileMxNxK{1, 8, 8}, // Truncations }; } // Default. return { - TileMxNxK{4, 16, 1}, // Aim to use SMLAL. - TileMxNxK{2, 32, 1}, // Truncation of the above. - TileMxNxK{1, 64, 1}, // Truncation of the above. + TileMxNxK{4, 16, 2}, // Aim to use SMLAL. + TileMxNxK{2, 16, 2}, // Truncation of the above. + TileMxNxK{1, 16, 2}, // Truncation of the above. }; } }