Skip to content

Commit

Permalink
Merge branch 'wip-spilling-2'
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgruetter committed May 8, 2021
2 parents 7db6053 + bdf7c23 commit 91b22b2
Show file tree
Hide file tree
Showing 18 changed files with 802 additions and 136 deletions.
2 changes: 1 addition & 1 deletion compiler/src/compiler/CompilerInvariant.v
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Require Import bedrock2.Map.SeparationLogic.
Require Import compiler.SeparationLogic.
Require Import coqutil.Tactics.Simp.
Require Import compiler.FlatToRiscvFunctions.
Require Import compiler.PipelineWithRename.
Require Import compiler.Pipeline.
Require Import compiler.ExprImpEventLoopSpec.
Require Import compiler.ToplevelLoop.

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/compiler/ExprImpEventLoopSpec.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Require Coq.Strings.String.
Require Import coqutil.Map.Interface coqutil.Word.Interface.
Require Import bedrock2.MetricLogging.
Require Import compiler.SeparationLogic.
Require Import compiler.PipelineWithRename.
Require Import compiler.Pipeline.

Section Params1.
Context {p: Semantics.parameters}.
Expand Down
78 changes: 62 additions & 16 deletions compiler/src/compiler/FlatImp.v
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,75 @@ Section Syntax.
| SCall binds _ _ | SInteract binds _ _ => list_union veq binds []
end.

Definition ForallVars_bcond(P: varname -> Prop)(cond: bcond) : Prop :=
Definition ForallVars_bcond_gen{R: Type}(and: R -> R -> R)(P: varname -> R)(cond: bcond): R :=
match cond with
| CondBinary _ x y => P x /\ P y
| CondBinary _ x y => and (P x) (P y)
| CondNez x => P x
end.

Definition Forall_vars_stmt(P: varname -> Prop)(P_calls: varname -> Prop): stmt -> Prop :=
Definition Forall_vars_stmt_gen{R: Type}(T: R)(and: R -> R -> R)(all: (varname -> R) -> list varname -> R)
(C: (varname -> R) -> bcond -> R)(P: varname -> R)(P_calls: varname -> R): stmt -> R :=
fix rec s :=
match s with
| SLoad _ x a _ => P x /\ P a
| SStore _ a x _ => P a /\ P x
| SInlinetable _ x _ i => P x /\ P i
| SStackalloc x n body => P x /\ rec body
| SLoad _ x a _ => and (P x) (P a)
| SStore _ a x _ => and (P a) (P x)
| SInlinetable _ x _ i => and (P x) (P i)
| SStackalloc x n body => and (P x) (rec body)
| SLit x _ => P x
| SOp x _ y z => P x /\ P y /\ P z
| SSet x y => P x /\ P y
| SIf c s1 s2 => ForallVars_bcond P c /\ rec s1 /\ rec s2
| SLoop s1 c s2 => ForallVars_bcond P c /\ rec s1 /\ rec s2
| SSeq s1 s2 => rec s1 /\ rec s2
| SSkip => True
| SCall binds _ args => Forall P_calls binds /\ Forall P_calls args
| SInteract binds _ args => Forall P_calls binds /\ Forall P_calls args
| SOp x _ y z => and (P x) (and (P y) (P z))
| SSet x y => and (P x) (P y)
| SIf c s1 s2 => and (C P c) (and (rec s1) (rec s2))
| SLoop s1 c s2 => and (C P c) (and (rec s1) (rec s2))
| SSeq s1 s2 => and (rec s1) (rec s2)
| SSkip => T
| SCall binds _ args => and (all P_calls binds) (all P_calls args)
| SInteract binds _ args => and (all P_calls binds) (all P_calls args)
end.

Definition ForallVars_stmt P := Forall_vars_stmt P P.
Definition ForallVars_bcond(P: varname -> Prop)(cond: bcond): Prop :=
Eval unfold ForallVars_bcond_gen in ForallVars_bcond_gen and P cond.

