New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ltac2: use preterm in exact / eexact #18157
Conversation
@coqbot bench |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Should we also add Constr.Preterm.of_constr (c : constr) : preterm := preterm:($c).
(this works, right?), either in this or another PR, so it's easier to see how to convert from constr to preterm? (I guess Constr.to_preterm
would also be a fine name?)
It works but I'm not sure what the point would be |
Yeah, I guess I don't have a concrete use-case in mind at the moment, and was thinking more about completeness of API. I'll retract my suggestion until/unless I have a concrete use-case in mind |
(Constr.Pretype.OfType (Control.goal())) | ||
c | ||
in | ||
Control.refine (fun _ => c)). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably have a refine_nocheck too
🏁 Bench results:
INFO: failed to install coq-metacoq-safechecker (dependency coq-metacoq-pcuic failed) 🐢 Top 25 slow downs┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ TOP 25 SLOW DOWNS │ │ │ │ OLD NEW DIFF %DIFF Ln FILE │ ├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 159.8330 162.5840 2.7510 1.72% 233 coq-fiat-crypto-with-bedrock/rupicola/bedrock2/deps/riscv-coq/src/riscv/Proofs/DecodeByExtension.v.html │ │ 22.5260 23.5440 1.0180 4.52% 660 coq-perennial/src/program_proof/vrsm/replica/roapply_proof.v.html │ │ 11.5640 12.4960 0.9320 8.06% 462 coq-perennial/src/program_proof/simple/write.v.html │ │ 54.5520 55.2410 0.6890 1.26% 609 coq-bedrock2/bedrock2/src/bedrock2Examples/lightbulb.v.html │ │ 34.0800 34.6760 0.5960 1.75% 522 coq-perennial/src/program_proof/txn/twophase_refinement_proof.v.html │ │ 62.5440 63.1210 0.5770 0.92% 139 coq-fiat-parsers/src/Parsers/Refinement/SharpenedJSON.v.html │ │ 17.1090 17.5110 0.4020 2.35% 32 coq-performance-tests-lite/src/pattern.v.html │ │ 32.5230 32.9190 0.3960 1.22% 12 coq-fourcolor/theories/job323to383.v.html │ │ 3.2680 3.6380 0.3700 11.32% 252 coq-perennial/src/program_proof/simple/iread.v.html │ │ 29.7340 30.0610 0.3270 1.10% 912 coq-perennial/src/program_proof/vrsm/reconfig/proof.v.html │ │ 4.3280 4.6500 0.3220 7.44% 5 coq-fiat-crypto-with-bedrock/src/Assembly/Parse/Examples/fiat_p256_mul_optimised_seed4.v.html │ │ 27.4330 27.7540 0.3210 1.17% 12 coq-fourcolor/theories/job279to282.v.html │ │ 90.8810 91.1790 0.2980 0.33% 103 coq-fiat-crypto-with-bedrock/src/Arithmetic/BarrettReduction.v.html │ │ 33.9620 34.2420 0.2800 0.82% 12 coq-fourcolor/theories/job254to270.v.html │ │ 3.2320 3.4860 0.2540 7.86% 32 coq-performance-tests-lite/src/n_polymorphic_universes.v.html │ │ 1.0770 1.3250 0.2480 23.03% 1677 coq-fiat-crypto-with-bedrock/src/Bedrock/End2End/RupicolaCrypto/ChaCha20.v.html │ │ 22.3510 22.5930 0.2420 1.08% 12 coq-fourcolor/theories/job283to286.v.html │ │ 39.1310 39.3590 0.2280 0.58% 85 coq-fiat-crypto-with-bedrock/src/Curves/Montgomery/AffineProofs.v.html │ │ 20.3710 20.5980 0.2270 1.11% 12 coq-fourcolor/theories/job311to314.v.html │ │ 44.1490 44.3580 0.2090 0.47% 558 coq-fiat-crypto-with-bedrock/rupicola/bedrock2/bedrock2/src/bedrock2Examples/insertionsort.v.html │ │ 28.9400 29.1450 0.2050 0.71% 12 coq-fourcolor/theories/job611to617.v.html │ │ 23.3510 23.5550 0.2040 0.87% 12 coq-fourcolor/theories/job542to545.v.html │ │ 33.7670 33.9690 0.2020 0.60% 548 coq-fiat-crypto-with-bedrock/rupicola/bedrock2/compiler/src/compiler/MMIO.v.html │ │ 2.1700 2.3710 0.2010 9.26% 1363 coq-perennial/src/program_proof/buf/buf_proof.v.html │ │ 0.1270 0.3210 0.1940 152.76% 167 coq-fiat-crypto-with-bedrock/src/PushButtonSynthesis/FancyMontgomeryReduction.v.html │ └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ 🐇 Top 25 speed ups┌────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │ TOP 25 SPEED UPS │ │ │ │ OLD NEW DIFF %DIFF Ln FILE │ ├────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ 112.3530 111.5240 -0.8290 -0.74% 48 coq-fiat-crypto-with-bedrock/src/Curves/Weierstrass/AffineProofs.v.html │ │ 43.1860 42.5950 -0.5910 -1.37% 236 coq-rewriter/src/Rewriter/Rewriter/Examples/PerfTesting/LiftLetsMap.v.html │ │ 42.4810 41.9090 -0.5720 -1.35% 827 coq-vst/veric/binop_lemmas4.v.html │ │ 65.2320 64.6740 -0.5580 -0.86% 361 coq-perennial/src/program_proof/buf/buf_proof.v.html │ │ 2.8080 2.4020 -0.4060 -14.46% 736 coq-stdlib/Reals/Cauchy/ConstructiveCauchyReals.v.html │ │ 39.9090 39.5030 -0.4060 -1.02% 835 coq-fiat-crypto-with-bedrock/src/Fancy/Compiler.v.html │ │ 9.2030 8.8220 -0.3810 -4.14% 24 coq-perennial/src/program_proof/txn/twophase_refinement_thm.v.html │ │ 27.0100 26.6390 -0.3710 -1.37% 68 coq-fiat-crypto-with-bedrock/rupicola/bedrock2/deps/riscv-coq/src/riscv/Proofs/VerifyDecode.v.html │ │ 27.7820 27.4380 -0.3440 -1.24% 12 coq-fourcolor/theories/job618to622.v.html │ │ 21.8170 21.4800 -0.3370 -1.54% 12 coq-fourcolor/theories/job490to494.v.html │ │ 17.6940 17.3730 -0.3210 -1.81% 3158 coq-fiat-crypto-with-bedrock/src/Assembly/WithBedrock/Proofs.v.html │ │ 31.8510 31.5310 -0.3200 -1.00% 12 coq-fourcolor/theories/job107to164.v.html │ │ 143.6780 143.3620 -0.3160 -0.22% 1190 coq-unimath/UniMath/CategoryTheory/GrothendieckConstruction/IsPullback.v.html │ │ 29.2430 28.9350 -0.3080 -1.05% 12 coq-fourcolor/theories/job531to534.v.html │ │ 60.4180 60.1110 -0.3070 -0.51% 27 coq-fiat-crypto-with-bedrock/src/Rewriter/Passes/ToFancyWithCasts.v.html │ │ 20.9250 20.6180 -0.3070 -1.47% 12 coq-fourcolor/theories/job207to214.v.html │ │ 42.4210 42.1150 -0.3060 -0.72% 224 coq-performance-tests-lite/PerformanceExperiments/rewrite_lift_lets_map.v.html │ │ 14.2900 13.9870 -0.3030 -2.12% 3090 coq-fiat-crypto-with-bedrock/src/Assembly/WithBedrock/Proofs.v.html │ │ 27.4910 27.2020 -0.2890 -1.05% 823 coq-perennial/src/program_proof/aof/proof.v.html │ │ 34.4980 34.2200 -0.2780 -0.81% 97 coq-vst/veric/binop_lemmas5.v.html │ │ 26.5990 26.3320 -0.2670 -1.00% 2293 coq-perennial/src/goose_lang/logical_reln_fund.v.html │ │ 18.4140 18.1510 -0.2630 -1.43% 12 coq-fourcolor/theories/job235to238.v.html │ │ 52.9840 52.7350 -0.2490 -0.47% 915 coq-fiat-crypto-with-bedrock/src/Bedrock/End2End/X25519/GarageDoor.v.html │ │ 13.2700 13.0260 -0.2440 -1.84% 187 coq-perennial/src/goose_lang/interpreter/disk_interpreter.v.html │ │ 22.4170 22.1800 -0.2370 -1.06% 12 coq-fourcolor/theories/job546to549.v.html │ └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ |
ie spurious network issue I don't think metacoq uses ltac2 much so no point rerunning a bench |
@coqbot ci minimize ci-neural_net_interp |
I am now running minimization at commit d132302 on requested target ci-neural_net_interp. I'll come back to you with the results once it's done. |
Minimized File /github/workspace/builds/coq/coq-failing/_build_ci/neural_net_interp/theories/TransformerLens/HookedTransformer.v (from ci-neural_net_interp) (full log on GitHub Actions) We are collecting data on the user experience of the Coq Bug Minimizer. Partially Minimized Coq File (could not inline Ltac2.Notations)(* -*- mode: coq; coq-prog-args: ("-emacs" "-q" "-w" "+implicit-core-hint-db,+implicits-in-term,+non-reversible-notation,+deprecated-intros-until-0,+deprecated-focus,+unused-intro-pattern,+variable-collision,+unexpected-implicit-declaration,+omega-is-deprecated,+deprecated-instantiate-syntax,+non-recursive,+undeclared-scope,+deprecated-hint-rewrite-without-locality,+deprecated-hint-without-locality,+deprecated-instance-without-locality,+deprecated-typeclasses-transparency-without-locality,-ltac2-missing-notation-var,unsupported-attributes" "-w" "-deprecated-native-compiler-option" "-native-compiler" "ondemand" "-R" "/github/workspace/builds/coq/coq-failing/_build_ci/neural_net_interp/theories" "NeuralNetInterp" "-Q" "/github/workspace/cwd" "Top" "-Q" "/github/workspace/builds/coq/coq-failing/_install_ci/lib/coq/user-contrib/Ltac2" "Ltac2" "-top" "NeuralNetInterp.TransformerLens.HookedTransformer") -*- *)
(* File reduced by coq-bug-minimizer from original input, then from 473 lines to 51 lines, then from 64 lines to 441 lines, then from 445 lines to 322 lines, then from 335 lines to 1735 lines, then from 1740 lines to 476 lines, then from 489 lines to 648 lines, then from 653 lines to 476 lines, then from 489 lines to 788 lines, then from 793 lines to 467 lines, then from 480 lines to 1012 lines, then from 1017 lines to 469 lines, then from 482 lines to 847 lines, then from 852 lines to 477 lines, then from 490 lines to 718 lines, then from 723 lines to 501 lines, then from 514 lines to 548 lines, then from 553 lines to 512 lines, then from 525 lines to 559 lines, then from 564 lines to 513 lines, then from 526 lines to 560 lines, then from 565 lines to 519 lines, then from 532 lines to 637 lines, then from 642 lines to 522 lines, then from 535 lines to 624 lines, then from 629 lines to 527 lines, then from 540 lines to 609 lines, then from 614 lines to 539 lines, then from 552 lines to 630 lines, then from 635 lines to 552 lines, then from 565 lines to 617 lines, then from 622 lines to 568 lines, then from 581 lines to 1192 lines, then from 1198 lines to 677 lines, then from 691 lines to 757 lines, then from 763 lines to 681 lines, then from 687 lines to 682 lines *)
(* coqc version 8.19+alpha compiled with OCaml 4.09.0
coqtop version 8b9980c1a092:/builds/coq/coq/_build/default,(HEAD detached at 2564ef3) (2564ef398cf18455f1f1b63024e337c08cfbab26)
Modules that could not be inlined: Ltac2.Notations
Expected coqc runtime on this file: 0.662 sec *)
Require Ltac2.Constr.
Require Ltac2.Notations.
Require Coq.Array.PArray.
Module Export Printf.
Import Ltac2.Message.
Ltac2 printf fmt := Format.kfprintf print fmt.
Module Export List.
Import Ltac2.Init.
Ltac2 rec length (ls : 'a list) :=
match ls with
| [] => 0
| _ :: xs => Int.add 1 (length xs)
end.
Ltac2 rec append ls1 ls2 :=
match ls1 with
| [] => ls2
| x :: xs => x :: append xs ls2
end.
Ltac2 rec map (f : 'a -> 'b) (ls : 'a list) :=
match ls with
| [] => []
| l :: ls => f l :: map f ls
end.
Ltac2 rec firstn (n : int) (ls : 'a list) :=
Control.assert_valid_argument "List.firstn" (Int.ge n 0);
match Int.equal n 0 with
| true => []
| false
=> match ls with
| [] => Control.throw_out_of_bounds "List.firstn"
| x :: xs
=> x :: firstn (Int.sub n 1) xs
end
end.
Ltac2 rec seq (start : int) (step : int) (last : int) :=
match Int.lt (Int.sub last start) step with
| true
=> []
| false
=> start :: seq (Int.add start step) step last
end.
Ltac2 init (len : int) (f : int -> 'a) :=
Control.assert_valid_argument "List.init" (Int.ge len 0);
map f (seq 0 1 len).
Ltac2 repeat (x : 'a) (n : 'int) :=
init n (fun _ => x).
Ltac2 rec merge (cmp : 'a -> 'a -> int) (l1 : 'a list) (l2 : 'b list) :=
let rec merge_aux l2 :=
match l1 with
| [] => l2
| a1 :: l1'
=> match l2 with
| [] => l1
| a2 :: l2'
=> match Int.le (cmp a1 a2) 0 with
| true => a1 :: merge cmp l1' l2
| false => a2 :: merge_aux l2'
end
end
end in
merge_aux l2.
Ltac2 rec merge_list_to_stack cmp stack l :=
match stack with
| [] => [Some l]
| l' :: stack'
=> match l' with
| None => Some l :: stack'
| Some l'
=> None :: merge_list_to_stack cmp stack' (merge cmp l' l)
end
end.
Ltac2 rec merge_stack cmp stack :=
match stack with
| [] => []
| l :: stack'
=> match l with
| None => merge_stack cmp stack'
| Some l => merge cmp l (merge_stack cmp stack')
end
end.
Ltac2 rec iter_merge cmp stack l :=
match l with
| [] => merge_stack cmp stack
| a::l' => iter_merge cmp (merge_list_to_stack cmp stack [a]) l'
end.
Ltac2 sort cmp l := iter_merge cmp [] l.
Ltac2 sort_uniq (cmp : 'a -> 'a -> int) (l : 'a list) :=
let rec uniq l :=
match l with
| [] => []
| x1 :: xs
=> match xs with
| [] => x1 :: xs
| x2 :: _
=> match Int.equal (cmp x1 x2) 0 with
| true => uniq xs
| false => x1 :: uniq xs
end
end
end in
uniq (sort cmp l).
End List.
Module Ltac2_DOT_Ltac2_WRAPPED.
Module Export Ltac2.
Export Ltac2.Init.
End Ltac2.
End Ltac2_DOT_Ltac2_WRAPPED.
Module Export Ltac2.
Module Ltac2.
Include Ltac2_DOT_Ltac2_WRAPPED.Ltac2.
End Ltac2.
Module Export Constr.
Import Ltac2.Ltac2.
Module Export Unsafe.
Export Ltac2.Constr.Unsafe.
Ltac2 rec kind_nocast (c : constr)
:= let k := kind c in
match k with
| Cast c _ _ => kind_nocast c
| _ => k
end.
End Unsafe.
End Constr.
Module Export MakeAbbreviations.
Import Ltac2.Ltac2.
Ltac2 mkApp (f : constr) (args : constr list) :=
make (App f (Array.of_list args)).
Ltac2 mkLambda b (body : constr) :=
make (Lambda b body).
Ltac2 mkRel (i : int) :=
make (Rel i).
Ltac2 mkVar (i : ident) :=
make (Var i).
End MakeAbbreviations.
Export Ltac2.Notations.
Ltac2 Notation "eval" "cbv" s(strategy) "in" c(tactic(6)) :=
Std.eval_cbv s c.
Ltac2 Notation "eval" "cbn" s(strategy) "in" c(tactic(6)) :=
Std.eval_cbn s c.
Reserved Infix "::'" (at level 59, left associativity).
Reserved Infix "++'" (at level 59, left associativity).
Reserved Infix "+'" (at level 48, left associativity).
Reserved Notation "√ x" (at level 5, right associativity, format "√ x").
Module Export Nat.
Fixpoint radd (n m : nat) {struct m} : nat.
exact (match m with
| 0 => n
| S p => S (radd n p)
end).
Defined.
Infix "+'" := radd : nat_scope.
Import Coq.ZArith.ZArith.
Class pointed T := point : T.
#[export] Instance default_Z : pointed Z.
Admitted.
Import Coq.Strings.String.
Definition with_default (name : string) {A} (x : A) := A.
#[global] Arguments with_default _ {_} _, _ _ _.
Existing Class with_default.
Ltac fill_default _ :=
lazymatch goal with
| [ |- @with_default ?name ?A ?x ]
=> match goal with
| [ H : @with_default ?name' ?A' _ |- _ ] => constr_eq A A'; constr_eq name name'; fail 1
| _ => exact x
end
end.
#[global] Hint Extern 0 (with_default _ _) => fill_default () : typeclass_instances.
Module Export NeuralNetInterp_DOT_Util_DOT_Arith_DOT_Classes_WRAPPED.
Module Export Classes.
Class has_eqb A := eqb : A -> A -> bool.
Class has_add_with A B C := add : A -> B -> C.
Notation has_add A := (has_add_with A A A).
Class has_sub_with A B C := sub : A -> B -> C.
Notation has_sub A := (has_sub_with A A A).
Class has_mul_with A B C := mul : A -> B -> C.
Notation has_mul A := (has_mul_with A A A).
Class has_zero A := zero : A.
Class has_one A := one : A.
Class has_max A := max : A -> A -> A.
Class has_sqrt A := sqrt : A -> A.
Class has_div_by A B C := div : A -> B -> C.
Notation has_div A := (has_div_by A A A).
Class has_exp_to A B := exp : A -> B.
Notation has_exp A := (has_exp_to A A).
Class has_coer A B := coer : A -> B.
Notation "√ x" := (sqrt x) : core_scope.
End Classes.
Module Export Arith.
Module Export Classes.
Include NeuralNetInterp_DOT_Util_DOT_Arith_DOT_Classes_WRAPPED.Classes.
#[export] Instance bool_has_eqb : has_eqb bool.
Admitted.
#[export] Instance bool_has_zero : has_zero bool.
Admitted.
#[export] Instance bool_has_one : has_one bool.
Admitted.
Coercion Z.of_N : N >-> Z.
Coercion Uint63.of_Z : Z >-> Uint63.int.
Import Coq.Numbers.Cyclic.Int63.Uint63.
Module Import Reduction.
Definition sum {A} {zeroA : has_zero A} {addA : has_add A} (start : int) (stop : int) (step : int) (f : int -> A) : A.
Admitted.
Module NeuralNetInterp_DOT_Util_DOT_PArray_WRAPPED.
Module Export PArray.
Open Scope uint63_scope.
End PArray.
End NeuralNetInterp_DOT_Util_DOT_PArray_WRAPPED.
Module Export NeuralNetInterp.
Module Export Util.
Module PArray.
Include NeuralNetInterp_DOT_Util_DOT_PArray_WRAPPED.PArray.
End PArray.
Module Export Tensor.
Import Coq.Lists.List.
Import ListNotations.
Definition Rank := nat.
Module Type IndexType.
Parameter t : Type.
End IndexType.
Module Type ExtendedIndexType.
Include IndexType.
End ExtendedIndexType.
Module Export IndexGen.
Module Make (IndexType : IndexType).
Notation IndexType := IndexType.t.
Fixpoint t (r : Rank) : Type
:= match r with
| O => unit
| S r => t r * IndexType.t
end.
Notation Index := t.
Definition nil : t 0.
Admitted.
Definition snoc {r} (s : t r) x : t (S r) := (s, x).
Module Import IndexNotations0.
Declare Scope index_scope.
Delimit Scope index_scope with index.
Notation "xs ::' x" := (snoc xs x) : index_scope.
Notation "[ ]" := nil : index_scope.
End IndexNotations0.
#[local] Open Scope index_scope.
Definition hd {r : Rank} : Index (S r) -> Index r.
exact (@fst _ _).
Defined.
Definition tl {r : Rank} : Index (S r) -> IndexType.
exact (@snd _ _).
Defined.
Fixpoint app {r1 r2 : Rank} {struct r2} : Index r1 -> Index r2 -> Index (r1 +' r2).
exact (match r2 with
| 0%nat => fun sz _tt => sz
| S r2 => fun sz1 sz2 => @app r1 r2 sz1 (hd sz2) ::' tl sz2
end%index).
Defined.
Fixpoint curriedT_dep {r : Rank} : (Index r -> Type) -> Type
:= match r with
| O => fun f => f []
| S r => fun f => curriedT_dep (fun init => forall i, f (init ::' i))
end.
Definition curriedT {r} (T : Type) : Type := @curriedT_dep r (fun _ => T).
Fixpoint uncurry_map_dep {r} : forall {A B}, (forall i, A i -> B i) -> @curriedT_dep r A -> (forall i : Index r, B i).
Admitted.
Definition uncurry_dep {r} {T} : @curriedT_dep r T -> (forall i : Index r, T i).
Admitted.
Definition uncurry {r T} : @curriedT r T -> (Index r -> T).
Admitted.
Definition uncurry_S {A r} : (Index 1 -> Index r -> A) -> (Index (1 +' r) -> A).
Admitted.
Module Export UncurryNotation.
Notation "'uncurry_fun' x1 .. xn => body"
:= (match _ return _ with
| ty => uncurry_S (fun x1 => .. (uncurry_S (fun xn => match body return Index 0 -> ty with v => fun 'tt => v end)) .. )
end)
(only parsing, at level 200, x1 binder, xn binder, body at level 200).
End UncurryNotation.
End Make.
Module ExtendedMake (IndexType : ExtendedIndexType).
Include Make IndexType.
End ExtendedMake.
End IndexGen.
Module Export Shape.
Module ShapeType <: ExtendedIndexType.
Definition t : Type.
exact (int).
Defined.
End ShapeType.
Include IndexGen.ExtendedMake ShapeType.
Module Export ShapeNotations.
Declare Scope shape_scope.
Delimit Scope shape_scope with shape.
Bind Scope shape_scope with t.
Notation "xs ::' x" := (snoc xs x) : shape_scope.
Notation "[ ]" := nil : shape_scope.
Notation "[ x ]" := (snoc nil x) : shape_scope.
Notation "[ x ; y ; .. ; z ]" := (snoc .. (snoc (snoc nil x) y) .. z) : shape_scope.
Notation "s1 ++' s2" := (app s1 s2) : shape_scope.
End ShapeNotations.
End Shape.
Notation ShapeType := Shape.IndexType.
Notation Shape := Shape.Index.
Module Export RawIndex.
Module RawIndexType <: ExtendedIndexType.
Definition t : Type.
exact (int).
Defined.
End RawIndexType.
Include IndexGen.ExtendedMake RawIndexType.
End RawIndex.
Notation RawIndexType := RawIndex.IndexType.
Notation RawIndex := RawIndex.Index.
Module Export Index.
Module IndexType <: ExtendedIndexType.
Definition t : Type.
Admitted.
End IndexType.
End Index.
Monomorphic Definition tensor {r : Rank} (s : Shape r) (A : Type) : Type.
Admitted.
Declare Scope tensor_scope.
Delimit Scope tensor_scope with tensor.
Definition with_shape {r} (s : Shape r) {A} : @Shape.curriedT r A -> A.
Admitted.
Definition ones {r} (s : Shape r) {A} {one : has_one A} : tensor s A.
Admitted.
Definition raw_get {r} {s : Shape r} {A} (t : tensor s A) (idxs : RawIndex r) : A.
Admitted.
Definition uncurry {r} {s : Shape r} {A} : @RawIndex.curriedT r A -> tensor s A.
Admitted.
Definition map' {ra1 ra2 rb} {sa1 : Shape ra1} {sa2 : Shape ra2} {sb : Shape rb} {A B} (f : tensor sa2 A -> tensor sb B) (t : tensor (sa1 ++' sa2) A) : tensor (sa1 ++' sb) B.
Admitted.
Definition to_bool {r} {s : Shape r} {A} {zero : has_zero A} {eqb : has_eqb A} (xs : tensor s A) : tensor s bool.
Admitted.
Definition tril {rnk} {s : Shape rnk} {r c} {A} {zero : has_zero A}
{diagonal : with_default "diagonal" int 0%int63} (input : tensor (s ++' [r; c]) A)
: tensor (s ++' [r; c]) A.
Admitted.
Module Export Einsum.
Import Coq.Unicode.Utf8.
Import Ltac2.Ltac2.
Ltac2 mutable debug () := false.
Module Import Internals.
Ltac2 debug_printf fmt := if debug () then Printf.printf fmt else Message.Format.kfprintf (fun x => ()) fmt.
Ltac2 Notation "debug_printf" fmt(format) := debug_printf fmt.
Ltac2 rec get_body (at_head : bool) (c : constr) :=
match Constr.Unsafe.kind_nocast c with
| Constr.Unsafe.Var v
=> let r := Std.VarRef v in
eval cbv delta [$r] in c
| Constr.Unsafe.Constant n _
=> let r := Std.ConstRef n in
eval cbv delta [$r] in c
| Constr.Unsafe.App f args
=> if at_head
then let f := get_body at_head f in
Constr.Unsafe.make (Constr.Unsafe.App f args)
else c
| _ => c
end.
Ltac2 shape_to_list (c : constr) : constr list
:= let rec aux c acc
:= lazy_match! c with
| Shape.nil => acc
| Shape.snoc ?cs ?c
=> aux cs (c :: acc)
end in
aux c [].
Ltac2 ident_of_constr (c : constr) : ident option
:= match Constr.Unsafe.kind_nocast c with
| Constr.Unsafe.Var v => Some v
| _ => None
end.
Ltac2 toplevel_rels (c : constr) : int list
:= let rec aux (acc : int list) (c : constr)
:= match Constr.Unsafe.kind_nocast c with
| Constr.Unsafe.Rel i => i :: acc
| Constr.Unsafe.App f args
=> let acc := aux acc f in
Array.fold_right aux acc args
| _ => acc
end in
List.sort_uniq Int.compare (aux [] c).
Local Notation try_tc := (ltac:(try typeclasses eauto)) (only parsing).
Ltac2 insert_all_einsums (sum_to : constr -> constr -> constr) (names : ident option list) (body : constr) : constr
:= let rawindexT := 'RawIndexType in
let nbinders := List.length names in
let body := Constr.Unsafe.liftn nbinders (Int.add 1 nbinders) body in
let rec aux (names : ident option list) (rel_above : int) (body : constr) : int * constr
:= match names with
| [] => (1, body)
| name :: names
=> let (cur_rel, body) := aux names (Int.add 1 rel_above) body in
(Int.add 1 cur_rel,
sum_to (mkRel (Int.add cur_rel rel_above)) (mkLambda (Constr.Binder.make name rawindexT) body))
end in
let (_, body) := aux names 0 body in
body.
Ltac2 constr_dropn (n : int) (k : int) (c : constr) : constr
:= let k := Int.sub k 1 in
let invalid := mkVar (ident:(__CONSTR_DROPN_INVALID)) in
debug_printf "dropping %i %i %t" n k c;
let res := Constr.Unsafe.substnl (List.repeat invalid n) k c in
debug_printf "dropped %i %i %t" n k res;
res.
Ltac2 Type exn ::= [ InternalEinsumNotEqual (constr, constr) ].
Ltac2 Type exn ::= [ InternalEinsumNotEnoughArgs (int, constr, constr array) ].
Ltac2 Type exn ::= [ InternalEinsumBadKind (Constr.Unsafe.kind) ].
Ltac2 rec remove_dead_einsum_helper (hd_c : constr) (nargs : int) (names : 'a list) (body : constr) : int * int list * ( int * constr)
:= match names with
| [] => (1, toplevel_rels body, (0, body))
| _ :: names
=> match Constr.Unsafe.kind_nocast body with
| Constr.Unsafe.App f args
=> (if Int.ge nargs (Array.length args)
then Control.throw (InternalEinsumNotEnoughArgs nargs f args)
else ());
(let first_args := List.firstn nargs (Array.to_list args) in
let fargs := mkApp f first_args in
if Bool.neg (Constr.equal fargs hd_c)
then Control.throw (InternalEinsumNotEqual fargs hd_c)
else ());
let lam_body_pos := Int.sub (Array.length args) 1 in
let lam_body := Array.get args lam_body_pos in
match Constr.Unsafe.kind_nocast lam_body with
| Constr.Unsafe.Lambda b body
=> let (cur_rel, used_rels, (accumulated_shift, body)) := remove_dead_einsum_helper hd_c nargs names body in
let (cur_rel_used, used_rels)
:= match used_rels with
| r :: rs => if Int.equal cur_rel r
then (true, rs)
else (false, used_rels)
| [] => (false, used_rels)
end in
(Int.add cur_rel 1,
used_rels,
if cur_rel_used
then
let body := constr_dropn (Int.neg accumulated_shift) 1 body in
let lam_body := mkLambda b body in
Array.set args lam_body_pos lam_body;
(0, Constr.Unsafe.make (Constr.Unsafe.App f args))
else
(Int.sub accumulated_shift 1,
body))
| k => Control.throw (InternalEinsumBadKind k)
end
| k => Control.throw (InternalEinsumBadKind k)
end
end.
Ltac2 remove_dead_einsum (hd_c : constr) (nargs : int) (names : 'a list) (body : constr) : constr
:= debug_printf "remove dead from %t" body;
let (_cur_rel, _used_rels, (accumulated_shift, body)) := remove_dead_einsum_helper hd_c nargs names body in
constr_dropn (Int.neg accumulated_shift) 1 body.
Ltac2 insert_einsums (ty : constr) (names : ident option list) (body : constr) : constr
:= let start := '(0%uint63) in
let step := '(1%uint63) in
let sum := '(@Reduction.sum $ty try_tc try_tc) in
let sum_to stop body := mkApp sum [start; stop; step; body] in
let body := insert_all_einsums sum_to names body in
let n_sum_args := match Constr.Unsafe.kind_nocast sum with
| Constr.Unsafe.App _ args => Array.length args
| k => Control.throw (InternalEinsumBadKind k)
end in
let body := remove_dead_einsum sum n_sum_args names body in
body.
Ltac2 Type exn ::= [ EinsumExtraNames (Constr.Unsafe.kind, ident option list) ].
Ltac2 insert_all_einsums_below (ty : constr) (names : ident option list) (body : constr) : constr
:= let rec aux (cur_names : ident option list) (body : constr)
:= match cur_names with
| [] => insert_einsums ty names body
| n :: ns
=> match Constr.Unsafe.kind_nocast body with
| Constr.Unsafe.Lambda b body
=> mkLambda b (aux ns body)
| k => Control.throw (EinsumExtraNames k cur_names)
end
end in
aux names body.
Ltac2 make_einsum (shapes : constr list) (body : constr) : constr
:= let ty := Constr.type body in
let c := Std.eval_pattern (List.map (fun s => (s, Std.AllOccurrences)) shapes) body in
let names := List.map ident_of_constr shapes in
match Constr.Unsafe.kind_nocast c with
| Constr.Unsafe.App f shape_args
=> let f := insert_all_einsums_below ty names f in
let f := Constr.Unsafe.make (Constr.Unsafe.App f shape_args) in
let f := (eval cbv beta in f) in
f
| k
=> Control.throw (InternalEinsumBadKind k)
end.
Ltac subst_type_lets_in_goal _ :=
repeat match goal with
| [ H := [ _ ] : Type |- _ ] => match goal with |- context[H] => idtac end; subst H
end.
End Internals.
Local Notation try_tc := (ltac2:(ltac1:(try typeclasses eauto))) (only parsing).
Local Notation indirect_einsum tensor_value ishape jshape
:= ltac2:(let get_body v := get_body false (Constr.pretype v) in
let get_shape v := shape_to_list (get_body v) in
let t := get_body tensor_value in
let shapes := List.append (get_shape ishape) (get_shape jshape) in
let t := make_einsum shapes t in
exact $t)
(only parsing).
#[local] Notation "'unify_rank_from_idxs' r @ i1 .. i_"
:= ((uncurry_fun i1 .. i_ => I) : RawIndex r -> True)
(only parsing, i1 binder, i_ binder, at level 10).
#[export] Hint Extern 1 => progress subst_type_lets_in_goal () : typeclass_instances.
Declare Custom Entry einsum_args.
Notation "{{{ {{ i1 .. i_ , j1 .. j_ -> k1 .. k_ }} , t1 , t2 }}}"
:= (match t1%tensor, t2%tensor, _ as A, _ as B, _ as C, _ as r1, _ as r2, _ as r3, _ as s1, _ as s2, _ as s3 return @tensor _ s3 C with
| t1', t2', A, B, C, r1, r2, r3, s1, s2, s3
=> match t1' : @tensor r1 s1 A, t2' : @tensor r2 s2 B return @tensor r3 s3 C with
| t1', t2'
=> match
unify_rank_from_idxs r1 @ i1 .. i_,
unify_rank_from_idxs r2 @ j1 .. j_,
unify_rank_from_idxs r3 @ k1 .. k_
return @tensor r3 s3 C
with
| _, _, _
=> @with_shape
r1 s1 (tensor s3 C)
(λ i1 .. i_ ,
@with_shape
r2 s2 (tensor s3 C)
(λ j1 .. j_ ,
(match
(Shape.snoc .. (Shape.snoc Shape.nil i1) .. i_),
(Shape.snoc .. (Shape.snoc Shape.nil j1) .. j_)
return @tensor r3 s3 C
with
| __I_SHAPE, __J_SHAPE
=> @Tensor.uncurry
r3 s3 C
(λ k1 .. k_ ,
match @Arith.Classes.mul
A B C try_tc
(raw_get t1' (RawIndex.snoc .. (RawIndex.snoc RawIndex.nil i1) .. i_))
(raw_get t2' (RawIndex.snoc .. (RawIndex.snoc RawIndex.nil j1) .. j_))
return C
with
| __EINSUM_TENSOR_VALUE
=> indirect_einsum
__EINSUM_TENSOR_VALUE __I_SHAPE __J_SHAPE
end)
end)))
end
end
end)
(only parsing, in custom einsum_args at level 0, i1 binder, i_ binder, j1 binder, j_ binder, k1 binder, k_ binder, t1 constr at level 10, t2 constr at level 10).
Notation "'weaksauce_einsum' x"
:= (match x return _ with
| y => ltac2:(let y := get_body false &y in
let z := (eval cbv beta iota delta [Tensor.with_shape Shape.uncurry Tensor.uncurry RawIndex.uncurry Shape.uncurry_dep RawIndex.uncurry_dep Shape.uncurry_map_dep RawIndex.uncurry_map_dep] in y) in
let z := (eval cbn beta iota delta [Nat.radd] in z) in
let z := (eval cbn beta iota zeta delta [Shape.snoc Shape.nil] in z) in
exact $z)
end)
(x custom einsum_args at level 10, at level 10, only parsing).
End Einsum.
Import NeuralNetInterp.Util.PArray.
Section __.
Context {r} {batch : Shape r}
{pos n_heads d_model d_head} {n_ctx:N}
{use_split_qkv_input : with_default "use_split_qkv_input" bool false}
{A}
{zeroA : has_zero A} {coerZ : has_coer Z A}
{addA : has_add A} {subA : has_sub A} {mulA : has_mul A} {divA : has_div A}
{maxA : has_max A}
{sqrtA : has_sqrt A} {expA : has_exp A}
(defaultA : pointed A := @coer _ _ coerZ point)
{use_checkpoint : with_default "use_checkpoint" bool true}
(W_Q W_K W_V W_O : tensor [n_heads; d_model; d_head] A)
(b_Q b_K b_V : tensor [n_heads; d_head] A)
(b_O : tensor [d_model] A)
(IGNORE : A := coerZ (-1 * 10 ^ 5)%Z)
(attn_scale : A := √(coer (Uint63.to_Z d_head)))
(maybe_n_heads := fun b : bool => (if b return Shape (if b then _ else _) then [n_heads] else [])%shape)
(query_input key_input value_input : tensor ((batch ::' pos) ++' (maybe_n_heads use_split_qkv_input ::' d_model)) A)
(mask : tensor [n_ctx; n_ctx] bool := to_bool (tril (A:=bool) (ones [n_ctx; n_ctx]))).
Definition einsum_input
(input : tensor ((batch ::' pos) ++' (maybe_n_heads use_split_qkv_input ::' d_model)) A)
(W : tensor [n_heads; d_model; d_head] A)
: tensor ((batch ::' pos) ++' [n_heads; d_head]) A.
exact (Tensor.map'
(if use_split_qkv_input return tensor (maybe_n_heads use_split_qkv_input ::' d_model) A -> tensor [n_heads; d_head] A
then fun input => weaksauce_einsum {{{ {{ head_index d_model,
head_index d_model d_head
-> head_index d_head }}
, input
, W }}}
else fun input => weaksauce_einsum {{{ {{ d_model,
head_index d_model d_head
-> head_index d_head }}
, input
, W }}})
input). Intermediate Coq File (useful for debugging if minimization did not go as far as you wanted) (truncated to 8.0KiB; full 42KiB file on GitHub Actions Artifacts under
|
It seems the exact in Notation "'weaksauce_einsum' x"
:= (match x return _ with
| y => ltac2:(let y := get_body false &y in
let z := (eval cbv beta iota delta [Tensor.with_shape Shape.uncurry Tensor.uncurry RawIndex.uncurry Shape.uncurry_dep RawIndex.uncurry_dep Shape.uncurry_map_dep RawIndex.uncurry_map_dep] in y) in
let z := (eval cbn beta iota delta [Nat.radd] in z) in
let z := (eval cbn beta iota zeta delta [Shape.snoc Shape.nil] in z) in
exact $z)
end)
(x custom einsum_args at level 10, at level 10, only parsing). produces an evar when unifying the type with the goal. With the master implementation of exact With this PR it is rejected (we do IDK which behaviour should be considered correct. |
cc @JasonGross who wrote the code leading to the example |
I'm a bit confused, can you explain what's going on a bit more? My current understanding:
Ultimately, I'm happy to replace the |
The goal is an evar and z contains an evar.
Not sure exactly but I guess so
yes
there are new evars (?)
It does the same pretype call |
d132302
to
4cf4ca4
Compare
The job library:ci-fiat_crypto_legacy has failed in allow failure mode |
4cf4ca4
to
f5e20e0
Compare
user-contrib/Ltac2/Constr.v
Outdated
(** Allows new unsolved evars. *) | ||
End Flags. | ||
|
||
Ltac2 Type expected_type := [ IsType | OfType (constr) | WithoutTypeConstraint ]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a good idea to expose this type as an ADT. For future-proofness it should come with opaque constructors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What change are you expecting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An opaque type where the constructors are turned into externals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean what future are you proofing aganst.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't know about the future, that's the point. As a rule of thumb, I think that Ltac2 should not transparently expose type implementations bound to the OCaml ones, except maybe low-level stuff like constr views.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
f5e20e0
to
a4dab25
Compare
Fix coq#12827 The implementation uses a generalization of Constr.pretype which takes flags (an opaque type) and a typing constraint. Changing `refine` is left to the future as the notation takes a tactic thunk at constr type so would be backwards incompatible.
a4dab25
to
3451f1f
Compare
@coqbot run full ci |
@coqbot run full ci |
@coqbot merge now |
Fix #12827
The implementation uses a generalization of Constr.pretype which takes flags (an opaque type) and a typing constraint.
Changing
refine
is left to the future as the notation takes a tactic thunk at constr type so would be backwards incompatible.Overlays: