-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LinalgExt] Generalize attention tiling interface implementation #17408
[LinalgExt] Generalize attention tiling interface implementation #17408
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high-level questions:
- Why do we need indexing_maps for the attention op?
- Could we have more than 1 dimension for batch/m/n/k1/k2?
- It looks like tiling on reduction dims in
getTiledImplementation()
is hard. Do you plan to implement PartialReductionOpInterface for the attention op?
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
|
adf5386
to
07da396
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for filling in me the context, it looks more reasonable to me now. The idea sounds okay to me. Is the PR ready for review? I'm seeing some code being commented out in AttentionOp::verify
98bc9c9
to
db189bc
Compare
Fixed. It should be ready to review now |
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
outputOffsets.push_back(offsets[dim.getPosition()]); | ||
outputSizes.push_back(sizes[dim.getPosition()]); | ||
} | ||
return {outputOffsets, outputSizes, outputStrides}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use std::make_tuple()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this. I prefer this since the tuple type is already explicitly typed out in the function definition. So there is no ambiguity. Can change if there is a preference.
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
} | ||
if (failed(checkShapeRank(op, "output", outputType, rankToCompareWith))) { | ||
// Check shape compatibility based on indexing maps. | ||
SmallVector<int64_t> shape(getIterationDomainRank(), -1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The concern is having magic number -1
; expect it to be different from ShapedType::kDynamic
. They could have the same value. (Actually ShapedType::kDynamic
was -1
in the past, and it is internalized to something else today.) What I'd suggest is having a method to compute the shape first, and then you can compare the result of affineMap.compose(shape)
and the target tensor shape. It avoids the -1
magic number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I have a way of computing the shape without using these indexing maps. I added a foundDims
array so there is no magic number.
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
186a3b0
to
dc6517b
Compare
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)
…e-org#17408) This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now). Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624) Signed-off-by: Lubo Litchev <lubol@google.com>
This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now).
Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)