From 53a35fd92d72c1bee440b6fadf4823b23eb51e8e Mon Sep 17 00:00:00 2001 From: zhiyuan yan Date: Fri, 17 May 2024 00:38:55 -0700 Subject: [PATCH] one snapshot left??? --- dag_in_context/src/pretty_print.rs | 313 +++++++++++++++++------------ src/test_util.rs | 25 --- src/util.rs | 39 ++-- 3 files changed, 210 insertions(+), 167 deletions(-) diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 770efc250..3779e2773 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -1,22 +1,24 @@ use crate::{ from_egglog::FromEgglog, prologue, - schema::{self, Assumption, BaseType, BinaryOp, Expr, RcExpr, TernaryOp, Type, UnaryOp}, + schema::{ + self, Assumption, BaseType, BinaryOp, Expr, RcExpr, TernaryOp, TreeProgram, Type, UnaryOp, + }, schema_helpers::AssumptionRef, to_egglog::TreeToEgglog, }; use egglog::{Term, TermDag}; -use std::{ - collections::{BTreeMap, HashMap}, - hash::Hash, - rc::Rc, -}; +use std::{collections::HashMap, hash::Hash, rc::Rc, vec}; pub struct PrettyPrinter { pub expr: RcExpr, // Type/Assum/BaseType -> intermediate variables symbols: indexmap::IndexMap, + // intermediate variables about to print + log: Vec, + // intermediate variable -> Type/Assum/BaseType lookup + table: std::collections::BTreeMap, } #[derive(PartialEq, Eq, Hash)] @@ -61,6 +63,55 @@ impl SymbolLog { } } +impl TreeProgram { + pub fn pretty_print_to_egglog(&self) -> String { + let main_binding = "fun_main_".to_string(); + let mut pp = PrettyPrinter::from_expr(self.entry.to_owned()); + let main = pp.to_egglog_default(main_binding.clone()); + let mut function_bindings = vec![]; + let functions = self + .functions + .clone() + .into_iter() + .map(|expr| match expr.as_ref() { + schema::Expr::Function(name, ..) => { + let binding = format!("fun_{name}_"); + function_bindings.push(binding.clone()); + pp.new_expr(expr.clone()); + pp.to_egglog_default(binding) + } + _ => panic!("not function at top level"), + }) + .collect::>() + .join("\n\n"); + let function_list = function_bindings + .into_iter() + .rev() + .fold("(Nil)".to_string(), |acc, binding| { + format!("(Cons {binding} {acc})") + }); + format!("{main}\n {functions} \n (let PROG_PP (Program {main_binding} {function_list}))") + } + + pub fn pretty_print_to_rust(&self) -> String { + std::iter::once( + PrettyPrinter::from_expr(self.entry.to_owned()).to_rust_default("fun_main_".into()), + ) + .chain( + self.functions + .clone() + .into_iter() + .map(|expr| match expr.as_ref() { + schema::Expr::Function(name, ..) => PrettyPrinter::from_expr(expr.clone()) + .to_rust_default(format!("fun_{name}_")), + _ => panic!("not function at top level"), + }), + ) + .collect::>() + .join("\n\n") + } +} + impl PrettyPrinter { pub fn from_string(str_expr: String) -> std::result::Result { let bounded_expr = format!("(let EXPR___ {})", str_expr); @@ -83,53 +134,68 @@ impl PrettyPrinter { PrettyPrinter { expr, symbols: indexmap::IndexMap::new(), + log: vec![], + table: std::collections::BTreeMap::new(), } } - pub fn to_egglog_default(&mut self) -> String { - self.to_egglog(&|rc, len| (rc > 1 && len > 30) || len > 80) + // accept new expr and preserve symbols, but clear the log + pub fn new_expr(&mut self, expr: RcExpr) { + self.expr = expr; + self.log = vec![]; + } + + pub fn to_egglog_default(&mut self, binding: String) -> String { + self.to_egglog(&|rc, len| (rc > 1 && len > 30) || len > 80, binding) } // turn the Expr to a nested egglog with intermediate variables. // fold_when: provide a function that decide when to fold the egglog expression to a let binding - pub fn to_egglog(&mut self, fold_when: &dyn Fn(usize, usize) -> bool) -> String { - let mut log = Vec::new(); - let mut table = std::collections::BTreeMap::::new(); + pub fn to_egglog( + &mut self, + fold_when: &dyn Fn(usize, usize) -> bool, + binding: String, + ) -> String { + // let mut log = Vec::new(); + // let mut table = std::collections::BTreeMap::::new(); self.assign_fresh_var(&self.expr.clone()); - let res = self.fold_expr(&self.expr.clone(), &mut log, &mut table, fold_when, false); - let log = log + let res = self.fold_expr(&self.expr.clone(), fold_when, false); + let log = self + .log .iter() .map(|fresh_var| { - let symbol = table.get(fresh_var).unwrap(); + let symbol = self.table.get(fresh_var).unwrap(); let pretty = symbol.symbol_log_to_str(false); format!("(let {fresh_var} \n{pretty})\n") }) .collect::>() .join(""); - log + &format!("\n(let EXPR___\n{})\n", res.pretty()) + log + &format!("\n(let {binding} \n{})\n", res.pretty()) } - pub fn to_rust_default(&mut self) -> String { - self.to_rust(&|rc, len| (rc > 1 && len > 30) || rc > 4 || len > 80) + pub fn to_rust_default(&mut self, binding: String) -> String { + self.to_rust( + &|rc, len| (rc > 1 && len > 30) || rc > 4 || len > 80, + binding, + ) } // turn the Expr to a rust ast macro string. // fold_when: provide a function that decide when to fold the macro to a let binding - pub fn to_rust(&mut self, fold_when: &dyn Fn(usize, usize) -> bool) -> String { - let mut log = Vec::new(); - let mut table = std::collections::BTreeMap::::new(); + pub fn to_rust(&mut self, fold_when: &dyn Fn(usize, usize) -> bool, binding: String) -> String { self.assign_fresh_var(&self.expr.clone()); - let res = self.fold_expr(&self.expr.clone(), &mut log, &mut table, fold_when, false); - let log = log + let res = self.fold_expr(&self.expr.clone(), fold_when, false); + let log = self + .log .iter() .map(|fresh_var| { - let symbol = table.get(fresh_var).unwrap(); + let symbol = self.table.get(fresh_var).unwrap(); let ast = symbol.symbol_log_to_str(true); format!("let {fresh_var} = {ast};") }) .collect::>() .join("\n"); - log + &format!("\nlet expr___ = {};\n", res.to_ast()) + log + &format!("\nlet {binding} = {};\n", res.to_ast()) } fn assign_fresh_var(&mut self, expr: &RcExpr) { @@ -154,7 +220,7 @@ impl PrettyPrinter { fn try_insert_fresh(var: Symbols, info: String, pp: &mut PrettyPrinter) { if !pp.symbols.contains_key(&var) { - let fresh_var = format!("{info}_{}", pp.symbols.len()); + let fresh_var = format!("{info}_v{}", pp.symbols.len()); pp.symbols.insert(var, fresh_var.clone()); } } @@ -174,6 +240,9 @@ impl PrettyPrinter { let c = match c { schema::Constant::Int(i) => format!("int{i}"), schema::Constant::Bool(b) => format!("bool{b}"), + schema::Constant::Float(f) => { + format!("float{}", std::ptr::addr_of!(f) as i64) + } }; try_insert_fresh(expr_symbol, c, self); } @@ -263,35 +332,27 @@ impl PrettyPrinter { fn fold_expr( &mut self, expr: &RcExpr, - log: &mut Vec, - table: &mut BTreeMap, fold_when: &dyn Fn(usize, usize) -> bool, to_rust: bool, ) -> Expr { let old_expr_addr = Rc::as_ptr(expr); - let fold = |pp: &mut PrettyPrinter, - new_expr: schema::Expr, - log: &mut Vec, - table: &mut BTreeMap| { + let fold = |pp: &mut PrettyPrinter, new_expr: schema::Expr| { let fresh_var = pp.symbols.get(&Symbols::Expr(old_expr_addr)).unwrap(); - if !table.contains_key(fresh_var) { - log.push(fresh_var.into()); - table.insert(fresh_var.into(), SymbolLog::Expr(new_expr)); + if !pp.table.contains_key(fresh_var) { + pp.log.push(fresh_var.into()); + pp.table.insert(fresh_var.into(), SymbolLog::Expr(new_expr)); } Expr::Symbolic(fresh_var.into()) }; let num_shared = Rc::strong_count(expr); - let fold_or_plain = |pp: &mut PrettyPrinter, - new_expr: Expr, - log: &mut Vec, - table: &mut BTreeMap| { + let fold_or_plain = |pp: &mut PrettyPrinter, new_expr: Expr| { let size = &new_expr .to_string() .replace(&['(', ')', ' ', ','][..], "") //don't count those char when computing size .len(); if fold_when(num_shared, *size) { - fold(pp, new_expr, log, table) + fold(pp, new_expr) } else { new_expr } @@ -300,8 +361,6 @@ impl PrettyPrinter { fn handle_assum( assum: &Assumption, pp: &mut PrettyPrinter, - log: &mut Vec, - table: &mut BTreeMap, fold_when: &dyn Fn(usize, usize) -> bool, to_rust: bool, ) -> String { @@ -310,156 +369,148 @@ impl PrettyPrinter { .get(&Symbols::Assumption(assum.to_ref())) .unwrap() .clone(); - if !table.contains_key(&old_assume_binding) { + if !pp.table.contains_key(&old_assume_binding) { let new_assum = match assum { Assumption::InFunc(_) => assum.clone(), Assumption::InIf(cond, left, right) => { - let left = pp.fold_expr(left, log, table, fold_when, to_rust); - let right = pp.fold_expr(right, log, table, fold_when, to_rust); + let left = pp.fold_expr(left, fold_when, to_rust); + let right = pp.fold_expr(right, fold_when, to_rust); Assumption::InIf(*cond, Rc::new(left), Rc::new(right)) } Assumption::InLoop(inputs, body) => { - let inputs = pp.fold_expr(inputs, log, table, fold_when, to_rust); - let body = pp.fold_expr(body, log, table, fold_when, to_rust); + let inputs = pp.fold_expr(inputs, fold_when, to_rust); + let body = pp.fold_expr(body, fold_when, to_rust); Assumption::InLoop(Rc::new(inputs), Rc::new(body)) } Assumption::InSwitch(cond, inputs, branch) => { - let inputs = pp.fold_expr(inputs, log, table, fold_when, to_rust); - let branch = pp.fold_expr(branch, log, table, fold_when, to_rust); + let inputs = pp.fold_expr(inputs, fold_when, to_rust); + let branch = pp.fold_expr(branch, fold_when, to_rust); Assumption::InSwitch(*cond, Rc::new(inputs), Rc::new(branch)) } Assumption::WildCard(_) => assum.clone(), }; - log.push(old_assume_binding.clone()); - table.insert(old_assume_binding.clone(), SymbolLog::Assumption(new_assum)); + pp.log.push(old_assume_binding.clone()); + pp.table + .insert(old_assume_binding.clone(), SymbolLog::Assumption(new_assum)); } old_assume_binding } - fn handle_type( - ty: &Type, - pp: &mut PrettyPrinter, - log: &mut Vec, - table: &mut BTreeMap, - ) -> Type { + fn handle_type(ty: &Type, pp: &mut PrettyPrinter) -> Type { let ty_str = pp.symbols.get(&Symbols::Type(ty.clone())).unwrap().clone(); - if !table.contains_key(&ty_str) { - log.push(ty_str.clone()); - table.insert(ty_str.clone(), SymbolLog::Type(ty.clone())); + if !pp.table.contains_key(&ty_str) { + pp.log.push(ty_str.clone()); + pp.table.insert(ty_str.clone(), SymbolLog::Type(ty.clone())); } Type::Symbolic(ty_str) } match expr.as_ref() { Expr::Function(name, inty, outty, body) => { - let inty = handle_type(inty, self, log, table); - let outty = handle_type(outty, self, log, table); - - let body = self.fold_expr(body, log, table, fold_when, to_rust); + let inty = handle_type(inty, self); + let outty = handle_type(outty, self); + let body = self.fold_expr(body, fold_when, to_rust); Expr::Function(name.into(), inty, outty, Rc::new(body)) } Expr::Const(c, ty, assum) => { - let ty = handle_type(ty, self, log, table); - - let old_assum_binding = handle_assum(assum, self, log, table, fold_when, to_rust); + let ty = handle_type(ty, self); + let old_assum_binding = handle_assum(assum, self, fold_when, to_rust); let c = Expr::Const(c.clone(), ty, Assumption::WildCard(old_assum_binding)); if to_rust { c } else { - fold(self, c, log, table) + fold(self, c) } } Expr::Top(op, x, y, z) => { - let left = self.fold_expr(x, log, table, fold_when, to_rust); - let mid = self.fold_expr(y, log, table, fold_when, to_rust); - let right = self.fold_expr(z, log, table, fold_when, to_rust); + let left = self.fold_expr(x, fold_when, to_rust); + let mid = self.fold_expr(y, fold_when, to_rust); + let right = self.fold_expr(z, fold_when, to_rust); let top = Expr::Top(op.clone(), Rc::new(left), Rc::new(mid), Rc::new(right)); - fold_or_plain(self, top, log, table) + fold_or_plain(self, top) } Expr::Bop(op, x, y) => { - let left = self.fold_expr(x, log, table, fold_when, to_rust); - let right = self.fold_expr(y, log, table, fold_when, to_rust); + let left = self.fold_expr(x, fold_when, to_rust); + let right = self.fold_expr(y, fold_when, to_rust); let bop = Expr::Bop(op.clone(), Rc::new(left), Rc::new(right)); - fold_or_plain(self, bop, log, table) + fold_or_plain(self, bop) } Expr::Uop(op, x) => { - let sub_expr = self.fold_expr(x, log, table, fold_when, to_rust); + let sub_expr = self.fold_expr(x, fold_when, to_rust); let uop = Expr::Uop(op.clone(), Rc::new(sub_expr)); - fold_or_plain(self, uop, log, table) + fold_or_plain(self, uop) } Expr::Get(x, pos) => { - let sub_expr = self.fold_expr(x, log, table, fold_when, to_rust); + let sub_expr = self.fold_expr(x, fold_when, to_rust); let get = Expr::Get(Rc::new(sub_expr), *pos); // fold Get Arg i anyway if let Expr::Arg(_, _) = x.as_ref() { if !to_rust { - return fold(self, get, log, table); + return fold(self, get); } } get } Expr::Alloc(id, x, y, ty) => { - let amount = self.fold_expr(x, log, table, fold_when, to_rust); - let state_edge = self.fold_expr(y, log, table, fold_when, to_rust); + let amount = self.fold_expr(x, fold_when, to_rust); + let state_edge = self.fold_expr(y, fold_when, to_rust); let alloc = Expr::Alloc(*id, Rc::new(amount), Rc::new(state_edge), ty.clone()); - fold_or_plain(self, alloc, log, table) + fold_or_plain(self, alloc) } Expr::Call(name, x) => { - let sub_expr = self.fold_expr(x, log, table, fold_when, to_rust); + let sub_expr = self.fold_expr(x, fold_when, to_rust); let call = Expr::Call(name.into(), Rc::new(sub_expr)); - fold_or_plain(self, call, log, table) + fold_or_plain(self, call) } Expr::Empty(ty, assum) => { - let ty = handle_type(ty, self, log, table); - let assum_str = handle_assum(assum, self, log, table, fold_when, to_rust); - + let ty = handle_type(ty, self); + let assum_str = handle_assum(assum, self, fold_when, to_rust); Expr::Empty(ty, Assumption::WildCard(assum_str)) } // doesn't fold Tuple Expr::Single(x) => { - let sub_expr = self.fold_expr(x, log, table, fold_when, to_rust); + let sub_expr = self.fold_expr(x, fold_when, to_rust); Expr::Single(Rc::new(sub_expr)) } Expr::Concat(x, y) => { - let left = self.fold_expr(x, log, table, fold_when, to_rust); - let right = self.fold_expr(y, log, table, fold_when, to_rust); + let left = self.fold_expr(x, fold_when, to_rust); + let right = self.fold_expr(y, fold_when, to_rust); Expr::Concat(Rc::new(left), Rc::new(right)) } Expr::Switch(x, inputs, _branches) => { - let cond = self.fold_expr(x, log, table, fold_when, to_rust); - let inputs = self.fold_expr(inputs, log, table, fold_when, to_rust); + let cond = self.fold_expr(x, fold_when, to_rust); + let inputs = self.fold_expr(inputs, fold_when, to_rust); let branches = _branches .iter() - .map(|branch| Rc::new(self.fold_expr(branch, log, table, fold_when, to_rust))) + .map(|branch| Rc::new(self.fold_expr(branch, fold_when, to_rust))) .collect::>(); let switch = Expr::Switch(Rc::new(cond), Rc::new(inputs), branches); - fold_or_plain(self, switch, log, table) + fold_or_plain(self, switch) } Expr::If(x, inputs, y, z) => { - let pred = self.fold_expr(x, log, table, fold_when, to_rust); - let inputs = self.fold_expr(inputs, log, table, fold_when, to_rust); - let left = self.fold_expr(y, log, table, fold_when, to_rust); - let right = self.fold_expr(z, log, table, fold_when, to_rust); + let pred = self.fold_expr(x, fold_when, to_rust); + let inputs = self.fold_expr(inputs, fold_when, to_rust); + let left = self.fold_expr(y, fold_when, to_rust); + let right = self.fold_expr(z, fold_when, to_rust); let if_expr = Expr::If( Rc::new(pred), Rc::new(inputs), Rc::new(left), Rc::new(right), ); - fold_or_plain(self, if_expr, log, table) + fold_or_plain(self, if_expr) } Expr::DoWhile(inputs, body) => { - let inputs = self.fold_expr(inputs, log, table, fold_when, to_rust); - let body = self.fold_expr(body, log, table, fold_when, to_rust); + let inputs = self.fold_expr(inputs, fold_when, to_rust); + let body = self.fold_expr(body, fold_when, to_rust); let dowhile = Expr::DoWhile(Rc::new(inputs), Rc::new(body)); - fold_or_plain(self, dowhile, log, table) + fold_or_plain(self, dowhile) } Expr::Arg(ty, assum) => { - let ty = handle_type(ty, self, log, table); - let assum_str = handle_assum(assum, self, log, table, fold_when, to_rust); - + let ty = handle_type(ty, self); + let assum_str = handle_assum(assum, self, fold_when, to_rust); Expr::Arg(ty, Assumption::WildCard(assum_str)) } Expr::Symbolic(_) => panic!("No symbolic should occur here"), @@ -490,11 +541,13 @@ impl Expr { } } pub fn to_ast(&self) -> String { + use schema::Constant::*; match self { Expr::Const(c, ..) => match c { - schema::Constant::Bool(true) => "ttrue()".into(), - schema::Constant::Bool(false) => "tfalse()".into(), - schema::Constant::Int(n) => format!("int({})", n), + Bool(true) => "ttrue()".into(), + Bool(false) => "tfalse()".into(), + Int(n) => format!("int({})", n), + Float(f) => format!("float({})", f), }, Expr::Top(op, x, y, z) => { let left = x.to_ast(); @@ -583,7 +636,7 @@ impl Assumption { pub fn to_ast(&self) -> String { match self { - Assumption::InFunc(fun_name) => format!("infunc(\"{fun_name}\".into())"), + Assumption::InFunc(fun_name) => format!("infunc(\"{fun_name}\")"), Assumption::InIf(is, pred, input) => { format!("inif({is}, {}, {})", pred.to_ast(), input.to_ast()) } @@ -627,6 +680,7 @@ impl BaseType { BaseType::BoolT => "boolt()".into(), BaseType::StateT => "statet()".into(), BaseType::PointerT(ptr) => format!("pointert({})", BaseType::to_ast(ptr)), + BaseType::FloatT => "floatt()".into(), } } @@ -636,6 +690,7 @@ impl BaseType { BaseType::BoolT => "b".into(), BaseType::StateT => "s".into(), BaseType::PointerT(ptr) => format!("ptr{}", &ptr.abbrev()), + BaseType::FloatT => "f".into(), } } } @@ -687,6 +742,7 @@ impl Type { impl BinaryOp { pub fn to_ast(&self) -> String { use schema::BinaryOp::*; + // the same as schema_helper's match self { Add => "add", Sub => "sub", @@ -703,6 +759,15 @@ impl BinaryOp { Load => "load", Print => "tprint", Free => "free", + FGreaterEq => "f_greater_eq", + FGreaterThan => "f_greater_than", + FLessEq => "f_less_eq", + FLessThan => "f_less_than", + FAdd => "f_add", + FSub => "f_sub", + FDiv => "f_div", + FMul => "f_mul", + FEq => "f_eq", } .into() } @@ -710,10 +775,11 @@ impl BinaryOp { impl TernaryOp { pub fn to_ast(&self) -> String { - use schema::TernaryOp::Write; match self { - Write => "twrite".into(), + Self::Write => "twrite", + Self::Select => "select", } + .into() } } @@ -727,8 +793,10 @@ impl UnaryOp { } #[test] -fn test_pretty_print() { +fn test_pretty_print() -> crate::Result { use crate::ast::*; + use crate::egglog_test; + use crate::Value; let output_ty = tuplet!(intt(), intt(), intt(), intt(), statet()); let inner_inv = sub(getat(2), getat(1)).with_arg_types(output_ty.clone(), base(intt())); let inv = add(inner_inv.clone(), int(0)).with_arg_types(output_ty.clone(), base(intt())); @@ -751,16 +819,15 @@ fn test_pretty_print() { let expr_str = my_loop.to_string(); let res = PrettyPrinter::from_string(expr_str.clone()) .unwrap() - .to_rust_default(); - println!("{res}"); + .to_egglog_default("EXPR_".into()); + + let check = format!("(let unfold {expr_str})\n {res} \n(check (= EXPR_ unfold))\n"); + egglog_test( + "", + &check, + vec![], + Value::Tuple(vec![]), + Value::Tuple(vec![]), + vec![], + ) } - -// #[test] -// fn test_pretty() { -// use crate::ast::*; -// let expr = "(Function \"main\" (TupleT (TCons (StateT) (TNil))) (TupleT (TCons (StateT) (TNil))) (Single (Bop (Print) (Const (Int 1) (TupleT (TCons (StateT) (TNil))) (InFunc \"main\")) (Get (DoWhile (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"main\")) 0)) (Concat (Single (Bop (LessThan) (Const (Int 3) (TupleT (TCons (StateT) (TNil))) (InLoop (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"main\")) 0)) (Concat (Single (Bop (LessThan) (Const (Int 3) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) (Const (Int 1) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")))) (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) 0))))) (Const (Int 1) (TupleT (TCons (StateT) (TNil))) (InLoop (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"main\")) 0)) (Concat (Single (Bop (LessThan) (Const (Int 3) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) (Const (Int 1) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")))) (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) 0))))))) (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InLoop (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"main\")) 0)) (Concat (Single (Bop (LessThan) (Const (Int 3) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) (Const (Int 1) (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")))) (Single (Get (Arg (TupleT (TCons (StateT) (TNil))) (InFunc \"dummy\")) 0))))) 0)))) 0))))"; - -// let pp = PrettyPrinter::from_string(expr.into()).unwrap().to_egglog_default(); - -// println!("{pp}"); -// } diff --git a/src/test_util.rs b/src/test_util.rs index 3545d70e0..b12eecc3e 100644 --- a/src/test_util.rs +++ b/src/test_util.rs @@ -10,10 +10,7 @@ macro_rules! to_block { }; } -use std::fs; - use bril_rs::Type; -use dag_in_context::print_with_intermediate_vars; pub(crate) use to_block; macro_rules! rvsdg_svg_test { @@ -82,28 +79,6 @@ macro_rules! cfg_test_equiv { }; } -#[test] -fn test_pretty_print_to_egglog() { - let schema = fs::read_to_string("./dag_in_context/src/schema.egg").unwrap(); - let paths = std::fs::read_dir("./tests/passing/small").unwrap(); - for path in paths { - let path = path.unwrap(); - let program = crate::util::TestProgram::BrilFile(path.path()).read_program(); - let rvsdg = crate::Optimizer::program_to_rvsdg(&program.program).unwrap(); - let dag = rvsdg.to_dag_encoding(true); - let (term, termdag) = dag.entry.to_egglog(); - let unfolded_program = print_with_intermediate_vars(&termdag, term); - let folded_program = - dag_in_context::pretty_print::PrettyPrinter::from_expr(dag.entry).to_egglog_default(); - let program = format!( - "{schema}\n {unfolded_program} \n {folded_program} \n (check (= PROG EXPR___))" - ); - egglog::EGraph::default() - .parse_and_run_program(&program) - .unwrap(); - } -} - pub(crate) use cfg_test_equiv; use crate::cfg::BranchOp; diff --git a/src/util.rs b/src/util.rs index d08952c81..709d9d098 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,7 +4,6 @@ use crate::{EggCCError, Optimizer}; use bril_rs::Program; use clap::ValueEnum; use dag_in_context::dag2svg::tree_to_svg; -use dag_in_context::pretty_print::PrettyPrinter; use dag_in_context::{build_program, check_roundtrip_egraph}; use dag_in_context::schema::TreeProgram; @@ -209,6 +208,8 @@ pub enum RunType { /// The different configurations are with and without egglog optimization, and with and without /// llvm optimization. TestBenchmark, + // test the pretty printer + TestPrettyPrint, } impl Display for RunType { @@ -243,6 +244,7 @@ impl RunType { | RunType::PrettyPrint | RunType::ToCfg | RunType::OptimizedCfg + | RunType::TestPrettyPrint | RunType::TestBenchmark => false, RunType::BrilToJson => false, } @@ -419,6 +421,7 @@ impl Run { RunType::DagRoundTrip, RunType::Optimize, RunType::CheckExtractIdentical, + RunType::TestPrettyPrint, ] { let default = Run { test_type, @@ -636,14 +639,7 @@ impl Run { RunType::PrettyPrint => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let dag = rvsdg.to_dag_encoding(true); - let res = std::iter::once(PrettyPrinter::from_expr(dag.entry).to_rust_default()) - .chain( - dag.functions - .into_iter() - .map(|expr| PrettyPrinter::from_expr(expr).to_rust_default()), - ) - .collect::>() - .join("\n\n"); + let res = TreeProgram::pretty_print_to_rust(&dag); ( vec![Visualization { result: res, @@ -657,16 +653,7 @@ impl Run { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let dag = rvsdg.to_dag_encoding(true); let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; - let res = - std::iter::once(PrettyPrinter::from_expr(optimized.entry).to_rust_default()) - .chain( - optimized - .functions - .into_iter() - .map(|expr| PrettyPrinter::from_expr(expr).to_rust_default()), - ) - .collect::>() - .join("\n\n"); + let res = TreeProgram::pretty_print_to_rust(&optimized); ( vec![Visualization { result: res, @@ -676,6 +663,20 @@ impl Run { None, ) } + RunType::TestPrettyPrint => { + let rvsdg = + crate::Optimizer::program_to_rvsdg(&self.prog_with_args.program).unwrap(); + let tree = rvsdg.to_dag_encoding(true); + let unfolded_program = build_program(&tree, false); + let folded_program = tree.pretty_print_to_egglog(); + let program = + format!("{unfolded_program} \n {folded_program} \n (check (= PROG_PP PROG))"); + //println!("{}", program); + egglog::EGraph::default() + .parse_and_run_program(&program) + .unwrap(); + (vec![], None) + } RunType::DagConversion => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let tree = rvsdg.to_dag_encoding(true);