Skip to content

Commit

Permalink
fix the get_at not folding bug
Browse files Browse the repository at this point in the history
  • Loading branch information
clyben committed Jun 7, 2024
1 parent f47470f commit 900f5f9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 112 deletions.
86 changes: 20 additions & 66 deletions dag_in_context/src/pretty_print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@ use crate::{
};
use egglog::{Term, TermDag};

use std::{
collections::{BTreeMap, HashMap},
hash::Hash,
rc::Rc,
vec,
};
use std::{collections::HashMap, hash::Hash, rc::Rc, vec};

#[derive(Default)]
pub struct PrettyPrinter {
// Type/Assum/BaseType -> intermediate variables
symbols: HashMap<NodeRef, String>,
// intermediate variable -> Type/Assum/BaseType lookup
table: BTreeMap<String, AstNode>,
table: HashMap<String, AstNode>,
fresh_count: u64,
}

Expand Down Expand Up @@ -136,7 +131,7 @@ impl PrettyPrinter {
}

pub fn to_egglog_default(&mut self, expr: &RcExpr) -> (String, String) {
self.to_egglog(expr, &|rc, len| (rc > 1 && len > 30) || len > 80)
self.to_egglog(expr, &|rc, len| (rc > 3 && len > 30) || len > 80)
}

// turn the Expr to a nested egglog with intermediate variables.
Expand Down Expand Up @@ -168,7 +163,7 @@ impl PrettyPrinter {
}

pub fn to_rust_default(&mut self, expr: &RcExpr) -> (String, String) {
self.to_rust(expr, &|rc, len| (rc > 1 && len > 30) || len > 80)
self.to_rust(expr, &|rc, len| (rc > 3 && len > 30) || len > 80)
}

