Skip to content

Commit

Permalink
Drop tile sizes specific to the ukernels-disabled case. (#17631)
Browse files Browse the repository at this point in the history
This patch is to make sure the matmul tiles selected for ARM64 to have
same dimensions for arm64 ukernels.

Signed-off-by: Alan Li <me@alanli.org>
  • Loading branch information
lialan committed Jun 18, 2024
1 parent 2b3c46c commit 3461314
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
}
}

if (hasUkernel(target) && lhs.isSignlessInteger(8) &&
rhs.isSignlessInteger(8) && out.isSignlessInteger(32)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
Expand All @@ -134,8 +134,8 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
}
}

if (hasUkernel(target) && lhs.isSignlessInteger(8) &&
rhs.isSignlessInteger(4) && out.isSignlessInteger(32)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 16},
Expand All @@ -158,45 +158,6 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
};
}

if (!hasUkernel(target)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
TileMxNxK{4, 8, 8}, // Truncation of the above.
TileMxNxK{2, 8, 8}, // Truncation of the above.
TileMxNxK{1, 8, 8}, // Truncation of the above.
};
}

// Default.
return {
TileMxNxK{8, 8, 1}, // Aim to use SMLAL.
TileMxNxK{4, 8, 1}, // Truncation of the above.
TileMxNxK{2, 8, 1}, // Truncation of the above.
TileMxNxK{1, 8, 1}, // Truncation of the above.
};
}
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 32},
TileMxNxK{2, 8, 32},
TileMxNxK{1, 8, 32},
};
}

// Default.
return {
TileMxNxK{4, 16, 1}, // Aim to use SMLAL.
TileMxNxK{2, 32, 1}, // Truncation of the above.
TileMxNxK{1, 64, 1}, // Truncation of the above.
};
}
}

// Fallback - no architecture-optimized tile size for this case.
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1250,31 +1250,28 @@ func.func @matmul_lowering_i8i8i32_aarch64() attributes {
-> !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2]>>>{%M, %N}
return
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK-LABEL: func @matmul_lowering_i8i8i32_aarch64()
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[$MAP0]]()[%[[M]]]
// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xi8>>{%[[TILED_M]], %[[K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP0]]()[%[[N]]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>{%[[M]], %[[K]]}
// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xi8>>{%[[TILED_N]], %[[K]]}
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>{%[[K]], %[[N]]}
// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xi32>>{%[[TILED_M]], %[[TILED_N]]}
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?xi32>>{%[[M]], %[[N]]}
// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[K]]], strides = [1, 1]
// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[K]], %[[N]]], strides = [1, 1]
// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[N]]], strides = [1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[N]]], strides = [1, 1]

// -----

Expand Down Expand Up @@ -1443,24 +1440,26 @@ func.func @matmul_lowering_i8i4i32_aarch64() attributes {
return
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
// CHECK-LABEL: func @matmul_lowering_i8i4i32_aarch64()
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[$MAP0]]()[%[[M]]]
// CHECK: %[[TILED_K:.+]] = affine.apply #[[$MAP1]]()[%[[K]]]
// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x4x1xi8>>{%[[TILED_M]], %[[K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP1]]()[%[[N]]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x4x2xi8>>{%[[TILED_M]], %[[TILED_K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP2]]()[%[[N]]]
// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xi4>>{%[[TILED_N]], %[[K]]}
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x2xi4>>{%[[TILED_N]], %[[TILED_K]]}
// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x4x16xi32>>{%[[TILED_M]], %[[TILED_N]]}
// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 4, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_K]], 4, 2], strides = [1, 1, 1, 1]
// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[TILED_K]], 16, 2], strides = [1, 1, 1, 1]
// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 4, 16], strides = [1, 1, 1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
Expand Down

0 comments on commit 3461314

Please sign in to comment.