You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Coming from #17022 and discord discussion, we are seeing pack(broadcast) -> mmt4d pattern. This is bad because we will allocate a big buffer for broadcast -> pack dispatch, and pass the result to mmt4d kernel. What's happening today is:
This is why we have big memory allocation. However, it is not a hard limit for data-tiling path. What we can do here is set encodings on the source of broadcast. This allows us to swap broadcast and set_encoding/tensor.pack op, which results in
In this context, the memory allocation is much smaller because we don't allocate it with batch dimension. The further action item is about how we codegen broadcast + batch_mmt4d dispatch. It can be achieved like what we have for batch_mmt4d codegen. We tile the batch dimension with size=1; leverage it to mmt4d codegen/ukernels.
After TileAndFuse with batch_size=1:
for (int i = 0; i < batch_size; i += 1) {
%lhs_slice = tensor.extract_slice %lhs …
%rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
%res = batch_mmt4d(%lhs_slice, %rhs_slice)
}
Coming from #17022 and discord discussion, we are seeing
pack(broadcast) -> mmt4d
pattern. This is bad because we will allocate a big buffer forbroadcast -> pack
dispatch, and pass the result tommt4d
kernel. What's happening today is:Set encodings on matmul operands:
If we write it in a materialized form, it is:
The dispatch formation results in
This is why we have big memory allocation. However, it is not a hard limit for data-tiling path. What we can do here is set encodings on the source of broadcast. This allows us to swap
broadcast
andset_encoding/tensor.pack
op, which results inWe should be able to make dispatch formation result in
In this context, the memory allocation is much smaller because we don't allocate it with batch dimension. The further action item is about how we codegen
broadcast + batch_mmt4d
dispatch. It can be achieved like what we have for batch_mmt4d codegen. We tile the batch dimension with size=1; leverage it to mmt4d codegen/ukernels.After TileAndFuse with batch_size=1:
After batch_mmt4d -> mmt4d decomposition:
With this flow, we should be able to get rid of huge memory allocation.
The text was updated successfully, but these errors were encountered: