diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3a23bbfd70eac..c7fdef74d33bf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3789,6 +3789,27 @@ TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { // FunctionOpInterfaceSignatureConversion //===----------------------------------------------------------------------===// +static SmallVector +convertFuncOpAttrs(FunctionOpInterface funcOp, + TypeConverter::SignatureConversion &sigConv, + FunctionType newType) { + if (newType.getNumInputs() == funcOp.getNumArguments()) { + return {}; + } + ArrayAttr allArgAttrs = funcOp.getAllArgAttrs(); + if (!allArgAttrs) + return {}; + + SmallVector newAttrs(newType.getNumInputs()); + for (auto i : llvm::seq(allArgAttrs.size())) { + auto mapping = sigConv.getInputMapping(i); + assert(mapping.has_value()); + auto outIdx = mapping->inputNo; + newAttrs[outIdx] = allArgAttrs[i]; + } + return newAttrs; +} + static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { @@ -3809,7 +3830,16 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, auto newType = FunctionType::get(rewriter.getContext(), result.getConvertedTypes(), newResults); - rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); + // If using 1-to-n type conversion, we must re-map argument attributes + // to the corresponding new argument index. + auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType); + + rewriter.modifyOpInPlace(funcOp, [&] { + funcOp.setType(newType); + if (!newArgAttrs.empty()) { + funcOp.setAllArgAttrs(newArgAttrs); + } + }); return success(); }