Skip to content

Commit

Permalink
[CIR] Add initial support for complex types
Browse files Browse the repository at this point in the history
This patch adds an initial support for the C complex type, i.e. `_Complex`. It
introduces the following new types, attributes, and operations:

- `!cir.complex`, which represents the C complex number type;
- `#cir.imag`, which represents an imaginary number literal in C;
- `cir.complex.create`, which creates a complex number from its real and
  imaginary parts;
- `cir.complex.get_real`, which derives a pointer to the real part of a complex
  number given a pointer to the complex number;
- `cir.complex.get_imag`, which derives a pointer to the imaginary part of a
  complex number given a pointer to the complex number.

CIRGen for some basic complex number operations is also included in this patch.
  • Loading branch information
Lancern committed May 14, 2024
1 parent 3bad644 commit 5a93a8c
Show file tree
Hide file tree
Showing 18 changed files with 1,076 additions and 58 deletions.
29 changes: 29 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,35 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return getPointerTo(::mlir::cir::VoidType::get(getContext()), addressSpace);
}

mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
}

mlir::TypedAttr getZeroAttr(mlir::Type t) {
return mlir::cir::ZeroAttr::get(getContext(), t);
}

mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto complexType = ty.dyn_cast<mlir::cir::ComplexType>())
return mlir::cir::ImaginaryAttr::getZero(complexType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
return getConstPtrAttr(ptrTy, 0);
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
return getZeroAttr(structTy);
if (ty.isa<mlir::cir::BoolType>()) {
return getCIRBoolAttr(false);
}
llvm_unreachable("Zero initializer for given type is NYI");
}

mlir::Value createLoad(mlir::Location loc, mlir::Value ptr) {
return create<mlir::cir::LoadOp>(loc, ptr, /*isDeref=*/false,
/*is_volatile=*/false,
Expand Down
34 changes: 34 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,40 @@ def FPAttr : CIR_Attr<"FP", "fp", [TypedAttrInterface]> {
}];
}

//===----------------------------------------------------------------------===//
// ImaginaryAttr
//===----------------------------------------------------------------------===//

