Skip to content

Commit

Permalink
Make tile selection same as ukernel configurations
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Li <me@alanli.org>
  • Loading branch information
lialan committed Jun 11, 2024
1 parent cda3ccb commit 0439758
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
}
}

if (hasUkernel(target) && lhs.isSignlessInteger(8) &&
rhs.isSignlessInteger(8) && out.isSignlessInteger(32)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
Expand All @@ -134,8 +134,8 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
}
}

if (hasUkernel(target) && lhs.isSignlessInteger(8) &&
rhs.isSignlessInteger(4) && out.isSignlessInteger(32)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
out.isSignlessInteger(32)) {
if (hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 16},
Expand All @@ -158,45 +158,6 @@ enumerateMatmulTileArm64(TypeRange elementTypes, ExecutableTargetAttr target) {
};
}

if (!hasUkernel(target)) {
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(8) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{8, 8, 8}, // Aim to use SMMLA.
TileMxNxK{4, 8, 8}, // Truncation of the above.
TileMxNxK{2, 8, 8}, // Truncation of the above.
TileMxNxK{1, 8, 8}, // Truncation of the above.
};
}

// Default.
return {
TileMxNxK{8, 8, 1}, // Aim to use SMLAL.
TileMxNxK{4, 8, 1}, // Truncation of the above.
TileMxNxK{2, 8, 1}, // Truncation of the above.
TileMxNxK{1, 8, 1}, // Truncation of the above.
};
}
if (lhs.isSignlessInteger(8) && rhs.isSignlessInteger(4) &&
(out.isSignlessInteger(32) || out.isF32())) {
if (out.isSignlessInteger(32) && hasFeature(target, "+i8mm")) {
return {
TileMxNxK{4, 8, 32},
TileMxNxK{2, 8, 32},
TileMxNxK{1, 8, 32},
};
}

// Default.
return {
TileMxNxK{4, 16, 1}, // Aim to use SMLAL.
TileMxNxK{2, 32, 1}, // Truncation of the above.
TileMxNxK{1, 64, 1}, // Truncation of the above.
};
}
}

// Fallback - no architecture-optimized tile size for this case.
return {};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1250,31 +1250,28 @@ func.func @matmul_lowering_i8i8i32_aarch64() attributes {
-> !flow.dispatch.tensor<readwrite:tensor<?x?xi32, #iree_encoding.encoding<role = RESULT, element_types = [i8, i8, i32], user_indexing_maps = [#map, #map1, #map2]>>>{%M, %N}
return
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK-LABEL: func @matmul_lowering_i8i8i32_aarch64()
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[$MAP0]]()[%[[M]]]
// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xi8>>{%[[TILED_M]], %[[K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP0]]()[%[[N]]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>{%[[M]], %[[K]]}
// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x8x1xi8>>{%[[TILED_N]], %[[K]]}
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xi8>>{%[[K]], %[[N]]}
// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x8x8xi32>>{%[[TILED_M]], %[[TILED_N]]}
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?xi32>>{%[[M]], %[[N]]}
// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[K]]], strides = [1, 1]
// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 8, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[K]], %[[N]]], strides = [1, 1]
// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[N]]], strides = [1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[OUTS]] :
// CHECK: flow.dispatch.tensor.store %[[MMT4D]], %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 8, 8], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0], sizes = [%[[M]], %[[N]]], strides = [1, 1]

// -----

Expand Down Expand Up @@ -1443,24 +1440,26 @@ func.func @matmul_lowering_i8i4i32_aarch64() attributes {
return
}
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
// CHECK-LABEL: func @matmul_lowering_i8i4i32_aarch64()
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[M:.+]] = hal.interface.constant.load[0]
// CHECK-DAG: %[[N:.+]] = hal.interface.constant.load[1]
// CHECK-DAG: %[[K:.+]] = hal.interface.constant.load[2]
// CHECK-DAG: %[[TILED_M:.+]] = affine.apply #[[$MAP0]]()[%[[M]]]
// CHECK: %[[TILED_K:.+]] = affine.apply #[[$MAP1]]()[%[[K]]]
// CHECK: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(0)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x4x1xi8>>{%[[TILED_M]], %[[K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP1]]()[%[[N]]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x4x2xi8>>{%[[TILED_M]], %[[TILED_K]]}
// CHECK: %[[TILED_N:.+]] = affine.apply #[[$MAP2]]()[%[[N]]]
// CHECK: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(1)
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x1xi4>>{%[[TILED_N]], %[[K]]}
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?x16x2xi4>>{%[[TILED_N]], %[[TILED_K]]}
// CHECK: %[[OUTS_BINDING:.+]] = hal.interface.binding.subspan set(0) binding(2)
// CHECK-SAME: !flow.dispatch.tensor<readwrite:tensor<?x?x4x16xi32>>{%[[TILED_M]], %[[TILED_N]]}
// CHECK: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[K]], 4, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_K]], 4, 2], strides = [1, 1, 1, 1]
// CHECK: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[K]], 16, 1], strides = [1, 1, 1, 1]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_N]], %[[TILED_K]], 16, 2], strides = [1, 1, 1, 1]
// CHECK: %[[OUTS:.+]] = flow.dispatch.tensor.load %[[OUTS_BINDING]]
// CHECK-SAME: offsets = [0, 0, 0, 0], sizes = [%[[TILED_M]], %[[TILED_N]], 4, 16], strides = [1, 1, 1, 1]
// CHECK: %[[MMT4D:.+]] = linalg.mmt4d
Expand Down

0 comments on commit 0439758

Please sign in to comment.