diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index f878672cd912a..f4bb79362ffd3 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -230,6 +230,11 @@ int64_t ShapedType::getSizeInBits() const { if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); + if (auto complexType = elementType.dyn_cast()) { + elementType = complexType.getElementType(); + return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; + } + // Tensors can have vectors and other tensors as elements, other shaped types // cannot. assert(isa() && "unsupported element type"); diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir index ec4f4dcf7dae4..3312fd54811b2 100644 --- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir +++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir @@ -5,9 +5,14 @@ func @verifyDerivedAttributes() { // expected-remark @+2 {{element_dtype = f32}} // expected-remark @+1 {{size = 320}} %0 = "test.derived_type_attr"() : () -> tensor<10xf32> + // expected-remark @+2 {{element_dtype = i79}} // expected-remark @+1 {{size = 948}} %1 = "test.derived_type_attr"() : () -> tensor<12xi79> + // expected-remark @+2 {{element_dtype = complex}} + // expected-remark @+1 {{size = 768}} + %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex> + return }