Skip to content

Commit

Permalink
[flang][CodeGen] Fix use-after-free in BoxedProcedurePass (#84376)
Browse files Browse the repository at this point in the history
Avoid inspecting an operation that has been replaced.

This was detected by address sanitizer.
  • Loading branch information
kparzysz committed Mar 11, 2024
1 parent aec9283 commit 546f32d
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class BoxedProcedurePass
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
mlir::Dialect *firDialect = context->getLoadedDialect("fir");
getModule().walk([&](mlir::Operation *op) {
bool opIsValid = true;
typeConverter.setLocation(op->getLoc());
if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
mlir::Type ty = addr.getVal().getType();
Expand All @@ -220,6 +221,7 @@ class BoxedProcedurePass
rewriter.setInsertionPoint(addr);
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
opIsValid = false;
} else if (typeConverter.needsConversion(resTy)) {
rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
Expand Down Expand Up @@ -271,10 +273,12 @@ class BoxedProcedurePass
llvm::ArrayRef<mlir::Value>{tramp});
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
adjustCall.getResult(0));
opIsValid = false;
} else {
// Just forward the function as a pointer.
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
embox.getFunc());
opIsValid = false;
}
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty = global.getType();
Expand All @@ -297,6 +301,7 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<AllocaOp>(
mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
mem.getShape());
opIsValid = false;
}
} else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
auto ty = mem.getType();
Expand All @@ -310,6 +315,7 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<AllocMemOp>(
mem, toTy, uniqName, bindcName, mem.getTypeparams(),
mem.getShape());
opIsValid = false;
}
} else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
auto ty = coor.getType();
Expand All @@ -321,6 +327,7 @@ class BoxedProcedurePass
auto toBaseTy = typeConverter.convertType(baseTy);
rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
coor.getCoor(), toBaseTy);
opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
auto ty = index.getType();
Expand All @@ -332,6 +339,7 @@ class BoxedProcedurePass
auto toOnTy = typeConverter.convertType(onTy);
rewriter.replaceOpWithNewOp<FieldIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
auto ty = index.getType();
Expand All @@ -343,6 +351,7 @@ class BoxedProcedurePass
auto toOnTy = typeConverter.convertType(onTy);
rewriter.replaceOpWithNewOp<LenParamIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
opIsValid = false;
}
} else if (op->getDialect() == firDialect) {
rewriter.startOpModification(op);
Expand All @@ -354,7 +363,7 @@ class BoxedProcedurePass
rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
if (op->getNumRegions() != 0) {
if (opIsValid && op->getNumRegions() != 0) {
rewriter.startOpModification(op);
for (mlir::Region &region : op->getRegions())
for (mlir::Block &block : region.getBlocks())
Expand Down

0 comments on commit 546f32d

Please sign in to comment.