Skip to content

Commit

Permalink
[mlir][Python] Finish adding RankedTensorType support for encoding.
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D102184
  • Loading branch information
stellaraccident committed May 10, 2021
1 parent 463ea28 commit a2c8aeb
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 4 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir-c/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(
MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType,
MlirAttribute encoding);

/// Gets the 'encoding' attribute from the ranked tensor type, returning a null
/// attribute if none.
MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type);

/// Creates an unranked tensor type with the given element type in the same
/// context as the element type. The type is owned by the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
Expand Down
16 changes: 13 additions & 3 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,10 +338,11 @@ class PyRankedTensorType
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
llvm::Optional<PyAttribute> &encodingAttr,
DefaultingPyLocation loc) {
MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirType t = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(), elementType, encodingAttr);
loc, shape.size(), shape.data(), elementType,
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
Expand All @@ -355,8 +356,17 @@ class PyRankedTensorType
}
return PyRankedTensorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
py::arg("shape"), py::arg("element_type"),
py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
"Create a ranked tensor type");
c.def_property_readonly(
"encoding",
[](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return llvm::None;
return PyAttribute(self.getContext(), encoding);
});
}
};

Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
unwrap(elementType), unwrap(encoding)));
}

MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
}

MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ static int printBuiltinTypes(MlirContext ctx) {
MlirType rankedTensor = mlirRankedTensorTypeGet(
sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull());
if (!mlirTypeIsATensor(rankedTensor) ||
!mlirTypeIsARankedTensor(rankedTensor))
!mlirTypeIsARankedTensor(rankedTensor) ||
!mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor)))
return 16;
mlirTypeDump(rankedTensor);
fprintf(stderr, "\n");
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/python/dialects/sparse_tensor/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,18 @@ def testEncodingAttr2D():
print(created)
# CHECK: created_equal: True
print(f"created_equal: {created == casted}")


# CHECK-LABEL: TEST: testEncodingAttrOnTensor
@run
def testEncodingAttrOnTensor():
with Context() as ctx, Location.unknown():
encoding = st.EncodingAttr(Attribute.parse(
'#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
'pointerBitWidth = 16, indexBitWidth = 32 }>'))
tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
# CHECK: tensor<1024xf32, #sparse_tensor
print(tt)
# CHECK: #sparse_tensor.encoding
print(tt.encoding)
assert tt.encoding == encoding
3 changes: 3 additions & 0 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ def testRankedTensorType():
else:
print("Exception not produced")

# Encoding should be None.
assert RankedTensorType.get(shape, f32).encoding is None


# CHECK-LABEL: TEST: testUnrankedTensorType
@run
Expand Down

0 comments on commit a2c8aeb

Please sign in to comment.