Skip to content

Commit

Permalink
Get rule compilation working
Browse files Browse the repository at this point in the history
This also removes these tests:
    tests/nbe.egg
    tests/typeinfer.egg

Because they didn't typecheck and it was non-trivial
to get them working again.
  • Loading branch information
mwillsey committed Oct 19, 2022
1 parent 66fbeda commit 15cccfa
Show file tree
Hide file tree
Showing 10 changed files with 480 additions and 352 deletions.
2 changes: 1 addition & 1 deletion src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ NonDefineAction: Action = {
"(" "delete" "(" <f: Ident> <args:Expr*> ")" ")" => Action::Delete ( f, args),
"(" "union" <e1:Expr> <e2:Expr> ")" => Action::Union(<>),
"(" "panic" <msg:String> ")" => Action::Panic(msg),
<e:Expr> => Action::Expr(e),
<e:CallExpr> => Action::Expr(e),
}

Action: Action = {
Expand Down
15 changes: 10 additions & 5 deletions src/gj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,26 @@ pub struct VarInfo {
occurences: Vec<usize>,
}

type VarMap = IndexMap<Symbol, VarInfo>;

#[derive(Debug, Clone)]
pub struct CompiledQuery {
query: Query,
pub vars: VarMap,
pub vars: IndexMap<Symbol, VarInfo>,
}

impl EGraph {
pub(crate) fn compile_gj_query(
&self,
query: Query,
_types: HashMap<Symbol, ArcSort>,
types: &IndexMap<Symbol, ArcSort>,
) -> CompiledQuery {
// NOTE: this vars order only used for ordering the tuple,
// It is not the GJ variable order.
let mut vars: IndexMap<Symbol, VarInfo> = Default::default();

for var in types.keys() {
vars.entry(*var).or_default();
}

for (i, atom) in query.atoms.iter().enumerate() {
for v in atom.vars() {
// only count grounded occurrences
Expand Down Expand Up @@ -333,7 +336,9 @@ impl EGraph {
let mut program: Vec<Instr> = vars
.iter()
.map(|(&v, info)| {
let idx = query.vars.get_index_of(&v).unwrap();
let idx = query.vars.get_index_of(&v).unwrap_or_else(|| {
panic!("variable {} not found in query", v);
});
Instr::Intersect {
value_idx: idx,
trie_accesses: info
Expand Down
94 changes: 65 additions & 29 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use std::hash::Hash;
use std::io::Read;
use std::ops::{Deref, Range};
use std::path::PathBuf;
use std::rc::Rc;
use std::{fmt::Debug, sync::Arc};
use typecheck::{AtomTerm, Bindings};
use typecheck::Program;

type ArcSort = Arc<dyn Sort>;

Expand All @@ -39,10 +40,20 @@ use crate::typecheck::TypeError;
pub struct Function {
decl: FunctionDecl,
schema: ResolvedSchema,
merge: MergeFn,
nodes: IndexMap<Vec<Value>, TupleOutput>,
updates: usize,
}

#[derive(Clone)]
enum MergeFn {
AssertEq,
Union,
// the rc is make sure it's cheaply clonable, since calling the merge fn
// requires a clone
Expr(Rc<Program>),
}

#[derive(Debug, Clone)]
struct TupleOutput {
value: Value,
Expand Down Expand Up @@ -177,7 +188,7 @@ pub type Subst = IndexMap<Symbol, Value>;

pub trait PrimitiveLike {
fn name(&self) -> Symbol;
fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort>;
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort>;
fn apply(&self, values: &[Value]) -> Option<Value>;
}

Expand Down Expand Up @@ -232,7 +243,7 @@ impl PrimitiveLike for SimplePrimitive {
fn name(&self) -> Symbol {
self.name
}
fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
if self.input.len() != types.len() {
return None;
}
Expand Down Expand Up @@ -266,7 +277,7 @@ pub struct EGraph {
#[derive(Clone, Debug)]
struct Rule {
query: CompiledQuery,
bindings: Bindings,
program: Program,
head: Vec<Action>,
matches: usize,
times_banned: usize,
Expand Down Expand Up @@ -589,11 +600,26 @@ impl EGraph {
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};

let merge = if let Some(merge_expr) = &decl.merge {
let mut types = IndexMap::<Symbol, ArcSort>::default();
types.insert("old".into(), output.clone());
types.insert("new".into(), output.clone());
let program = self
.compile_expr(&types, merge_expr, output.clone())
.map_err(Error::TypeErrors)?;
MergeFn::Expr(Rc::new(program))
} else if output.is_eq_sort() {
MergeFn::Union
} else {
MergeFn::AssertEq
};

let function = Function {
decl: decl.clone(),
schema: ResolvedSchema { input, output },
nodes: Default::default(),
updates: 0,
merge,
// TODO figure out merge and default here
};

Expand Down Expand Up @@ -699,7 +725,7 @@ impl EGraph {
// HACK
let types = values
.iter()
.map(|v| &*self.sorts[&v.tag])
.map(|v| self.sorts[&v.tag].clone())
.collect::<Vec<_>>();
if prim.accept(&types).is_some() {
if res.is_none() {
Expand Down Expand Up @@ -833,24 +859,25 @@ impl EGraph {
}

fn step_rules(&mut self, iteration: usize) -> [Duration; 2] {
fn make_subst(rule: &Rule, values: &[Value]) -> Subst {
let get_val = |t: &AtomTerm| match t {
AtomTerm::Var(sym) => {
let i = rule
.query
.vars
.get_index_of(sym)
.unwrap_or_else(|| panic!("Couldn't find variable '{sym}'"));
values[i]
}
AtomTerm::Value(val) => *val,
};

rule.bindings
.iter()
.map(|(k, t)| (*k, get_val(t)))
.collect()
}
// fn make_subst(rule: &Rule, values: &[Value]) -> Subst {
// let get_val = |t: &AtomTerm| match t {
// AtomTerm::Var(sym) => {
// let i = rule
// .query
// .vars
// .get_index_of(sym)
// .unwrap_or_else(|| panic!("Couldn't find variable '{sym}'"));
// values[i]
// }
// AtomTerm::Value(val) => *val,
// };

// todo!()
// // rule.bindings
// // .iter()
// // .map(|(k, t)| (*k, get_val(t)))
// // .collect()
// }

let ban_length = 5;

Expand Down Expand Up @@ -895,15 +922,17 @@ impl EGraph {

let rule_apply_start = Instant::now();

let stack = &mut vec![];
for values in all_values.chunks(n) {
rule.matches += 1;
if rule.matches > 10_000_000 {
log::warn!("Rule {} has matched {} times, bailing!", name, rule.matches);
break 'outer;
}
let subst = make_subst(rule, values);
log::trace!("Applying with {subst:?}");
let _result: Result<_, _> = self.eval_actions(Some(subst), &rule.head);
// log::trace!("Applying with {subst:?}");
assert!(stack.is_empty());
self.run_actions(stack, values, &rule.program);
// let _result: Result<_, _> = self.eval_actions(Some(subst), &rule.head);
}

rule.apply_time += rule_apply_start.elapsed();
Expand All @@ -916,16 +945,23 @@ impl EGraph {
fn add_rule_with_name(&mut self, name: String, rule: ast::Rule) -> Result<Symbol, Error> {
let name = Symbol::from(name);
let mut ctx = typecheck::Context::new(self);
let (query0, bindings) = ctx.typecheck_query(&rule.body).map_err(Error::TypeErrors)?;
let query = self.compile_gj_query(query0, ctx.types);
let query0 = ctx.typecheck_query(&rule.body).map_err(Error::TypeErrors)?;
let query = self.compile_gj_query(query0, &ctx.types);
let program = self
.compile_actions(&ctx.types, &rule.head)
.map_err(Error::TypeErrors)?;
// println!(
// "Compiled rule {rule:?}\n{subst:?}to {program:#?}",
// subst = &ctx.types
// );
let compiled_rule = Rule {
query,
bindings,
head: rule.head,
matches: 0,
times_banned: 0,
banned_until: 0,
todo_timestamp: 0,
program,
search_time: Duration::default(),
apply_time: Duration::default(),
};
Expand Down
2 changes: 1 addition & 1 deletion src/sort/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ macro_rules! add_primitives {
$name.into()
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
let mut types = types.iter();
$(
if self.$param.name() != types.next()?.name() {
Expand Down
18 changes: 9 additions & 9 deletions src/sort/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl PrimitiveLike for Ctor {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[] => Some(self.map.clone()),
_ => None,
Expand All @@ -145,7 +145,7 @@ impl PrimitiveLike for Insert {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map, key, value]
if (map.name(), (key.name(), value.name()))
Expand Down Expand Up @@ -174,7 +174,7 @@ impl PrimitiveLike for Get {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map, key] if (map.name(), key.name()) == (self.map.name, self.map.key.name()) => {
Some(self.map.value.clone())
Expand All @@ -200,7 +200,7 @@ impl PrimitiveLike for NotContains {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map, key] if (map.name(), key.name()) == (self.map.name, self.map.key.name()) => {
Some(self.unit.clone())
Expand Down Expand Up @@ -230,7 +230,7 @@ impl PrimitiveLike for Contains {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map, key] if (map.name(), key.name()) == (self.map.name, self.map.key.name()) => {
Some(self.unit.clone())
Expand Down Expand Up @@ -259,7 +259,7 @@ impl PrimitiveLike for Union {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map1, map2] if map1.name() == self.map.name && map2.name() == self.map.name() => {
Some(self.map.clone())
Expand Down Expand Up @@ -287,7 +287,7 @@ impl PrimitiveLike for Intersect {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map1, map2] if map1.name() == self.map.name && map2.name() == self.map.name() => {
Some(self.map.clone())
Expand Down Expand Up @@ -315,10 +315,10 @@ impl PrimitiveLike for Remove {
self.name
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[map, key] if (map.name(), key.name()) == (self.map.name, self.map.key.name()) => {
Some(self.map.value.clone())
Some(self.map.clone())
}
_ => None,
}
Expand Down
2 changes: 1 addition & 1 deletion src/sort/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl PrimitiveLike for NotEqualPrimitive {
"!=".into()
}

fn accept(&self, types: &[&dyn Sort]) -> Option<ArcSort> {
fn accept(&self, types: &[ArcSort]) -> Option<ArcSort> {
match types {
[a, b] if a.name() == b.name() => Some(self.unit.clone()),
_ => None,
Expand Down
Loading

0 comments on commit 15cccfa

Please sign in to comment.