diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 700fe1d1..96cba0a9 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -662,7 +662,7 @@ fn test_pretty_print() -> crate::Result { let inv = sub(getat(2), getat(1)).with_arg_types(output_ty.clone(), base(intt())); let pred = less_than(getat(0), getat(3)).with_arg_types(output_ty.clone(), base(boolt())); let print = tprint(inv, getat(4)).with_arg_types(output_ty.clone(), base(statet())); - let my_loop = dowhile( + let (my_loop, _) = dowhile( parallel!(int(1), int(2), int(3), int(4), getat(0)), concat( parallel!(pred.clone(), getat(0), getat(1)), @@ -672,7 +672,7 @@ fn test_pretty_print() -> crate::Result { .with_arg_types(tuplet!(statet()), output_ty.clone()) .add_ctx(schema::Assumption::dummy()); - let pureloop = dowhile( + let (pureloop, _) = dowhile( single(int(1)), parallel!( less_than(get(arg(), 0), int(3)), diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 54034619..47b74297 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -301,10 +301,7 @@ impl Expr { Expr::Single(x) => Rc::new(Expr::Single(map_child(x))), Expr::Concat(x, y) => Rc::new(Expr::Concat(map_child(x), map_child(y))), Expr::Switch(cond, inputs, branches) => { - let br = branches - .iter() - .map(&mut map_child) - .collect::>(); + let br = branches.iter().map(&mut map_child).collect::>(); Rc::new(Expr::Switch(map_child(cond), map_child(inputs), br)) } Expr::If(pred, input, then, els) => Rc::new(Expr::If( diff --git a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap index 0ecf9879..36b31f6b 100644 --- a/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap +++ b/dag_in_context/src/snapshots/dag_in_context__pretty_print__pretty_print.snap @@ -8,42 +8,35 @@ let concat_v2 = parallel!(int(4), getat(0)); let concat_v3 = concat(single(int(2)), concat(single(int(3)), concat_v2.clone())); -let concat_v4 = concat(single(int(1)), -concat_v3.clone()); -let less_than_v5 = less_than(getat(0), +let less_than_v4 = less_than(getat(0), getat(3)); -let single_v6 = single(getat(0)); -let single_v7 = single(getat(1)); -let single_v8 = single(getat(2)); -let single_v9 = single(getat(3)); -let sub_v10 = sub(getat(2), +let single_v5 = single(getat(0)); +let single_v6 = single(getat(1)); +let single_v7 = single(getat(2)); +let single_v8 = single(getat(3)); +let sub_v9 = sub(getat(2), getat(1)); -let tprint_v11 = tprint(sub_v10.clone(), +let tprint_v10 = tprint(sub_v9.clone(), getat(4)); -let concat_v12 = concat(concat(single(less_than_v5.clone()), -concat(single_v6.clone(), -single_v7.clone())), -concat(concat(single_v8.clone(), -single_v9.clone()), -single(tprint_v11.clone()))); -let tpl__v13 = emptyt(); -let single_v14 = single(int(1)); -let tpl_i_v15 = tuplet!(intt()); -let single_v16 = single(less_than(getat(0), -int(3))); -let switch_v17 = switch!(int(0), arg(); parallel!(int(4), int(5))); -let concat_v18 = concat(single_v16.clone(), -single(get(switch_v17.clone(), 0))); -let in_loop_v19 = inloop(single_v14.clone(), -concat_v18.clone()); -let less_than_v20 = less_than(getat(0), +let concat_v11 = concat(concat(single(less_than_v4.clone()), +concat(single_v5.clone(), +single_v6.clone())), +concat(concat(single_v7.clone(), +single_v8.clone()), +single(tprint_v10.clone()))); +let tpl__v12 = emptyt(); +let tpl_i_v13 = tuplet!(intt()); +let in_func_v14 = infunc(" loop_ctx_0"); +let less_than_v15 = less_than(getat(0), int(3)); -let in_switch_v21 = inswitch(0, +let in_switch_v16 = inswitch(0, int(0), arg()); -let switch_v22 = switch!(int(0), arg(); parallel!(int(4), int(5))); -let concat_v23 = concat(dowhile(concat_v4.clone(), -concat_v12.clone()), -dowhile(single_v14.clone(), -parallel!(less_than_v20.clone(), get(switch_v22.clone(), 0)))); -let concat_v24 = concat_v23.clone(); +let switch_v17 = switch!(int(0), arg(); parallel!(int(4), int(5))); +let dowhile_v18 = dowhile(single(int(1)), +parallel!(less_than_v15.clone(), get(switch_v17.clone(), 0))); +let concat_v19 = concat(dowhile(concat(single(int(1)), +concat_v3.clone()), +concat_v11.clone()), +dowhile_v18.clone()); +let concat_v20 = concat_v19.clone(); diff --git a/src/util.rs b/src/util.rs index 56877ab9..2ba94c1e 100644 --- a/src/util.rs +++ b/src/util.rs @@ -624,7 +624,7 @@ impl Run { } RunType::PrettyPrint => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let dag = rvsdg.to_dag_encoding(true); + let (dag, _) = rvsdg.to_dag_encoding(true); let res = TreeProgram::pretty_print_to_rust(&dag); ( vec![Visualization { @@ -637,8 +637,9 @@ impl Run { } RunType::OptimizedPrettyPrint => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let dag = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; + let (prog, mut ctx_cache) = rvsdg.to_dag_encoding(true); + let optimized = + dag_in_context::optimize(&prog, &mut ctx_cache).map_err(EggCCError::EggLog)?; let res = TreeProgram::pretty_print_to_rust(&optimized); ( vec![Visualization { @@ -652,8 +653,8 @@ impl Run { RunType::TestPrettyPrint => { let rvsdg = crate::Optimizer::program_to_rvsdg(&self.prog_with_args.program).unwrap(); - let tree = rvsdg.to_dag_encoding(true); - let unfolded_program = build_program(&tree, false); + let (tree, mut cache) = rvsdg.to_dag_encoding(true); + let unfolded_program = build_program(&tree, &mut cache, false); let folded_program = tree.pretty_print_to_egglog(); let program = format!("{unfolded_program} \n {folded_program} \n (check (= PROG_PP PROG))");