Skip to content

Commit

Permalink
Revert "[mlir][spirv] Lower arith overflow flags to corresponding S…
Browse files Browse the repository at this point in the history
…PIR-V op decorations (llvm#77714)"

Temporaryly reverting as it broke python bindings

This reverts commit 4278d9b.
  • Loading branch information
Hardcode84 authored and justinfargnoli committed Jan 28, 2024
1 parent 7ff779c commit 8b507ae
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 96 deletions.
59 changes: 3 additions & 56 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Expand Up @@ -158,61 +158,8 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
}

// TODO: Move to some common place?
static std::string getDecorationString(spirv::Decoration decor) {
return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
}

namespace {

/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
/// operations. Op can potentially support overflow flags.
template <typename Op, typename SPIRVOp>
struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;

LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 3);
auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
}

if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
!getElementTypeOrSelf(op.getType()).isIndex() &&
dstType != op.getType()) {
return op.emitError("bitwidth emulation is not implemented yet on "
"unsigned op pattern version");
}

auto overflowFlags = arith::IntegerOverflowFlags::none;
if (auto overflowIface =
dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
if (converter->getTargetEnv().allows(
spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
overflowFlags = overflowIface.getOverflowAttr().getValue();
}

auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
op, dstType, adaptor.getOperands());

if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
rewriter.getUnitAttr());

if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
rewriter.getUnitAttr());

return success();
}
};

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1207,9 +1154,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
Expand Down
40 changes: 0 additions & 40 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Expand Up @@ -1407,43 +1407,3 @@ func.func @float_scalar(%arg0: f16) {
}

} // end module

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @ops_flags
func.func @ops_flags(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

} // end module


// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
} {

// No decorations should be generated is corresponding Extensions/Capabilities are missing
// CHECK-LABEL: @ops_flags
func.func @ops_flags(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
// CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}

} // end module

0 comments on commit 8b507ae

Please sign in to comment.