Skip to content

Commit

Permalink
[mlir] Allow vector.contract to have mixed types operands
Browse files Browse the repository at this point in the history
Allow lhs and rhs to have different type than accumulator/destination. Some
hardware like GPUs support natively operations like uint8xuint8xuint32.

Differential Revision: https://reviews.llvm.org/D82069
  • Loading branch information
ThomasRaoux committed Jun 20, 2020
1 parent c310bf8 commit e4bc08f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
12 changes: 7 additions & 5 deletions mlir/include/mlir/Dialect/Vector/VectorOps.td
Expand Up @@ -40,12 +40,9 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// with operators other than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect,
PredOpTrait<"first operand lhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand rhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
TCresVTEtIsSameAsOpBase<0, 2>>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Expand Down Expand Up @@ -140,6 +137,11 @@ def Vector_ContractionOp :

%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>

// Vector contraction with mixed typed. lhs/rhs have different element
// types than accumulator/result.
%6 = vector.contract #contraction_trait %0, %1, %2
: vector<10xf16>, vector<10xf16> into f32
```
}];
let builders = [OpBuilder<
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Vector/VectorTransforms.cpp
Expand Up @@ -28,6 +28,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"

#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -1731,6 +1732,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO(ajcbik): implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
// TODO(thomasraoux): support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
return failure();

// TODO(ntv, ajcbik): implement benefits, cost models.
MLIRContext *ctx = op.getContext();
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Vector/invalid.mlir
Expand Up @@ -760,7 +760,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
func @contraction(%arg0: vector<4x3xi32>,
%arg1: vector<3x7xf32>,
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
// expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
// expected-error@+1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
}
Expand Down
6 changes: 5 additions & 1 deletion mlir/test/Dialect/Vector/ops.mlir
Expand Up @@ -175,7 +175,7 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
// CHECK-LABEL: @contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
%arg4 : index) {
%arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
Expand All @@ -193,6 +193,10 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// Test contraction with mixed type.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
return
}

Expand Down

0 comments on commit e4bc08f

Please sign in to comment.