Skip to content

Commit

Permalink
Fix match expressions not dropping all values
Browse files Browse the repository at this point in the history
When matching against owned values, depending on the pattern(s) used,
not all values extracted as part of a match would be dropped. This could
then lead to dangling reference panics further down the line.

This fixes #563.

Changelog: fixed
  • Loading branch information
yorickpeterse committed Jun 5, 2023
1 parent c5af219 commit a5a10f1
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 73 deletions.
180 changes: 107 additions & 73 deletions compiler/src/mir/passes.rs
Expand Up @@ -274,6 +274,22 @@ impl Scope {
}
}

/// A type describing the action to take when destructuring an object as part of
/// a pattern.
#[derive(Copy, Clone)]
enum RegisterAction {
/// A field is to be moved into a new register.
///
/// The wrapped value is the register that owned the field.
Move(RegisterId),

/// A field is to be incremented, and the reference moved into a new
/// register.
///
/// The wrapped value is the register that owned the field.
Increment(RegisterId),
}

struct DecisionState {
/// The register to write the results of a case body to.
output: RegisterId,
Expand All @@ -285,9 +301,9 @@ struct DecisionState {
/// the variables.
registers: Vec<RegisterId>,

/// Registers for which the reference count should be incremented when they
/// are bound to a variable.
increment: HashSet<RegisterId>,
/// The action to take per register when destructuring a value such as an
/// enum variant of class.
actions: HashMap<RegisterId, RegisterAction>,

/// The basic blocks for every case body, and the code to compile for them.
bodies: HashMap<
Expand All @@ -302,9 +318,6 @@ struct DecisionState {
/// The location of the `match` expression.
location: LocationId,

/// If the original input value is owned or not.
owned: bool,

/// If the result of a match arm should be written to a register or ignored.
write_result: bool,
}
Expand All @@ -313,18 +326,16 @@ impl DecisionState {
fn new(
output: RegisterId,
after_block: BlockId,
owned: bool,
write_result: bool,
location: LocationId,
) -> Self {
Self {
output,
after_block,
registers: Vec::new(),
increment: HashSet::new(),
actions: HashMap::new(),
bodies: HashMap::new(),
guards: HashMap::new(),
owned,
location,
write_result,
}
Expand Down Expand Up @@ -2411,13 +2422,8 @@ impl<'a> LowerMethod<'a> {
let input_var = vars.new_variable(input_type);
let after_block = self.add_block();
let loc = self.add_location(node.location.clone());
let mut state = DecisionState::new(
output_reg,
after_block,
input_type.is_owned_or_uni(self.db()),
node.write_result,
loc,
);
let mut state =
DecisionState::new(output_reg, after_block, node.write_result, loc);

for case in node.cases {
let var_regs = self.match_binding_registers(case.variable_ids);
Expand Down Expand Up @@ -2662,48 +2668,57 @@ impl<'a> LowerMethod<'a> {
self.mark_register_as_moved(source);
self.add_drop_flag(target, loc);

if state.increment.contains(&source) {
let typ = self.register_type(source);

if typ.is_value_type(self.db()) {
let copy =
self.clone_value_type(source, typ, false, loc);

self.mark_register_as_moved(copy);
match state.actions.get(&source) {
Some(&RegisterAction::Move(parent)) => {
// We mark the parent as _partially_ moved so we can
// still deallocate it, but know not to run its
// destructor.
self.mark_register_as_partially_moved(parent);
self.current_block_mut()
.move_register(target, copy, loc);
} else {
.move_register(target, source, loc);
}
Some(&RegisterAction::Increment(_)) => {
let typ = self.register_type(source);

if typ.is_value_type(self.db()) {
let copy = self
.clone_value_type(source, typ, false, loc);

self.mark_register_as_moved(copy);
self.current_block_mut()
.move_register(target, copy, loc);
} else {
self.current_block_mut()
.reference(target, source, loc);
}
}
None => {
self.current_block_mut()
.reference(target, source, loc);
.move_register(target, source, loc);
}
} else {
self.current_block_mut()
.move_register(target, source, loc);
}
}
pmatch::Binding::Ignored(pvar) => {
let reg = state.registers[pvar.0];

self.mark_register_as_moved(reg);

if self.register_type(reg).is_permanent(self.db()) {
continue;
}
match state.actions.get(&reg) {
Some(&RegisterAction::Move(parent)) => {
self.mark_register_as_partially_moved(parent);

// If the value matched against is owned, it's destructured
// as part of the match. This means any fields ignored need
// to be dropped. If the match input is a reference no
// change is needed, because the reference count isn't
// modified unless we bind the fields to a variable.
if !state.increment.contains(&reg) {
if state.owned {
// Owned values are destructured, so we can't run
// their destructors.
self.current_block_mut()
.drop_without_dropper(reg, loc);
} else {
self.drop_register(reg, loc);
if self.register_type(reg).is_permanent(self.db()) {
self.mark_register_as_moved(reg);
} else {
self.drop_register(reg, loc);
}
}
None => {
if self.register_type(reg).is_permanent(self.db()) {
self.mark_register_as_moved(reg);
} else {
self.drop_register(reg, loc);
}
}
_ => self.mark_register_as_moved(reg),
}
}
}
Expand Down Expand Up @@ -2768,23 +2783,36 @@ impl<'a> LowerMethod<'a> {
let loc = state.location;

while let Some(reg) = registers.pop() {
if !self.register_is_available(reg) {
// We may encounter values partially moved, such as for the pattern
// `(a, b)` where the surrounding tuple is partially moved.
if self.register_is_moved(reg) {
continue;
}

match state.actions.get(&reg) {
Some(
&RegisterAction::Move(parent)
| &RegisterAction::Increment(parent),
) if self.register_is_moved(parent) => {
continue;
}
Some(&RegisterAction::Increment(_)) => {
// Registers are only incremented when bound. If we reach
// this point it means the register is never bound, and thus
// no dropping is needed.
self.mark_register_as_moved(reg);
continue;
}
_ => {}
}

self.mark_register_as_moved(reg);

if self.register_type(reg).is_permanent(self.db()) {
continue;
}

// If the register is the match input register, we always need to
// drop it. If the input is a reference, we don't want to drop
// unbound intermediate registers, because they were never
// incremented in the first place.
if reg == state.input_register() || state.owned {
self.current_block_mut().drop_without_dropper(reg, loc);
}
self.current_block_mut().drop_without_dropper(reg, loc);
}
}

Expand Down Expand Up @@ -2819,13 +2847,14 @@ impl<'a> LowerMethod<'a> {
cases: Vec<pmatch::Case>,
fallback_node: pmatch::Decision,
parent_block: BlockId,
registers: Vec<RegisterId>,
mut registers: Vec<RegisterId>,
) -> BlockId {
let blocks = self.add_blocks(cases.len());
let loc = state.location;

self.add_edge(parent_block, blocks[0]);
self.connect_block_sequence(&blocks);
registers.push(test_reg);

let fallback = self.decision(
state,
Expand Down Expand Up @@ -2877,13 +2906,14 @@ impl<'a> LowerMethod<'a> {
cases: Vec<pmatch::Case>,
fallback_node: pmatch::Decision,
parent_block: BlockId,
registers: Vec<RegisterId>,
mut registers: Vec<RegisterId>,
) -> BlockId {
let loc = state.location;
let blocks = self.add_blocks(cases.len());

self.add_edge(parent_block, blocks[0]);
self.connect_block_sequence(&blocks);
registers.push(test_reg);

let fallback = self.decision(
state,
Expand Down Expand Up @@ -2961,9 +2991,13 @@ impl<'a> LowerMethod<'a> {
let class =
self.register_type(test_reg).class_id(self.db()).unwrap();

if !test_type.is_owned_or_uni(self.db()) {
state.increment.insert(reg);
}
let action = if test_type.is_owned_or_uni(self.db()) {
RegisterAction::Move(test_reg)
} else {
RegisterAction::Increment(test_reg)
};

state.actions.insert(reg, action);

self.block_mut(parent_block)
.get_field(reg, test_reg, class, field, loc);
Expand Down Expand Up @@ -3017,11 +3051,13 @@ impl<'a> LowerMethod<'a> {
case.arguments.into_iter().zip(member_regs.iter())
{
let reg = state.registers[arg.0];
let action = if test_type.is_owned_or_uni(self.db()) {
RegisterAction::Move(test_reg)
} else {
RegisterAction::Increment(test_reg)
};

if !test_type.is_owned_or_uni(self.db()) {
state.increment.insert(reg);
}

state.actions.insert(reg, action);
self.block_mut(block).move_register(reg, member_reg, loc);
self.mark_register_as_moved_in_block(member_reg, block);
}
Expand Down Expand Up @@ -3513,10 +3549,7 @@ impl<'a> LowerMethod<'a> {
self.record_loop_move(register, location);

if self.register_kind(register).is_field() {
self.update_register_state(
self.self_register,
RegisterState::PartiallyMoved,
);
self.mark_register_as_partially_moved(self.self_register);
}

if let Some(flag) = self.drop_flags.get(&register).cloned() {
Expand Down Expand Up @@ -3623,10 +3656,7 @@ impl<'a> LowerMethod<'a> {
return;
}

self.update_register_state(
self.self_register,
RegisterState::PartiallyMoved,
);
self.mark_register_as_partially_moved(self.self_register);
}

fn clone_value_type(
Expand Down Expand Up @@ -4168,6 +4198,10 @@ impl<'a> LowerMethod<'a> {
final_state
}

fn mark_register_as_partially_moved(&mut self, register: RegisterId) {
self.update_register_state(register, RegisterState::PartiallyMoved);
}

fn mark_register_as_moved(&mut self, register: RegisterId) {
self.update_register_state(register, RegisterState::Moved);
}
Expand Down
10 changes: 10 additions & 0 deletions std/test/compiler/test_pattern_matching.inko
Expand Up @@ -40,4 +40,14 @@ fn pub tests(t: mut Tests) {

drop(state)
}

t.no_panic('match with a wildcard drops all components') fn {
let a = Letter.A
let b = Letter.B

match (ref a, ref b) {
case (A, A) -> true
case _ -> false
}
}
}

0 comments on commit a5a10f1

Please sign in to comment.