// turn the Expr to a rust ast macro string.
Expand All @@ -182,7 +177,7 @@ impl PrettyPrinter {
fold_when: &dyn Fn(usize, usize) -> bool,
) -> (String, String) {
let mut log = vec![];
let res = self.refactor_shared_expr(expr, fold_when, false, &mut log);
let res = self.refactor_shared_expr(expr, fold_when, true, &mut log);
let log = log
.iter()
.map(|fresh_var| {
Expand All @@ -207,10 +202,7 @@ impl PrettyPrinter {

fn try_insert_fresh(&mut self, var: NodeRef, info: String) -> String {
match self.symbols.get(&var) {
Some(binding) => {
println!("found {}", binding.clone());
binding.clone()
}
Some(binding) => binding.clone(),
None => {
let fresh_var = &self.mk_fresh(info);
self.symbols.insert(var, fresh_var.clone());
Expand Down Expand Up @@ -276,16 +268,13 @@ impl PrettyPrinter {
log: &mut Vec<String>,
) -> RcExpr {
let old_expr_addr = Rc::as_ptr(expr);
// println!("{:?} has address {:?}", expr, old_expr_addr);
let num_shared = Rc::strong_count(expr);
let fold = |pp: &mut PrettyPrinter, new_expr: RcExpr, log: &mut Vec<String>| {
let binding = pp.try_insert_fresh(NodeRef::Expr(old_expr_addr), expr.abbrev());
if !pp.table.contains_key(&binding) {
log.push(binding.clone());
pp.table
.insert(binding.clone(), AstNode::Expr(new_expr.as_ref().clone()));
} else {
println!("found dup");
}
Rc::new(Expr::Symbolic(binding))
};
Expand All @@ -306,35 +295,34 @@ impl PrettyPrinter {
Rc::new(Expr::Symbolic(binding.to_owned()))
} else {
match expr.as_ref() {
Expr::Const(c, ty, assum) if !to_rust => {
Expr::Const(c, ty, assum) => {
let ty = self.refactor_shared_type(ty, log);
let assum = self.refactor_shared_assum(assum, fold_when, to_rust, log);
let c = Rc::new(Expr::Const(c.clone(), ty, assum));
fold(self, c, log)
if to_rust {
c
} else {
fold(self, c, log)
}
}
Expr::Get(x, pos) if matches!(x.as_ref(), Expr::Arg(..)) && !to_rust => {
// fold Get Arg i anyway
let sub_expr = self.refactor_shared_expr(x, fold_when, to_rust, log);
let get = Rc::new(Expr::Get(sub_expr, *pos));
return fold(self, get, log);
Expr::Get(x, pos) if matches!(x.as_ref(), Expr::Arg(..)) => {
if to_rust {
expr.clone()
} else {
let sub_expr = self.refactor_shared_expr(x, fold_when, to_rust, log);
let get = Rc::new(Expr::Get(sub_expr, *pos));
fold(self, get, log)
}
}
Expr::Symbolic(_) => panic!("Expected non symbolic"),
_ => {
// but you should not put it here as it will not fold ty and assum
// and type and assume must be folded to make sure it's correct
let expr2 = expr.map_expr_type(|ty| self.refactor_shared_type(ty, log));
//println!("{}",expr2.clone());
let expr3 = expr2.map_expr_assum(|assum| {
self.refactor_shared_assum(assum, fold_when, to_rust, log)
});
//println!("{}",expr3.clone());
// you should not call refactor_shard_expr on new expr
let mapped_expr = expr3.map_expr_children(|e| {
self.refactor_shared_expr(e, fold_when, to_rust, log)
});
//println!("{}",mapped_expr.clone());
// in conclusion this three function must happen at the same time.
// but mut reference don't allow this to happen
fold_or_plain(self, mapped_expr, log)
}
}
Expand Down Expand Up @@ -698,41 +686,7 @@ fn test_pretty_print() -> crate::Result {
let expr_str = concat_loop.to_string();
let (egglog, binding) = PrettyPrinter::default().to_egglog_default(&concat_loop);
let (ast, _) = PrettyPrinter::default().to_rust_default(&concat_loop);
println!("{}", egglog.clone());
assert_snapshot!(ast);
let check = format!("(let unfold {expr_str})\n {egglog} \n(check (= {binding} unfold))\n");

egglog_test(
"",
&check,
vec![],
Value::Tuple(vec![]),
Value::Tuple(vec![]),
vec![],
)
}

#[test]
fn test_pretty_print2() -> crate::Result {
use crate::ast::*;
use crate::egglog_test;
use crate::Value;
use insta::assert_snapshot;
let output_ty = tuplet!(intt(), intt(), statet());
let getat0 = getat(0).with_arg_types(output_ty.clone(), base(intt()));
let getat1 = getat(1).with_arg_types(output_ty.clone(), base(intt()));
let inv = sub(getat0.clone(), getat1.clone());
let pred = ttrue();
let my_loop = dowhile(
parallel!(int(1), int(2), getat(0)),
concat(parallel!(pred.clone(), getat0, inv), single(getat(2))),
)
.with_arg_types(tuplet!(statet()), output_ty.clone())
.add_ctx(schema::Assumption::dummy());
let expr_str = my_loop.to_string();
let (egglog, binding) = PrettyPrinter::default().to_egglog_default(&my_loop);
let (ast, _) = PrettyPrinter::default().to_rust_default(&my_loop);
println!("{}", egglog.clone());
assert_snapshot!(ast);
let check = format!("(let unfold {expr_str})\n {egglog} \n(check (= {binding} unfold))\n");

Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/schema_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl Expr {
Expr::Switch(cond, inputs, branches) => {
let br = branches
.iter()
.map(|branch| map_child(branch))
.map(&mut map_child)
.collect::<Vec<_>>();
Rc::new(Expr::Switch(map_child(cond), map_child(inputs), br))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,48 @@
source: src/pretty_print.rs
expression: ast
---
let in_func_v0 = infunc("dummy");
let tpl_s_v1 = tuplet!(statet());
let int1_v2 = int(1);
let int2_v3 = int(2);
let int3_v4 = int(3);
let int4_v5 = int(4);
let tpl_i_i_i_i_s_v6 = tuplet!(intt(), intt(), intt(), intt(), statet());
let less_than_v7 = less_than(getat(0),
let tpl_s_v0 = tuplet!(statet());
let in_func_v1 = infunc("dummy");
let concat_v2 = parallel!(int(4), getat(0));
let concat_v3 = concat(single(int(2)),
concat(single(int(3)),
concat_v2.clone()));
let concat_v4 = concat(single(int(1)),
concat_v3.clone());
let less_than_v5 = less_than(getat(0),
getat(3));
let sub_v8 = sub(getat(2),
let single_v6 = single(getat(0));
let single_v7 = single(getat(1));
let single_v8 = single(getat(2));
let single_v9 = single(getat(3));
let sub_v10 = sub(getat(2),
getat(1));
let tprint_v9 = tprint(sub_v8.clone(),
let tprint_v11 = tprint(sub_v10.clone(),
getat(4));
let in_loop_v10 = inloop(parallel!(int1_v2.clone(), int2_v3.clone(), int3_v4.clone(), int4_v5.clone(), getat(0)),
parallel!(less_than_v7.clone(), getat(0), getat(1), getat(2), getat(3), tprint_v9.clone()));
let less_than_v11 = less_than(getat(0),
getat(3));
let sub_v12 = sub(getat(2),
getat(1));
let tprint_v13 = tprint(sub_v12.clone(),
getat(4));
let dowhile_v14 = dowhile(parallel!(int1_v2.clone(), int2_v3.clone(), int3_v4.clone(), int4_v5.clone(), getat(0)),
parallel!(less_than_v11.clone(), getat(0), getat(1), getat(2), getat(3), tprint_v13.clone()));
let tpl__v15 = emptyt();
let int1_v16 = int(1);
let tpl_i_v17 = tuplet!(intt());
let int3_v18 = int(3);
let less_than_v19 = less_than(getat(0),
int3_v18.clone());
let int0_v20 = int(0);
let int4_v21 = int(4);
let int5_v22 = int(5);
let switch_v23 = switch!(int0_v20.clone(), arg(); parallel!(int4_v21.clone(), int5_v22.clone()));
let in_loop_v24 = inloop(single(int1_v16.clone()),
parallel!(less_than_v19.clone(), get(switch_v23.clone(), 0)));
let int3_v25 = int(3);
let less_than_v26 = less_than(getat(0),
int3_v25.clone());
let int0_v27 = int(0);
let in_switch_v28 = inswitch(0,
int0_v27.clone(),
let concat_v12 = concat(concat(single(less_than_v5.clone()),
concat(single_v6.clone(),
single_v7.clone())),
concat(concat(single_v8.clone(),
single_v9.clone()),
single(tprint_v11.clone())));
let tpl__v13 = emptyt();
let single_v14 = single(int(1));
let tpl_i_v15 = tuplet!(intt());
let single_v16 = single(less_than(getat(0),
int(3)));
let switch_v17 = switch!(int(0), arg(); parallel!(int(4), int(5)));
let concat_v18 = concat(single_v16.clone(),
single(get(switch_v17.clone(), 0)));
let in_loop_v19 = inloop(single_v14.clone(),
concat_v18.clone());
let less_than_v20 = less_than(getat(0),
int(3));
let in_switch_v21 = inswitch(0,
int(0),
arg());
let int4_v29 = int(4);
let int5_v30 = int(5);
let switch_v31 = switch!(int0_v27.clone(), arg(); parallel!(int4_v29.clone(), int5_v30.clone()));
let dowhile_v32 = dowhile(single(int1_v16.clone()),
parallel!(less_than_v26.clone(), get(switch_v31.clone(), 0)));
let concat_v33 = concat(dowhile_v14.clone(),
dowhile_v32.clone());
let switch_v22 = switch!(int(0), arg(); parallel!(int(4), int(5)));
let concat_v23 = concat(dowhile(concat_v4.clone(),
concat_v12.clone()),
dowhile(single_v14.clone(),
parallel!(less_than_v20.clone(), get(switch_v22.clone(), 0))));
let concat_v24 = concat_v23.clone();

0 comments on commit 900f5f9

Please sign in to comment.