Skip to content

Commit

Permalink
Fix nit
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu committed Jul 11, 2024
1 parent 1d023bb commit 5182e7b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
statistics while processing the current tile.

With indexing_maps we can also infer when variables are transposed.
for example, if V has indexing_map of (m, k1, k2, n) -> (n, k2) insted
for example, if V has indexing_map of (m, k1, k2, n) -> (n, k2) instead
of (m, k1, k2, n) -> (k2, n), then we'd know that V is transposed (V.T),
and we can expect computation to look like:

Expand Down Expand Up @@ -669,8 +669,9 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
return sumType->getRank();
}

// Use indexing_map to infer if V is transposed.
bool getTransposeV() {
// Since we expect map_V to be (b, m, k1, k2, n) -> (b, k2, n), this function
// returns true if n is not the fastest dim.
bool isTransposeV() {
AffineMap vMap = getValueMap();
auto vFastDim = llvm::cast<AffineDimExpr>(vMap.getResults().back());
int64_t nDim = vMap.getNumInputs() - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
auto [result, newMax, newSum] = createAttentionBody(
keySlice, valueSlice, querySlice, tiledResult, max, sum,
sequenceTileLength, keyValueTileLength, headDimension, elementType, ops,
tiledAttnOp.getTransposeV(), loc, rewriter);
tiledAttnOp.isTransposeV(), loc, rewriter);

rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,15 @@ IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
headDimension, elementType, loc, rewriter);
Value valueSlice =
extractSlice(value, keyShape, ivs, keyValueTileLength, headDimension,
elementType, loc, rewriter, attnOp.getTransposeV());
elementType, loc, rewriter, attnOp.isTransposeV());
Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
headDimension, elementType, loc, rewriter);

Value scale = attnOp.getScale();

int64_t tiledInputRank = cast<ShapedType>(querySlice.getType()).getRank();
SmallVector<AffineMap> tiledIndexingMaps = getTileAttentionIndexingMaps(
rewriter, tiledInputRank, attnOp.getTransposeV());
rewriter, tiledInputRank, attnOp.isTransposeV());

auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
attnOp.getLoc(),
Expand Down

0 comments on commit 5182e7b

Please sign in to comment.