Skip to content

Commit

Permalink
Add operations needed to support lowering of AffineExpr to SPIR-V.
Browse files Browse the repository at this point in the history
Lowering of CmpIOp, DivISOp, RemISOp, SubIOp and SelectOp to SPIR-V
dialect enables the lowering of operations generated by AffineExpr ->
StandardOps conversion into the SPIR-V dialect.

PiperOrigin-RevId: 280039204
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Nov 12, 2019
1 parent 8082e3a commit 2be5360
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 15 deletions.
86 changes: 72 additions & 14 deletions mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
Expand Up @@ -314,8 +314,9 @@ class ConstantIndexOpConversion final : public ConversionPattern {
return matchFailure();
}

// Use the bitwidth set in the value attribute to decide the result type of
// the SPIR-V constant operation since SPIR-V does not support index types.
// Use the bitwidth set in the value attribute to decide the result type
// of the SPIR-V constant operation since SPIR-V does not support index
// types.
auto constVal = constAttr.getValue();
auto constValType = constAttr.getType().dyn_cast<IndexType>();
if (!constValType) {
Expand All @@ -331,11 +332,47 @@ class ConstantIndexOpConversion final : public ConversionPattern {
}
};

/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
/// for this. If the integer operation is on variables of IndexType, the type of
/// the return value of the replacement operation differs from that of the
/// replaced operation. This is not handled in tablegen-based pattern
/// specification.
/// Convert compare operation to SPIR-V dialect.
class CmpIOpConversion final : public ConversionPattern {
public:
CmpIOpConversion(MLIRContext *context)
: ConversionPattern(CmpIOp::getOperationName(), 1, context) {}

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto cmpIOp = cast<CmpIOp>(op);
CmpIOpOperandAdaptor cmpIOpOperands(operands);

switch (cmpIOp.getPredicate()) {
#define DISPATCH(cmpPredicate, spirvOp) \
case cmpPredicate: \
rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(), \
cmpIOpOperands.lhs(), \
cmpIOpOperands.rhs()); \
return matchSuccess();

DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);

#undef DISPATCH

default:
break;
}
return matchFailure();
}
};

/// Convert integer binary operations to SPIR-V operations. Cannot use
/// tablegen for this. If the integer operation is on variables of IndexType,
/// the type of the return value of the replacement operation differs from
/// that of the replaced operation. This is not handled in tablegen-based
/// pattern specification.
template <typename StdOp, typename SPIRVOp>
class IntegerOpConversion final : public ConversionPattern {
public:
Expand Down Expand Up @@ -396,9 +433,25 @@ class ReturnToSPIRVConversion : public ConversionPattern {
}
};

/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
/// Convert select -> spv.Select
class SelectOpConversion : public ConversionPattern {
public:
SelectOpConversion(MLIRContext *context)
: ConversionPattern(SelectOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
SelectOpOperandAdaptor selectOperands(operands);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
selectOperands.true_value(),
selectOperands.false_value());
return matchSuccess();
}
};

/// Convert store -> spv.StoreOp. The operands of the replaced operation are
/// of IndexType while that of the replacement operation are of type i32. This
/// is not supported in tablegen based pattern specification.
// TODO(ravishankarm) : These could potentially be templated on the operation
// being converted, since the same logic should work for linalg.store.
class StoreOpConversion final : public ConversionPattern {
Expand Down Expand Up @@ -437,9 +490,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
patterns.insert<ConstantIndexOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
ReturnToSPIRVConversion, StoreOpConversion>(context);
patterns
.insert<ConstantIndexOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<DivISOp, spirv::SDivOp>,
IntegerOpConversion<RemISOp, spirv::SModOp>,
IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
context);
}
} // namespace mlir
45 changes: 44 additions & 1 deletion mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir
Expand Up @@ -57,4 +57,47 @@ func @constval() {
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
return
}
}

// CHECK-LABEL: @cmpiop
func @cmpiop(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.IEqual
%0 = cmpi "eq", %arg0, %arg1 : i32
// CHECK: spv.INotEqual
%1 = cmpi "ne", %arg0, %arg1 : i32
// CHECK: spv.SLessThan
%2 = cmpi "slt", %arg0, %arg1 : i32
// CHECK: spv.SLessThanEqual
%3 = cmpi "sle", %arg0, %arg1 : i32
// CHECK: spv.SGreaterThan
%4 = cmpi "sgt", %arg0, %arg1 : i32
// CHECK: spv.SGreaterThanEqual
%5 = cmpi "sge", %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @select
func @selectOp(%arg0 : i32, %arg1 : i32) {
%0 = cmpi "sle", %arg0, %arg1 : i32
// CHECK: spv.Select
%1 = select %0, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @div_rem
func @div_rem(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.SDiv
%0 = divis %arg0, %arg1 : i32
// CHECK: spv.SMod
%1 = remis %arg0, %arg1 : i32
return
}

// CHECK-LABEL: @add_sub
func @add_sub(%arg0 : i32, %arg1 : i32) {
// CHECK: spv.IAdd
%0 = addi %arg0, %arg1 : i32
// CHECK: spv.ISub
%1 = subi %arg0, %arg1 : i32
return
}

0 comments on commit 2be5360

Please sign in to comment.