Skip to content

Commit

Permalink
[LinalgExt] Adding IndexingMaps to linalg_ext.attentionOp (#17864)
Browse files Browse the repository at this point in the history
In order to make fusion with other generics, specifically transpose
easier, we introduce affineMaps/indexingMaps to linalg_ext.attentionOp.
With that we are also enforcing the number and types of dpsInputs. We
are also removing "transpose_V" attribute in lieu of infering from
indexingMaps.

---------

Co-authored-by: Kunwar Grover <groverkss@gmail.com>
  • Loading branch information
raikonenfnu and Groverkss committed Jul 11, 2024
1 parent 20d8308 commit 9d2d766
Show file tree
Hide file tree
Showing 15 changed files with 368 additions and 292 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,80 +71,37 @@ struct ScatterOpConversion
};
} // namespace

static Value collapseBatches(PatternRewriter &rewriter, Location loc,
Value val) {
auto valSizes = cast<RankedTensorType>(val.getType()).getShape();
int64_t newBatch =
std::accumulate(valSizes.begin(), valSizes.end() - 2, 1,
[](int64_t x, int64_t y) { return x * y; });
Type elementType = cast<RankedTensorType>(val.getType()).getElementType();
SmallVector<int64_t> newSizes{newBatch};
newSizes.append(valSizes.end() - 2, valSizes.end());
Type newType = RankedTensorType::get(newSizes, elementType);

auto rank = valSizes.size();
SmallVector<int64_t> collapsed;
for (auto i = 0; i < rank - 2; i++)
collapsed.push_back(i);

SmallVector<ReassociationIndices> reassociation(3);
reassociation[0].append(collapsed);
reassociation[1].push_back(rank - 2);
reassociation[2].push_back(rank - 1);

return rewriter
.create<tensor::CollapseShapeOp>(loc, newType, val, reassociation)
.getResult();
}
static Value expandBatches(PatternRewriter &rewriter, Location loc,
SmallVector<int64_t> batchSizes, Value val) {
auto valSizes = cast<RankedTensorType>(val.getType()).getShape();
Type elementType = cast<RankedTensorType>(val.getType()).getElementType();
SmallVector<int64_t> newSizes(batchSizes);
newSizes.append(valSizes.end() - 2, valSizes.end());
auto rank = newSizes.size();
Type newType = RankedTensorType::get(newSizes, elementType);

SmallVector<ReassociationIndices> reassociation(3);
for (auto i = 0; i < batchSizes.size(); i++)
reassociation[0].push_back(i);
reassociation[1].push_back(rank - 2);
reassociation[2].push_back(rank - 1);

return rewriter
.create<tensor::ExpandShapeOp>(loc, newType, val, reassociation)
.getResult();
static SmallVector<AffineMap>
getStandardAttentionIndexingMaps(MLIRContext *ctx) {
AffineExpr m, k1, k2, n;
bindDims(ctx, m, k1, k2, n);

AffineMap qMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, k1}, ctx);
AffineMap kMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, k1}, ctx);
AffineMap vMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {k2, n}, ctx);
AffineMap rMap =
AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, {m, n}, ctx);

return {qMap, kMap, vMap, rMap};
}

struct AttentionOpConversion
: public OpRewritePattern<mlir::torch::TMTensor::AttentionOp> {
using OpRewritePattern<mlir::torch::TMTensor::AttentionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::torch::TMTensor::AttentionOp op,
PatternRewriter &rewriter) const override {
MLIRContext *ctx = getContext();
Location loc = op->getLoc();
Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();
auto sizes = cast<RankedTensorType>(query.getType()).getShape();
SmallVector<int64_t> batchSizes(sizes.begin(), sizes.end() - 2);

if (sizes.size() > 3) {
query = collapseBatches(rewriter, loc, query);
key = collapseBatches(rewriter, loc, key);
value = collapseBatches(rewriter, loc, value);
}

SmallVector<int64_t> resultShape(
cast<RankedTensorType>(op->getResultTypes()[0]).getShape());
SmallVector<int64_t> collapsedResultShape;
collapsedResultShape.push_back(
std::accumulate(resultShape.begin(), resultShape.end() - 2, 1,
[](int64_t x, int64_t y) { return x * y; }));
collapsedResultShape.append(resultShape.end() - 2, resultShape.end());
Type elementType = cast<RankedTensorType>(query.getType()).getElementType();
auto collapsedResultType =
RankedTensorType::get(collapsedResultShape, elementType);
Value collapsedResult = rewriter.create<tensor::EmptyOp>(
loc, collapsedResultShape, elementType);
ShapedType outputType = op.getOutputType();
Value result = rewriter.create<tensor::EmptyOp>(
loc, outputType.getShape(), outputType.getElementType());

// TODO: This is a hack. This should be replaced with a simple getScale()
// when support for scaling is plumbed to TMTensor on the torch-mlir side.
Expand Down Expand Up @@ -174,15 +131,21 @@ struct AttentionOpConversion
Value scale = rewriter.create<arith::ConstantOp>(
loc, targetType, rewriter.getFloatAttr(targetType, dk));

// Add batches to standard attention indexing maps.
SmallVector<AffineMap> indexingMaps = getStandardAttentionIndexingMaps(ctx);
int64_t numBatches = op.getQueryType().getRank() - 2;
for (AffineMap &map : indexingMaps) {
map = map.shiftDims(numBatches);
for (int batch : llvm::seq<int>(numBatches)) {
map = map.insertResult(rewriter.getAffineDimExpr(batch), batch);
}
}

auto attention = rewriter.create<IREE::LinalgExt::AttentionOp>(
loc, collapsedResultType, SmallVector<Value>{query, key, value, scale},
collapsedResult);

if (sizes.size() > 3)
rewriter.replaceOp(
op, expandBatches(rewriter, loc, batchSizes, attention.getResult(0)));
else
rewriter.replaceOp(op, attention.getResult(0));
loc, result.getType(), query, key, value, scale, result,
rewriter.getAffineMapArrayAttr(indexingMaps));

