Skip to content

Commit 05b74b0

Browse files
committed
fix(Tactic/NormNum): fix normNum bug
[Reported on zulip](https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20while.20testing.20ring.20tactic/near/249677205). Also some minor cleanup of normNum and ring.
1 parent 3d61f90 commit 05b74b0

File tree

3 files changed

+93
-89
lines changed

3 files changed

+93
-89
lines changed

Mathlib/Tactic/Core.lean

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/-
2+
Copyright (c) 2021 Mario Carneiro. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Mario Carneiro, Aurélien Saue
5+
-/
6+
7+
import Lean.Expr
8+
9+
namespace Lean.Expr
10+
11+
private def getAppFnArgsAux : Expr → Array Expr → Nat → Name × Array Expr
12+
| app f a _, as, i => getAppFnArgsAux f (as.set! i a) (i-1)
13+
| const n _ _, as, i => (n, as)
14+
| _, as, _ => (Name.anonymous, as)
15+
16+
def getAppFnArgs (e : Expr) : Name × Array Expr :=
17+
let nargs := e.getAppNumArgs
18+
getAppFnArgsAux e (mkArray nargs arbitrary) (nargs-1)
19+
20+
def natLit! : Expr → Nat
21+
| lit (Literal.natVal v) _ => v
22+
| _ => panic! "nat literal expected"
23+
24+
end Expr

Mathlib/Tactic/NormNum.lean

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Authors: Mario Carneiro
55
-/
66
import Lean.Elab.Tactic.Basic
77
import Mathlib.Algebra.Ring.Basic
8+
import Mathlib.Tactic.Core
89

910
namespace Lean
1011

@@ -34,6 +35,9 @@ def mkOfNatLit (u : Level) (α sα n : Expr) : Expr :=
3435

3536
namespace NormNum
3637

