104 changes: 38 additions & 66 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,16 @@ using namespace mlir::scf;

namespace {

// Unpacks the single unrealized_conversion_cast using the list of inputs
// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
static void unpackUnrealizedConversionCast(Value v,
SmallVectorImpl<Value> &unpacked) {
if (auto cast =
dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
if (cast.getInputs().size() != 1) {
// 1 : N type conversion.
unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
return;
}
}
// 1 : 1 type conversion.
unpacked.push_back(v);
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> result;
for (const auto &vals : values)
llvm::append_range(result, vals);
return result;
}

static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

// CRTP
Expand All @@ -40,19 +36,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::typeConverter;
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
using OneToNOpAdaptor =
typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;

//
// Derived classes should provide the following method which performs the
// actual conversion. It should return std::nullopt upon conversion failure
// and return the converted operation upon success.
//
// std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;
// std::optional<SourceOp> convertSourceOp(
// SourceOp op, OneToNOpAdaptor adaptor,
// ConversionPatternRewriter &rewriter,
// TypeRange dstTypes) const;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> dstTypes;
SmallVector<unsigned> offsets;
Expand All @@ -73,28 +71,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
return rewriter.notifyMatchFailure(op, "could not convert operation");

// Packs the return value.
SmallVector<Value> packedRets;
SmallVector<ValueRange> packedRets;
for (unsigned i = 1, e = offsets.size(); i < e; i++) {
unsigned start = offsets[i - 1], end = offsets[i];
unsigned len = end - start;
ValueRange mappedValue = newOp->getResults().slice(start, len);
if (len != 1) {
// 1 : N type conversion.
Type origType = op.getResultTypes()[i - 1];
Value mat = typeConverter->materializeSourceConversion(
rewriter, op.getLoc(), origType, mappedValue);
if (!mat) {
return rewriter.notifyMatchFailure(
op, "Failed to materialize 1:N type conversion");
}
packedRets.push_back(mat);
} else {
// 1 : 1 type conversion.
packedRets.push_back(mappedValue.front());
}
packedRets.push_back(mappedValue);
}

rewriter.replaceOp(op, packedRets);
rewriter.replaceOpWithMultiple(op, packedRets);
return success();
}
};
Expand All @@ -105,7 +90,7 @@ class ConvertForOpTypes
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

// The callback required by CRTP.
std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Create a empty new op and inline the regions from the old op.
Expand All @@ -129,16 +114,13 @@ class ConvertForOpTypes
if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
return std::nullopt;

// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getInitArgs())
unpackUnrealizedConversionCast(arg, flatArgs);

// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
adaptor.getUpperBound(),
adaptor.getStep(), flatArgs);
ForOp newOp = rewriter.create<ForOp>(
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));

// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
Expand All @@ -160,12 +142,12 @@ class ConvertIfOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {

IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
adaptor.getCondition(), true);
IfOp newOp = rewriter.create<IfOp>(
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
newOp->setAttrs(op->getAttrs());

// We do not need the empty blocks created by rewriter.
Expand All @@ -189,15 +171,11 @@ class ConvertWhileOpTypes
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
// Unpacked the iteration arguments.
SmallVector<Value> flatArgs;
for (Value arg : adaptor.getOperands())
unpackUnrealizedConversionCast(arg, flatArgs);

auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
flattenValues(adaptor.getOperands()));

for (auto i : {0u, 1u}) {
if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
Expand All @@ -218,13 +196,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);

rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(
op, flattenValues(adaptor.getOperands()));
return success();
}
};
Expand All @@ -235,13 +210,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
public:
using OpConversionPattern<ConditionOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> unpackedYield;
for (Value operand : adaptor.getOperands())
unpackUnrealizedConversionCast(operand, unpackedYield);

rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
rewriter.modifyOpInPlace(
op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
return success();
}
};
Expand Down
170 changes: 90 additions & 80 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}

Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}

Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
Expand Down
3 changes: 0 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);

/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);

/// Generates code to retrieve the slice offset for the sparse tensor slice,
/// return a constant if the offset is statically known.
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,6 @@ class MutSparseTensorDescriptor
}
};

/// Returns the "tuple" value of the adapted tensor.
inline UnrealizedConversionCastOp getTuple(Value tensor) {
return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
}

/// Packs the given values as a "tuple" value.
inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
ValueRange values) {
Expand All @@ -245,18 +240,17 @@ inline Value genTuple(OpBuilder &builder, Location loc,
return genTuple(builder, loc, desc.getRankedTensorType(), desc.getFields());
}

inline SparseTensorDescriptor getDescriptorFromTensorTuple(Value tensor) {
auto tuple = getTuple(tensor);
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return SparseTensorDescriptor(stt, tuple.getInputs());
inline SparseTensorDescriptor
getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
}

inline MutSparseTensorDescriptor
getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields) {
auto tuple = getTuple(tensor);
fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
SparseTensorType stt(cast<RankedTensorType>(tuple.getResultTypes()[0]));
return MutSparseTensorDescriptor(stt, fields);
getMutDescriptorFromTensorTuple(ValueRange adaptorValues,
SmallVectorImpl<Value> &fields,
RankedTensorType type) {
fields.assign(adaptorValues.begin(), adaptorValues.end());
return MutSparseTensorDescriptor(SparseTensorType(type), fields);
}

} // namespace sparse_tensor
Expand Down
219 changes: 150 additions & 69 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp

Large diffs are not rendered by default.

38 changes: 6 additions & 32 deletions mlir/test/Transforms/decompose-call-graph-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i32
// CHECK-12N-LABEL: func @identity(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
Expand Down Expand Up @@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// CHECK-LABEL: func @mixed_recursive_decomposition(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
// CHECK: return %[[V7]], %[[V10]] : i1, i2
// CHECK: return %[[ARG0]], %[[ARG1]] : i1, i2
// CHECK-12N-LABEL: func @mixed_recursive_decomposition(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
Expand All @@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
// CHECK-LABEL: func @caller(
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
// CHECK: %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
// CHECK: %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
// CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
// CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
// CHECK: return %[[RET0]], %[[RET1]] : i1, i32
// CHECK: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
// CHECK: return %[[V0]]#0, %[[V0]]#1 : i1, i32
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[ARG0:.*]]: i1,
// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
Expand Down Expand Up @@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
// CHECK-SAME: %[[I4:.*]]: i4,
// CHECK-SAME: %[[I5:.*]]: i5,
// CHECK-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
// CHECK: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
// CHECK: %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
// CHECK: %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
// CHECK: %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
// CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
// CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
// CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
// CHECK-12N-LABEL: func @caller(
// CHECK-12N-SAME: %[[I1:.*]]: i1,
// CHECK-12N-SAME: %[[I2:.*]]: i2,
Expand Down