Definition Forall_vars_stmt(P: varname -> Prop)(P_calls: varname -> Prop): stmt -> Prop :=
Eval unfold Forall_vars_stmt_gen in
Forall_vars_stmt_gen True and (@Forall varname) ForallVars_bcond P P_calls.

Definition forallbVars_bcond(P: varname -> bool)(cond: bcond): bool :=
Eval unfold ForallVars_bcond_gen in ForallVars_bcond_gen andb P cond.

Definition forallb_vars_stmt(P: varname -> bool)(P_calls: varname -> bool): stmt -> bool :=
Eval unfold Forall_vars_stmt_gen in
Forall_vars_stmt_gen true andb (@forallb varname) forallbVars_bcond P P_calls.

Lemma forallb_vars_stmt_correct
(P: varname -> Prop)(p: varname -> bool)(P_calls: varname -> Prop)(p_calls: varname -> bool)
(p_correct: forall x, p x = true <-> P x)
(p_calls_correct: forall x, p_calls x = true <-> P_calls x):
forall s, forallb_vars_stmt p p_calls s = true <-> Forall_vars_stmt P P_calls s.
Proof.
assert (p_correct_fw: forall x, p x = true -> P x). {
intros. eapply p_correct. assumption.
}
assert (p_correct_bw: forall x, P x -> p x = true). {
intros. eapply p_correct. assumption.
}
assert (p_calls_correct_fw: forall x, p_calls x = true -> P_calls x). {
intros. eapply p_calls_correct. assumption.
}
assert (p_calls_correct_bw: forall x, P_calls x -> p_calls x = true). {
intros. eapply p_calls_correct. assumption.
}
clear p_correct p_calls_correct.
induction s; split; simpl; intros; unfold ForallVars_bcond, forallbVars_bcond in *;
repeat match goal with
| c: bcond |- _ => destruct c
| H: andb _ _ = true |- _ => eapply Bool.andb_true_iff in H
| H: _ /\ _ |- _ => destruct H
| H: _ <-> _ |- _ => destruct H
| |- andb _ _ = true => apply Bool.andb_true_iff
| |- _ /\ _ => split
end;
eauto using List.Forall_to_forallb, List.forallb_to_Forall.
Qed.

Lemma ForallVars_bcond_impl: forall (P Q: varname -> Prop),
(forall x, P x -> Q x) ->
Expand All @@ -134,6 +178,8 @@ Section Syntax.
induction s; intros; simpl in *; intuition eauto using ForallVars_bcond_impl, Forall_impl.
Qed.

Definition ForallVars_stmt P := Forall_vars_stmt P P.

Lemma ForallVars_stmt_impl: forall (P Q: varname -> Prop),
(forall x, P x -> Q x) ->
forall s, ForallVars_stmt P s -> ForallVars_stmt Q s.
Expand Down
13 changes: 4 additions & 9 deletions compiler/src/compiler/FlattenExpr.v
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,8 @@ Section FlattenExpr1.
Lemma freshNameGenState_disjoint_fbody: forall (fbody: cmd) (params rets: list String.string),
disjoint (ExprImp.allVars_cmd fbody)
(allFreshVars (@freshNameGenState _ (@NGstate p) (@NG p)
(ListSet.list_union String.eqb (ExprImp.allVars_cmd_as_list fbody)
(ListSet.list_union String.eqb params rets)))).
(ListSet.list_union String.eqb (ListSet.list_union String.eqb params rets)
(ExprImp.allVars_cmd_as_list fbody)))).
Proof.
unfold disjoint. intros.
epose proof (freshNameGenState_spec _ x) as P.
Expand All @@ -579,10 +579,8 @@ Section FlattenExpr1.
+ right. apply P. assumption.
+ left. clear -Ino hyps.
intro. apply Ino.
unshelve eapply ListSet.In_list_union_spec. left.
epose proof (ExprImp.allVars_cmd_allVars_cmd_as_list _ _) as P. destruct P as [P _].
apply P.
apply H.
eauto using ListSet.In_list_union_l, ListSet.In_list_union_r, nth_error_In.
Qed.

