Skip to content

Commit

Permalink
[mlir][python] Adapt to segment_sizes attribute type change.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Mar 20, 2021
1 parent f380066 commit 8d05a28
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");

std::vector<uint64_t> operandSegmentLengths;
std::vector<uint64_t> resultSegmentLengths;
std::vector<uint32_t> operandSegmentLengths;
std::vector<uint32_t> resultSegmentLengths;

// Validate/determine region count.
auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
Expand Down Expand Up @@ -1247,8 +1247,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
// Add result_segment_sizes attribute.
if (!resultSegmentLengths.empty()) {
int64_t size = resultSegmentLengths.size();
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
resultSegmentLengths.size(), resultSegmentLengths.data());
(*attributes)["result_segment_sizes"] =
PyAttribute(context, segmentLengthAttr);
Expand All @@ -1257,8 +1257,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
// Add operand_segment_sizes attribute.
if (!operandSegmentLengths.empty()) {
int64_t size = operandSegmentLengths.size();
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
operandSegmentLengths.size(), operandSegmentLengths.data());
(*attributes)["operand_segment_sizes"] =
PyAttribute(context, segmentLengthAttr);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Bindings/Python/ods_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class TestOp(OpView):
# CHECK: %[[V2:.+]] = "custom.value"
# CHECK: %[[V3:.+]] = "custom.value"
# CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
# CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64>
# CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64>
# CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>
# CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi32>
# CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
op = TestOp.build_generic(
results=[[t0, t1], t2, t3],
Expand Down

0 comments on commit 8d05a28

Please sign in to comment.