Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions mlir/lib/Bindings/Python/DialectLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);

llvmStructType.def_classmethod(
"get_literal",
[](const nb::object &cls, const std::vector<MlirType> &elements,
bool packed, MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));

MlirType type = mlirLLVMStructTypeLiteralGetChecked(
loc, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"loc"_a = nb::none());
llvmStructType
.def_classmethod(
"get_literal",
[](const nb::object &cls, const std::vector<MlirType> &elements,
bool packed, MlirLocation loc) {
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));

MlirType type = mlirLLVMStructTypeLiteralGetChecked(
loc, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"loc"_a = nb::none())
.def_classmethod(
"get_literal_unchecked",
[](const nb::object &cls, const std::vector<MlirType> &elements,
bool packed, MlirContext context) {
CollectDiagnosticsToStringScope scope(context);

MlirType type = mlirLLVMStructTypeLiteralGet(
context, elements.size(), elements.data(), packed);
if (mlirTypeIsNull(type)) {
throw nb::value_error(scope.takeMessage().c_str());
}
return cls(type);
},
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
"context"_a = nb::none());

llvmStructType.def_classmethod(
"get_identified",
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
},
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_unchecked",
[](PyType &type, double value, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute attr =
mlirFloatAttrDoubleGet(context.get()->get(), type, value);
if (mlirAttributeIsNull(attr))
throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
"Gets an uniqued float point attribute associated to a type");
c.def_static(
"get_f32",
[](double value, DefaultingPyMlirContext context) {
Expand Down
116 changes: 111 additions & 5 deletions mlir/lib/Bindings/Python/IRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
using PyConcreteType::PyConcreteType;

static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, nb::arg("shape"),
c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
nb::arg("loc") = nb::none(), "Create a vector type")
.def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable") = nb::none(),
nb::arg("scalable_dims") = nb::none(),
nb::arg("context") = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
Expand All @@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
}

private:
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
std::optional<nb::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
static PyVectorType
getChecked(std::vector<int64_t> shape, PyType &elementType,
std::optional<nb::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
Expand Down Expand Up @@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}

static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
std::optional<nb::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyMlirContext context) {
if (scalable && scalableDims) {
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
}

PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type;
if (scalable) {
if (scalable->size() != shape.size())
throw nb::value_error("Expected len(scalable) == len(shape).");

SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
*scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
scalableDimFlags.data(), elementType);
} else if (scalableDims) {
SmallVector<bool> scalableDimFlags(shape.size(), false);
for (int64_t dim : *scalableDims) {
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
throw nb::value_error("Scalable dimension index out of bounds.");
scalableDimFlags[dim] = true;
}
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
scalableDimFlags.data(), elementType);
} else {
type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
}
if (mlirTypeIsNull(type))
throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), type);
}
};

/// Ranked Tensor Type subclass - RankedTensorType.
Expand Down Expand Up @@ -724,6 +766,22 @@ class PyRankedTensorType
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
"Create a ranked tensor type");
c.def_static(
"get_unchecked",
[](std::vector<int64_t> shape, PyType &elementType,
std::optional<PyAttribute> &encodingAttr,
DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType t = mlirRankedTensorTypeGet(
shape.size(), shape.data(), elementType,
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
if (mlirTypeIsNull(t))
throw MLIRError("Invalid type", errors.take());
return PyRankedTensorType(elementType.getContext(), t);
},
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
"Create a ranked tensor type");
c.def_prop_ro(
"encoding",
[](PyRankedTensorType &self)
Expand Down Expand Up @@ -758,6 +816,17 @@ class PyUnrankedTensorType
},
nb::arg("element_type"), nb::arg("loc") = nb::none(),
"Create a unranked tensor type");
c.def_static(
"get_unchecked",
[](PyType &elementType, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType t = mlirUnrankedTensorTypeGet(elementType);
if (mlirTypeIsNull(t))
throw MLIRError("Invalid type", errors.take());
return PyUnrankedTensorType(elementType.getContext(), t);
},
nb::arg("element_type"), nb::arg("context") = nb::none(),
"Create a unranked tensor type");
}
};

Expand Down Expand Up @@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
nb::arg("loc") = nb::none(), "Create a memref type")
.def_static(
"get_unchecked",
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute layoutAttr =
layout ? *layout : mlirAttributeGetNull();
MlirAttribute memSpaceAttr =
memorySpace ? *memorySpace : mlirAttributeGetNull();
MlirType t =
mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
layoutAttr, memSpaceAttr);
if (mlirTypeIsNull(t))
throw MLIRError("Invalid type", errors.take());
return PyMemRefType(elementType.getContext(), t);
},
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout") = nb::none(),
nb::arg("memory_space") = nb::none(),
nb::arg("context") = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
Expand Down Expand Up @@ -858,6 +948,22 @@ class PyUnrankedMemRefType
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc") = nb::none(), "Create a unranked memref type")
.def_static(
"get_unchecked",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute memSpaceAttr = {};
if (memorySpace)
memSpaceAttr = *memorySpace;

MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
if (mlirTypeIsNull(t))
throw MLIRError("Invalid type", errors.take());
return PyUnrankedMemRefType(elementType.getContext(), t);
},
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("context") = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self)
Expand Down
11 changes: 8 additions & 3 deletions mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,16 @@ def testAbstractShapedType():
# CHECK-LABEL: TEST: testVectorType
@run
def testVectorType():
shape = [2, 3]
with Context():
f32 = F32Type.get()
# CHECK: unchecked vector type: vector<2x3xf32>
print("unchecked vector type:", VectorType.get_unchecked(shape, f32))

with Context(), Location.unknown():
f32 = F32Type.get()
shape = [2, 3]
# CHECK: vector type: vector<2x3xf32>
print("vector type:", VectorType.get(shape, f32))
# CHECK: checked vector type: vector<2x3xf32>
print("checked vector type:", VectorType.get(shape, f32))

none = NoneType.get()
try:
Expand Down