Lemma flattenStmt_correct_aux: forall max_size eH eL,
Expand Down Expand Up @@ -793,10 +791,7 @@ Section FlattenExpr1.
unfold map.of_list_zip in G.
eapply map.putmany_of_list_zip_find_index in G. 2: eassumption.
rewrite map.get_empty in G. destruct G as [G | G]; [|discriminate G]. simp.
apply ListSet.In_list_union_spec. right.
apply ListSet.In_list_union_spec.
left.
eapply nth_error_In. eassumption.
eauto using ListSet.In_list_union_l, ListSet.In_list_union_r, nth_error_In.
-- eapply freshNameGenState_disjoint_fbody.
* cbv beta. intros. simp.
edestruct H4 as [resvals ?]. 1: eassumption. simp.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/compiler/FlattenExprDef.v
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ Section FlattenExpr1.
Definition flatten_function:
list String.string * list String.string * Syntax.cmd -> option (list String.string * list String.string * FlatImp.stmt string) :=
fun '(argnames, retnames, body) =>
let avoid := ListSet.list_union String.eqb (ExprImp.allVars_cmd_as_list body)
(ListSet.list_union String.eqb argnames retnames) in
let avoid := ListSet.list_union String.eqb
(ListSet.list_union String.eqb argnames retnames)
(ExprImp.allVars_cmd_as_list body) in
let body' := fst (flattenStmt (freshNameGenState avoid) body) in
if FlatImp.stmt_size body' <? max_size then Some (argnames, retnames, body') else None.

Expand Down
9 changes: 0 additions & 9 deletions compiler/src/compiler/GoFlatToRiscv.v
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,6 @@ Section Go.
apply ptsto_bytes_putmany_of_tuple; assumption.
Qed.

Lemma length_flat_map: forall {A B: Type} (f: A -> list B) n (l: list A),
(forall (a: A), length (f a) = n) ->
length (flat_map f l) = (n * length l)%nat.
Proof.
induction l; intros.
- simpl. blia.
- simpl. rewrite app_length. rewrite H. rewrite IHl; assumption || blia.
Qed.

Lemma mod_eq_to_diff: forall e1 e2 m,
m <> 0 ->
e1 mod m = e2 mod m ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Section Pipeline1.

Definition flattenPhase(prog: source_env): option flat_env := flatten_functions (2^10) prog.
Definition renamePhase(prog: flat_env): option renamed_env :=
rename_functions prog.
rename_functions_old prog.

