Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Packing the result of broadcast introduces huge memory allocation #17377

Open
Tracked by #17608
hanhanW opened this issue May 13, 2024 · 0 comments
Open
Tracked by #17608

Packing the result of broadcast introduces huge memory allocation #17377

hanhanW opened this issue May 13, 2024 · 0 comments
Assignees
Labels
codegen Shared code generation infrastructure and dialects

Comments

@hanhanW
Copy link
Contributor

hanhanW commented May 13, 2024

#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  util.func public @turbine_llm_mmtfp_3d_8640_3200_f32f16(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
    %0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<?x8640x3200xf16>
    %2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    %4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    util.return %4 : tensor<?x?x8640xf32>
  }

Coming from #17022 and discord discussion, we are seeing pack(broadcast) -> mmt4d pattern. This is bad because we will allocate a big buffer for broadcast -> pack dispatch, and pass the result to mmt4d kernel. What's happening today is:

Set encodings on matmul operands:

%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = set_encoding(%original_lhs)
%rhs = set_encoding(%bcast)
%gemm = linalg.batch_matmul ins(%lhs, %rhs) ...

If we write it in a materialized form, it is:

%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs
%rhs = tensor.pack %bcast
%gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...

The dispatch formation results in

dispatch {
  %bcast ...
  %rhs = tensor.pack %bacst
  return %rhs
}

This is why we have big memory allocation. However, it is not a hard limit for data-tiling path. What we can do here is set encodings on the source of broadcast. This allows us to swap broadcast and set_encoding/tensor.pack op, which results in

%packed_src = tensor.pack %src
%rhs = linalg.generic ins(%packed_src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs

%gemm = linalg.mmt4d ins(%lhs, %rhs) ...

We should be able to make dispatch formation result in

dispatch {
  tensor.pack %src
}

dispatch {
  %rhs = linalg.generic ...
  %gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...
}

In this context, the memory allocation is much smaller because we don't allocate it with batch dimension. The further action item is about how we codegen broadcast + batch_mmt4d dispatch. It can be achieved like what we have for batch_mmt4d codegen. We tile the batch dimension with size=1; leverage it to mmt4d codegen/ukernels.

After TileAndFuse with batch_size=1:

for (int i = 0; i < batch_size; i += 1) {
  %lhs_slice = tensor.extract_slice %lhs …
  %rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
  %res = batch_mmt4d(%lhs_slice, %rhs_slice)
} 

After batch_mmt4d -> mmt4d decomposition:

for (int i = 0; i < batch_size; i += 1) {
  %lhs_slice = tensor.extract_slice %lhs … -> tensor<1xM0xK0xM1xK1xf16>
  %rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
  %lhs_wo_batch = tensor.extract_slice %lhs_slice … -> tensor<M0xK0xM1xK1xf16>
  %rhs_wo_batch = tensor.extract_slice %rhs_slice … -> tensor<N0xK0xN1xK1xf16>
  %res = mmt4d(%lhs_wo_batch, %rhs_wo_batch)
}

With this flow, we should be able to get rid of huge memory allocation.

@hanhanW hanhanW added the codegen Shared code generation infrastructure and dialects label May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
codegen Shared code generation infrastructure and dialects
Projects
Status: No status
Development

No branches or pull requests

2 participants