Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DT][CPU] Some casting ops are not fused with mmt4d kernels #15826

Closed
hanhanW opened this issue Dec 7, 2023 · 3 comments · Fixed by #15663
Closed

[DT][CPU] Some casting ops are not fused with mmt4d kernels #15826

hanhanW opened this issue Dec 7, 2023 · 3 comments · Fixed by #15663
Assignees
Labels
codegen Shared code generation infrastructure and dialects

Comments

@hanhanW
Copy link
Contributor

hanhanW commented Dec 7, 2023

We set encodings on CastOpInterface ops, so we can fuse in cast ops to mmt4d dispatch. However, we only fuse cast ops when they are not "group dequantization" ops. It does not happen when it is something like arith.truncf %lhs f32 to f16 + f16.f16.f16 matmul. This results in additional dispatches. We are having pack dispatch, arith.truncf dispatch and mmt4d dispatch. We should either fix it in set_encoding or dispatch formation. In the former solution, we will have arith.truncf + pack dispatch and mmt4d dispatch. In the latter solution, we will have [optional consumers] + pack dispatch and arith.truncf + mmt4d dispatch.

To repro:

Run iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/repro.mlir -o /tmp/a.vmfb

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @main(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf16>) -> tensor<?x?xf16> {
    %cst = arith.constant 0.000000e+00 : f16
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %0 = tensor.empty(%dim, %dim_0) : tensor<?x?xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf16>) {
    ^bb0(%in: f32, %out: f16):
      %5 = arith.truncf %in : f32 to f16
      linalg.yield %5 : f16
    } -> tensor<?x?xf16>
    %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf16>
    %2 = tensor.empty(%dim, %dim_1) : tensor<?x?xf16>
    %3 = linalg.fill ins(%cst : f16) outs(%2 : tensor<?x?xf16>) -> tensor<?x?xf16>
    %4 = linalg.matmul ins(%1, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%3 : tensor<?x?xf16>) -> tensor<?x?xf16>
    return %4 : tensor<?x?xf16>
  }
}
@hanhanW hanhanW added the codegen Shared code generation infrastructure and dialects label Dec 7, 2023
@hanhanW hanhanW self-assigned this Dec 7, 2023
@hanhanW
Copy link
Contributor Author

hanhanW commented Dec 7, 2023

@Max191 @MaheshRavishankar and I had an offline discussion about the issue. If we fuse the truncf with mmt4d op and the ukernel is not supported for the type, we will pay costs for load/store before calling ukernels.

We should not set encodings on all the casting ops. We should do it only for casting ops that widen bitwidth. In this context, the truncf op will be fused with its producers (i.e., generics). So we no longer see the issue that we set encodings on cast ops but not fuse them with mmt4d ops.

There are two sub-tasks to address the issue:

  1. Do not set encoding on casting ops that reduce bitwidth.
  2. Fuse those cast ops with mmt4d kernels.

@Max191 given that you've touched these files recently, can you help fix the issue?

(FYI @bjacob )

@pzread
Copy link
Contributor

pzread commented Dec 7, 2023

I think #15760 is also related to this so I have a question:

We want to extend i8 to f32 on LHS/RHS in some cases so even if we don't have the corresponding ukernels, we can still use f32 ukernels with LHS/RHS extended.

My current idea is to fuse cast ops with the pack op (and potentially its producer, so generic + truncf + pack) and we can hide the overhead. IIUC in this issue we are actually trying to fuse cast ops with the mmt4d ukernel? How would that work as from my understanding we can't actually fuse truncf with ukernel call in the same dispatch?

@hanhanW
Copy link
Contributor Author

hanhanW commented Dec 7, 2023

I think #15760 is also related to this so I have a question:

We want to extend i8 to f32 on LHS/RHS in some cases so even if we don't have the corresponding ukernels, we can still use f32 ukernels with LHS/RHS extended.

My current idea is to fuse cast ops with the pack op (and potentially its producer, so generic + truncf + pack) and we can hide the overhead. IIUC in this issue we are actually trying to fuse cast ops with the mmt4d ukernel? How would that work as from my understanding we can't actually fuse truncf with ukernel call in the same dispatch?

Hey sorry that I wrote the tasks in a opposite case. I updated the comment. We don't want to set encodings on truncf, so the op order is truncf + pack. In this context, we will create generic + pack dispatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen Shared code generation infrastructure and dialects
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants