From 93ceed72664d72d3d2dc540f60943b60f1e872cc Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 22 Mar 2022 16:59:12 -0700 Subject: [PATCH] Add reveal_at_least, a more clever form of reveal (#1167) Depth determines which indices get expanded, but all references to the same index get expanded if they appear in the output. This is because Joel's latest examples require a rewriting pass that needs either uneven revealing or the ability to check expression equality modulo the dag in the middle of rewriting. https://github.com/mit-plv/fiat-crypto/pull/1134#issuecomment-1075320548 Roughly the issue is that if we want to turn `a + 4 * a` into `5 * a`, we need to reveal enough structure to see `4 * a`, but we need to see that the two instances of `a` are the same (e.g., if `a` is an ExprRef pointing to `b|c`, and we reveal uniformly, then we need to recognize `b|c + 4 * a`) --- src/Assembly/Symbolic.v | 82 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/src/Assembly/Symbolic.v b/src/Assembly/Symbolic.v index 0e388bc2cd..dc7e819077 100644 --- a/src/Assembly/Symbolic.v +++ b/src/Assembly/Symbolic.v @@ -24,6 +24,7 @@ Require Import Crypto.Util.ListUtil.Filter. Require Import Crypto.Util.ListUtil.PermutationCompat. Import ListUtil.PermutationCompat.Coq.Sorting.Permutation. Require Import Crypto.Util.NUtil.Sorting. Require Import Crypto.Util.NUtil.Testbit. +Require Import Crypto.Util.MSetN. Require Import Crypto.Util.ListUtil.PermutationCompat. Require Import Crypto.Util.Bool.LeCompat. Require Import Crypto.Util.Tactics.DestructHead. @@ -238,6 +239,54 @@ Section WithDag. Definition reveal_node n '(op, args) := ExprApp (op, List.map (reveal n) args). + (** given a set of indices, get the set of indices of their arguments *) + Definition reveal_gather_deps_args (ls : NSet.t) : NSet.t + := fold_right + (fun i so_far => match List.nth_error dag (N.to_nat i) with + | None => so_far + | Some (_op, args) => fold_right NSet.add so_far args + end) + NSet.empty + (NSet.elements ls). + + (** given a set of seen indices and a set of newly-revealed indices, + we want to merge the new indices into what's been seen and recurse + on the new indices *) + Definition reveal_gather_deps_step reveal_gather_deps (so_far : NSet.t) (new_idxs : NSet.t) : NSet.t + := let new_idxs := NSet.diff new_idxs so_far in + if NSet.is_empty new_idxs + then so_far + else reveal_gather_deps (NSet.union so_far new_idxs) (reveal_gather_deps_args new_idxs). + + Fixpoint reveal_gather_deps_list (n : nat) (so_far : NSet.t) (new_idxs : NSet.t) : NSet.t + := match n with + | O => NSet.union so_far new_idxs + | S n => reveal_gather_deps_step (reveal_gather_deps_list n) so_far new_idxs + end. + + Definition reveal_gather_deps (n : nat) (i : idx) : NSet.t + := reveal_gather_deps_list n NSet.empty (NSet.singleton i). + + Definition reveal_step_from_deps reveal (deps : NSet.t) (i : idx) : expr + := if NSet.mem i deps + then match List.nth_error dag (N.to_nat i) with + | None => (* undefined *) ExprRef i + | Some (op, args) => ExprApp (op, List.map reveal args) + end + else ExprRef i. + Fixpoint reveal_from_deps_fueled (fuel : nat) (deps : NSet.t) (i : idx) := + match fuel with + | O => ExprRef i + | S fuel => reveal_step_from_deps (reveal_from_deps_fueled fuel deps) deps i + end. + (** depth determines which indices get expanded, but all references + to the same index get expanded if they appear in the output *) + Definition reveal_at_least n (i : idx) : expr + := reveal_from_deps_fueled (S (List.length dag)) (reveal_gather_deps n i) i. + + Definition reveal_node_at_least n '(op, args) := + ExprApp (op, List.map (reveal_at_least n) args). + Local Unset Elimination Schemes. Inductive eval : expr -> Z -> Prop := | ERef i op args args' n @@ -316,6 +365,35 @@ Section WithDag. eapply Forall2_weaken; try eassumption; []; cbv beta; intros. eapply eval_reveal; eauto. Qed. + + Lemma eval_reveal_from_deps_fueled deps : forall n i, forall v, eval (ExprRef i) v -> + forall e, reveal_from_deps_fueled n deps i = e -> eval e v. + Proof using Type. + induction n; cbn [reveal_from_deps_fueled]; cbv [reveal_step_from_deps]; intros; subst; eauto; []. + break_innermost_match_step; eauto; []. + inversion H; subst; clear H. + rewrite H1; econstructor; try eassumption; []. + eapply (proj1 (Forall2_map_l _ _ _)) in H2. + clear dependent i; clear dependent v. + induction H2; cbn; eauto. + Qed. + + Lemma eval_reveal_at_least : forall n i, forall v, eval (ExprRef i) v -> + forall e, reveal_at_least n i = e -> eval e v. + Proof using Type. + cbv [reveal_at_least]. + intros; eapply eval_reveal_from_deps_fueled; eassumption. + Qed. + + Lemma eval_node_reveal_node_at_least : forall n v, eval_node n v -> + forall f e, reveal_node_at_least f n = e -> eval e v. + Proof using Type. + cbv [reveal_node]; inversion 1; intros; subst. + econstructor; eauto. + eapply (proj1 (Forall2_map_l _ _ _)) in H0; eapply Forall2_map_l. + eapply Forall2_weaken; try eassumption; []; cbv beta; intros. + eapply eval_reveal_at_least; eauto. + Qed. End WithDag. Module dag. @@ -1950,10 +2028,10 @@ Qed. End Rewrite. Definition simplify (dag : dag) (e : node idx) := - Rewrite.expr (reveal_node dag 3 e). + Rewrite.expr (reveal_node_at_least dag 3 e). Lemma eval_simplify G d n v : eval_node G d n v -> eval G d (simplify d n) v. -Proof using Type. eauto using Rewrite.eval_expr, eval_node_reveal_node. Qed. +Proof using Type. eauto using Rewrite.eval_expr, eval_node_reveal_node_at_least. Qed. Definition reg_state := Tuple.tuple (option idx) 16. Definition flag_state := Tuple.tuple (option idx) 6.