Skip to content

Commit

Permalink
Add reveal_at_least, a more clever form of reveal (#1167)
Browse files Browse the repository at this point in the history
Depth determines which indices get expanded, but all references to the
same index get expanded if they appear in the output.

This is because Joel's latest examples require a rewriting pass that
needs either uneven revealing or the ability to check expression
equality modulo the dag in the middle of rewriting.
#1134 (comment)
Roughly the issue is that if we want to turn `a + 4 * a` into `5 * a`,
we need to reveal enough structure to see `4 * a`, but we need to see
that the two instances of `a` are the same (e.g., if `a` is an ExprRef
pointing to `b|c`, and we reveal uniformly, then we need to recognize
`b|c + 4 * a`)
  • Loading branch information
JasonGross committed Mar 22, 2022
1 parent d8b3f3f commit 93ceed7
Showing 1 changed file with 80 additions and 2 deletions.
82 changes: 80 additions & 2 deletions src/Assembly/Symbolic.v
Expand Up @@ -24,6 +24,7 @@ Require Import Crypto.Util.ListUtil.Filter.
Require Import Crypto.Util.ListUtil.PermutationCompat. Import ListUtil.PermutationCompat.Coq.Sorting.Permutation.
Require Import Crypto.Util.NUtil.Sorting.
Require Import Crypto.Util.NUtil.Testbit.
Require Import Crypto.Util.MSetN.
Require Import Crypto.Util.ListUtil.PermutationCompat.
Require Import Crypto.Util.Bool.LeCompat.
Require Import Crypto.Util.Tactics.DestructHead.
Expand Down Expand Up @@ -238,6 +239,54 @@ Section WithDag.
Definition reveal_node n '(op, args) :=
ExprApp (op, List.map (reveal n) args).

(** given a set of indices, get the set of indices of their arguments *)
Definition reveal_gather_deps_args (ls : NSet.t) : NSet.t
:= fold_right
(fun i so_far => match List.nth_error dag (N.to_nat i) with
| None => so_far
| Some (_op, args) => fold_right NSet.add so_far args
end)
NSet.empty
(NSet.elements ls).

(** given a set of seen indices and a set of newly-revealed indices,
we want to merge the new indices into what's been seen and recurse
on the new indices *)
Definition reveal_gather_deps_step reveal_gather_deps (so_far : NSet.t) (new_idxs : NSet.t) : NSet.t
:= let new_idxs := NSet.diff new_idxs so_far in
if NSet.is_empty new_idxs
then so_far
else reveal_gather_deps (NSet.union so_far new_idxs) (reveal_gather_deps_args new_idxs).

Fixpoint reveal_gather_deps_list (n : nat) (so_far : NSet.t) (new_idxs : NSet.t) : NSet.t
:= match n with
| O => NSet.union so_far new_idxs
| S n => reveal_gather_deps_step (reveal_gather_deps_list n) so_far new_idxs
end.

Definition reveal_gather_deps (n : nat) (i : idx) : NSet.t
:= reveal_gather_deps_list n NSet.empty (NSet.singleton i).

Definition reveal_step_from_deps reveal (deps : NSet.t) (i : idx) : expr
:= if NSet.mem i deps
then match List.nth_error dag (N.to_nat i) with
| None => (* undefined *) ExprRef i
| Some (op, args) => ExprApp (op, List.map reveal args)
end
else ExprRef i.
Fixpoint reveal_from_deps_fueled (fuel : nat) (deps : NSet.t) (i : idx) :=
match fuel with
| O => ExprRef i
| S fuel => reveal_step_from_deps (reveal_from_deps_fueled fuel deps) deps i
end.
(** depth determines which indices get expanded, but all references
to the same index get expanded if they appear in the output *)
Definition reveal_at_least n (i : idx) : expr
:= reveal_from_deps_fueled (S (List.length dag)) (reveal_gather_deps n i) i.

Definition reveal_node_at_least n '(op, args) :=
ExprApp (op, List.map (reveal_at_least n) args).

Local Unset Elimination Schemes.
Inductive eval : expr -> Z -> Prop :=
| ERef i op args args' n
Expand Down Expand Up @@ -316,6 +365,35 @@ Section WithDag.
eapply Forall2_weaken; try eassumption; []; cbv beta; intros.
eapply eval_reveal; eauto.
Qed.

Lemma eval_reveal_from_deps_fueled deps : forall n i, forall v, eval (ExprRef i) v ->
forall e, reveal_from_deps_fueled n deps i = e -> eval e v.
Proof using Type.
induction n; cbn [reveal_from_deps_fueled]; cbv [reveal_step_from_deps]; intros; subst; eauto; [].
break_innermost_match_step; eauto; [].
inversion H; subst; clear H.
rewrite H1; econstructor; try eassumption; [].
eapply (proj1 (Forall2_map_l _ _ _)) in H2.
clear dependent i; clear dependent v.
induction H2; cbn; eauto.
Qed.

Lemma eval_reveal_at_least : forall n i, forall v, eval (ExprRef i) v ->
forall e, reveal_at_least n i = e -> eval e v.
Proof using Type.
cbv [reveal_at_least].
intros; eapply eval_reveal_from_deps_fueled; eassumption.
Qed.

Lemma eval_node_reveal_node_at_least : forall n v, eval_node n v ->
forall f e, reveal_node_at_least f n = e -> eval e v.
Proof using Type.
cbv [reveal_node]; inversion 1; intros; subst.
econstructor; eauto.
eapply (proj1 (Forall2_map_l _ _ _)) in H0; eapply Forall2_map_l.
eapply Forall2_weaken; try eassumption; []; cbv beta; intros.
eapply eval_reveal_at_least; eauto.
Qed.
End WithDag.

Module dag.
Expand Down Expand Up @@ -1950,10 +2028,10 @@ Qed.
End Rewrite.

Definition simplify (dag : dag) (e : node idx) :=
Rewrite.expr (reveal_node dag 3 e).
Rewrite.expr (reveal_node_at_least dag 3 e).

Lemma eval_simplify G d n v : eval_node G d n v -> eval G d (simplify d n) v.
Proof using Type. eauto using Rewrite.eval_expr, eval_node_reveal_node. Qed.
Proof using Type. eauto using Rewrite.eval_expr, eval_node_reveal_node_at_least. Qed.

Definition reg_state := Tuple.tuple (option idx) 16.
Definition flag_state := Tuple.tuple (option idx) 6.
Expand Down

0 comments on commit 93ceed7

Please sign in to comment.