diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 2e8f1c18..700fe1d1 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -9,19 +9,14 @@ use crate::{ }; use egglog::{Term, TermDag}; -use std::{ - collections::{BTreeMap, HashMap}, - hash::Hash, - rc::Rc, - vec, -}; +use std::{collections::HashMap, hash::Hash, rc::Rc, vec}; #[derive(Default)] pub struct PrettyPrinter { // Type/Assum/BaseType -> intermediate variables symbols: HashMap, // intermediate variable -> Type/Assum/BaseType lookup - table: BTreeMap, + table: HashMap, fresh_count: u64, } @@ -136,7 +131,7 @@ impl PrettyPrinter { } pub fn to_egglog_default(&mut self, expr: &RcExpr) -> (String, String) { - self.to_egglog(expr, &|rc, len| (rc > 1 && len > 30) || len > 80) + self.to_egglog(expr, &|rc, len| (rc > 3 && len > 30) || len > 80) } // turn the Expr to a nested egglog with intermediate variables. @@ -168,7 +163,7 @@ impl PrettyPrinter { } pub fn to_rust_default(&mut self, expr: &RcExpr) -> (String, String) { - self.to_rust(expr, &|rc, len| (rc > 1 && len > 30) || len > 80) + self.to_rust(expr, &|rc, len| (rc > 3 && len > 30) || len > 80) } // turn the Expr to a rust ast macro string. @@ -182,7 +177,7 @@ impl PrettyPrinter { fold_when: &dyn Fn(usize, usize) -> bool, ) -> (String, String) { let mut log = vec![]; - let res = self.refactor_shared_expr(expr, fold_when, false, &mut log); + let res = self.refactor_shared_expr(expr, fold_when, true, &mut log); let log = log .iter() .map(|fresh_var| { @@ -207,10 +202,7 @@ impl PrettyPrinter { fn try_insert_fresh(&mut self, var: NodeRef, info: String) -> String { match self.symbols.get(&var) { - Some(binding) => { - println!("found {}", binding.clone()); - binding.clone() - } + Some(binding) => binding.clone(), None => { let fresh_var = &self.mk_fresh(info); self.symbols.insert(var, fresh_var.clone()); @@ -276,7 +268,6 @@ impl PrettyPrinter { log: &mut Vec, ) -> RcExpr { let old_expr_addr = Rc::as_ptr(expr); - // println!("{:?} has address {:?}", expr, old_expr_addr); let num_shared = Rc::strong_count(expr); let fold = |pp: &mut PrettyPrinter, new_expr: RcExpr, log: &mut Vec| { let binding = pp.try_insert_fresh(NodeRef::Expr(old_expr_addr), expr.abbrev()); @@ -284,8 +275,6 @@ impl PrettyPrinter { log.push(binding.clone()); pp.table .insert(binding.clone(), AstNode::Expr(new_expr.as_ref().clone())); - } else { - println!("found dup"); } Rc::new(Expr::Symbolic(binding)) }; @@ -306,35 +295,34 @@ impl PrettyPrinter { Rc::new(Expr::Symbolic(binding.to_owned())) } else { match expr.as_ref() { - Expr::Const(c, ty, assum) if !to_rust => { + Expr::Const(c, ty, assum) => { let ty = self.refactor_shared_type(ty, log); let assum = self.refactor_shared_assum(assum, fold_when, to_rust, log); let c = Rc::new(Expr::Const(c.clone(), ty, assum)); - fold(self, c, log) + if to_rust { + c + } else { + fold(self, c, log) + } } - Expr::Get(x, pos) if matches!(x.as_ref(), Expr::Arg(..)) && !to_rust => { - // fold Get Arg i anyway - let sub_expr = self.refactor_shared_expr(x, fold_when, to_rust, log); - let get = Rc::new(Expr::Get(sub_expr, *pos)); - return fold(self, get, log); + Expr::Get(x, pos) if matches!(x.as_ref(), Expr::Arg(..)) => { + if to_rust { + expr.clone() + } else { + let sub_expr = self.refactor_shared_expr(x, fold_when, to_rust, log); + let get = Rc::new(Expr::Get(sub_expr, *pos)); + fold(self, get, log) + } } Expr::Symbolic(_) => panic!("Expected non symbolic"), _ => { - // but you should not put it here as it will not fold ty and assum - // and type and assume must be folded to make sure it's correct let expr2 = expr.map_expr_type(|ty| self.refactor_shared_type(ty, log)); - //println!("{}",expr2.clone()); let expr3 = expr2.map_expr_assum(|assum| { self.refactor_shared_assum(assum, fold_when, to_rust, log) }); - //println!("{}",expr3.clone()); - // you should not call refactor_shard_expr on new expr let mapped_expr = expr3.map_expr_children(|e| { self.refactor_shared_expr(e, fold_when, to_rust, log) }); - //println!("{}",mapped_expr.clone()); - // in conclusion this three function must happen at the same time. - // but mut reference don't allow this to happen fold_or_plain(self, mapped_expr, log) } } @@ -698,41 +686,7 @@ fn test_pretty_print() -> crate::Result { let expr_str = concat_loop.to_string(); let (egglog, binding) = PrettyPrinter::default().to_egglog_default(&concat_loop); let (ast, _) = PrettyPrinter::default().to_rust_default(&concat_loop); - println!("{}", egglog.clone()); - assert_snapshot!(ast); - let check = format!("(let unfold {expr_str})\n {egglog} \n(check (= {binding} unfold))\n"); - egglog_test( - "", - &check, - vec![], - Value::Tuple(vec![]), - Value::Tuple(vec![]), - vec![], - ) -} - -#[test] -fn test_pretty_print2() -> crate::Result { - use crate::ast::*; - use crate::egglog_test; - use crate::Value; - use insta::assert_snapshot; - let output_ty = tuplet!(intt(), intt(), statet()); - let getat0 = getat(0).with_arg_types(output_ty.clone(), base(intt())); - let getat1 = getat(1).with_arg_types(output_ty.clone(), base(intt())); - let inv = sub(getat0.clone(), getat1.clone()); - let pred = ttrue(); - let my_loop = dowhile( - parallel!(int(1), int(2), getat(0)), - concat(parallel!(pred.clone(), getat0, inv), single(getat(2))), - ) - .with_arg_types(tuplet!(statet()), output_ty.clone()) - .add_ctx(schema::Assumption::dummy()); - let expr_str = my_loop.to_string(); - let (egglog, binding) = PrettyPrinter::default().to_egglog_default(&my_loop); - let (ast, _) = PrettyPrinter::default().to_rust_default(&my_loop); - println!("{}", egglog.clone()); assert_snapshot!(ast); let check = format!("(let unfold {expr_str})\n {egglog} \n(check (= {binding} unfold))\n"); diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index a88ff06f..54034619 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -303,7 +303,7 @@ impl Expr { Expr::Switch(cond, inputs, branches) => { let br = branches .iter() - .map(|branch| map_child(branch)) + .map(&mut map_child) .collect::>(); Rc::new(Expr::Switch(map_child(cond), map_child(inputs), br)) } diff --git a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap index 524d0a05..0ecf9879 100644 --- a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap +++ b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap @@ -2,52 +2,48 @@ source: src/pretty_print.rs expression: ast --- -let in_func_v0 = infunc("dummy"); -let tpl_s_v1 = tuplet!(statet()); -let int1_v2 = int(1); -let int2_v3 = int(2); -let int3_v4 = int(3); -let int4_v5 = int(4); -let tpl_i_i_i_i_s_v6 = tuplet!(intt(), intt(), intt(), intt(), statet()); -let less_than_v7 = less_than(getat(0), +let tpl_s_v0 = tuplet!(statet()); +let in_func_v1 = infunc("dummy"); +let concat_v2 = parallel!(int(4), getat(0)); +let concat_v3 = concat(single(int(2)), +concat(single(int(3)), +concat_v2.clone())); +let concat_v4 = concat(single(int(1)), +concat_v3.clone()); +let less_than_v5 = less_than(getat(0), getat(3)); -let sub_v8 = sub(getat(2), +let single_v6 = single(getat(0)); +let single_v7 = single(getat(1)); +let single_v8 = single(getat(2)); +let single_v9 = single(getat(3)); +let sub_v10 = sub(getat(2), getat(1)); -let tprint_v9 = tprint(sub_v8.clone(), +let tprint_v11 = tprint(sub_v10.clone(), getat(4)); -let in_loop_v10 = inloop(parallel!(int1_v2.clone(), int2_v3.clone(), int3_v4.clone(), int4_v5.clone(), getat(0)), -parallel!(less_than_v7.clone(), getat(0), getat(1), getat(2), getat(3), tprint_v9.clone())); -let less_than_v11 = less_than(getat(0), -getat(3)); -let sub_v12 = sub(getat(2), -getat(1)); -let tprint_v13 = tprint(sub_v12.clone(), -getat(4)); -let dowhile_v14 = dowhile(parallel!(int1_v2.clone(), int2_v3.clone(), int3_v4.clone(), int4_v5.clone(), getat(0)), -parallel!(less_than_v11.clone(), getat(0), getat(1), getat(2), getat(3), tprint_v13.clone())); -let tpl__v15 = emptyt(); -let int1_v16 = int(1); -let tpl_i_v17 = tuplet!(intt()); -let int3_v18 = int(3); -let less_than_v19 = less_than(getat(0), -int3_v18.clone()); -let int0_v20 = int(0); -let int4_v21 = int(4); -let int5_v22 = int(5); -let switch_v23 = switch!(int0_v20.clone(), arg(); parallel!(int4_v21.clone(), int5_v22.clone())); -let in_loop_v24 = inloop(single(int1_v16.clone()), -parallel!(less_than_v19.clone(), get(switch_v23.clone(), 0))); -let int3_v25 = int(3); -let less_than_v26 = less_than(getat(0), -int3_v25.clone()); -let int0_v27 = int(0); -let in_switch_v28 = inswitch(0, -int0_v27.clone(), +let concat_v12 = concat(concat(single(less_than_v5.clone()), +concat(single_v6.clone(), +single_v7.clone())), +concat(concat(single_v8.clone(), +single_v9.clone()), +single(tprint_v11.clone()))); +let tpl__v13 = emptyt(); +let single_v14 = single(int(1)); +let tpl_i_v15 = tuplet!(intt()); +let single_v16 = single(less_than(getat(0), +int(3))); +let switch_v17 = switch!(int(0), arg(); parallel!(int(4), int(5))); +let concat_v18 = concat(single_v16.clone(), +single(get(switch_v17.clone(), 0))); +let in_loop_v19 = inloop(single_v14.clone(), +concat_v18.clone()); +let less_than_v20 = less_than(getat(0), +int(3)); +let in_switch_v21 = inswitch(0, +int(0), arg()); -let int4_v29 = int(4); -let int5_v30 = int(5); -let switch_v31 = switch!(int0_v27.clone(), arg(); parallel!(int4_v29.clone(), int5_v30.clone())); -let dowhile_v32 = dowhile(single(int1_v16.clone()), -parallel!(less_than_v26.clone(), get(switch_v31.clone(), 0))); -let concat_v33 = concat(dowhile_v14.clone(), -dowhile_v32.clone()); +let switch_v22 = switch!(int(0), arg(); parallel!(int(4), int(5))); +let concat_v23 = concat(dowhile(concat_v4.clone(), +concat_v12.clone()), +dowhile(single_v14.clone(), +parallel!(less_than_v20.clone(), get(switch_v22.clone(), 0)))); +let concat_v24 = concat_v23.clone();