Skip to content

Commit

Permalink
Refactor how guard patterns are compiled
Browse files Browse the repository at this point in the history
Instead of allocating guard blocks ahead of time and reusing those, the
pattern match compiler is given the HIR expressions of guards directly.
This simplifies the compiler, and ensures that when OR patterns are
guarded, the guards for each branch are compiled separately.

This fixes #599.

Changelog: fixed
  • Loading branch information
yorickpeterse committed Feb 1, 2024
1 parent ac24bcc commit 895826b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 56 deletions.
32 changes: 4 additions & 28 deletions compiler/src/mir/passes.rs
Expand Up @@ -316,10 +316,6 @@ struct DecisionState {
(Vec<hir::Expression>, Vec<RegisterId>, SourceLocation),
>,

/// The basic blocks for every guard, and the expression to compile for
/// them.
guards: HashMap<BlockId, hir::Expression>,

/// The location of the `match` expression.
location: LocationId,

Expand All @@ -340,7 +336,6 @@ impl DecisionState {
registers: Vec::new(),
actions: HashMap::new(),
bodies: HashMap::new(),
guards: HashMap::new(),
location,
write_result,
}
Expand Down Expand Up @@ -2553,21 +2548,14 @@ impl<'a> LowerMethod<'a> {

for case in node.cases {
let var_regs = self.match_binding_registers(case.variable_ids);
let guard = case.guard.map(|expr| {
let block = self.add_block();

state.guards.insert(block, expr);
block
});

let block = self.add_block();
let pat =
pmatch::Pattern::from_hir(self.db(), self.mir, case.pattern);
let col = pmatch::Column::new(input_var, pat);
let body = pmatch::Body::new(block);

state.bodies.insert(block, (case.body, var_regs, case.location));
rows.push(pmatch::Row::new(vec![col], guard, body));
rows.push(pmatch::Row::new(vec![col], case.guard, body));
}

let bounds = self.method.id.bounds(self.db()).clone();
Expand Down Expand Up @@ -2639,22 +2627,10 @@ impl<'a> LowerMethod<'a> {
self.decision_body(state, self.current_block, body_block);
vars_block
}
pmatch::Decision::Guard(guard, ok, fail) => {
self.add_edge(parent_block, guard);

let guard_node = if let Some(node) = state.guards.remove(&guard)
{
node
} else {
// It's possible we visit the same guard twice, such as when
// encountering the case `case A or B if X -> {}`, as the
// guard is visited for every pattern in the OR pattern. In
// this case we compile the guard on the first visit, then
// return the block as-is (making sure to still connect the
// parent block above).
return guard;
};
pmatch::Decision::Guard(guard_node, ok, fail) => {
let guard = self.add_block();

self.add_edge(parent_block, guard);
self.enter_scope();

// Bindings are defined _after_ the guard, otherwise the failure
Expand Down
104 changes: 76 additions & 28 deletions compiler/src/mir/pattern_matching.rs
Expand Up @@ -305,14 +305,14 @@ impl Variable {
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct Row {
columns: Vec<Column>,
guard: Option<BlockId>,
guard: Option<hir::Expression>,
body: Body,
}

impl Row {
pub(crate) fn new(
columns: Vec<Column>,
guard: Option<BlockId>,
guard: Option<hir::Expression>,
body: Body,
) -> Self {
Self { columns, guard, body }
Expand Down Expand Up @@ -378,7 +378,7 @@ pub(crate) enum Decision {
/// 1. The guard to evaluate.
/// 2. The body to evaluate if the guard matches.
/// 3. The sub tree to evaluate when the guard fails.
Guard(BlockId, Body, Box<Decision>),
Guard(hir::Expression, Body, Box<Decision>),

/// Checks if a value is any of the given patterns.
///
Expand Down Expand Up @@ -884,13 +884,22 @@ impl<'a> Compiler<'a> {
mod tests {
use super::*;
use crate::config::Config;
use ast::source_location::SourceLocation;
use similar_asserts::assert_eq;
use types::module_name::ModuleName;
use types::{
Class, ClassInstance, ClassKind, Module, TypeId,
Variable as VariableType, VariableLocation, Visibility,
};

fn expr(value: i64) -> hir::Expression {
hir::Expression::Int(Box::new(hir::IntLiteral {
resolved_type: types::TypeRef::Unknown,
value,
location: SourceLocation::new(1..=1, 1..=1),
}))
}

fn state() -> State {
State::new(Config::new())
}
Expand All @@ -906,7 +915,7 @@ mod tests {

fn rules_with_guard(
input: Variable,
patterns: Vec<(Pattern, Option<BlockId>, BlockId)>,
patterns: Vec<(Pattern, Option<hir::Expression>, BlockId)>,
) -> Vec<Row> {
patterns
.into_iter()
Expand All @@ -924,18 +933,22 @@ mod tests {
Decision::Success(Body::new(block))
}

fn guard(block: BlockId, body: BlockId, fallback: Decision) -> Decision {
Decision::Guard(block, Body::new(body), Box::new(fallback))
fn guard(
code: hir::Expression,
body: BlockId,
fallback: Decision,
) -> Decision {
Decision::Guard(code, Body::new(body), Box::new(fallback))
}

fn guard_with_bindings(
block: BlockId,
code: hir::Expression,
bindings: Vec<Binding>,
body: BlockId,
fallback: Decision,
) -> Decision {
Decision::Guard(
block,
code,
Body { bindings, block_id: body },
Box::new(fallback),
)
Expand Down Expand Up @@ -1612,7 +1625,7 @@ mod tests {
input,
vec![
(Pattern::Int(4), None, BlockId(1)),
(Pattern::Wildcard, Some(BlockId(3)), BlockId(2)),
(Pattern::Wildcard, Some(expr(3)), BlockId(2)),
],
));

Expand All @@ -1626,7 +1639,7 @@ mod tests {
success(BlockId(1))
)],
Some(Box::new(guard_with_bindings(
BlockId(3),
expr(3),
vec![Binding::Ignored(input)],
BlockId(2),
fail()
Expand All @@ -1645,7 +1658,7 @@ mod tests {
let result = compiler.compile(rules_with_guard(
input,
vec![
(Pattern::Int(4), Some(BlockId(3)), BlockId(1)),
(Pattern::Int(4), Some(expr(3)), BlockId(1)),
(Pattern::Wildcard, None, BlockId(2)),
],
));
Expand All @@ -1658,7 +1671,7 @@ mod tests {
Constructor::Int(4),
Vec::new(),
guard(
BlockId(3),
expr(3),
BlockId(1),
success_with_bindings(
vec![Binding::Ignored(input)],
Expand All @@ -1674,6 +1687,49 @@ mod tests {
);
}

#[test]
fn test_guard_with_or_pattern() {
let mut state = state();
let mut compiler = compiler(&mut state);
let input = compiler.new_variable(TypeRef::int());
let result = compiler.compile(rules_with_guard(
input,
vec![
(
Pattern::Or(vec![Pattern::Int(4), Pattern::Int(5)]),
Some(expr(42)),
BlockId(1),
),
(Pattern::Int(4), None, BlockId(2)),
(Pattern::Int(5), None, BlockId(3)),
(Pattern::Wildcard, None, BlockId(4)),
],
));

assert_eq!(
result.tree,
Decision::Switch(
input,
vec![
Case::new(
Constructor::Int(4),
Vec::new(),
guard(expr(42), BlockId(1), success(BlockId(2)))
),
Case::new(
Constructor::Int(5),
Vec::new(),
guard(expr(42), BlockId(1), success(BlockId(3)))
)
],
Some(Box::new(success_with_bindings(
vec![Binding::Ignored(input)],
BlockId(4)
)))
)
);
}

#[test]
fn test_guard_with_same_int() {
let mut state = state();
Expand All @@ -1682,8 +1738,8 @@ mod tests {
let result = compiler.compile(rules_with_guard(
input,
vec![
(Pattern::Int(4), Some(BlockId(10)), BlockId(1)),
(Pattern::Int(4), Some(BlockId(20)), BlockId(2)),
(Pattern::Int(4), Some(expr(10)), BlockId(1)),
(Pattern::Int(4), Some(expr(20)), BlockId(2)),
(Pattern::Wildcard, None, BlockId(3)),
],
));
Expand All @@ -1696,10 +1752,10 @@ mod tests {
Constructor::Int(4),
Vec::new(),
guard(
BlockId(10),
expr(10),
BlockId(1),
guard(
BlockId(20),
expr(20),
BlockId(2),
success_with_bindings(
vec![Binding::Ignored(input)],
Expand All @@ -1724,16 +1780,8 @@ mod tests {
let result = compiler.compile(rules_with_guard(
input,
vec![
(
Pattern::String("a".to_string()),
Some(BlockId(3)),
BlockId(1),
),
(
Pattern::String("a".to_string()),
Some(BlockId(4)),
BlockId(2),
),
(Pattern::String("a".to_string()), Some(expr(3)), BlockId(1)),
(Pattern::String("a".to_string()), Some(expr(4)), BlockId(2)),
(Pattern::Wildcard, None, BlockId(3)),
],
));
Expand All @@ -1746,10 +1794,10 @@ mod tests {
Constructor::String("a".to_string()),
Vec::new(),
guard(
BlockId(3),
expr(3),
BlockId(1),
guard(
BlockId(4),
expr(4),
BlockId(2),
success_with_bindings(
vec![Binding::Ignored(input)],
Expand Down

0 comments on commit 895826b

Please sign in to comment.