Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Generalize attention tiling interface implementation #17408

Merged
merged 6 commits into from
May 22, 2024

Conversation

Groverkss
Copy link
Contributor

@Groverkss Groverkss commented May 15, 2024

This patch generalizes tiling implementation for AttentionOp. Before, only the batch and M dimension of attention could be tiled. This patch instead, allows tiling of N dimension as well as allows transposition based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because TileAndDecomposeAttention pass is hardecoded with dimensions. This will be fixed once we implement reduction tiling interface for it (after llvm/llvm-project#92624)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high-level questions:

  1. Why do we need indexing_maps for the attention op?
  2. Could we have more than 1 dimension for batch/m/n/k1/k2?
  3. It looks like tiling on reduction dims in getTiledImplementation() is hard. Do you plan to implement PartialReductionOpInterface for the attention op?

@Groverkss
Copy link
Contributor Author

Some high-level questions:

  1. Why do we need indexing_maps for the attention op?
  2. Could we have more than 1 dimension for batch/m/n/k1/k2?
  3. It looks like tiling on reduction dims in getTiledImplementation() is hard. Do you plan to implement PartialReductionOpInterface for the attention op?
  1. The indexing maps are so we can do more fusions with the attention op. One of the fusions that we want to do is fusing transposes with AttentionOp. This PR just hardcodes them, but eventually, they are going to be replaced with indexing maps on the op.

  2. Once I add indexing maps, it should be possible to have more than 1 dimensions for batch/m/n/k1/k2. So eventually yes, but not right now. The pytorch SDPA op actually has 2 batch dimensions, so this is it is common to have multiple dimensions for these things.

  3. Eventually yes. The plan is to move TileAndDecomposeAttention into Tiling (PartialReductionOpInterface) and Decompose (AggregateOpInterface). This needs some changes upstream, but eventually, this will happen.

@Groverkss Groverkss force-pushed the generalize-attention-tiling branch from adf5386 to 07da396 Compare May 16, 2024 11:20
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for filling in me the context, it looks more reasonable to me now. The idea sounds okay to me. Is the PR ready for review? I'm seeing some code being commented out in AttentionOp::verify

@Groverkss Groverkss requested a review from hanhanW May 16, 2024 19:34
@Groverkss Groverkss force-pushed the generalize-attention-tiling branch 2 times, most recently from 98bc9c9 to db189bc Compare May 16, 2024 19:44
@Groverkss
Copy link
Contributor Author

Thanks for filling in me the context, it looks more reasonable to me now. The idea sounds okay to me. Is the PR ready for review? I'm seeing some code being commented out in AttentionOp::verify

Fixed. It should be ready to review now

outputOffsets.push_back(offsets[dim.getPosition()]);
outputSizes.push_back(sizes[dim.getPosition()]);
}
return {outputOffsets, outputSizes, outputStrides};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use std::make_tuple()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. I prefer this since the tuple type is already explicitly typed out in the function definition. So there is no ambiguity. Can change if there is a preference.

}
if (failed(checkShapeRank(op, "output", outputType, rankToCompareWith))) {
// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank(), -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern is having magic number -1; expect it to be different from ShapedType::kDynamic. They could have the same value. (Actually ShapedType::kDynamic was -1 in the past, and it is internalized to something else today.) What I'd suggest is having a method to compute the shape first, and then you can compare the result of affineMap.compose(shape) and the target tensor shape. It avoids the -1 magic number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I have a way of computing the shape without using these indexing maps. I added a foundDims array so there is no magic number.

@Groverkss Groverkss merged commit 9fe159d into iree-org:main May 22, 2024
56 of 57 checks passed
gglangg pushed a commit to gglangg/iree that referenced this pull request Jun 4, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
gglangg pushed a commit to gglangg/iree that referenced this pull request Jun 4, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
bangtianliu pushed a commit to bangtianliu/iree that referenced this pull request Jun 5, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)
LLITCHEV pushed a commit to LLITCHEV/iree that referenced this pull request Jul 30, 2024
…e-org#17408)

This patch generalizes tiling implementation for AttentionOp. Before,
only the batch and M dimension of attention could be tiled. This patch
instead, allows tiling of N dimension as well as allows transposition
based on indexing maps (hardcoded for now).

Tiling on dimension N is disabled in CPU backend for now, because
TileAndDecomposeAttention pass is hardecoded with dimensions. This will
be fixed once we implement reduction tiling interface for it (after
llvm/llvm-project#92624)

Signed-off-by: Lubo Litchev <lubol@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants