Skip to content

Commit

Permalink
Extract Rust predicate expression generation
Browse files Browse the repository at this point in the history
Also rename `test_simple_filter` to `test_simple_boolean_relation_filter` so it
is easier to grep for where we use boolean relations in tests.
  • Loading branch information
tylerhou committed Apr 6, 2023
1 parent fc2eac3 commit d98c6f6
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 35 deletions.
89 changes: 55 additions & 34 deletions hydroflow_datalog_core/src/join_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ fn find_relation_local_constraints<'a>(
indices_grouped_by_var
}

/// Given a mapping from variable names to their repeated indices, builds a Rust expression that
/// tests whether the values at those indices are equal for each variable.
/// Builds a Rust expression that evaluates whether a relation's local constraints are satisfied.
///
/// For example, `rel(a, b, a, a, b)` would give us the map `{ "a" => [0, 2, 3], "b" => [1, 4] }`.
/// Then we would want to generate the code `row.0 == row.2 && row.0 == row.3 && row.1 == row.4`.
fn build_local_constraint_conditions(constraints: &BTreeMap<String, Vec<usize>>) -> syn::Expr {
/// For example, suppose we have the Datalog relation `rel(a, b, a, a, b)`. Calling
/// `find_relation_local_constraints` gives us the map `{ "a" => [0, 2, 3], "b" => [1, 4] }`. A
/// Rust expression that tests whether these local constraints are satisfied is `row.0 == row.2 &&
/// row.0 == row.3 && row.1 == row.4`.
fn build_rust_local_constraint_expr(constraints: &BTreeMap<String, Vec<usize>>) -> syn::Expr {
constraints
.values()
.flat_map(|indices| {
Expand All @@ -175,7 +176,14 @@ fn build_local_constraint_conditions(constraints: &BTreeMap<String, Vec<usize>>)
.unwrap()
}

fn gen_predicate_value_expr(
/// Builds a Rust expression that computes a Datalog IntExpr, recursively.
///
/// For example, given the Datalog IntExpr `1 + 2 + a`, generates the Rust code `1 + 2 + row.0`
/// (assuming the variable `a` is at index 0 in the variable mapping).
///
/// Note: The code currently does not generate parens around subexpressions because addition and
/// subtraction associate.
fn build_rust_int_expr(
expr: &IntExpr,
variable_mapping: &BTreeMap<String, usize>,
diagnostics: &mut Vec<Diagnostic>,
Expand All @@ -200,18 +208,49 @@ fn gen_predicate_value_expr(
lit: syn::Lit::Int(syn::LitInt::new(&i.to_string(), get_span(i.span))),
}),
IntExpr::Add(l, _, r) => {
let l = gen_predicate_value_expr(l, variable_mapping, diagnostics, get_span);
let r = gen_predicate_value_expr(r, variable_mapping, diagnostics, get_span);
let l = build_rust_int_expr(l, variable_mapping, diagnostics, get_span);
let r = build_rust_int_expr(r, variable_mapping, diagnostics, get_span);
parse_quote!(#l + #r)
}
IntExpr::Sub(l, _, r) => {
let l = gen_predicate_value_expr(l, variable_mapping, diagnostics, get_span);
let r = gen_predicate_value_expr(r, variable_mapping, diagnostics, get_span);
let l = build_rust_int_expr(l, variable_mapping, diagnostics, get_span);
let r = build_rust_int_expr(r, variable_mapping, diagnostics, get_span);
parse_quote!(#l - #r)
}
}
}

/// Builds a Rust expression that computes the Dataflow predicates.
///
/// For example, given Dataflow predicates like `( a >= b ), ( b = c )`, builds a Rust expression
/// equivalent to `(a >= b) && (b = c)`. ("Equivalent" because instead of emitting `a`, `b`, `c`,
/// we emit `row.a-idx`, `row.b-idx`, `row.c-idx`, where a-idx, b-idx, c-idx are the respective
/// indices in the variable mapping).
fn build_rust_predicate_expr(
predicates: &[&Spanned<BoolExpr>],
variable_mapping: &BTreeMap<String, usize>,
diagnostics: &mut Vec<Diagnostic>,
get_span: &impl Fn((usize, usize)) -> Span,
) -> syn::Expr {
predicates
.iter()
.map(|p| {
let l = build_rust_int_expr(&p.left, variable_mapping, diagnostics, get_span);
let r = build_rust_int_expr(&p.right, variable_mapping, diagnostics, get_span);

match &p.op {
BoolOp::Lt(_) => parse_quote_spanned!(get_span(p.span)=> #l < #r),
BoolOp::LtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l <= #r),
BoolOp::Gt(_) => parse_quote_spanned!(get_span(p.span)=> #l > #r),
BoolOp::GtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l >= #r),
BoolOp::Eq(_) => parse_quote_spanned!(get_span(p.span)=> #l == #r),
BoolOp::Neq(_) => parse_quote_spanned!(get_span(p.span)=> #l != #r),
}
})
.reduce(|a: syn::Expr, b| parse_quote!(#a && #b))
.unwrap()
}

/// Generates a Hydroflow pipeline that computes the output to a given [`JoinPlan`].
pub fn expand_join_plan(
// The plan we are converting to a Hydroflow pipeline.
Expand Down Expand Up @@ -269,7 +308,7 @@ pub fn expand_join_plan(
Span::call_site(),
);

let conditions = build_local_constraint_conditions(&local_constraints);
let conditions = build_rust_local_constraint_expr(&local_constraints);

flat_graph_builder.add_statement(parse_quote_spanned! {get_span(rule_span)=>
#filter_node = #relation_node [#relation_idx] -> filter(|row: &#row_type| #conditions)
Expand Down Expand Up @@ -409,8 +448,8 @@ pub fn expand_join_plan(
.add_statement(parse_quote_spanned!(get_span(rule_span)=> #join_node = anti_join() -> map(#flatten_closure)));
} else {
flat_graph_builder.add_statement(
parse_quote_spanned!(get_span(rule_span)=> #join_node = join::<'tick>() -> map(#flatten_closure)),
);
parse_quote_spanned!(get_span(rule_span)=> #join_node = join::<'tick>() -> map(#flatten_closure)),
);
}

let output_type = repeat_tuple::<syn::Type, syn::Type>(
Expand Down Expand Up @@ -481,26 +520,8 @@ pub fn expand_join_plan(
let row_type = inner_expanded.tuple_type;
let variable_mapping = &inner_expanded.variable_mapping;

let conditions = predicates
.iter()
.map(|p| {
let l =
gen_predicate_value_expr(&p.left, variable_mapping, diagnostics, get_span);
let r =
gen_predicate_value_expr(&p.right, variable_mapping, diagnostics, get_span);

match &p.op {
BoolOp::Lt(_) => parse_quote_spanned!(get_span(p.span)=> #l < #r),
BoolOp::LtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l <= #r),
BoolOp::Gt(_) => parse_quote_spanned!(get_span(p.span)=> #l > #r),
BoolOp::GtEq(_) => parse_quote_spanned!(get_span(p.span)=> #l >= #r),
BoolOp::Eq(_) => parse_quote_spanned!(get_span(p.span)=> #l == #r),
BoolOp::Neq(_) => parse_quote_spanned!(get_span(p.span)=> #l != #r),
}
})
.reduce(|a: syn::Expr, b| parse_quote!(#a && #b))
.unwrap();

let predicate_expr =
build_rust_predicate_expr(predicates, variable_mapping, diagnostics, get_span);
let predicate_filter_node = syn::Ident::new(
&format!(
"predicate_{}_filter",
Expand All @@ -510,7 +531,7 @@ pub fn expand_join_plan(
);

flat_graph_builder.add_statement(parse_quote_spanned! { get_span(rule_span)=>
#predicate_filter_node = #inner_name -> filter(|row: &#row_type| #conditions )
#predicate_filter_node = #inner_name -> filter(|row: &#row_type| #predicate_expr )
});

IntermediateJoinNode {
Expand Down
2 changes: 1 addition & 1 deletion hydroflow_datalog_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ mod tests {
}

#[test]
fn test_simple_filter() {
fn test_simple_boolean_relation_filter() {
test_snapshots!(
r#"
.input input `source_stream(input)`
Expand Down

0 comments on commit d98c6f6

Please sign in to comment.