Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit fc307f9

Browse files
committed
feat(tactic/norm_num): make norm_num extensible (#4820)
This allows you to extend `norm_num` by defining additional tactics of type `expr → tactic (expr × expr)` with the `@[norm_num]` attribute. It still requires some tactic proficiency to use correctly, but it at least allows us to move all the possible norm_num extensions to their own files instead of the current dependency cycle problem. This could potentially become a performance problem if too many things are marked `@[norm_num]`, as they are simply looked through in linear order. It could be improved by having extensions register a finite set of constants that they wish to evaluate, and dispatch to the right extension tactic using a `name_map`. ```lean def foo : ℕ := 1 @[norm_num] meta def eval_foo : expr → tactic (expr × expr) | `(foo) := pure (`(1:ℕ), `(eq.refl 1)) | _ := tactic.failed example : foo = 1 := by norm_num ```
1 parent 2c7efdf commit fc307f9

File tree

4 files changed

+86
-16
lines changed

4 files changed

+86
-16
lines changed

src/tactic/abel.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ meta def eval_add (c : cache) : normal_expr → normal_expr → tactic (normal_e
136136
(a', h) ← eval_add he₁ a₂,
137137
return (term' c n₂ x₂ a', c.iapp ``const_add_term [e₁, n₂.1, x₂, a₂, a', h])
138138
else do
139-
(n', h₁) ← mk_app ``has_add.add [n₁.1, n₂.1] >>= norm_num.derive',
139+
(n', h₁) ← mk_app ``has_add.add [n₁.1, n₂.1] >>= norm_num.eval_field,
140140
(a', h₂) ← eval_add a₁ a₂,
141141
let k := n₁.2 + n₂.2,
142142
let p₁ := c.iapp ``term_add_term [n₁.1, x₁, a₁, n₂.1, a₂, n', a', h₁, h₂],
@@ -155,7 +155,7 @@ meta def eval_neg (c : cache) : normal_expr → tactic (normal_expr × expr)
155155
p ← c.mk_app ``neg_zero ``add_group [],
156156
return (zero' c, p)
157157
| (nterm e n x a) := do
158-
(n', h₁) ← mk_app ``has_neg.neg [n.1] >>= norm_num.derive',
158+
(n', h₁) ← mk_app ``has_neg.neg [n.1] >>= norm_num.eval_field,
159159
(a', h₂) ← eval_neg a,
160160
return (term' c (n', -n.2) x a',
161161
c.app ``term_neg c.inst [n.1, x, a, n', a', h₁, h₂])
@@ -183,7 +183,7 @@ meta def eval_smul (c : cache) (k : expr × ℤ) :
183183
normal_expr → tactic (normal_expr × expr)
184184
| (zero _) := return (zero' c, c.iapp ``zero_smul [k.1])
185185
| (nterm e n x a) := do
186-
(n', h₁) ← mk_app ``has_mul.mul [k.1, n.1] >>= norm_num.derive',
186+
(n', h₁) ← mk_app ``has_mul.mul [k.1, n.1] >>= norm_num.eval_field,
187187
(a', h₂) ← eval_smul a,
188188
return (term' c (n', k.2 * n.2) x a',
189189
c.iapp ``term_smul [k.1, n.1, x, a, n', a', h₁, h₂])

src/tactic/norm_num.lean

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,33 +1384,96 @@ meta def eval_prime : expr → tactic (expr × expr)
13841384
| _ := failed
13851385

13861386
/-- This version of `derive` does not fail when the input is already a numeral -/
1387-
meta def derive' (e : expr) : tactic (expr × expr) :=
1387+
meta def derive.step (e : expr) : tactic (expr × expr) :=
13881388
eval_field e <|> eval_nat_int_ext e <|>
13891389
eval_pow e <|> eval_ineq e <|> eval_prime e
13901390

1391-
meta def derive : expr → tactic (expr × expr) | e :=
1391+
/-- An attribute for adding additional extensions to `norm_num`. To use this attribute, put
1392+
`@[norm_num]` on a tactic of type `expr → tactic (expr × expr)`; the tactic will be called on
1393+
subterms by `norm_num`, and it is responsible for identifying that the expression is a numerical
1394+
function applied to numerals, for example `nat.fib 17`, and should return the reduced numerical
1395+
expression (which must be in `norm_num`-normal form: a natural or rational numeral, i.e. `37`,
1396+
`12 / 7` or `-(2 / 3)`, although this can be an expression in any type), and the proof that the
1397+
original expression is equal to the rewritten expression.
1398+
1399+
Failure is used to indicate that this tactic does not apply to the term. For performance reasons,
1400+
it is best to detect non-applicability as soon as possible so that the next tactic can have a go,
1401+
so generally it will start with a pattern match and then checking that the arguments to the term
1402+
are numerals or of the appropriate form, followed by proof construction, which should not fail.
1403+
1404+
Propositions are treated like any other term. The normal form for propositions is `true` or
1405+
`false`, so it should produce a proof of the form `p = true` or `p = false`. `eq_true_intro` can be
1406+
used to help here.
1407+
-/
1408+
@[user_attribute]
1409+
protected meta def attr : user_attribute (expr → tactic (expr × expr)) unit :=
1410+
{ name := `norm_num,
1411+
descr := "Add norm_num derivers",
1412+
cache_cfg :=
1413+
{ mk_cache := λ ns, do {
1414+
t ← ns.mfoldl
1415+
(λ (t : expr → tactic (expr × expr)) n, do
1416+
t' ← eval_expr (expr → tactic (expr × expr)) (expr.const n []),
1417+
pure (λ e, t' e <|> t e))
1418+
(λ _, failed),
1419+
pure (λ e, derive.step e <|> t e) },
1420+
dependencies := [] } }
1421+
1422+
add_tactic_doc
1423+
{ name := "norm_num",
1424+
category := doc_category.attr,
1425+
decl_names := [`norm_num.attr],
1426+
tags := ["arithmetic", "decision_procedure"] }
1427+
1428+
/-- Look up the `norm_num` extensions in the cache and return a tactic extending `derive.step` with
1429+
additional reduction procedures. -/
1430+
meta def get_step : tactic (expr → tactic (expr × expr)) := norm_num.attr.get_cache
1431+
1432+
/-- Simplify an expression bottom-up using `step` to simplify the subexpressions. -/
1433+
meta def derive' (step : expr → tactic (expr × expr))
1434+
: expr → tactic (expr × expr) | e :=
13921435
do e ← instantiate_mvars e,
13931436
(_, e', pr) ←
13941437
ext_simplify_core () {} simp_lemmas.mk (λ _, failed) (λ _ _ _ _ _, failed)
13951438
(λ _ _ _ _ e,
1396-
do (new_e, pr) ← derive' e,
1439+
do (new_e, pr) ← step e,
13971440
guard (¬ new_e =ₐ e),
13981441
return ((), new_e, some pr, tt))
13991442
`eq e,
14001443
return (e', pr)
14011444

1445+
/-- Simplify an expression bottom-up using the default `norm_num` set to simplify the
1446+
subexpressions. -/
1447+
meta def derive (e : expr) : tactic (expr × expr) := do f ← get_step, derive' f e
1448+
14021449
end norm_num
14031450

1451+
/-- Basic version of `norm_num` that does not call `simp`. It uses the provided `step` tactic
1452+
to simplify the expression; use `get_step` to get the default `norm_num` set and `derive.step` for
1453+
the basic builtin set of simplifications. -/
1454+
meta def tactic.norm_num1 (step : expr → tactic (expr × expr))
1455+
(loc : interactive.loc) : tactic unit :=
1456+
do ns ← loc.get_locals,
1457+
tt ← tactic.replace_at (norm_num.derive' step) ns loc.include_goal
1458+
| fail "norm_num failed to simplify",
1459+
when loc.include_goal $ try tactic.triv,
1460+
when (¬ ns.empty) $ try tactic.contradiction
1461+
1462+
/-- Normalize numerical expressions. It uses the provided `step` tactic to simplify the expression;
1463+
use `get_step` to get the default `norm_num` set and `derive.step` for the basic builtin set of
1464+
simplifications. -/
1465+
meta def tactic.norm_num (step : expr → tactic (expr × expr))
1466+
(hs : list simp_arg_type) (l : interactive.loc) : tactic unit :=
1467+
repeat1 $ orelse' (tactic.norm_num1 step l) $
1468+
interactive.simp_core {} (tactic.norm_num1 step (interactive.loc.ns [none]))
1469+
ff (simp_arg_type.except ``one_div :: hs) [] l
1470+
14041471
namespace tactic.interactive
14051472
open norm_num interactive interactive.types
14061473

14071474
/-- Basic version of `norm_num` that does not call `simp`. -/
14081475
meta def norm_num1 (loc : parse location) : tactic unit :=
1409-
do ns ← loc.get_locals,
1410-
tt ← tactic.replace_at derive ns loc.include_goal
1411-
| fail "norm_num failed to simplify",
1412-
when loc.include_goal $ try tactic.triv,
1413-
when (¬ ns.empty) $ try tactic.contradiction
1476+
do f ← get_step, tactic.norm_num1 f loc
14141477

14151478
/-- Normalize numerical expressions. Supports the operations
14161479
`+` `-` `*` `/` `^` and `%` over numerical types such as
@@ -1419,8 +1482,7 @@ and can prove goals of the form `A = B`, `A ≠ B`, `A < B` and `A ≤ B`,
14191482
where `A` and `B` are numerical expressions.
14201483
It also has a relatively simple primality prover. -/
14211484
meta def norm_num (hs : parse simp_arg_list) (l : parse location) : tactic unit :=
1422-
repeat1 $ orelse' (norm_num1 l) $
1423-
simp_core {} (norm_num1 (loc.ns [none])) ff (simp_arg_type.except ``one_div :: hs) [] l
1485+
do f ← get_step, tactic.norm_num f hs l
14241486

14251487
add_hint_tactic "norm_num"
14261488

src/tactic/ring_exp.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ with the proof of `expr.of_rat p + expr.of_rat q = expr.of_rat (p + q)`.
627627
meta def add_coeff (p_p q_p : expr) (p q : coeff) : ring_exp_m (ex prod) := do
628628
ctx ← get_context,
629629
pq_o ← mk_add [p_p, q_p],
630-
(pq_p, pq_pf) ← lift $ norm_num.derive' pq_o,
630+
(pq_p, pq_pf) ← lift $ norm_num.eval_field pq_o,
631631
pure $ ex.coeff ⟨pq_o, pq_p, pq_pf⟩ ⟨p.1 + q.1
632632

633633
lemma mul_coeff_pf_one_mul (q : α) : 1 * q = q := one_mul q
@@ -654,7 +654,7 @@ match p.1, q.1 with -- Special case to speed up multiplication with 1.
654654
| _, _ := do
655655
ctx ← get_context,
656656
pq' ← mk_mul [p_p, q_p],
657-
(pq_p, pq_pf) ← lift $ norm_num.derive' pq',
657+
(pq_p, pq_pf) ← lift $ norm_num.eval_field pq',
658658
pure $ ex.coeff ⟨pq_p, pq_p, pq_pf⟩ ⟨p.1 * q.1
659659
end
660660

@@ -975,7 +975,7 @@ with the proof of `expr.of_rat p ^ expr.of_rat q = expr.of_rat (p ^ q)`.
975975
meta def pow_coeff (p_p q_p : expr) (p q : coeff) : ring_exp_m (ex prod) := do
976976
ctx ← get_context,
977977
pq' ← mk_pow [p_p, q_p],
978-
(pq_p, pq_pf) ← lift $ norm_num.derive' pq',
978+
(pq_p, pq_pf) ← lift $ norm_num.eval_pow pq',
979979
pure $ ex.coeff ⟨pq_p, pq_p, pq_pf⟩ ⟨p.1 * q.1
980980

981981
/--

test/norm_num.lean

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ example : 100 - 100 = 0 := by norm_num
9191
example : 5 * (2 - 3) = 0 := by norm_num
9292
example : 10 - 5 * 5 + (7 - 3) * 6 = 27 - 3 := by norm_num
9393

94+
def foo : ℕ := 1
95+
96+
@[norm_num] meta def eval_foo : expr → tactic (expr × expr)
97+
| `(foo) := pure (`(1:ℕ), `(eq.refl 1))
98+
| _ := tactic.failed
99+
100+
example : foo = 1 := by norm_num
101+
94102
-- ordered field examples
95103

96104
variable {α : Type}

0 commit comments

Comments
 (0)