(* Note: we could also track code size from the source program all the way to the target
program, and a lot of infrastructure is already there, will do once/if we want to get
Expand Down Expand Up @@ -191,11 +191,11 @@ Section Pipeline1.
(e1 : FlattenExpr.ExprImp_env)
(e2 : FlattenExpr.FlatImp_env)
(funname : string)
(Hf: flatten_functions (2 ^ 10) e1 = Some e2).
(Hf: flatten_functions (2^10) e1 = Some e2).
Let c2 := (@FlatImp.SSeq String.string FlatImp.SSkip (FlatImp.SCall [] funname [])).
Let c3 := (@FlatImp.SSeq Z FlatImp.SSkip (FlatImp.SCall [] f_entry_name [])).
Context (av' : Z) (r' : string_keyed_map Z)
(ER: envs_related e2 prog)
(ER: envs_related_old e2 prog)
(Ren: rename map.empty c2 lowest_available_impvar = Some (r', c3, av'))
(H_p_call: word.unsigned p_call mod 4 = 0)
(H_p_functions: word.unsigned p_functions mod 4 = 0)
Expand All @@ -220,13 +220,13 @@ Section Pipeline1.
End Sim.

Lemma rename_fun_valid: forall argnames retnames body impl',
rename_fun (argnames, retnames, body) = Some impl' ->
rename_fun_old (argnames, retnames, body) = Some impl' ->
NoDup argnames ->
NoDup retnames ->
FlatImp.stmt_size body < 2 ^ 10 ->
FlatToRiscvDef.valid_FlatImp_fun impl'.
Proof.
unfold rename_fun, FlatToRiscvDef.valid_FlatImp_fun.
unfold rename_fun_old, FlatToRiscvDef.valid_FlatImp_fun.
intros.
simp.
eapply rename_binds_props in E; cycle 1.
Expand Down Expand Up @@ -651,9 +651,9 @@ Section Pipeline1.
eassumption.
}
{ eassumption. }
{ unfold envs_related.
{ unfold envs_related_old.
intros f [ [argnames resnames] body1 ] G.
unfold rename_functions in *.
unfold rename_functions_old in *.
eapply map.map_all_values_fw.
5: exact G. 4: eassumption.
- eapply String.eqb_spec.
Expand All @@ -667,7 +667,7 @@ Section Pipeline1.
destruct (map.map_all_values_fw _ _ _ _ E _ _ H1) as [ [ [args' rets'] fbody' ] [ F G ] ].
unfold flatten_function in F. simp.
epose proof (map.map_all_values_fw _ _ _ _ E0 _ _ G) as [ [ [args' rets'] fbody' ] [ F G' ] ].
unfold rename_fun in F. simp.
unfold rename_fun_old in F. simp.
eapply compile_funs_nonnil; eassumption.
}
{ unfold riscvPhase in *. simp. exact GetPos. }
Expand All @@ -682,7 +682,7 @@ Section Pipeline1.
destruct (map.map_all_values_fw _ _ _ _ E _ _ H1) as [ [ [args' rets'] fbody' ] [ F G ] ].
unfold flatten_function in F. simp.
epose proof (map.map_all_values_fw _ _ _ _ E0 _ _ G) as [ [ [args' rets'] fbody' ] [ F G' ] ].
unfold rename_fun in F. simp.
unfold rename_fun_old in F. simp.
apply_in_hyps rename_binds_preserves_length.
destruct rets'; [|discriminate].
destruct args'; [|discriminate].
Expand All @@ -696,9 +696,9 @@ Section Pipeline1.
intros.
simpl in *.
match goal with
| H: rename_functions _ = _ |- _ => rename H into RenameEq
| H: rename_functions_old _ = _ |- _ => rename H into RenameEq
end.
unfold rename_functions in RenameEq.
unfold rename_functions_old in RenameEq.
match goal with
| H: _ |- _ => unshelve epose proof (map.map_all_values_bw _ _ _ _ RenameEq _ _ H)
end.
Expand All @@ -721,7 +721,7 @@ Section Pipeline1.
destruct V.
ssplit.
- eapply rename_fun_valid; try eassumption.
unfold ExprImp2FlatImp in *.
unfold rename_fun_old in *.
simp.
repeat match goal with
| H: @eq bool _ _ |- _ => autoforward with typeclass_instances in H
Expand Down Expand Up @@ -834,7 +834,7 @@ Section Pipeline1.
{ assert (0 < bytes_per_word). { (* TODO: deduplicate *)
unfold bytes_per_word; simpl; destruct width_cases as [EE | EE]; rewrite EE; cbv; trivial.
}
rewrite (length_flat_map _ (Z.to_nat bytes_per_word)).
rewrite (List.length_flat_map _ (Z.to_nat bytes_per_word)).
{ rewrite Nat2Z.inj_mul, Z2Nat.id by blia. rewrite Z.sub_0_r in H2p0p1p8p0.
rewrite <-H2p0p1p8p0, <-Z_div_exact_2; try trivial.
{ eapply Z.lt_gt; assumption. }
Expand Down
Loading

0 comments on commit 91b22b2

Please sign in to comment.