Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add action to set custom cost #355

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ impl<'a> ActionCompiler<'a> {
self.do_atom_term(e);
self.instructions.push(Instruction::Set(func.name));
}
GenericCoreAction::Cost(_ann, f, args, e) => {
let ResolvedCall::Func(func) = f else {
panic!(
"Cannot set cost of primitive- should have been caught by typechecking!!!"
)
};
for arg in args {
self.do_atom_term(arg);
}
self.do_atom_term(e);
self.instructions.push(Instruction::Cost(func.name));
}
GenericCoreAction::Change(_ann, change, f, args) => {
let ResolvedCall::Func(func) = f else {
panic!("Cannot change primitive- should have been caught by typechecking!!!")
Expand Down Expand Up @@ -128,12 +140,15 @@ enum Instruction {
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(SpecializedPrimitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
/// Pop function arguments off the stack and either deletes, subsumes, or changes the cost
/// of the corresponding row in the function.
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
Change(Change, Symbol),
/// Pop the value to be set and the function arguments off the stack.
/// Set the function at the given arguments to the new value.
Set(Symbol),
/// Pop the value to have its cost set and the function arguments off the stack.
/// Set the function at the given arguments to the new cost.
Cost(Symbol),
/// Union the last `n` values on the stack.
Union(usize, ArcSort),
/// Extract the best expression. `n` is always 2.
Expand Down Expand Up @@ -346,10 +361,25 @@ impl EGraph {
// set to union
let new_value = stack.pop().unwrap();
let new_len = stack.len() - function.schema.input.len();

self.perform_set(*f, new_value, stack)?;
stack.truncate(new_len)
}
Instruction::Cost(f) => {
let function = self.functions.get_mut(f).unwrap();
let new_cost = stack.pop().unwrap();
let new_len = stack.len() - function.schema.input.len();

let function = self.functions.get_mut(f).unwrap();

let args = &stack[new_len..];

let i64sort: Arc<I64Sort> = self.type_info.get_sort_nofail();
let cost = i64::load(&i64sort, &new_cost);
function.update_cost(args, cost.try_into().unwrap());

stack.truncate(new_len);
}

Instruction::Union(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
Expand Down
29 changes: 29 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,13 @@ where
Vec<GenericExpr<Head, Leaf>>,
GenericExpr<Head, Leaf>,
),
/// `cost` sets the cost of a function to a particular value.
Cost(
Span,
Head,
Vec<GenericExpr<Head, Leaf>>,
GenericExpr<Head, Leaf>,
),
/// Delete or subsume (mark as hidden from future rewrites and unextractable) an entry from a function.
Change(Span, Change, Head, Vec<GenericExpr<Head, Leaf>>),
/// `union` two datatypes, making them equal
Expand Down Expand Up @@ -1304,6 +1311,7 @@ where
match self {
GenericAction::Let(_ann, lhs, rhs) => list!("let", lhs, rhs),
GenericAction::Set(_ann, lhs, args, rhs) => list!("set", list!(lhs, ++ args), rhs),
GenericAction::Cost(_ann, lhs, args, rhs) => list!("cost", list!(lhs, ++ args), rhs),
GenericAction::Union(_ann, lhs, rhs) => list!("union", lhs, rhs),
GenericAction::Change(_ann, change, lhs, args) => {
list!(
Expand Down Expand Up @@ -1344,6 +1352,15 @@ where
right,
)
}
GenericAction::Cost(span, lhs, args, rhs) => {
let right = f(rhs);
GenericAction::Cost(
span.clone(),
lhs.clone(),
args.iter().map(f).collect(),
right,
)
}
GenericAction::Change(span, change, lhs, args) => GenericAction::Change(
span.clone(),
*change,
Expand Down Expand Up @@ -1378,6 +1395,10 @@ where
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Set(span, lhs.clone(), args, rhs.visit_exprs(f))
}
GenericAction::Cost(span, lhs, args, rhs) => {
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Cost(span, lhs.clone(), args, rhs.visit_exprs(f))
}
GenericAction::Change(span, change, lhs, args) => {
let args = args.into_iter().map(|e| e.visit_exprs(f)).collect();
GenericAction::Change(span, change, lhs.clone(), args)
Expand Down Expand Up @@ -1417,6 +1438,14 @@ where
let rhs = rhs.subst_leaf(&mut fvar_expr!());
GenericAction::Set(span, lhs.clone(), args, rhs)
}
GenericAction::Cost(span, lhs, args, rhs) => {
let args = args
.into_iter()
.map(|e| e.subst_leaf(&mut fvar_expr!()))
.collect();
let rhs = rhs.subst_leaf(&mut fvar_expr!());
GenericAction::Cost(span, lhs.clone(), args, rhs)
}
GenericAction::Change(span, change, lhs, args) => {
let args = args
.into_iter()
Expand Down
6 changes: 6 additions & 0 deletions src/ast/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,12 @@ fn non_let_action(ctx: &Context) -> Res<Action> {
expr,
))
.map(|((), (f, args), v), span| Action::Set(span, f, args, v))(ctx),
"cost" => parens(sequence3(
text("cost"),
parens(sequence(ident, repeat_until_end_paren(expr))),
expr,
))
.map(|((), (f, args), v), span| Action::Cost(span, f, args, v))(ctx),
"delete" => parens(sequence(
text("delete"),
parens(sequence(ident, repeat_until_end_paren(expr))),
Expand Down
44 changes: 44 additions & 0 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,32 @@ impl Assignment<AtomTerm, ArcSort> {
rhs,
))
}
// (cost (f [*x]) rhs) where rhs should have type I64
GenericAction::Cost(
span,
CorrespondingVar {
head,
to: _mapped_var,
},
children,
rhs,
) => {
let children: Vec<_> = children
.iter()
.map(|child| self.annotate_expr(child, typeinfo))
.collect();
let types: Vec<_> = children.iter().map(|child| child.output_type()).collect();
let resolved_call =
ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
.ok_or_else(|| TypeError::UnboundFunction(*head, span.clone()))?;
let resolved_rhs = self.annotate_expr(rhs, typeinfo);
Ok(ResolvedAction::Cost(
span.clone(),
resolved_call,
children.clone(),
resolved_rhs,
))
}
// Note mapped_var for delete is a dummy variable that does not mean anything
GenericAction::Change(
span,
Expand Down Expand Up @@ -541,6 +567,24 @@ impl CoreAction {
)?)
.collect())
}
CoreAction::Cost(span, head, args, rhs) => {
let mut args = args.clone();
let var = symbol_gen.fresh(head);
args.push(AtomTerm::Var(span.clone(), var));

let mut all_terms = args.clone();
all_terms.push(rhs.clone());

Ok(get_literal_and_global_constraints(&all_terms, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
.chain(once(Constraint::Assign(
rhs.clone(),
typeinfo.get_sort_nofail::<I64Sort>() as ArcSort,
)))
.collect())
}
CoreAction::Change(span, _change, head, args) => {
let mut args = args.clone();
// Add a dummy last output argument
Expand Down
34 changes: 34 additions & 0 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ pub enum GenericCoreAction<Head, Leaf> {
Vec<GenericAtomTerm<Leaf>>,
GenericAtomTerm<Leaf>,
),
Cost(
Span,
Head,
Vec<GenericAtomTerm<Leaf>>,
GenericAtomTerm<Leaf>,
),
Change(Span, Change, Head, Vec<GenericAtomTerm<Leaf>>),
Union(Span, GenericAtomTerm<Leaf>, GenericAtomTerm<Leaf>),
Panic(Span, String),
Expand Down Expand Up @@ -492,6 +498,34 @@ where
mapped_expr,
));
}
GenericAction::Cost(span, head, args, expr) => {
let mut mapped_args = vec![];
for arg in args {
let (actions, mapped_arg) =
arg.to_core_actions(typeinfo, binding, fresh_gen)?;
norm_actions.extend(actions.0);
mapped_args.push(mapped_arg);
}
let (actions, mapped_expr) =
expr.to_core_actions(typeinfo, binding, fresh_gen)?;
norm_actions.extend(actions.0);
norm_actions.push(GenericCoreAction::Cost(
span.clone(),
head.clone(),
mapped_args
.iter()
.map(|e| e.get_corresponding_var_or_lit(typeinfo))
.collect(),
mapped_expr.get_corresponding_var_or_lit(typeinfo),
));
let v = fresh_gen.fresh(head);
mapped_actions.0.push(GenericAction::Cost(
span.clone(),
CorrespondingVar::new(head.clone(), v),
mapped_args,
mapped_expr,
));
}
GenericAction::Change(span, change, head, args) => {
let mut mapped_args = vec![];
for arg in args {
Expand Down
5 changes: 3 additions & 2 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,9 @@ impl<'a> Extractor<'a> {
function: &Function,
children: &[Value],
termdag: &mut TermDag,
cost: Option<usize>,
) -> Option<(Vec<Term>, Cost)> {
let mut cost = function.decl.cost.unwrap_or(1);
let mut cost = cost.unwrap_or(function.decl.cost.unwrap_or(1));
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
let types = &function.schema.input;
let mut terms: Vec<Term> = vec![];
for (ty, value) in types.iter().zip(children) {
Expand All @@ -188,7 +189,7 @@ impl<'a> Extractor<'a> {
if func.schema.output.is_eq_sort() {
for (inputs, output) in func.nodes.iter(false) {
if let Some((term_inputs, new_cost)) =
self.node_total_cost(func, inputs, termdag)
self.node_total_cost(func, inputs, termdag, output.cost)
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

Expand Down
8 changes: 7 additions & 1 deletion src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub struct TupleOutput {
pub value: Value,
pub timestamp: u32,
pub subsumed: bool,
pub cost: Option<usize>,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -240,6 +241,11 @@ impl Function {
self.nodes.get_mut(inputs).unwrap().subsumed = true;
}

/// Updates the cost of the given inputs.
pub fn update_cost(&mut self, inputs: &[Value], cost: usize) {
self.nodes.get_mut(inputs).unwrap().cost = Some(cost);
}

/// Return a column index that contains (a superset of) the offsets for the
/// given column. This method can return nothing if the indexes available
/// contain too many irrelevant offsets.
Expand Down Expand Up @@ -443,7 +449,7 @@ impl Function {
}
let out_ty = &self.schema.output;
self.nodes
.insert_and_merge(scratch, timestamp, out.subsumed, |prev| {
.insert_and_merge(scratch, timestamp, out.subsumed, out.cost, |prev| {
if let Some(mut prev) = prev {
out_ty.canonicalize(&mut prev, uf);
let mut appended = false;
Expand Down
9 changes: 7 additions & 2 deletions src/function/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl Table {
/// previous value, if there was one.
pub(crate) fn insert(&mut self, inputs: &[Value], out: Value, ts: u32) -> Option<Value> {
let mut res = None;
self.insert_and_merge(inputs, ts, false, |prev| {
self.insert_and_merge(inputs, ts, false, None, |prev| {
res = prev;
out
});
Expand All @@ -154,6 +154,7 @@ impl Table {
inputs: &[Value],
ts: u32,
subsumed: bool,
cost: Option<usize>,
on_merge: impl FnOnce(Option<Value>) -> Value,
) {
assert!(ts >= self.max_ts);
Expand All @@ -164,8 +165,9 @@ impl Table {
{
let (inp, prev) = &mut self.vals[*off];
let prev_subsumed = prev.subsumed;
let prev_cost = prev.cost;
let next = on_merge(Some(prev.value));
if next == prev.value && prev_subsumed == subsumed {
if next == prev.value && prev_subsumed == subsumed && prev_cost == cost {
return;
}
inp.stale_at = ts;
Expand All @@ -178,6 +180,8 @@ impl Table {
value: next,
timestamp: ts,
subsumed: subsumed || prev_subsumed,
// Take miminum of cost and prev_cost, if both exist
cost: cost.or(prev_cost.map(|x| cost.map_or(x, |y| x.min(y)))),
},
));
*off = new_offset;
Expand All @@ -190,6 +194,7 @@ impl Table {
value: on_merge(None),
timestamp: ts,
subsumed,
cost: None,
},
));
let to = TableOffset {
Expand Down
3 changes: 2 additions & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ impl EGraph {
egraph_serialize::Node {
op: func.decl.name.to_string(),
eclass: class_id.clone(),
cost: NotNan::new(func.decl.cost.unwrap_or(1) as f64).unwrap(),
cost: NotNan::new(output.cost.unwrap_or(func.decl.cost.unwrap_or(1)) as f64)
.unwrap(),
children,
subsumed: output.subsumed,
},
Expand Down
6 changes: 3 additions & 3 deletions tests/eggcc-extraction.egg
Original file line number Diff line number Diff line change
Expand Up @@ -1362,20 +1362,20 @@

;; if we reach a new context, union
(rule ((= theta (Theta pred inputs outputs))
(= (BodyAndCost extracted cost)
(= (BodyAndCost extracted cost_)
(ExtractedBody theta)))
((union theta extracted))
:ruleset fast-analyses)
(rule ((= gamma (Gamma pred inputs outputs))
(= (BodyAndCost extracted cost)
(= (BodyAndCost extracted cost_)
(ExtractedBody gamma)))
((union gamma extracted))
:ruleset fast-analyses)


;; if we reach the function at the top level, union
(rule ((= func (Func name intypes outtypes body))
(= (VecOperandAndCost extracted cost)
(= (VecOperandAndCost extracted cost_)
(ExtractedVecOperand body)))
((union func
(Func name intypes outtypes extracted)))
Expand Down
12 changes: 12 additions & 0 deletions tests/extract-costfn.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(datatype E
(Foo String))

(union (Foo "x") (Foo "y"))
(union (Foo "y") (Foo "z"))

(cost (Foo "x") 17)
(cost (Foo "y") 11)
(cost (Foo "z") 15)

(extract (Foo "y"))
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved

3 changes: 3 additions & 0 deletions tests/fail-typecheck/custom-cost-wrong-type.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
;;; test that setting the cost of a function to a non i64 will be a type error
(relation a ())
(cost (a) "hi")
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading