diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp index a571b5333682..eba15585ef3a 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -49,6 +49,12 @@ namespace IREE = mlir::iree_compiler::IREE; // Utils. //===----------------------------------------------------------------------===// +static Type getComplexElementTypeOrSelf(Type ty) { + if (auto complex = dyn_cast_or_null(ty)) + return complex.getElementType(); + return ty; +} + static void getEffectsImpl( SmallVectorImpl> &effects, @@ -218,7 +224,8 @@ LogicalResult ScatterOp::verify() { } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); - if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { + if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() || + !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) { return op->emitOpError( "expected region to have scalar argument of integer or float types"); }