def ImaginaryAttr : CIR_Attr<"Imaginary", "imag", [TypedAttrInterface]> {
let summary = "An attribute containing an imaginary value";
let description = [{
A `#cir.imag` attribute is a literal attribute that represents an imaginary
number value of the specified complex type.

The `value` parameter gives the imaginary part of the complex constant
represented by the attribute.
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"TypedAttr":$value);

let builders = [
AttrBuilderWithInferredContext<(ins "Type":$type, "TypedAttr":$value), [{
return $_get(type.getContext(), type, value);
}]>,
];

let extraClassDeclaration = [{
static ImaginaryAttr getZero(Type type);
}];

let genVerifyDecl = 1;

let assemblyFormat = [{
`<` $value `>`
}];
}

//===----------------------------------------------------------------------===//
// ConstPointerAttr
//===----------------------------------------------------------------------===//
Expand Down
84 changes: 84 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,90 @@ def BinOpOverflowOp : CIR_Op<"binop.overflow", [Pure, SameTypeOperands]> {
];
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//

def ComplexCreateOp : CIR_Op<"complex.create", [Pure, SameTypeOperands]> {
let summary = "Create a complex value from its real and imaginary parts";
let description = [{
`cir.complex.create` operation takes two operands that represent the real
and imaginary part of a complex number, and yields the complex number.

Example:

```mlir
%0 = cir.const #cir.fp<1.000000e+00> : !cir.double
%1 = cir.const #cir.fp<2.000000e+00> : !cir.double
%2 = cir.complex.create %0, %1 : !cir.complex<!cir.double>
```
}];

let results = (outs CIR_ComplexType:$result);
let arguments = (ins CIR_AnyIntOrFloat:$real, CIR_AnyIntOrFloat:$imag);

let assemblyFormat = [{
$real `,` $imag
`:` qualified(type($real)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ComplexGetRealOp and ComplexGetImagOp
//===----------------------------------------------------------------------===//

def ComplexGetRealOp : CIR_Op<"complex.get_real", [Pure]> {
let summary = "Extract the real part of a complex value";
let description = [{
`cir.complex.get_real` operation takes a pointer operand that points to a
complex value of type `!cir.complex` and yields a pointer to the real part
of the operand.

Example:

```mlir
%1 = cir.complex.get_real %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
```
}];

let results = (outs PrimitiveIntOrFPPtr:$result);
let arguments = (ins ComplexPtr:$operand);

let assemblyFormat = [{
$operand `:`
qualified(type($operand)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

def ComplexGetImagOp : CIR_Op<"complex.get_imag", [Pure]> {
let summary = "Extract the imaginary part of a complex value";
let description = [{
`cir.complex.get_real` operation takes a pointer operand that points to a
complex value of type `!cir.complex` and yields a pointer to the imaginary
part of the operand.

Example:

```mlir
%1 = cir.complex.get_imag %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
```
}];

let results = (outs PrimitiveIntOrFPPtr:$result);
let arguments = (ins ComplexPtr:$operand);

let assemblyFormat = [{
$operand `:`
qualified(type($operand)) `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// BitsOp
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 35 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,32 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_LongDouble]>;
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

def CIR_ComplexType : CIR_Type<"Complex", "complex",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR complex type";
let description = [{
CIR type that represents a C complex number. `cir.complex` models the C type
`T _Complex`.

The parameter `elementTy` gives the type of the real and imaginary part of
the complex number. `elementTy` must be either a CIR integer type or a CIR
floating-point type.
}];

let parameters = (ins "mlir::Type":$elementTy);

let assemblyFormat = [{
`<` $elementTy `>`
}];

let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -401,6 +427,14 @@ def PrimitiveIntOrFPPtr : Type<
]>, "{int,void}*"> {
}

def ComplexPtr : Type<
And<[
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
CPred<"$_self.cast<::mlir::cir::PointerType>()"
".getPointee().isa<::mlir::cir::ComplexType>()">,
]>, "!cir.complex*"> {
}

// Pointer to struct
def StructPtr : Type<
And<[
Expand Down Expand Up @@ -475,7 +509,7 @@ def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_BoolType, CIR_ArrayType,
CIR_VectorType, CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo,
CIR_AnyFloat,
CIR_AnyFloat, CIR_ComplexType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
75 changes: 45 additions & 30 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::GlobalViewAttr::get(type, symbol, indices);
}

mlir::TypedAttr getZeroAttr(mlir::Type t) {
return mlir::cir::ZeroAttr::get(getContext(), t);
}

mlir::cir::BoolAttr getCIRBoolAttr(bool state) {
return mlir::cir::BoolAttr::get(getContext(), getBoolTy(), state);
}

mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
assert(t.isa<mlir::cir::PointerType>() && "expected cir.ptr");
return mlir::cir::ConstPtrAttr::get(getContext(), t, 0);
Expand Down Expand Up @@ -243,25 +235,6 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return mlir::cir::DataMemberAttr::get(getContext(), ty, std::nullopt);
}

mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
if (ty.isa<mlir::cir::IntType>())
return mlir::cir::IntAttr::get(ty, 0);
if (auto fltType = ty.dyn_cast<mlir::cir::SingleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto fltType = ty.dyn_cast<mlir::cir::DoubleType>())
return mlir::cir::FPAttr::getZero(fltType);
if (auto arrTy = ty.dyn_cast<mlir::cir::ArrayType>())
return getZeroAttr(arrTy);
if (auto ptrTy = ty.dyn_cast<mlir::cir::PointerType>())
return getConstPtrAttr(ptrTy, 0);
if (auto structTy = ty.dyn_cast<mlir::cir::StructType>())
return getZeroAttr(structTy);
if (ty.isa<mlir::cir::BoolType>()) {
return getCIRBoolAttr(false);
}
llvm_unreachable("Zero initializer for given type is NYI");
}

// TODO(cir): Once we have CIR float types, replace this by something like a
// NullableValueInterface to allow for type-independent queries.
bool isNullValue(mlir::Attribute attr) const {
Expand Down Expand Up @@ -722,6 +695,46 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::GetMemberOp>(loc, result, base, name, index);
}

mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
mlir::Value imag) {
auto resultComplexTy =
mlir::cir::ComplexType::get(getContext(), real.getType());
return create<mlir::cir::ComplexCreateOp>(loc, resultComplexTy, real, imag);
}

/// Create a cir.complex.get_real operation that derives a pointer to the real
/// part of the complex value pointed to by the specified pointer value.
mlir::Value createGetReal(mlir::Location loc, mlir::Value value) {
auto srcComplexElemTy = value.getType()
.cast<mlir::cir::PointerType>()
.getPointee()
.cast<mlir::cir::ComplexType>()
.getElementTy();
return create<mlir::cir::ComplexGetRealOp>(
loc, getPointerTo(srcComplexElemTy), value);
}

Address createGetReal(mlir::Location loc, Address addr) {
return Address{createGetReal(loc, addr.getPointer()), addr.getAlignment()};
}

/// Create a cir.complex.get_imag operation that derives a pointer to the
/// imaginary part of the complex value pointed to by the specified pointer
/// value.
mlir::Value createGetImag(mlir::Location loc, mlir::Value value) {
auto srcComplexElemTy = value.getType()
.cast<mlir::cir::PointerType>()
.getPointee()
.cast<mlir::cir::ComplexType>()
.getElementTy();
return create<mlir::cir::ComplexGetImagOp>(
loc, getPointerTo(srcComplexElemTy), value);
}

Address createGetImag(mlir::Location loc, Address addr) {
return Address{createGetImag(loc, addr.getPointer()), addr.getAlignment()};
}

/// Cast the element type of the given address to a different type,
/// preserving information like the alignment.
cir::Address createElementBitCast(mlir::Location loc, cir::Address addr,
Expand All @@ -734,14 +747,16 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
addr.getAlignment());
}

mlir::Value createLoad(mlir::Location loc, Address addr) {
mlir::Value createLoad(mlir::Location loc, Address addr,
bool isVolatile = false) {
auto ptrTy = addr.getPointer().getType().dyn_cast<mlir::cir::PointerType>();
if (addr.getElementType() != ptrTy.getPointee())
addr = addr.withPointer(
createPtrBitcast(addr.getPointer(), addr.getElementType()));

return create<mlir::cir::LoadOp>(loc, addr.getElementType(),
addr.getPointer());
return create<mlir::cir::LoadOp>(
loc, addr.getElementType(), addr.getPointer(), /*isDeref=*/false,
/*is_volatile=*/isVolatile, /*mem_order=*/mlir::cir::MemOrderAttr{});
}

mlir::Value createAlignedLoad(mlir::Location loc, mlir::Type ty,
Expand Down
6 changes: 5 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,11 @@ void CIRGenFunction::buildExprAsInit(const Expr *init, const ValueDecl *D,
buildScalarInit(init, getLoc(D->getSourceRange()), lvalue);
return;
case TEK_Complex: {
assert(0 && "not implemented");
mlir::Value complex = buildComplexExpr(init);
if (capturedByInit)
llvm_unreachable("NYI");
buildStoreOfComplex(getLoc(init->getExprLoc()), complex, lvalue,
/*init*/ true);
return;
}
case TEK_Aggregate:
Expand Down

0 comments on commit 5a93a8c

Please sign in to comment.