Skip to content

Commit

Permalink
[mlir][spirv] Define spv.GLSL.Fma and add lowerings
Browse files Browse the repository at this point in the history
Also changes some rewriter.create + rewriter.replaceOp calls
into rewriter.replaceOpWithNewOp calls.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D94965
  • Loading branch information
antiagainst committed Jan 19, 2021
1 parent 6259fbd commit 3a56a96
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 19 deletions.
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -972,4 +972,44 @@ def SPV_GLSLSClampOp : SPV_GLSLTernaryArithmeticOp<"SClamp", 45, SPV_SignedInt>
}];
}

// -----

def SPV_GLSLFmaOp : SPV_GLSLTernaryArithmeticOp<"Fma", 50, SPV_Float> {
let summary = "Computes a * b + c.";

let description = [{
In uses where this operation is decorated with NoContraction:

- fma is considered a single operation, whereas the expression a * b + c
is considered two operations.
- The precision of fma can differ from the precision of the expression
a * b + c.
- fma will be computed with the same precision as any other fma decorated
with NoContraction, giving invariant results for the same input values
of a, b, and c.

Otherwise, in the absence of a NoContraction decoration, there are no
special constraints on the number of operations or difference in precision
between fma and the expression a * b +c.

The operands must all be a scalar or vector whose component type is
floating-point.

Result Type and the type of all operands must be the same type. Results
are computed per component.

<!-- End of AutoGen section -->
```
fma-op ::= ssa-id `=` `spv.GLSL.Fma` ssa-use, ssa-use, ssa-use `:`
float-scalar-vector-type
```
#### Example:

```mlir
%0 = spv.GLSL.Fma %a, %b, %c : f32
%1 = spv.GLSL.Fma %a, %b, %c : vector<3xf16>
```
}];
}

#endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS
49 changes: 30 additions & 19 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ struct VectorBroadcastConvert final
vector::BroadcastOp::Adaptor adaptor(operands);
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
adaptor.source());
Value construct = rewriter.create<spirv::CompositeConstructOp>(
broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
rewriter.replaceOp(broadcastOp, construct);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
broadcastOp, broadcastOp.getVectorType(), source);
return success();
}
};
Expand All @@ -55,9 +54,23 @@ struct VectorExtractOpConvert final
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
extractOp.getLoc(), adaptor.vector(), id);
rewriter.replaceOp(extractOp, newExtract);
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
return success();
}
};

struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
return failure();
vector::FMAOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
return success();
}
};
Expand All @@ -74,9 +87,8 @@ struct VectorInsertOpConvert final
return failure();
vector::InsertOp::Adaptor adaptor(operands);
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
rewriter.replaceOp(insertOp, newInsert);
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.source(), adaptor.dest(), id);
return success();
}
};
Expand All @@ -92,10 +104,9 @@ struct VectorExtractElementOpConvert final
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
return failure();
vector::ExtractElementOp::Adaptor adaptor(operands);
Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
extractElementOp, extractElementOp.getType(), adaptor.vector(),
extractElementOp.position());
rewriter.replaceOp(extractElementOp, newExtractElement);
return success();
}
};
Expand All @@ -111,10 +122,9 @@ struct VectorInsertElementOpConvert final
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
return failure();
vector::InsertElementOp::Adaptor adaptor(operands);
Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
insertElementOp.getLoc(), insertElementOp.getType(),
insertElementOp.dest(), adaptor.source(), insertElementOp.position());
rewriter.replaceOp(insertElementOp, newInsertElement);
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
adaptor.source(), insertElementOp.position());
return success();
}
};
Expand All @@ -124,7 +134,8 @@ struct VectorInsertElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
VectorInsertOpConvert, VectorExtractElementOpConvert,
VectorInsertElementOpConvert>(typeConverter, context);
patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
VectorExtractOpConvert, VectorFmaOpConvert,
VectorInsertOpConvert, VectorInsertElementOpConvert>(
typeConverter, context);
}
10 changes: 10 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/simple.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,13 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
spv.Return
}

// -----

// CHECK-LABEL: func @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) {
%0 = vector.fma %a, %b, %c: vector<4xf32>
spv.Return
}
20 changes: 20 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,23 @@ func @fclamp(%arg0 : i32, %min : i32, %max : i32) -> () {
%2 = spv.GLSL.SClamp %arg0, %min, %max : i32
return
}

// -----

//===----------------------------------------------------------------------===//
// spv.GLSL.Fma
//===----------------------------------------------------------------------===//

func @fma(%a : f32, %b : f32, %c : f32) -> () {
// CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%2 = spv.GLSL.Fma %a, %b, %c : f32
return
}

// -----

func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
// CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
%2 = spv.GLSL.Fma %a, %b, %c : vector<3xf32>
return
}
6 changes: 6 additions & 0 deletions mlir/test/Target/SPIRV/glsl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32
spv.Return
}

spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
// CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
%13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
spv.Return
}
}

0 comments on commit 3a56a96

Please sign in to comment.