38+
theorem ofNat_nat (n : ℕ) : n = @OfNat.ofNat _ n (@Numeric.OfNat _ _ _) := rfl
39+
set_option pp.all true
40+
3741
theorem ofNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
3842
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' + b' = c → a + b = OfNat.ofNat c
3943
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_add _ _).symm
@@ -46,67 +50,58 @@ theorem ofNat_pow {α} [Semiring α] : (a : α) → (n a' c : Nat) →
4650
a = OfNat.ofNat a' → a'^n = c → a ^ n = OfNat.ofNat c
4751
| _, _, _, _, rfl, rfl => (Semiring.ofNat_pow _ _).symm
4852

49-
partial def evalAux : Expr → MetaM (Expr × Expr)
50-
| e => e.withApp fun f args => do
51-
if f.isConstOf ``HAdd.hAdd then
52-
evalB ``NormNum.ofNat_add (·+·) args
53-
else if f.isConstOf ``HMul.hMul then
54-
evalB ``NormNum.ofNat_mul (·*·) args
55-
else if f.isConstOf ``HPow.hPow then
56-
evalC ``NormNum.ofNat_pow (·^·) args
57-
else if f.isConstOf ``OfNat.ofNat then
58-
let #[α,ln,_] ← args | throwError "fail"
59-
let some n ← ln.natLit? | throwError "fail"
60-
if n = 0 then
53+
partial def evalAux (e : Expr) : MetaM (Expr × Expr) := do
54+
match e.getAppFnArgs with
55+
| (``HAdd.hAdd, #[_, _, α, _, a, b]) => evalB ``NormNum.ofNat_add (·+·) α a b
56+
| (``HMul.hMul, #[_, _, α, _, a, b]) => evalB ``NormNum.ofNat_mul (·*·) α a b
57+
| (``HPow.hPow, #[_, _, α, _, a, n]) => evalC ``NormNum.ofNat_pow (·^·) α a n
58+
| (``OfNat.ofNat, #[α, ln, _]) =>
59+
match ← ln.natLit? with
60+
| some 0 =>
6161
let Level.succ u _ ← getLevel α | throwError "fail"
6262
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
6363
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
6464
let e ← mkOfNatLit u α nα (mkRawNatLit 0)
6565
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
66-
return (e,p)
67-
else if n = 1 then
66+
(e, p)
67+
| some 1 =>
6868
let Level.succ u _ ← getLevel α | throwError "fail"
6969
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
7070
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
7171
let e ← mkOfNatLit u α nα (mkRawNatLit 1)
7272
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
73-
return (e,p)
74-
else pure (e, ← mkEqRefl e)
75-
else if f.isNatLit then pure (e, ← mkEqRefl e)
76-
else throwError "fail"
77-
where
78-
evalB (name : Name) (f : Nat → Nat → Nat)
79-
(args : Array Expr) : MetaM (Expr × Expr) := do
80-
if let #[_, _, α, _, a, b] ← args then
81-
let Level.succ u _ ← getLevel α | throwError "fail"
82-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
83-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
84-
let (a', pa) ← evalAux a
85-
let (b', pb) ← evalAux b
86-
let la := Expr.getRevArg! a' 1
87-
let some na ← la.natLit? | throwError "fail"
88-
let lb := Expr.getRevArg! b' 1
89-
let some nb ← lb.natLit? | throwError "fail"
90-
let lc := mkRawNatLit (f na nb)
91-
let c := mkOfNatLit u α nα lc
92-
pure (c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
93-
else throwError "fail"
94-
evalC (name : Name) (f : Nat → Nat → Nat)
95-
(args : Array Expr) : MetaM (Expr × Expr) := do
96-
if let #[_, _, α, _, a, n] ← args then
97-
let Level.succ u _ ← getLevel α | throwError "fail"
98-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
99-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
100-
let (a', pa) ← evalAux a
101-
let la := Expr.getRevArg! a' 1
102-
let some na ← la.natLit? | throwError "fail"
103-
let some nn ← n.numeral? | throwError "fail"
104-
let lc := mkRawNatLit (f na nn)
105-
let c := mkOfNatLit u α nα lc
106-
pure (c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))
73+
(e, p)
74+
| some _ => pure (e, ← mkEqRefl e)
75+
| none => throwError "fail"
76+
| _ =>
77+
if e.isNatLit then
78+
(mkOfNatLit levelZero (mkConst ``Nat) (mkConst ``Nat.instNumericNat) e,
79+
mkApp (mkConst ``ofNat_nat) e)
10780
else throwError "fail"
81+
where
82+
evalB (name : Name) (f : Nat → Nat → Nat) (α a b : Expr) : MetaM (Expr × Expr) := do
83+
let Level.succ u _ ← getLevel α | throwError "fail"
84+
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
85+
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
86+
let (a', pa) ← evalAux a
87+
let (b', pb) ← evalAux b
88+
let la := Expr.getRevArg! a' 1
89+
let lb := Expr.getRevArg! b' 1
90+
let lc := mkRawNatLit (f la.natLit! lb.natLit!)
91+
let c := mkOfNatLit u α nα lc
92+
(c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
93+
evalC (name : Name) (f : Nat → Nat → Nat) (α a n : Expr) : MetaM (Expr × Expr) := do
94+
let Level.succ u _ ← getLevel α | throwError "fail"
95+
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
96+
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
97+
let (a', pa) ← evalAux a
98+
let la := Expr.getRevArg! a' 1
99+
let some nn ← n.numeral? | throwError "fail"
100+
let lc := mkRawNatLit (f la.natLit! nn)
101+
let c := mkOfNatLit u α nα lc
102+
(c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))
108103

109-
partial def eval (e : Expr) : MetaM (Expr × Expr) := do
104+
def eval (e : Expr) : MetaM (Expr × Expr) := do
110105
let (e', p) ← evalAux e
111106
e'.withApp fun f args => do
112107
if f.isConstOf ``OfNat.ofNat then
@@ -118,16 +113,16 @@ partial def eval (e : Expr) : MetaM (Expr × Expr) := do
118113
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 0))
119114
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 0) nα
120115
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
121-
return (e'',p')
116+
(e'', p')
122117
else if n = 1 then
123118
let Level.succ u _ ← getLevel α | throwError "fail"
124119
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
125120
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 1))
126121
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 1) nα
127122
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
128-
return (e'',p')
129-
else pure (e',p)
130-
else pure (e', p)
123+
(e'', p')
124+
else (e', p)
125+
else (e', p)
131126
end NormNum
132127
end Meta
133128

@@ -150,3 +145,4 @@ example : (1 + 0 : α) = 1 := by normNum
150145
example : (0 + (2 + 3) + 1 : α) = 6 := by normNum
151146
example : (70 * (33 + 2) : α) = 2450 := by normNum
152147
example : (8 + 2 ^ 2 * 3 : α) = 20 := by normNum
148+
example : (2 ^ 2 : ℕ) = 4 := by normNum

Mathlib/Tactic/Ring.lean

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,6 @@ Based on <http://www.cs.ru.nl/~freek/courses/tt-2014/read/10.1.1.61.3041.pdf> .
1616

1717
open Lean Parser.Tactic Elab Command Elab.Tactic Meta
1818

19-
open Expr in
20-
private def getAppFnAndArgsAux : Expr → Array Expr → Nat → Option (Name × Array Expr)
21-
| app f a _, as, i => getAppFnAndArgsAux f (as.set! i a) (i-1)
22-
| const n _ _, as, i => some (n,as)
23-
| _, as, _ => none
24-
25-
def Lean.Expr.getAppFnAndArgs (e : Expr) : Option (Name × Array Expr) :=
26-
let dummy := mkSort levelZero
27-
let nargs := e.getAppNumArgs
28-
getAppFnAndArgsAux e (mkArray nargs dummy) (nargs-1)
29-
3019
namespace Tactic
3120
namespace Ring
3221

@@ -41,10 +30,9 @@ structure State :=
4130

4231
/-- The monad that `ring` works in. This is a reader monad containing a cache and
4332
the list of atoms-up-to-defeq encountered thus far, used for atom sorting. -/
44-
abbrev RingM := ReaderT Cache $ StateRefT State TacticM
33+
abbrev RingM := ReaderT Cache $ StateRefT State MetaM
4534

46-
def run (e : Expr) {α} (m : RingM α): TacticM α := do
47-
let ty ← inferType e
35+
def RingM.run (ty : Expr) (m : RingM α) : MetaM α := do
4836
let u ← getLevel ty
4937
(m {α := ty, univ := u}).run' {}
5038

@@ -115,7 +103,7 @@ def xadd' (a : HornerExpr) (x : Expr × ℕ) (n : Expr × ℕ) (b : HornerExpr)
115103
def reflConv (e : HornerExpr) : RingM (HornerExpr × Expr) := do (e, ← mkEqRefl e)
116104

117105
/-- Pretty printer for `horner_expr`. -/
118-
def pp : HornerExpr → TacticM Format
106+
def pp : HornerExpr → MetaM Format
119107
| (const e c) => do
120108
let pe ← PrettyPrinter.ppExpr Name.anonymous [] e
121109
return "[" ++ pe ++ ", " ++ toString c ++ "]"
@@ -343,27 +331,27 @@ partial def evalPow : HornerExpr → Expr × ℕ → RingM (HornerExpr × Expr)
343331
| e, (_, 0) => do
344332
let α1 ← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 1, none]
345333
let p ← mkAppM ``pow_zero #[e]
346-
return (const α1 1, p)
334+
(const α1 1, p)
347335
| e, (_, 1) => do
348336
let p ← mkAppM ``pow_one #[e]
349-
return (e, p)
350-
| (const e coeff), (e₂, m) => do
337+
(e, p)
338+
| const e coeff, (e₂, m) => do
351339
let (e', p) ← NormNum.eval $ ← mkAppM ``HPow.hPow #[e, e₂]
352-
return (const e' (coeff ^ m), p)
340+
(const e' (coeff ^ m), p)
353341
| he@(xadd e a x n b), m =>
354342
match b.e.numeral? with
355343
| some 0 => do
356344
let n' ← mkRawNatLit (n.2 * m.2)
357345
let h₁ ← mkEqRefl n'
358346
let (a', h₂) ← evalPow a m
359347
let α0 ← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 0, none]
360-
return (← xadd' a' x (n', n.2 * m.2) (const α0 0),
348+
(← xadd' a' x (n', n.2 * m.2) (const α0 0),
361349
← mkAppM ``horner_pow #[a, x.1, n.1, m.1, n', a', h₁, h₂])
362350
| _ => do
363351
let e₂ ← mkRawNatLit (m.2 - 1)
364352
let (tl, hl) ← evalPow he (e₂, m.2-1)
365353
let (t, p₂) ← evalMul tl he
366-
return (t, ← mkAppM ``pow_succ_eq #[e, e₂, tl, t, hl, p₂])
354+
(t, ← mkAppM ``pow_succ_eq #[e, e₂, tl, t, hl, p₂])
367355

368356

369357
theorem horner_atom {α} [CommSemiring α] (x : α) : x = horner 1 x 1 0 := by
@@ -374,8 +362,7 @@ def evalAtom (e : Expr) : RingM (HornerExpr × Expr) := do
374362
let i ← addAtom e
375363
let zero ← const (← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 0, none]) 0
376364
let one ← const (← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 1, none]) 1
377-
return (← xadd' one (e,i) (mkRawNatLit 1,1) zero, ← mkAppM ``horner_atom #[e])
378-
365+
(← xadd' one (e,i) (mkRawNatLit 1,1) zero, ← mkAppM ``horner_atom #[e])
379366

380367
theorem subst_into_add {α} [Add α] (l r tl tr t)
381368
(prl : (l : α) = tl) (prr : r = tr) (prt : tl + tr = t) : l + r = t :=
@@ -390,20 +377,20 @@ theorem subst_into_pow {α} [Monoid α] (l r tl tr t)
390377
by rw [prl, prr, prt]
391378

392379
partial def eval (e : Expr) : RingM (HornerExpr × Expr) :=
393-
match e.getAppFnAndArgs with
394-
| some (``HAdd.hAdd, #[_,_,_,_,e₁,e₂]) => do
380+
match e.getAppFnArgs with
381+
| (``HAdd.hAdd, #[_,_,_,_,e₁,e₂]) => do
395382
let (e₁', p₁) ← eval e₁
396383
let (e₂', p₂) ← eval e₂
397384
let (e', p') ← evalAdd e₁' e₂'
398385
let p ← mkAppM ``subst_into_add #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
399386
(e',p)
400-
| some (``HMul.hMul, #[_,_,_,_,e₁,e₂]) => do
387+
| (``HMul.hMul, #[_,_,_,_,e₁,e₂]) => do
401388
let (e₁', p₁) ← eval e₁
402389
let (e₂', p₂) ← eval e₂
403390
let (e', p') ← evalMul e₁' e₂'
404391
let p ← mkAppM ``subst_into_mul #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
405392
return (e', p)
406-
| some (``HPow.hPow, #[_,_,_,P,e₁,e₂]) => do
393+
| (``HPow.hPow, #[_,_,_,P,e₁,e₂]) => do
407394
-- let (e₂', p₂) ← lift $ norm_num.derive e₂ <|> refl_conv e₂,
408395
let (e₂', p₂) ← (e₂, ← mkEqRefl e₂)
409396
match e₂'.numeral?, P.getAppFn with
@@ -412,24 +399,21 @@ partial def eval (e : Expr) : RingM (HornerExpr × Expr) :=
412399
let (e', p') ← evalPow e₁' (e₂, k)
413400
let p ← mkAppM ``subst_into_pow #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
414401
return (e', p)
415-
| _, _ => do ← evalAtom e
416-
evalAtom e
402+
| _, _ => evalAtom e
417403
| _ =>
418404
match e.numeral? with
419405
| some n => (const e n).reflConv
420-
| _ => (evalAtom e)
421-
406+
| _ => evalAtom e
422407

423408
elab "ring" : tactic => do
424409
let g ← getMainTarget
425-
match g.getAppFnAndArgs with
426-
| some (`Eq, #[ty, e₁, e₂]) =>
427-
let ((e₁', p₁), (e₂', p₂)) ← run e₁ $ Prod.mk <$> eval e₁ <*> eval e₂
428-
if (← isDefEq e₁' e₂') then
410+
match g.getAppFnArgs with
411+
| (`Eq, #[ty, e₁, e₂]) =>
412+
let ((e₁', p₁), (e₂', p₂)) ← RingM.run ty $ do (← eval e₁, ← eval e₂)
413+
if ← isDefEq e₁' e₂' then
429414
let p ← mkEqTrans p₁ (← mkEqSymm p₂)
430415
ensureHasNoMVars p
431416
assignExprMVar (← getMainGoal) p
432-
433417
replaceMainGoal []
434418
else
435419
throwError "failed \n{← e₁'.pp}\n{← e₂'.pp}"

0 commit comments

Comments
 (0)