Skip to content

Commit

Permalink
[MLIR][Shape] Remove type conversion from lowering to standard
Browse files Browse the repository at this point in the history
Operating on indices and extent tensors directly, the type conversion is no
longer needed for the supported cases.

Differential Revision: https://reviews.llvm.org/D84442
  • Loading branch information
frgossen committed Jul 29, 2020
1 parent 5d9f33a commit b6b9d3e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 61 deletions.
33 changes: 3 additions & 30 deletions mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Expand Up @@ -219,25 +219,6 @@ RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef<Value> operands,
return success();
}

namespace {
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
using TypeConverter::convertType;

ShapeTypeConverter(MLIRContext *ctx) {
// Add default pass-through conversion.
addConversion([&](Type type) { return type; });

addConversion([ctx](SizeType type) { return IndexType::get(ctx); });
addConversion([ctx](ShapeType type) {
return RankedTensorType::get({ShapedType::kDynamicSize},
IndexType::get(ctx));
});
}
};
} // namespace

namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
Expand All @@ -248,23 +229,15 @@ class ConvertShapeToStandardPass
} // namespace

void ConvertShapeToStandardPass::runOnOperation() {
// Setup type conversion.
MLIRContext &ctx = getContext();
ShapeTypeConverter typeConverter(&ctx);

// Setup target legality.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp, ReturnOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();

// Setup conversion patterns.
OwningRewritePatternList patterns;
populateShapeToStandardConversionPatterns(patterns, &ctx);
populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter);

// Apply conversion.
auto module = getOperation();
Expand Down
33 changes: 2 additions & 31 deletions mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
@@ -1,40 +1,11 @@
// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s

// Convert `size` to `index` type.
// CHECK-LABEL: @size_id
// CHECK-SAME: (%[[SIZE:.*]]: index)
func @size_id(%size : !shape.size) -> !shape.size {
// CHECK: return %[[SIZE]] : index
return %size : !shape.size
}

// -----

// Convert `shape` to `tensor<?xindex>` type.
// CHECK-LABEL: @shape_id
// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>)
func @shape_id(%shape : !shape.shape) -> !shape.shape {
// CHECK: return %[[SHAPE]] : tensor<?xindex>
return %shape : !shape.shape
}

// -----

// Lower binary ops.
// CHECK-LABEL: @binary_ops
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) {
// CHECK: addi %[[LHS]], %[[RHS]] : index
%sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size
return
}

// -----

// Lower binary ops.
// CHECK-LABEL: @binary_ops
// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index)
func @binary_ops(%lhs : index, %rhs : index) {
// CHECK: addi %[[LHS]], %[[RHS]] : index
%sum = shape.add %lhs, %rhs : index, index -> index
// CHECK: muli %[[LHS]], %[[RHS]] : index
%product = shape.mul %lhs, %rhs : index, index -> index
return
Expand Down

0 comments on commit b6b9d3e

Please sign in to comment.