Skip to content

Commit

Permalink
[mlir][capi] Add TypeID to MLIR C-API
Browse files Browse the repository at this point in the history
Exposes mlir::TypeID to the C API as MlirTypeID along with various accessors
and helper functions.

Differential Revision: https://reviews.llvm.org/D110897
  • Loading branch information
trilorez committed Oct 1, 2021
1 parent 4cdee8d commit 782a97a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mlir/include/mlir-c/IR.h
Expand Up @@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
DEFINE_C_API_STRUCT(MlirType, const void);
DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirValue, const void);

#undef DEFINE_C_API_STRUCT
Expand Down Expand Up @@ -356,6 +357,11 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op,
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);

/// Gets the type id of the operation.
/// Returns null if the operation does not have a registered operation
/// description.
MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op);

/// Gets the name of the operation as an identifier.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op);

Expand Down Expand Up @@ -626,6 +632,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context,
/// Gets the context that a type was created with.
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);

/// Gets the type ID of the type.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);

/// Checks whether a type is null.
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }

Expand Down Expand Up @@ -655,6 +664,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute);
/// Gets the type of this attribute.
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);

/// Gets the type id of the attribute.
MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);

/// Checks whether an attribute is null.
static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }

Expand Down Expand Up @@ -693,6 +705,21 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
/// Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);

//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//

/// Checks whether a type id is null.
MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) {
return !typeID.ptr;
}

/// Checks if two type ids are equal.
MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);

/// Returns the hash value of the type id.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);

#ifdef __cplusplus
}
#endif
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/CAPI/IR.h
Expand Up @@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_METHODS(MlirValue, mlir::Value)

#endif // MLIR_CAPI_IR_H
28 changes: 28 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Parser.h"

#include "llvm/Support/Debug.h"
#include <cstddef>

using namespace mlir;

Expand Down Expand Up @@ -345,6 +346,13 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}

MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) {
return wrap(abstractOp->typeID);
}
return {nullptr};
}

MlirIdentifier mlirOperationGetName(MlirOperation op) {
return wrap(unwrap(op)->getName().getIdentifier());
}
Expand Down Expand Up @@ -658,6 +666,10 @@ MlirContext mlirTypeGetContext(MlirType type) {
return wrap(unwrap(type).getContext());
}

MlirTypeID mlirTypeGetTypeID(MlirType type) {
return wrap(unwrap(type).getTypeID());
}

bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}
Expand Down Expand Up @@ -685,6 +697,10 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) {
return wrap(unwrap(attribute).getType());
}

MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
return wrap(unwrap(attr).getTypeID());
}

bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
Expand Down Expand Up @@ -721,3 +737,15 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
return wrap(unwrap(ident).strref());
}

//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//

bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
return unwrap(typeID1) == unwrap(typeID2);
}

size_t mlirTypeIDHashValue(MlirTypeID typeID) {
return hash_value(unwrap(typeID));
}
96 changes: 96 additions & 0 deletions mlir/test/CAPI/ir.c
Expand Up @@ -1739,6 +1739,99 @@ void testDiagnostics() {
// CHECK: more test diagnostics
}

int testTypeID(MlirContext ctx) {
fprintf(stderr, "@testTypeID\n");

// Test getting and comparing type and attribute type ids.
MlirType i32 = mlirIntegerTypeGet(ctx, 32);
MlirTypeID i32ID = mlirTypeGetTypeID(i32);
MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
MlirType f32 = mlirF32TypeGet(ctx);
MlirTypeID f32ID = mlirTypeGetTypeID(f32);
MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);

if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
fprintf(stderr, "ERROR: Expected type ids to be present\n");
return 1;
}

if (!mlirTypeIDEqual(i32ID, ui32ID) ||
mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
fprintf(
stderr,
"ERROR: Expected different integer types to have the same type id\n");
return 2;
}

if (mlirTypeIDEqual(i32ID, f32ID) ||
mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(f32ID)) {
fprintf(stderr,
"ERROR: Expected integer type id to not equal float type id\n");
return 3;
}

if (mlirTypeIDEqual(i32ID, i32AttrID) ||
mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(i32AttrID)) {
fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
"attribute type id\n");
return 4;
}

MlirLocation loc = mlirLocationUnknownGet(ctx);
MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");

// Create a registered operation, which should have a type id.
MlirAttribute indexZeroLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
MlirOperationState constZeroState = mlirOperationStateGet(
mlirStringRefCreateFromCString("std.constant"), loc);
mlirOperationStateAddResults(&constZeroState, 1, &indexType);
mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
MlirOperation constZero = mlirOperationCreate(&constZeroState);

if (mlirOperationIsNull(constZero)) {
fprintf(stderr, "ERROR: Expected registered operation to be present\n");
return 5;
}

MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);

if (mlirTypeIDIsNull(registeredOpID)) {
fprintf(stderr,
"ERROR: Expected registered operation type id to be present\n");
return 6;
}

// Create an unregistered operation, which should not have a type id.
mlirContextSetAllowUnregisteredDialects(ctx, true);
MlirOperationState opState =
mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
MlirOperation unregisteredOp = mlirOperationCreate(&opState);
if (mlirOperationIsNull(unregisteredOp)) {
fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
return 7;
}

MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);

if (!mlirTypeIDIsNull(unregisteredOpID)) {
fprintf(stderr,
"ERROR: Expected unregistered operation type id to be null\n");
return 8;
}

mlirOperationDestroy(constZero);
mlirOperationDestroy(unregisteredOp);

return 0;
}

int main() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
Expand Down Expand Up @@ -1768,6 +1861,9 @@ int main() {
return 11;
if (testClone())
return 12;
if (testTypeID(ctx)) {
return 13;
}

mlirContextDestroy(ctx);

Expand Down

0 comments on commit 782a97a

Please sign in to comment.