Skip to content

Commit

Permalink
analyze: rewrite: downgrade Option before unwrapping for deref
Browse files Browse the repository at this point in the history
  • Loading branch information
spernsteiner committed May 15, 2024
1 parent b240afd commit a6cc054
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 25 deletions.
74 changes: 50 additions & 24 deletions c2rust-analyze/src/rewrite/expr/mir_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::type_desc::{self, Ownership, Quantity, TypeDesc};
use crate::util::{self, ty_callee, Callee};
use rustc_ast::Mutability;
use rustc_middle::mir::{
BasicBlock, Body, Location, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
BasicBlock, Body, BorrowKind, Location, Operand, Place, PlaceElem, PlaceRef, Rvalue, Statement,
StatementKind, Terminator, TerminatorKind,
};
use rustc_middle::ty::print::FmtPrinter;
Expand Down Expand Up @@ -339,7 +339,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
self.enter_rvalue(|v| v.visit_rvalue(rv, Some(rv_lty)));
// The cast from `rv_lty` to `pl_lty` should be applied to the RHS.
self.enter_rvalue(|v| v.emit_cast_lty_lty(rv_lty, pl_lty));
self.enter_dest(|v| v.visit_place(pl));
self.enter_dest(|v| v.visit_place(pl, true));
}
StatementKind::FakeRead(..) => {}
StatementKind::SetDiscriminant { .. } => todo!("statement {:?}", stmt),
Expand Down Expand Up @@ -553,14 +553,18 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
Rvalue::Repeat(ref op, _) => {
self.enter_rvalue_operand(0, |v| v.visit_operand(op, None));
}
Rvalue::Ref(_rg, _kind, pl) => {
self.enter_rvalue_place(0, |v| v.visit_place(pl));
Rvalue::Ref(_rg, kind, pl) => {
let mutbl = match kind {
BorrowKind::Mut { .. } => true,
BorrowKind::Shared | BorrowKind::Shallow | BorrowKind::Unique => false,
};
self.enter_rvalue_place(0, |v| v.visit_place(pl, mutbl));
}
Rvalue::ThreadLocalRef(_def_id) => {
// TODO
}
Rvalue::AddressOf(mutbl, pl) => {
self.enter_rvalue_place(0, |v| v.visit_place(pl));
self.enter_rvalue_place(0, |v| v.visit_place(pl, mutbl == Mutability::Mut));
if let Some(expect_ty) = expect_ty {
let desc = type_desc::perms_to_desc_with_pointee(
self.acx.tcx(),
Expand All @@ -579,7 +583,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
}
}
Rvalue::Len(pl) => {
self.enter_rvalue_place(0, |v| v.visit_place(pl));
self.enter_rvalue_place(0, |v| v.visit_place(pl, false));
}
Rvalue::Cast(_kind, ref op, ty) => {
if util::is_null_const_operand(op) && ty.is_unsafe_ptr() {
Expand Down Expand Up @@ -640,7 +644,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
self.enter_rvalue_operand(0, |v| v.visit_operand(op, None));
}
Rvalue::Discriminant(pl) => {
self.enter_rvalue_place(0, |v| v.visit_place(pl));
self.enter_rvalue_place(0, |v| v.visit_place(pl, false));
}
Rvalue::Aggregate(ref _kind, ref ops) => {
for (i, op) in ops.iter().enumerate() {
Expand All @@ -651,7 +655,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
self.enter_rvalue_operand(0, |v| v.visit_operand(op, None));
}
Rvalue::CopyForDeref(pl) => {
self.enter_rvalue_place(0, |v| v.visit_place(pl));
self.enter_rvalue_place(0, |v| v.visit_place(pl, false));
}
}
}
Expand All @@ -661,7 +665,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
fn visit_operand(&mut self, op: &Operand<'tcx>, expect_ty: Option<LTy<'tcx>>) {
match *op {
Operand::Copy(pl) | Operand::Move(pl) => {
self.enter_operand_place(|v| v.visit_place(pl));
self.enter_operand_place(|v| v.visit_place(pl, false));

if let Some(expect_ty) = expect_ty {
let ptr_lty = self.acx.type_of(pl);
Expand All @@ -678,7 +682,7 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
fn visit_operand_desc(&mut self, op: &Operand<'tcx>, expect_desc: TypeDesc<'tcx>) {
match *op {
Operand::Copy(pl) | Operand::Move(pl) => {
self.visit_place(pl);
self.visit_place(pl, false);

let ptr_lty = self.acx.type_of(pl);
if !ptr_lty.label.is_none() {
Expand All @@ -689,19 +693,25 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
}
}

fn visit_place(&mut self, pl: Place<'tcx>) {
fn visit_place(&mut self, pl: Place<'tcx>, in_mutable_context: bool) {
let mut ltys = Vec::with_capacity(1 + pl.projection.len());
ltys.push(self.acx.type_of(pl.local));
for proj in pl.projection {
let prev_lty = ltys.last().copied().unwrap();
ltys.push(self.acx.projection_lty(prev_lty, &proj));
}
self.visit_place_ref(pl.as_ref(), &ltys);
self.visit_place_ref(pl.as_ref(), &ltys, in_mutable_context);
}

/// Generate rewrites for a `Place` represented as a `PlaceRef`. `proj_ltys` gives the `LTy`
/// for the `Local` and after each projection.
fn visit_place_ref(&mut self, pl: PlaceRef<'tcx>, proj_ltys: &[LTy<'tcx>]) {
/// for the `Local` and after each projection. `in_mutable_context` is `true` if the `Place`
/// is in a mutable context, such as the LHS of an assignment.
fn visit_place_ref(
&mut self,
pl: PlaceRef<'tcx>,
proj_ltys: &[LTy<'tcx>],
in_mutable_context: bool,
) {
let (&last_proj, rest) = match pl.projection.split_last() {
Some(x) => x,
None => return,
Expand All @@ -720,17 +730,36 @@ impl<'a, 'tcx> ExprRewriteVisitor<'a, 'tcx> {
match last_proj {
PlaceElem::Deref => {
self.enter_place_deref_pointer(|v| {
v.visit_place_ref(base_pl, proj_ltys);
if !v.perms[base_lty.label].contains(PermissionSet::NON_NULL) {
v.visit_place_ref(base_pl, proj_ltys, in_mutable_context);
if !v.perms[base_lty.label].contains(PermissionSet::NON_NULL)
&& !v.flags[base_lty.label].contains(FlagSet::FIXED)
{
// If the pointer type is non-copy, downgrade (borrow) before calling
// `unwrap()`.
let desc = type_desc::perms_to_desc(
base_lty.ty,
v.perms[base_lty.label],
v.flags[base_lty.label],
);
if !desc.own.is_copy() {
v.emit(RewriteKind::OptionDowngrade {
mutbl: in_mutable_context,
deref: true,
});
}
v.emit(RewriteKind::OptionUnwrap);
}
});
}
PlaceElem::Field(_idx, _ty) => {
self.enter_place_field_base(|v| v.visit_place_ref(base_pl, proj_ltys));
self.enter_place_field_base(|v| {
v.visit_place_ref(base_pl, proj_ltys, in_mutable_context)
});
}
PlaceElem::Index(_) | PlaceElem::ConstantIndex { .. } | PlaceElem::Subslice { .. } => {
self.enter_place_index_array(|v| v.visit_place_ref(base_pl, proj_ltys));
self.enter_place_index_array(|v| {
v.visit_place_ref(base_pl, proj_ltys, in_mutable_context)
});
}
PlaceElem::Downcast(_, _) => {}
}
Expand Down Expand Up @@ -934,11 +963,8 @@ where
// moving/consuming the input. For example, if the `from` type is `Option<Box<T>>` and
// `to` is `&mut T`, we start by calling `p.as_mut().as_deref()`, which gives
// `Option<&mut T>` without consuming `p`.
match from.own {
Ownership::Raw | Ownership::RawMut | Ownership::Imm | Ownership::Cell => {
// No-op. The `from` type is `Copy`, so we can unwrap it without consequence.
}
Ownership::Mut | Ownership::Rc | Ownership::Box => match to.own {
if !from.own.is_copy() {
match to.own {
Ownership::Raw | Ownership::Imm => {
(self.emit)(RewriteKind::OptionDowngrade {
mutbl: false,
Expand All @@ -956,7 +982,7 @@ where
_ => {
// Remaining cases are unsupported.
}
},
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions c2rust-analyze/src/type_desc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ impl PtrDesc {
}
}

impl Ownership {
pub fn is_copy(&self) -> bool {
match *self {
Ownership::Raw | Ownership::RawMut | Ownership::Imm | Ownership::Cell => true,
Ownership::Mut | Ownership::Rc | Ownership::Box => false,
}
}
}

fn perms_to_ptr_desc(perms: PermissionSet, flags: FlagSet) -> PtrDesc {
let own = if perms.contains(PermissionSet::UNIQUE | PermissionSet::WRITE) {
Ownership::Mut
Expand Down
2 changes: 1 addition & 1 deletion c2rust-analyze/tests/filecheck/non_null_rewrites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ unsafe fn call_use_mut(cond: bool) -> i32 {
// CHECK-SAME: p: core::option::Option<&{{('[^ ]* )?}}mut (i32)>
unsafe fn use_mut(p: *mut i32) -> i32 {
if !p.is_null() {
// CHECK: *(p).unwrap() = 1;
// CHECK: *(p).as_deref_mut().unwrap() = 1;
*p = 1;
}
// CHECK: use_const
Expand Down

0 comments on commit a6cc054

Please sign in to comment.