rewriter.replaceOp(op, attention.getResult(0));
return success();
}
};
Expand Down
37 changes: 22 additions & 15 deletions compiler/plugins/input/Torch/InputConversion/test/attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,53 @@ func.func @attention(%arg0: tensor<5x2x3x4xf32>, %arg1: tensor<5x2x3x4xf32>, %ar
return %0 : tensor<5x2x3x4xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
// CHECK: %arg3: tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
// CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
// CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x3x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x3x4xf32>) -> tensor<10x3x4xf32>
// CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x3x4xf32> into tensor<5x2x3x4xf32>
// CHECK: return %[[RET]] : tensor<5x2x3x4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x4xf32>) -> tensor<5x2x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x3x4xf32>

// -----
func.func @attention(%arg0: tensor<5x2x8x4xf32>, %arg1: tensor<5x2x3x4xf32>, %arg2: tensor<5x2x3x4xf32>, %arg3: tensor<5x2x8x4xf32>) -> (tensor<5x2x8x4xf32>) {
%0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>) outs(%arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
return %0 : tensor<5x2x8x4xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x8x4xf32>, %[[ARG1:.*]]: tensor<5x2x3x4xf32>, %[[ARG2:.*]]: tensor<5x2x3x4xf32>,
// CHECK: %arg3: tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[COL:.*]] = tensor.collapse_shape %[[ARG0]] {{.*}} : tensor<5x2x8x4xf32> into tensor<10x8x4xf32>
// CHECK: %[[COL0:.*]] = tensor.collapse_shape %[[ARG1]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
// CHECK: %[[COL1:.*]] = tensor.collapse_shape %[[ARG2]] {{.*}} : tensor<5x2x3x4xf32> into tensor<10x3x4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[COL]], %[[COL0]], %[[COL1]], %[[SCALE]] : tensor<10x8x4xf32>, tensor<10x3x4xf32>, tensor<10x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<10x8x4xf32>) -> tensor<10x8x4xf32>
// CHECK: %[[RET:.*]] = tensor.expand_shape %[[ATTN]] {{.*}} : tensor<10x8x4xf32> into tensor<5x2x8x4xf32>
// CHECK: return %[[RET]] : tensor<5x2x8x4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x2x8x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x8x4xf32>, tensor<5x2x3x4xf32>, tensor<5x2x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x8x4xf32>) -> tensor<5x2x8x4xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x8x4xf32>

// -----
func.func @attention(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32>, %arg2: tensor<1x3x4xf32>, %arg3: tensor<1x3x4xf32>) -> (tensor<1x3x4xf32>) {
%0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>) outs(%arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
return %0 : tensor<1x3x4xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>

// CHECK-LABEL: func.func @attention(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3x4xf32>, %[[ARG1:.*]]: tensor<1x3x4xf32>, %[[ARG2:.*]]: tensor<1x3x4xf32>,
// CHECK: %arg3: tensor<1x3x4xf32>) -> tensor<1x3x4xf32> {
// CHECK: %[[SCALE:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<1x3x4xf32>, tensor<1x3x4xf32>, tensor<1x3x4xf32>, f32) outs(%[[EMPTY]] : tensor<1x3x4xf32>) -> tensor<1x3x4xf32>
// CHECK: return %[[ATTN]] : tensor<1x3x4xf32>
Original file line number Diff line number Diff line change
Expand Up @@ -1606,7 +1606,10 @@ module {
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%7 = tensor.empty() : tensor<20x4096x64xf16>
%8 = iree_linalg_ext.attention
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%7 : tensor<20x4096x64xf16>) -> tensor<20x4096x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : tensor<20x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x4096x64xf16>>
Expand All @@ -1618,7 +1621,7 @@ module {
// CHECK: func.func @attention()
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: iree_linalg_ext.attention
// CHECK-SAME: {lowering_config = #[[CONFIG]]}
// CHECK-SAME: lowering_config = #[[CONFIG]]

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ module {
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf16>> -> tensor<192x1024x64xf16>
%7 = tensor.empty() : tensor<192x1024x64xf16>
%8 = iree_linalg_ext.attention ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f16) outs(%7 : tensor<192x1024x64xf16>) -> tensor<192x1024x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [192, 1024, 64], strides = [1, 1, 1] : tensor<192x1024x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf16>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ module {
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x16384x128xf16>> -> tensor<16x16384x128xf16>
%7 = tensor.empty() : tensor<16x16384x128xf16>
%8 = iree_linalg_ext.attention ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %scale : tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, tensor<16x16384x128xf16>, f16) outs(%7 : tensor<16x16384x128xf16>) -> tensor<16x16384x128xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [16, 16384, 128], strides = [1, 1, 1] : tensor<16x16384x128xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x16384x128xf16>>
return
}
Expand Down
Loading

0 comments on commit 9d2d766

Please sign in to comment.