Skip to content

Commit

Permalink
Loosen requirements on iree_linalg_ext.scatter for complex types (ire…
Browse files Browse the repository at this point in the history
…e-org#13055)

We support complex regions in the `iree_linalg_ext.scatter` just loosen
requirements so that we do not fail during successful cases.
  • Loading branch information
rsuderman committed Apr 12, 2023
1 parent 43a9b5a commit 862f414
Showing 1 changed file with 8 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ namespace IREE = mlir::iree_compiler::IREE;
// Utils.
//===----------------------------------------------------------------------===//

static Type getComplexElementTypeOrSelf(Type ty) {
if (auto complex = dyn_cast_or_null<ComplexType>(ty))
return complex.getElementType();
return ty;
}

static void getEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
Expand Down Expand Up @@ -218,7 +224,8 @@ LogicalResult ScatterOp::verify() {
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) {
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
return op->emitOpError(
"expected region to have scalar argument of integer or float types");
}
Expand Down

0 comments on commit 862f414

Please sign in to comment.