-
Notifications
You must be signed in to change notification settings - Fork 4
/
evaluation.lean
1246 lines (1060 loc) · 48.1 KB
/
evaluation.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/-
Copyright (c) 2020 Jason Rute. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Jason Rute and Jesse Michael Han
REPL for interacting with Lean
-/
import data.nat.basic
import system.io
import tactic.core
import tactic
import utils
import .tactic_state
import basic.queue
section tactic_state
open interaction_monad.result
setup_tactic_parser
/- set this to tt to enable tracing
`tactic.set_bool_option` is tactic state specific,
and is difficult to globally set -/
meta def num_goals' : tactic_state → option ℕ :=
λ ts, match tactic.num_goals ts with | (success val _) := pure val | _ := none end
-- TODO(jesse): this is a hack. might be better to do this in python
meta def consume_with_parser {α} (p : lean.parser α) : string → io string := λ inp, do {
io.run_tactic' $ do
prod.snd <$> (lean.parser.run_with_input (with_input p inp) "")
}
-- TODO(jesse): performance
meta def consume_spaces : string → string
| arg@⟨[]⟩ := arg
| arg@⟨(x::xs)⟩ := if x = ' ' then consume_spaces ⟨xs⟩ else arg
-- WARNING: this is a hack
meta def remove_indents_with_split (c : char := '\t'): string → string := λ str,
let strs := str.split (= '\t') in
string.intercalate (⟨['\t']⟩ : string) (consume_spaces <$> strs)
meta def postprocess_tactic_state : tactic_state → tactic string := λ ts, do
let main : tactic string := do {
let ts_str := ts.to_format.to_string,
tabbed_ts_str ← do {
if (num_goals' ts).get_or_else 0 ≤ 1
then pure $ ts_str.replace_char '\n' '\t'
else tactic.unsafe_run_io $ (λ x, string.replace_char x '\n' '\t')
<$> (consume_with_parser small_nat >=>
consume_with_parser ident) ts_str},
pure $ remove_indents_with_split '\t' tabbed_ts_str
},
main <|> (let msg := "[postprocess_tactic_state] WARNING: POSTPROCESSING FAILED" in tactic.trace msg *> tactic.fail msg)
end tactic_state
meta def add_open_namespace : name → tactic unit := λ nm, do
env ← tactic.get_env, tactic.set_env (env.execute_open nm)
meta def add_open_namespaces (nms : list name) : tactic unit :=
nms.mmap' add_open_namespace
section run_with_state'
namespace interaction_monad
open interaction_monad.result
meta def run_with_state' {σ₁ σ₂ : Type} {α : Type*} (state : σ₁) (tac : interaction_monad σ₁ α) : interaction_monad σ₂ α :=
λ s, match (tac state) with
| (success val _) := success val s
| (exception fn pos _) := exception fn pos s
end
end interaction_monad
end run_with_state'
namespace tactic
open interaction_monad.result
meta def run (tac : tactic unit) : tactic (interaction_monad.result tactic_state unit) := do {
σ ← get_state,
match tac σ with
| r@(success _ new_state) := interaction_monad.set_state new_state *> pure r
| r@(exception fn pos new_state) := pure r
end
}
-- meta instance has_format_result {α σ} [has_to_format σ] [has_to_format α] : has_to_format (interaction_monad.result σ α) := ⟨by mk_to_format `interaction_monad.result⟩ -- ayyy
meta instance has_to_format_tactic_result {α : Type*} [has_to_format α] : has_to_format (interaction_monad.result tactic_state α) :=
⟨λ r,
match r with
| (success val new_state) := format!"SUCCESS!\nNEW_STATE: {new_state}\nVAL: {val}"
| (exception fn pos old_state) := do {
let msg := (fn.get_or_else (λ _, format.of_string "n/a")) (),
format!"EXCEPTION!\nMSG: {msg}\nPOS: {pos}\nOLD_STATE: {old_state}"
}
end
⟩
meta instance has_to_tactic_format_tactic_result {α : Type*} [has_to_format α] : has_to_tactic_format (interaction_monad.result tactic_state α) :=
⟨λ σ, pure $ has_to_format.to_format σ⟩
end tactic
section parse_eval_tac
setup_tactic_parser
open tactic
meta def parse_eval_tac
(ps : parser_state)
(tactic_string : string)
: tactic (tactic unit × format) := do {
let itactic_string := "{" ++ tactic_string ++ "}",
texpr ← (reflected_value.expr ∘ prod.fst) <$>
(interaction_monad.run_with_state' ps $ with_input parser.itactic_reflected itactic_string),
prod.mk <$> (eval_expr (tactic unit) texpr) <*> has_to_tactic_format.to_tactic_format texpr
}
end parse_eval_tac
section frontend
open tactic lean lean.parser interactive
meta def read_eval_print_loop (ps : parser_state) : tactic unit :=
do
trace "\nTactic state:",
trace_state,
let rest : tactic unit := do
{trace "\nEnter a tactic command:",
tactic_string <- tactic.unsafe_run_io $ io.get_line,
(t, fmt) ← parse_eval_tac ps tactic_string,
trace "",
trace ("Running tactic:\n" ++ fmt.to_string),
tactic.run t >>= eval_trace, -- runs the tactic on the goal. It is crashing
read_eval_print_loop}, --- loops forever
done <|> rest
--- like main_t, but sets the goal and environment to a user-supplied theorem in mathlib
--- note: if the environment contains all declarations in mathlib,
--- and olean files exist,
--- we don't need to supply the lean file as env.decl_olean will find it automatically.
meta def main_t_at : parser_state → tactic unit := λ ps, do {
trace "enter declaration to prove",
goal_nm_string ← tactic.unsafe_run_io $ io.get_line,
⟨nm, _⟩ ← interaction_monad.run_with_state' ps $ with_input ident goal_nm_string,
env ← get_env,
decl ← env.get nm,
let g := decl.type,
set_goal_to g,
lean_file ← env.decl_olean nm,
set_env_core $ environment.for_decl_of_imported_module lean_file nm,
read_eval_print_loop ps
}
@[user_command]
meta def main_app_at
(meta_info : decl_meta_info) (_ : parse (tk "main_app_at")) : lean.parser unit :=
get_state >>= of_tactic ∘ (main_t_at)
end frontend
section main
setup_tactic_parser
open tactic
meta def main_t : parser_state → tactic unit := λ ps,
do
trace "Enter a goal:",
set_goal_to `(true),
goal_string <- tactic.unsafe_run_io $ io.get_line,
trace "GOAL STRING: " *> trace goal_string,
(goal_pexpr, _) ← interaction_monad.run_with_state' ps $ with_input types.texpr goal_string,
eval_trace ps.cur_pos,
set_goal_to <e> goal_pexpr,
(read_eval_print_loop ps) *> main_t ps -- loops forever
@[user_command]
meta def main_app
(meta_info : decl_meta_info) (_ : parse (tk "main_app")) : lean.parser unit :=
(get_state >>= of_tactic ∘ main_t)
end main
section parse_tac
setup_tactic_parser
open tactic
/-- Parse a reflected interactive tactic from a string.
The result can be evaluated to a `tactic unit` by using
`eval_expr (tactic unit)`. -/
meta def parse_itactic_reflected (tactic_string : string) : tactic expr :=
let itactic_string := "{ " ++ tactic_string ++ " }" in
lean.parser.run $ do
get_state >>= λ ps, of_tactic $ do
tactic.set_env ps.env,
-- eval_trace format!"[parse_itactic_reflected] TRYING TO PARSE {itactic_string}",
(reflected_value.expr ∘ prod.fst) <$>
(@interaction_monad.run_with_state' parser_state _ _ ps $
with_input parser.itactic_reflected itactic_string)
/-- Parse an interactive tactic from a string. -/
meta def parse_itactic (tactic_string : string) : tactic (tactic unit) := do
rtac ← parse_itactic_reflected tactic_string,
eval_expr (tactic unit) rtac
end parse_tac
section evaluation_harness
meta def run_tac_with_tactic_state
(tac : tactic unit)
(ts : tactic_state)
: tactic (result _ _) := do {
tactic.write ts,
pure $ tac ts
}
meta structure EvaluationInput : Type :=
(decl_nm : name)
(ts_data : tactic_state_data)
(tactic_string : string)
(open_namespaces : list name)
/- (before_state, action, after_state) tuple -/
meta structure EvaluationResult : Type :=
(before_state : tactic_state)
(tactic_string : string)
(result : result tactic_state unit)
meta instance : has_to_format EvaluationResult :=
⟨λ ⟨σ, tac_str, r⟩,
format.join [
-- hmm. why doesn't the highlighting work?
format.highlight "INPUT STATE:\n" format.color.pink,
σ.to_format,
"\n\n",
format.highlight "ACTION:\n" format.color.pink,
tac_str,
"\n\n",
format.highlight "RESULT:\n" format.color.pink,
(has_to_format.to_format r)
]
⟩
/-- Creates an empty tactic state. -/
meta def mk_tactic_state : tactic tactic_state :=
tactic.unsafe_run_io $ io.run_tactic' $ tactic.exact `(trivial) *> tactic.read
/-- creates tactic_state_data as if we were proving the declaration
(currently only theorems are supported) with name `decl_nm`. -/
meta def get_tsd_at_decl (decl_nm : name) : tactic tactic_state_data := do {
env ← tactic.get_env,
decl ← env.get decl_nm,
mk_tactic_state >>= tactic.write,
ts ← tactic.read,
tactic.set_goal_to decl.type,
result ← tactic_state_data.get,
tactic.write ts,
pure result
}
meta def run_evaluation : EvaluationInput → tactic EvaluationResult :=
λ ⟨decl, ts_data, tactic_string, ns⟩, (tactic.unsafe_run_io ∘ io.run_tactic') $ do {
tac ← parse_itactic tactic_string,
env ← get_env_at_decl decl,
tactic.set_env_core env,
add_open_namespaces ns,
ts ← rebuild_tactic_state ts_data *> tactic.read,
EvaluationResult.mk ts tactic_string <$> run_tac_with_tactic_state tac ts
}
section greedy_proof_search
meta structure ModelAPI (input_format : Type := json) : Type :=
(query : input_format → io json)
/- for testing -/
meta def dummy_api {α} : ModelAPI α :=
⟨λ _, pure $ json.of_string "[DummyAPI] FAILURE"⟩
namespace tactic
open interaction_monad interaction_monad.result
/- capture but backtrack the state -/
meta def capture' {α} (t : tactic α) : tactic (tactic_result α) :=
λ s, match t s with
| (success r s') := success (success r s') s
| (exception f p s') := success (exception f p s') s
end
end tactic
meta def tactic_hash : tactic ℕ := do {
gs ← tactic.get_goals,
hs ← gs.mmap $ λ g, do {
tactic.set_goal_to g,
es ← (::) <$> tactic.target <*> tactic.local_context,
pure $ es.foldl (λ acc e, acc + e.hash) 0},
pure $ hs.sum
}
meta def compare_tactic_state : tactic_state → tactic_state → tactic bool := λ ts₁ ts₂, do {
ts ← tactic.read,
h₁ ← (tactic.write ts₁ *> tactic_hash),
h₂ ← (tactic.write ts₂ *> tactic_hash),
tactic.write ts,
pure $ h₁ = h₂
}
-- let done_handler : state_t GreedyProofSearchState tactic bool := do {
-- state_t.lift $ do {
-- tactic.done,
-- tactic.result >>= (λ x, tactic.guard_sorry x <|>
-- eval_trace format! "[greedy_proof_search_step] WARNING: result contains sorry" *> tactic.failed),
-- tactic.result >>= (λ x, tactic.type_check x <|>
-- eval_trace format! "[greedy_proof_search_step] WARNING: result failed typechecking" *> tactic.failed)
-- },
meta def validate_proof (pf : expr) : tactic unit := do {
let tac (e : expr) : tactic unit := do {
mk_tactic_state >>= tactic.write,
guard (bnot pf.has_meta_var),
tactic.guard_sorry e,
tactic.type_check e
},
result ← tactic.capture' (tac pf),
match result with
| (interaction_monad.result.success r s') := pure ()
| (interaction_monad.result.exception f p s') := tactic.fail "[validate_proof] ERROR: VALIDATION FAILED"
end
}
meta def tactic_state.is_done (state : tactic_state) : tactic bool := do {
ts ← tactic.read,
result_flag ← do {
tactic.write state,
(do {
tactic.done *> pure tt
}) <|> pure ff
},
tactic.write ts,
pure result_flag
}
meta def tactic_result.is_done {α} (tr : tactic_result α) : tactic bool := do {
match tr with
| (interaction_monad.result.success val state) := state.is_done
| (interaction_monad.result.exception _ _ _) := pure ff
end
}
meta def get_tac_and_capture_result (next_candidate : string) (timeout : ℕ := 5000) : tactic (tactic_result unit) := do {
tac ← do {
env ← tactic.get_env,
eval_trace format!"[get_tac_and_capture_result] PARSING TACTIC: {next_candidate}",
tac ← parse_itactic next_candidate,
eval_trace format!"[get_tac_and_capture_result] PARSE SUCCESSFUL",
tactic.set_env env,
pure tac
},
eval_trace format!"[get_tac_and_capture_result] TRYING TACTIC: {next_candidate}",
result ← tactic.capture' (tactic.try_for_time timeout $ tactic.try_for 200000 tac), -- if `tac` fails, exception is captured here
eval_trace format!"[get_tac_and_capture_result] RESULT: {result}",
/- use tactic state hashing to fail on no-ops modulo permutation -/
result ← match result with
| (interaction_monad.result.success val ts') := do {
ts ← tactic.read,
mcond (compare_tactic_state ts ts')
(pure $ interaction_monad.mk_exception "tactic state no-op" none ts')
(pure result)
}
| exc := pure exc
end,
pure result
}
meta def try_get_tac_and_capture_result (tac_string : string) (timeout : ℕ := 5000) : tactic (tactic_result unit) := do {
get_tac_and_capture_result tac_string timeout <|> do {
let msg : format := format!"[try_get_tac_and_capture_result] parse_itactic failed on {tac_string}",
eval_trace msg,
interaction_monad.mk_exception msg none <$> tactic.read
}
}
-- caller of this function is responsible for inspecting results and stopping the search
meta def run_all_beam_candidates
(get_candidates : json → tactic (list (string × native.float)))
(msg : json)
(tac_timeout : ℕ := 5000)
: tactic (list (tactic_result unit × string × native.float) × list string) := do {
let try_candidate_state := (list (string × native.float) × (list $ option $ tactic_result unit × string × native.float)),
let stop : option (tactic_result unit × string × native.float) → state_t try_candidate_state tactic bool :=
λ arg, match arg with
| some ⟨result, candidate⟩ := do {
state_t.lift result.is_done
}
| none := pure ff
end,
/- TODO(jesse): get rid of this state_t and just use `run_async <$> ...` instead -/
let try_candidate : state_t try_candidate_state tactic (option $ tactic_result unit × string × native.float) := do {
state_t.lift $ eval_trace format!"[try_candidate] ENTERING",
ts ← state_t.lift tactic.read,
state_t.lift $ eval_trace format!"[try_candidate] READ TACTIC STATE",
⟨rest, _⟩ ← state_t.get,
match rest with
| [] := do {
state_t.lift $ eval_trace format!"[try_candidate] END OF LOOP",
pure $ some ⟨interaction_monad.fail "all candidates failed" ts, "FAILURE", 0.0⟩
}
| (next_candidate::candidates) := do {
state_t.modify (λ ⟨_, rs⟩, ⟨candidates, rs⟩),
result ← monad_lift $ try_get_tac_and_capture_result next_candidate.fst tac_timeout,
when (interaction_monad.result.is_success $ result) $
state_t.modify $ λ ⟨candidates, rs⟩, ⟨candidates, rs ++ [some $ ⟨result, next_candidate⟩]⟩,
state_t.lift $ eval_trace format!"[try_candidate] CAPTURED RESULT: {result}",
pure $ some ⟨result, next_candidate⟩
}
end
},
-- let find_successful_candidates
-- (candidates : list (string × native.float))
-- : tactic (list (tactic_result unit × string × native.float)) := do {
-- tasks ← candidates.mmap (λ arg, flip prod.mk arg <$> tactic.run_async (try_get_tac_and_capture_result arg.fst : tactic $ tactic_result unit)),
-- tactic.using_new_ref ff $ λ flag, do
-- tasks.iterM [] $ λ acc ⟨task, tac_string, score⟩, do {
-- mcond (tactic.read_ref flag) (pure acc) $ do {
-- let result := task.get,
-- if (interaction_monad.result.is_success result) then do {
-- whenM (result.is_done) $ tactic.write_ref flag tt,
-- pure $ acc ++ [⟨result, tac_string, score⟩]
-- } else do {
-- pure acc
-- }
-- }
-- }
-- },
-- this is responsible for gracefully handling "error" JSON messages and should return an empty list of candidates
unwrapped_candidates ← get_candidates msg,
-- eval_trace format!"[run_all_beam_candidates] UNWRAPPED CANDIDATES: {unwrapped_candidates}",
dedup_unwrapped_candidates ← list.dedup' unwrapped_candidates,
eval_trace format!"[run_all_beam_candidates] DEDUP_UNWRAPPED CANDIDATES: {dedup_unwrapped_candidates}",
let candidates := list.filter (λ x, ¬ "tidy".is_prefix_of (prod.fst x)) dedup_unwrapped_candidates,
eval_trace format!"[run_all_beam_candidates] CANDIDATES: {candidates}",
-- old failure callback
-- let failure_callback := do {
-- -- do ts ← state_t.lift tactic.read, pure ⟨interaction_monad.fail "all candidates failed" ts, "FAILURE", 0.0⟩
-- }
--
successful_candidates ← (prod.snd <$> prod.snd <$> state_t.run (iterate_until try_candidate stop candidates.length $ pure none) ⟨candidates, []⟩),
-- successful_candidates ← find_successful_candidates candidates,
eval_trace format!"[run_all_beam_candidates] EXITING TRY_CANDIDATE LOOP",
eval_trace format!"[run_all_beam_candidates] SUCCESSFUL CANDIDATES: {successful_candidates}",
pure ⟨successful_candidates.filter_map id, prod.fst <$> candidates⟩
}
meta def run_best_beam_candidate
(get_candidates : json → tactic (list string))
(msg : json)
(tac_timeout : ℕ := 5000)
: tactic (tactic_result unit × string × list string) := do {
let try_candidates : state_t (list string) tactic ((tactic_result unit) × string) := do {
ts ← state_t.lift (tactic.read),
(next_candidate::candidates) ← state_t.get | pure (interaction_monad.fail "all candidates failed" ts, "FAILURE"),
state_t.modify (λ _, candidates),
flip prod.mk next_candidate <$> (state_t.lift $ try_get_tac_and_capture_result next_candidate tac_timeout)
-- let get_tac_and_capture_result : state_t (list string) tactic ((tactic_result unit) × string) := do {
-- tac ← state_t.lift $ do {
-- env ← tactic.get_env,
-- eval_trace format!"[run_best_beam_candidate] PARSING TACTIC: {next_candidate}",
-- tac ← parse_itactic next_candidate, -- this is the only possible point of program failure
-- eval_trace format!"[run_best_beam_candidate] PARSE SUCCESSFUL",
-- tactic.set_env env,
-- pure tac
-- },
-- state_t.lift $ do
-- eval_trace format!"[run_best_beam_candidate] TRYING TACTIC: {next_candidate}",
-- result ← tactic.capture' (tactic.try_for 200000 tac), -- if `tac` fails, exception is captured here
-- eval_trace format!"[run_best_beam_candidate] RESULT: {result}",
-- result ← match result with
-- | (interaction_monad.result.success val ts') := do {
-- ts ← tactic.read,
-- -- ite ((format.to_string $ has_to_format.to_format ts) = (format.to_string $ has_to_format.to_format ts'))
-- mcond (compare_tactic_state ts ts')
-- (pure $ interaction_monad.mk_exception "tactic state no-op" none ts')
-- (pure result)
-- }
-- | exc := pure exc
-- end,
-- pure ⟨result, next_candidate⟩
-- },
-- get_tac_and_capture_result <|> -- handle `parse_itactic` failure here
-- let msg : format := format!"[run_best_beam_candidate.try_candidates] parse_itactic failed on {next_candidate}" in do
-- state_t.lift (eval_trace msg),
-- (pure ⟨interaction_monad.mk_exception
-- msg none ts, next_candidate⟩)
},
let stop : tactic_result unit × string → state_t (list string) tactic bool :=
pure ∘ interaction_monad.result.is_success ∘ prod.fst,
candidates ← list.filter (λ x, ¬ "tidy".is_prefix_of x) <$> (get_candidates msg >>= list.dedup),
eval_trace format!"[run_best_beam_candidate] CANDIDATES: {candidates}",
try_candidates_result@⟨result, tac_string⟩ ← ((prod.fst <$> state_t.run (iterate_until try_candidates stop) candidates) <|>
do {
old_state ← tactic.read,
pure (prod.mk (interaction_monad.mk_exception "all candidates failed" none old_state) "all failed")
}),
eval_trace format!"[run_best_beam_candidate] TRY_CANDIDATES_RESULT: {try_candidates_result}",
pure ⟨result, tac_string, candidates⟩
}
-- TODO(jesse): finalize JSON format for serialize_ts/decode_response protocol
meta structure GreedyProofSearchState : Type :=
(depth : ℕ := 0)
(tactics : list string := [])
(predictions : list (list string) := [])
(states : list tactic_state := []) -- TODO(jesse): this might make the logs extremely verbose
(success : bool := ff)
(all_failed : bool := ff)
(task_id : string := "")
(global_timeout : bool := ff)
(tac_timeout : ℕ := 5)
(fuel_exhausted : bool := ff)
(decl_goal : string := "")
attribute [derive has_to_format] GreedyProofSearchState
meta instance : has_mark_global_timeout GreedyProofSearchState :=
⟨λ σ, {global_timeout := tt, ..σ}⟩
meta instance : has_register_task_id GreedyProofSearchState :=
⟨λ σ task, {task_id := task, ..σ}⟩
meta instance : has_set_tac_timeout GreedyProofSearchState :=
⟨λ σ timeout, {tac_timeout := timeout, ..σ}⟩
meta instance : has_get_tac_timeout GreedyProofSearchState :=
⟨GreedyProofSearchState.tac_timeout⟩
meta instance : has_mark_fuel_exhausted GreedyProofSearchState :=
⟨λ σ, {fuel_exhausted := tt, ..σ}⟩
meta instance : has_register_decl_goal GreedyProofSearchState :=
⟨λ σ decl_goal, {decl_goal := decl_goal, ..σ}⟩
meta def serialize_list_string : list string → json := λ xs,
json.array $ json.of_string <$> xs
-- we supply our own instance since the default derive handler is too verbose
meta instance : has_to_tactic_json GreedyProofSearchState :=
let fn : GreedyProofSearchState → tactic json := λ σ, do {
(serialized_states : list string) ← do {
σ.states.mmap postprocess_tactic_state
},
pure $ json.object $
[
("depth", json.of_int σ.depth)
, ("tactics", serialize_list_string σ.tactics)
, ("predictions", json.array $ serialize_list_string <$> σ.predictions)
-- store only pretty-printed tactic states for now, with newlines replaced by tabs
, ("states", json.array $ json.of_string <$> serialized_states)
, ("success", json.of_bool σ.success)
, ("all_failed", json.of_bool σ.all_failed)
, ("task_id", json.of_string σ.task_id)
, ("global_timeout", json.of_bool σ.global_timeout)
, ("fuel_exhausted", json.of_bool σ.fuel_exhausted)
, ("decl_goal", json.of_string σ.decl_goal)
]
} in ⟨fn⟩
namespace GreedyProofSearchState
meta def add_tac : GreedyProofSearchState → string → GreedyProofSearchState :=
λ σ tac, {tactics := σ.tactics ++ [tac], ..σ}
meta def bump_depth : GreedyProofSearchState → GreedyProofSearchState :=
λ σ, {depth := nat.succ σ.depth, ..σ}
meta def add_prediction : GreedyProofSearchState → list string → GreedyProofSearchState :=
λ σ pred, {predictions := σ.predictions ++ [pred], ..σ}
meta def add_state : GreedyProofSearchState → tactic_state → GreedyProofSearchState :=
λ σ ts, {states := σ.states ++ [ts], ..σ}
meta def mark_all_failed : GreedyProofSearchState → GreedyProofSearchState :=
λ σ, {all_failed := tt, ..σ}
end GreedyProofSearchState
meta def greedy_proof_search_step
{input_format : Type}
(api : ModelAPI input_format)
(serialize_ts : tactic_state → tactic input_format)
-- note(jesse): this is responsible for e.g. selecting a candidate from the beam
(decode_response : json → ℕ → tactic (tactic_result unit × string × list string))
: state_t GreedyProofSearchState tactic (bool × GreedyProofSearchState) := do {
state_t.lift $ eval_trace "[greedy_proof_search_step] ENTERING",
let handle_response (response : tactic_result unit × string × list string) : state_t GreedyProofSearchState tactic unit :=
match response with | response@⟨result, tac_string, candidates⟩ := do {
if not result.is_success
then do {
-- let msg := format!"[greedy_proof_search_step.handle_response] UNEXPECTED
-- interaction_monad.result.exception: {result}",
-- eval_trace msg *> tactic.fail msg
modify $ λ σ, σ.mark_all_failed,
old_state ← state_t.lift $ tactic.read,
modify $ λ σ, σ.add_state old_state,
modify $ λ σ, σ.add_prediction candidates
}
else do {
(interaction_monad.result.success _ ts) ← pure result,
-- state_t.lift $ tactic.unsafe_run_io $ io.run_tactic' $ _
modify $ λ σ, σ.bump_depth,
modify $ λ σ, σ.add_tac tac_string,
modify $ λ σ, σ.add_prediction candidates,
old_state ← state_t.lift $ tactic.read,
modify $ λ σ, σ.add_state old_state
}
}
end,
let get_response_and_resume : state_t GreedyProofSearchState tactic bool := do {
tac_timeout_seconds ← GreedyProofSearchState.tac_timeout <$> get,
response@⟨result, tac_string, candidates⟩ ← state_t.lift $ do {
ts_msg ← tactic.read >>= serialize_ts,
eval_trace "[greedy_proof_search_step] QUERYING API",
response_msg ← tactic.unsafe_run_io $ api.query ts_msg,
eval_trace format!"[greedy_proof_search_step] RESPONSE MSG {response_msg}",
eval_trace format!"[greedy_proof_search_step] RUNNING DECODE RESPONSE WITH TIMEOUT {tac_timeout_seconds*1000}",
response ← decode_response response_msg $ tac_timeout_seconds*1000,
pure response
},
state_t.lift $ eval_trace "[greedy_proof_search_step] HANDLING RESPONSE",
handle_response response,
state_t.lift $ eval_trace "[greedy_proof_search_step] HANDLED RESPONSE",
state_t.lift $ do {
eval_trace format!"DECODED RESPONSE: {result}",
when result.is_success $ tactic.resume result,
eval_trace format!"RESUMED"
},
let done_handler : state_t GreedyProofSearchState tactic bool := do {
state_t.lift $ do {
tactic.done,
tactic.result >>= (λ x, tactic.guard_sorry x <|>
eval_trace format! "[greedy_proof_search_step] WARNING: result contains sorry" *> tactic.failed),
tactic.result >>= (λ x, tactic.type_check x <|>
eval_trace format! "[greedy_proof_search_step] WARNING: result failed typechecking" *> tactic.failed)
},
modify $ λ σ, {success := tt, ..σ},
pure tt
} <|> pure ff,
let tactic_exception_handler : state_t GreedyProofSearchState tactic bool := do {
pure (bnot result.is_success)
},
done_flag ← done_handler,
exception_flag ← tactic_exception_handler,
pure (done_flag || exception_flag)
},
done_flag ← get_response_and_resume,
-- let done_handler : state_t GreedyProofSearchState tactic bool := do {
-- state_t.lift done,
-- modify $ λ σ, {success := tt, ..σ},
-- pure tt
-- },
-- done_flag ← done_handler <|> pure ff,
state_t.lift $ eval_trace format! "DONE FLAG: {done_flag}",
-- pure ⟨done_flag⟩
prod.mk done_flag <$> get /- in second branch, no predictions succeeded and the proof search halts early -/
}
/- TODO(jesse): record proof search state + search statistics; wrap this in a state_t -/
/- `decode_response` should not modify the tactic state and currently
should only return successful `tactic_result`s -/
/-
in the full BFS case, `decode_response` should return a list of
`tactic_result unit`s, which may be ```success` or `exception`
these should then be logged by the `handle_response` function in
`greedy_proof_search_step`, and the successful applications
should be added to the search queue.
-/
meta def greedy_proof_search_core
{input_format : Type}
(api : ModelAPI input_format)
(serialize_ts : tactic_state → tactic input_format)
(decode_response : json → ℕ → tactic (tactic_result unit × string × list string))
(fuel := 100000)
: state_t GreedyProofSearchState tactic unit :=
let fuel_exhausted_callback : state_t GreedyProofSearchState tactic (bool × GreedyProofSearchState) := do {
state_t.modify has_mark_fuel_exhausted.mark_fuel_exhausted,
prod.mk ff <$> get
} in
iterate_until
(greedy_proof_search_step api serialize_ts decode_response)
(λ x, pure $ x.fst = tt)
fuel
fuel_exhausted_callback *> pure ()
meta def greedy_proof_search
{input_format : Type}
(api : ModelAPI input_format)
(serialize_ts : tactic_state → tactic input_format)
(decode_response : json → ℕ → tactic (tactic_result unit × string × list string))
(fuel := 100000)
(verbose := ff)
: tactic unit := do {
⟨_, σ⟩ ← state_t.run (greedy_proof_search_core api serialize_ts decode_response fuel) {},
when verbose $ eval_trace σ
}
end greedy_proof_search
section bfs
meta structure BFSNode : Type :=
(state : tactic_state)
(score : ℤ := 0)
(tactics : list string := [])
(depth : ℕ := 0)
attribute [derive has_to_format] BFSNode
namespace BFSNode
meta def of_current_state (score : ℤ := 0) (tacs : list string := []) : tactic BFSNode := do {
ts ← tactic.read,
pure $ ⟨ts, score, tacs, 0⟩
}
end BFSNode
meta structure BFSState : Type :=
(depth : ℕ := 0)
(num_queries : ℕ := 0)
(tactics : list string := [])
(predictions : list (list string) := [])
(states : list tactic_state := []) -- TODO(jesse): this might make the logs extremely verbose
(success : bool := ff)
(all_failed : bool := ff)
(task_id : string := "")
(max_width : ℕ := 25) -- max qsize
(max_depth : ℕ := 50)
(nodes : pqueue BFSNode.score := pqueue.empty)
(api_failure_count : ℕ := 0)
(all_failed_count : ℕ := 0)
(global_timeout : bool := ff)
(tac_timeout : ℕ := 5)
(fuel_exhausted : bool := ff)
(decl_goal : string := "")
attribute [derive has_to_format] BFSState
meta instance : has_mark_global_timeout BFSState :=
⟨λ σ, {global_timeout := tt, ..σ}⟩
meta instance : has_register_task_id BFSState :=
⟨λ σ task, {task_id := task, ..σ}⟩
meta instance : has_set_tac_timeout BFSState :=
⟨λ σ timeout, {tac_timeout := timeout, ..σ}⟩
meta instance : has_mark_fuel_exhausted BFSState :=
⟨λ σ, {fuel_exhausted := tt, ..σ}⟩
meta instance : has_get_tac_timeout BFSState :=
⟨BFSState.tac_timeout⟩
meta instance : has_register_decl_goal BFSState :=
⟨λ σ decl_goal, {decl_goal := decl_goal, ..σ}⟩
-- meta instance : has_to_tactic_json GreedyProofSearchState :=
-- let fn : GreedyProofSearchState → tactic json := λ σ, do {
-- let serialize_list_string : list string → json := λ xs,
-- json.array $ json.of_string <$> xs,
-- (serialized_states : list string) ← do {
-- σ.states.mmap postprocess_tactic_state
-- },
-- pure $ json.object $
-- [
-- ("depth", json.of_int σ.depth)
-- , ("tactics", serialize_list_string σ.tactics)
-- , ("predictions", json.array $ serialize_list_string <$> σ.predictions)
-- -- store only pretty-printed tactic states for now, with newlines replaced by tabs
-- , ("states", json.array $ json.of_string <$> serialized_states)
-- , ("success", json.of_bool σ.success)
-- , ("all_failed", json.of_bool σ.all_failed)
-- , ("task_id", json.of_string σ.task_id)
-- ]
-- } in ⟨fn⟩
-- TODO(jesse): upgrade to full instance
meta instance : has_to_tactic_json BFSState :=
let fn : BFSState → tactic json := λ σ, do {
(serialized_states : list string) ← do {
σ.states.mmap $ λ x, tactic.with_full_names $ postprocess_tactic_state x
},
pure $ json.object
[
("depth", json.of_int (σ.depth))
, ("tactics", json.array $ (json.of_string <$> σ.tactics))
-- , ("predictions", json.array $ serialize_list_string <$> σ.predictions)
, ("states", json.array $ json.of_string <$> serialized_states)
, ("success", json.of_bool (σ.success))
, ("num_queries", json.of_int (σ.num_queries))
, ("task_id", json.of_string $ σ.task_id)
, ("api_failure_count", json.of_int $ σ.api_failure_count)
, ("all_failed_count", json.of_int $ σ.all_failed_count)
, ("global_timeout", json.of_bool (σ.global_timeout))
, ("fuel_exhausted", json.of_bool σ.fuel_exhausted)
, ("decl_goal", json.of_string σ.decl_goal)
]
}
in ⟨fn⟩
section BFSState
namespace BFSState
meta def register_task_id : BFSState → string → BFSState :=
λ σ task_id, {task_id := task_id, ..σ}
meta def bump_num_queries : BFSState → BFSState :=
λ σ, {num_queries := nat.succ σ.num_queries, ..σ}
meta def push_nodes : BFSState → list BFSNode → BFSState :=
λ σ nodes, {nodes := σ.nodes.enqueue_many nodes, ..σ}
meta def push_node : BFSState → BFSNode → BFSState := λ S σ, S.push_nodes [σ]
meta def add_tac : BFSState → string → BFSState :=
λ σ tac, {tactics := σ.tactics ++ [tac], ..σ}
meta def add_tacs : BFSState → list string → BFSState :=
λ σ tacs, {tactics := σ.tactics ++ tacs, ..σ}
meta def bump_depth : BFSState → BFSState :=
-- λ ⟨d, tacs, preds⟩, ⟨nat.succ d, tacs, preds⟩
λ σ, {depth := nat.succ σ.depth, ..σ}
meta def add_prediction : BFSState → list string → BFSState :=
-- λ ⟨d, tacs, preds⟩ pred, ⟨d, tacs, preds ++ [pred]⟩
λ σ pred, {predictions := σ.predictions ++ [pred], ..σ}
meta def add_state : BFSState → tactic_state → BFSState :=
λ σ ts, {states := σ.states ++ [ts], ..σ}
meta def mark_all_failed : BFSState → BFSState :=
λ σ, {all_failed := tt, ..σ}
meta def mark_success : BFSState → BFSState :=
λ σ, {success := tt, ..σ}
meta def of_current_state (score : ℤ := 0) (max_width : ℕ := 25) (max_depth : ℕ := 50) (tac_timeout : ℕ := 5000) : tactic BFSState := do {
init_node ← BFSNode.of_current_state score,
pure {nodes := pqueue.empty.enqueue init_node, max_width := max_width, max_depth := max_depth, tac_timeout := tac_timeout}
}
meta def bump_api_failure_count : BFSState → BFSState :=
λ σ, {api_failure_count := σ.api_failure_count + 1, ..σ}
meta def bump_all_failed_count : BFSState → BFSState :=
λ σ, {all_failed_count := σ.all_failed_count + 1, ..σ}
end BFSState
end BFSState
meta def score_of_float : native.float → int :=
λ x, native.float.floor ((1000.0 : native.float) * x)
meta def pop_node : state_t BFSState tactic BFSNode := do {
tss ← BFSState.nodes <$> get,
match tss.dequeue with
| (some (next, new_nodes)) := do {
modify $ λ S, {nodes := new_nodes, ..S},
pure next
}
| none := state_t.lift ∘ tactic.fail $ format!"[pop_node] pqueue empty"
end
}
/-
Each step of BFS does the following:
- pops a node off the queue
- expands the node, producing a list of descendant nodes
- if any descendant node represents a completed proof search,
extract the result and use it to fill the main goal
- this is accomplished by setting the tactic state
-/
meta def bfs_step
{input_format : Type}
(api : ModelAPI input_format)
(serialize_ts : tactic_state → tactic input_format)
-- note(jesse): this is responsible for e.g. selecting a candidate from the beam
(decode_response : json → ℕ → tactic (list (tactic_result unit × string × native.float) × list string))
: state_t BFSState tactic (bool × BFSState) := do {
σ ← get,
state_t.lift $ eval_trace format!"[bfs_step] ENTERING, QUEUE STATE: {σ.nodes}",
(some next_node) ← (some <$> pop_node <|> pure none) | (state_t.lift $ eval_trace format!"[bfs_step] queue empty") *> pure ⟨tt, σ⟩, -- yikes
ts ← state_t.lift $ tactic.read,
state_t.lift $ tactic.write next_node.state,
let handle_response
(response : list (tactic_result unit × string × native.float) × (list string))
: state_t BFSState tactic unit :=
match response with | response@⟨successes, candidates⟩ := do {
modify $ λ σ, σ.add_prediction candidates,
modify $ λ σ, σ.bump_num_queries,
when (successes.length = 0) $ modify $ λ σ, σ.bump_all_failed_count,
when (candidates.length = 0) $ modify $ λ σ, σ.bump_api_failure_count
-- successes.mmap' $ λ ⟨result, tac_string, score⟩, do {
-- pure ()
-- }
}
end,
let get_response_and_resume : state_t BFSState tactic bool := do {
-- TODO(jesse): `successes` needs to be generalized from `string` to an arbitrary datatype `α`
-- for the best-first case. for now, we just score everything 0 to get this working
tac_timeout_seconds ← BFSState.tac_timeout <$> get,
decl_nm ← BFSState.task_id <$> get,
response@⟨successes, candidates⟩ ← state_t.lift $ do {
ts_msg ← tactic.read >>= serialize_ts,
eval_trace "[bfs_step] QUERYING API",
response_msg ← tactic.unsafe_run_io $ api.query ts_msg,
eval_trace format!"[bfs_step] RESPONSE MSG {response_msg}",
eval_trace format!"[bfs_step] RUNNING DECODE RESPONSE WITH TAC_TIMEOUT {tac_timeout_seconds * 1000}",
do {env ← tactic.get_env, d ← env.get decl_nm, tactic.trace format! "[run_proof_search_step] WARNING: GOT DECL {decl_nm}"} <|> do {tactic.trace format! "[run_proof_search_step] NO GOT DECL {decl_nm}"},
decoded_response@⟨successes, _⟩ ← decode_response response_msg $ tac_timeout_seconds * 1000,
eval_trace $ format!"[bfs_step] SUCCESSFUL CANDIDATES: {successes}",
pure decoded_response
},
handle_response response,
-- turn the successes into a list of new BFSNodes
(new_nodes : list BFSNode) ← successes.mmap (λ ⟨tr, tac, score⟩, match tr with
| (interaction_monad.result.success _ state) := do {
-- assumes that higher float score is better
let scale_factor := (-100000.0 : native.float),
let int_score : int := native.float.floor $ (score * scale_factor) + next_node.score,
pure $ BFSNode.mk state int_score (next_node.tactics ++ [tac]) (next_node.depth + 1)
}
| exc := state_t.lift (tactic.fail format!"[bfs_step] UNEXPECTED TACTIC RESULT WHEN CONVERTING TO new_nodes: {exc}")
end
),
monad_lift $ eval_trace format! "[bfs_step] NODES BEFORE SORTING: {new_nodes}",
let new_nodes : list BFSNode := @list.merge_sort _ (λ x y : BFSNode, x.score < y.score) _ new_nodes,
monad_lift $ eval_trace format! "[bfs_step] NODES AFTER SORTING: {new_nodes}",
/- loop through the new nodes. if any of them are finished (no open goals),
the proof search is done. otherwise, we add all of them to the BFSState pqueue
and continue the search.
-/
-- modify (λ σ, σ.push_nodes new_nodes),
-- for_ new_nodes $ λ ⟨state, score⟩, do {
-- }
let BFSM := state_t BFSState tactic,
let push_node_state := (list BFSNode),
/- note(jesse): this version of `push_node_state_tac` does not quite replicate
the behavior of greedy_proof_search at
max_width := 1, because it will loop over all the successful candidates anyways and check
if the proof is finished regardless of the state of the queue
-/
-- let push_node_state_tac : state_t push_node_state BFSM (option BFSNode) := do {
-- (next_node::xs) ← get | pure none,
-- done_flag ← state_t.lift $ state_t.lift $ next_node.state.is_done,
-- result ← if done_flag then pure (some next_node) else do {
-- q ← BFSState.nodes <$> state_t.lift get,
-- let qsize := q.size,
-- state_t.lift $ state_t.lift $ eval_trace format!"[push_tac] QSIZE: {qsize}",