Skip to content

Commit

Permalink
Merge 1e4b977 into 90f29a6
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar committed Jun 20, 2024
2 parents 90f29a6 + 1e4b977 commit 88f32fa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ getLoopRangesImpl(ReifyRankedShapedTypeOpInterface shapedOp, Location loc,
LogicalResult status = shapedOp.reifyResultShapes(builder, resultDims);
(void)status;
assert(succeeded(status) && "reifyResultShapes failed");
return llvm::map_to_vector(
resultDims[0], [&](OpFoldResult v) { return Range{zero, v, one}; });
return llvm::map_to_vector(resultDims[0], [&](OpFoldResult v) {
return Range{zero, v, one};
});
}

/// For a given operation returns the loop ranges needed to compute the op.
Expand Down Expand Up @@ -543,40 +544,40 @@ bool isDequantizationLikeOp(Operation *op) {
return false;
}

// Check that only one input has an identity map, and the rest are projected
// permutations and not full permutations
OpOperand *identityInput = nullptr;
// Check that all operands that have the highest rank have bit width
// less than the output bit-width.
DenseMap<int64_t, SmallVector<RankedTensorType>> rankBuckets;
int64_t maxRank = 0;
for (OpOperand *input : genericOp.getDpsInputOperands()) {
auto inputMap = genericOp.getMatchingIndexingMap(input);
if (inputMap.isIdentity()) {
if (identityInput) {
return false;
}
identityInput = input;
} else if (!inputMap.isProjectedPermutation(true) ||
inputMap.isPermutation()) {
return false;
auto inputType = dyn_cast<RankedTensorType>(input->get().getType());
if (!inputType) {
continue;
}
int64_t currRank = inputType.getRank();
maxRank = std::max(currRank, maxRank);
rankBuckets[currRank].push_back(inputType);
}

if (!identityInput) {
if (rankBuckets[maxRank].empty()) {
return false;
}

auto indexingMaps = genericOp.getIndexingMapsArray();
if (!indexingMaps.back().isIdentity()) {
return false;
unsigned int maxInputElementBitWidth = 0;
for (auto t : rankBuckets[maxRank]) {
Type elementType = t.getElementType();
if (!elementType.isIntOrFloat()) {
return false;
}
maxInputElementBitWidth = std::max(
maxInputElementBitWidth, t.getElementType().getIntOrFloatBitWidth());
}

// Check that the identity input element bitwidth is smaller than the output
// element bitwidth.
Type inputElementType = getElementTypeOrSelf(identityInput->get().getType());
Type outputElementType = getElementTypeOrSelf(genericOp->getResultTypes()[0]);
if (!inputElementType.isIntOrFloat() || !outputElementType.isIntOrFloat()) {
if (!outputElementType.isIntOrFloat()) {
return false;
}
if (inputElementType.getIntOrFloatBitWidth() >=
outputElementType.getIntOrFloatBitWidth()) {
if (maxInputElementBitWidth >= outputElementType.getIntOrFloatBitWidth()) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,18 +502,18 @@ util.func public @scf_nested_dispatch(%arg0 : tensor<?xi32>) -> (tensor<?xi32>)

// -----

util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32x1xf32>, %arg3: tensor<4096x32x1xf32>) -> tensor<1x1x4096xf32> {
util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>, %arg2: tensor<4096x32xf32>, %arg3: tensor<4096x32xf32>) -> tensor<1x1x4096xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<1x1x4096xf32>
%1 = tensor.empty() : tensor<4096x32x128xf32>
%2 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1, 0)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%1 : tensor<4096x32x128xf32>) {
ins(%arg0, %arg2, %arg3 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%1 : tensor<4096x32x128xf32>) {
^bb0(%in: i8, %in_0: f32, %in_1: f32, %out: f32):
%5 = arith.extui %in : i8 to i32
%6 = arith.uitofp %5 : i32 to f32
Expand All @@ -537,8 +537,8 @@ util.func public @no_dequantization_fusion(%arg0: tensor<4096x32x128xi8>, %arg1:
// CHECK: util.func public @no_dequantization_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4096x32x128xi8>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1x1x32x128xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32x1xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<4096x32xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<1x1x4096xf32>
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<4096x32x128xf32>
Expand Down Expand Up @@ -605,3 +605,28 @@ util.func public @no_dequantization_like_fusion(%arg0: tensor<32x1x16x1x8xi16>,
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: flow.return %[[MMT4D]] :
// CHECK: util.return %[[DISP]]

// -----

util.func public @broadcasting_dequant_op(%arg0 : tensor<?x?xi8>,
%rhs : tensor<?x?x?xi32>, %init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d1 = tensor.dim %arg0, %c0 : tensor<?x?xi8>
%d2 = tensor.dim %arg0, %c1 : tensor<?x?xi8>
%d0 = tensor.dim %rhs, %c0 : tensor<?x?x?xi32>
%empty = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xi32>
%dequant = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<?x?xi8>) outs(%empty : tensor<?x?x?xi32>) {
^bb0(%in: i8, %out: i32):
%12 = arith.extui %in : i8 to i32
linalg.yield %12 : i32
} -> tensor<?x?x?xi32>
%op = linalg.batch_matmul_transpose_b
ins(%dequant, %rhs : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
outs(%init : tensor<?x?x?xi32>) -> tensor<?x?x?xi32>
util.return %op : tensor<?x?x?xi32>
}

0 comments on commit 88f32fa

Please sign in to comment.