Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Avoid shifting by non-splat amounts #16913

Open
kuhar opened this issue Mar 27, 2024 · 4 comments
Open

[SPIR-V] Avoid shifting by non-splat amounts #16913

kuhar opened this issue Mar 27, 2024 · 4 comments
Labels
codegen/spirv SPIR-V code generation compiler backend good first issue 🌱 Good for newcomers help wanted Extra attention is needed performance ⚡ Performance/optimization related work across the compiler and runtime

Comments

@kuhar
Copy link
Member

kuhar commented Mar 27, 2024

On RDNA gpus, shift vector instructions are much faster when the shift amount is a splat constant. However, we seem to emit non-splat shifts for int4 matvec from LLama2:

Input

func.func @main() {
  %c64 = arith.constant 64 : index
  %c8192 = arith.constant 8192 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f16
  %2 = util.unfoldable_constant dense<1> : tensor<4096x32x128xi4>
  %3 = util.unfoldable_constant dense<1.0> : tensor<4096x32xf16>
  %4 = util.unfoldable_constant dense<1.0> : tensor<4096x32xf16>
  %5 = util.unfoldable_constant dense<1.0> : tensor<32x128xf16>

  %9 = tensor.empty() : tensor<4096xf16>
  %10 = tensor.empty() : tensor<4096x32x128xf16>
  %11 = linalg.fill ins(%cst : f16) outs(%9 : tensor<4096xf16>) -> tensor<4096xf16>
  %12 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1)>,
                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
    iterator_types = ["parallel", "parallel", "parallel"]
  } ins(%2, %3, %4 : tensor<4096x32x128xi4>, tensor<4096x32xf16>, tensor<4096x32xf16>)
    outs(%10 : tensor<4096x32x128xf16>) {
  ^bb0(%in: i4, %in_0: f16, %in_1: f16, %out: f16):
    %14 = arith.extui %in : i4 to i32
    %15 = arith.uitofp %14 : i32 to f16
    %16 = arith.subf %15, %in_1 : f16
    %17 = arith.mulf %16, %in_0 : f16
    linalg.yield %17 : f16
  } -> tensor<4096x32x128xf16>
  %13 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
                     affine_map<(d0, d1, d2) -> (d0)>],
    iterator_types = ["parallel", "reduction", "reduction"]
  } ins(%5, %12 : tensor<32x128xf16>, tensor<4096x32x128xf16>)
    outs(%11 : tensor<4096xf16>) {
  ^bb0(%in: f16, %in_0: f16, %out: f16):
    %14 = arith.mulf %in, %in_0 : f16
    %15 = arith.addf %14, %out : f16
    linalg.yield %15 : f16
  } -> tensor<4096xf16>

  check.expect_eq_const(%13, dense<4096.0> : tensor<4096xf16>) : tensor<4096xf16>
  return
}

Compile command:

tools/iree-compile vmt_int4.mlir \
  --iree-hal-target-backends=vulkan-spirv \
  --iree-vulkan-target-triple=rdna3-7900-linux \
  --iree-hal-dump-executable-files-to=dumps \
  -o vmt_int4.vmfb

The dumps directory contains IR at the level of the spir-v dialect (dumps/module__main_dispatch_0__main_dispatch_0_generic_4096x32x128_f16.spirv.mlir), with the following shift ops:

      %cst_vec_4xi8 = spirv.Constant dense<[0, 4, 0, 4]> : vector<4xi8>
      // ...
      %44 = spirv.VectorShuffle [0 : i32, 1 : i32] %27, %27 : vector<4xi8>, vector<4xi8> -> vector<2xi8>
      %45 = spirv.VectorShuffle [0 : i32, 0 : i32, 1 : i32, 1 : i32] %44, %44 : vector<2xi8>, vector<2xi8> -> vector<4xi8>
      %46 = spirv.BitwiseAnd %45, %cst_vec_4xi8_0 : vector<4xi8>
      %47 = spirv.ShiftRightLogical %46, %cst_vec_4xi8 : vector<4xi8>, vector<4xi8>

We should figure out which pass produces this instruction sequence (--mlir-print-ir-after-all) and change it to avoid shifting by non-splat amounts.

@kuhar kuhar added help wanted Extra attention is needed good first issue 🌱 Good for newcomers performance ⚡ Performance/optimization related work across the compiler and runtime codegen/spirv SPIR-V code generation compiler backend labels Mar 27, 2024
@kuhar
Copy link
Member Author

kuhar commented Mar 27, 2024

cc: @inbelic

@benvanik
Copy link
Collaborator

Dynamic shift amounts can be slow on many ISAs as they are often implemented as microcoded loops - would be interested to know if we hit this on other targets due to this being something that comes from higher up on the stack! /cc @hanhanW @antiagainst

@inbelic
Copy link

inbelic commented Mar 27, 2024

From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.

In this case we are transforming a vector.bitcast into a sequence of vector.shuffle, arith.andi and arith.rshui operations. So we could potentially rewrite this pass to use the splat vectors for shifts as mentioned. Or, we would it also be possible to prevent this pass from expanding the vector.bitcast to allow for a direct lowering to spirv.bitcast. WDYT?

@kuhar
Copy link
Member Author

kuhar commented Mar 27, 2024

From initial inspection, these operations are a result of the following rewrite pattern: BitCastRewriter::genericRewrite and implemented here.

IIRC @nicolasvasilache wrote this based on the CPU performance work at the time. This lowers the bitcasts in a principled way, especially for more complicated cases like with i5, but this expansion is suboptimal for GPU. I don't think we can rely on spirv.Bitcast because the vector element type is i4 IIUC, which needs to be expanded to an integer type supported by the target..

I think the solution would be to change the order of operations to avoid non-splat shifts, i.e., so that this turns this into something like:

      %cst_vec_2xi8 = spirv.Constant dense<[4, 4]> : vector<2xi8>
      %46 = spirv.BitwiseAnd %27, %cst_vec_4xi8_0 : vector<4xi8>
      %44 = spirv.VectorShuffle [0 : i32, 1 : i32] %46, %46 : vector<4xi8>, vector<4xi8> -> vector<2xi8>
      %44_s = spirv.ShiftRightLogical %44, %cst_vec_2xi8 : vector<2xi8>, vector<2xi8>
      %45 = spirv.VectorShuffle [0 : i32, 4 : i32, 1 : i32, 5 : i32] %44, %44_s : vector<2xi8>, vector<2xi8> -> vector<4xi8>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen/spirv SPIR-V code generation compiler backend good first issue 🌱 Good for newcomers help wanted Extra attention is needed performance ⚡ Performance/optimization related work across the compiler and runtime
Projects
None yet
Development

No branches or pull requests

3 participants