Skip to content

Commit

Permalink
fix(Tactic/NormNum): fix normNum bug
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 committed Aug 18, 2021
1 parent 3d61f90 commit 05b74b0
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 89 deletions.
24 changes: 24 additions & 0 deletions Mathlib/Tactic/Core.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/-
Copyright (c) 2021 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro, Aurélien Saue
-/

import Lean.Expr

namespace Lean.Expr

private def getAppFnArgsAux : Expr → Array Expr → Nat → Name × Array Expr
| app f a _, as, i => getAppFnArgsAux f (as.set! i a) (i-1)
| const n _ _, as, i => (n, as)
| _, as, _ => (Name.anonymous, as)

def getAppFnArgs (e : Expr) : Name × Array Expr :=
let nargs := e.getAppNumArgs
getAppFnArgsAux e (mkArray nargs arbitrary) (nargs-1)

def natLit! : Expr → Nat
| lit (Literal.natVal v) _ => v
| _ => panic! "nat literal expected"

end Expr
102 changes: 49 additions & 53 deletions Mathlib/Tactic/NormNum.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Mario Carneiro
-/
import Lean.Elab.Tactic.Basic
import Mathlib.Algebra.Ring.Basic
import Mathlib.Tactic.Core

namespace Lean

Expand Down Expand Up @@ -34,6 +35,9 @@ def mkOfNatLit (u : Level) (α sα n : Expr) : Expr :=

namespace NormNum

theorem ofNat_nat (n : ℕ) : n = @OfNat.ofNat _ n (@Numeric.OfNat _ _ _) := rfl
set_option pp.all true

theorem ofNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' + b' = c → a + b = OfNat.ofNat c
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_add _ _).symm
Expand All @@ -46,67 +50,58 @@ theorem ofNat_pow {α} [Semiring α] : (a : α) → (n a' c : Nat) →
a = OfNat.ofNat a' → a'^n = c → a ^ n = OfNat.ofNat c
| _, _, _, _, rfl, rfl => (Semiring.ofNat_pow _ _).symm

partial def evalAux : Expr → MetaM (Expr × Expr)
| e => e.withApp fun f args => do
if f.isConstOf ``HAdd.hAdd then
evalB ``NormNum.ofNat_add (·+·) args
else if f.isConstOf ``HMul.hMul then
evalB ``NormNum.ofNat_mul (·*·) args
else if f.isConstOf ``HPow.hPow then
evalC ``NormNum.ofNat_pow (·^·) args
else if f.isConstOf ``OfNat.ofNat then
let #[α,ln,_] ← args | throwError "fail"
let some n ← ln.natLit? | throwError "fail"
if n = 0 then
partial def evalAux (e : Expr) : MetaM (Expr × Expr) := do
match e.getAppFnArgs with
| (``HAdd.hAdd, #[_, _, α, _, a, b]) => evalB ``NormNum.ofNat_add (·+·) α a b
| (``HMul.hMul, #[_, _, α, _, a, b]) => evalB ``NormNum.ofNat_mul (·*·) α a b
| (``HPow.hPow, #[_, _, α, _, a, n]) => evalC ``NormNum.ofNat_pow (·^·) α a n
| (``OfNat.ofNat, #[α, ln, _]) =>
match ← ln.natLit? with
| some 0 =>
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let e ← mkOfNatLit u α nα (mkRawNatLit 0)
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
return (e,p)
else if n = 1 then
(e, p)
| some 1 =>
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let e ← mkOfNatLit u α nα (mkRawNatLit 1)
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
return (e,p)
else pure (e, ← mkEqRefl e)
else if f.isNatLit then pure (e, ← mkEqRefl e)
else throwError "fail"
where
evalB (name : Name) (f : Nat → Nat → Nat)
(args : Array Expr) : MetaM (Expr × Expr) := do
if let #[_, _, α, _, a, b] ← args then
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← evalAux a
let (b', pb) ← evalAux b
let la := Expr.getRevArg! a' 1
let some na ← la.natLit? | throwError "fail"
let lb := Expr.getRevArg! b' 1
let some nb ← lb.natLit? | throwError "fail"
let lc := mkRawNatLit (f na nb)
let c := mkOfNatLit u α nα lc
pure (c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
else throwError "fail"
evalC (name : Name) (f : Nat → Nat → Nat)
(args : Array Expr) : MetaM (Expr × Expr) := do
if let #[_, _, α, _, a, n] ← args then
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← evalAux a
let la := Expr.getRevArg! a' 1
let some na ← la.natLit? | throwError "fail"
let some nn ← n.numeral? | throwError "fail"
let lc := mkRawNatLit (f na nn)
let c := mkOfNatLit u α nα lc
pure (c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))
(e, p)
| some _ => pure (e, ← mkEqRefl e)
| none => throwError "fail"
| _ =>
if e.isNatLit then
(mkOfNatLit levelZero (mkConst ``Nat) (mkConst ``Nat.instNumericNat) e,
mkApp (mkConst ``ofNat_nat) e)
else throwError "fail"
where
evalB (name : Name) (f : Nat → Nat → Nat) (α a b : Expr) : MetaM (Expr × Expr) := do
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← evalAux a
let (b', pb) ← evalAux b
let la := Expr.getRevArg! a' 1
let lb := Expr.getRevArg! b' 1
let lc := mkRawNatLit (f la.natLit! lb.natLit!)
let c := mkOfNatLit u α nα lc
(c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
evalC (name : Name) (f : Nat → Nat → Nat) (α a n : Expr) : MetaM (Expr × Expr) := do
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← evalAux a
let la := Expr.getRevArg! a' 1
let some nn ← n.numeral? | throwError "fail"
let lc := mkRawNatLit (f la.natLit! nn)
let c := mkOfNatLit u α nα lc
(c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))

partial def eval (e : Expr) : MetaM (Expr × Expr) := do
def eval (e : Expr) : MetaM (Expr × Expr) := do
let (e', p) ← evalAux e
e'.withApp fun f args => do
if f.isConstOf ``OfNat.ofNat then
Expand All @@ -118,16 +113,16 @@ partial def eval (e : Expr) : MetaM (Expr × Expr) := do
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 0))
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 0) nα
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
return (e'',p')
(e'', p')
else if n = 1 then
let Level.succ u _ ← getLevel α | throwError "fail"
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 1))
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 1) nα
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
return (e'',p')
else pure (e',p)
else pure (e', p)
(e'', p')
else (e', p)
else (e', p)
end NormNum
end Meta

Expand All @@ -150,3 +145,4 @@ example : (1 + 0 : α) = 1 := by normNum
example : (0 + (2 + 3) + 1 : α) = 6 := by normNum
example : (70 * (33 + 2) : α) = 2450 := by normNum
example : (8 + 2 ^ 2 * 3 : α) = 20 := by normNum
example : (2 ^ 2 : ℕ) = 4 := by normNum
56 changes: 20 additions & 36 deletions Mathlib/Tactic/Ring.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@ Based on <http://www.cs.ru.nl/~freek/courses/tt-2014/read/10.1.1.61.3041.pdf> .

open Lean Parser.Tactic Elab Command Elab.Tactic Meta

open Expr in
private def getAppFnAndArgsAux : Expr → Array Expr → Nat → Option (Name × Array Expr)
| app f a _, as, i => getAppFnAndArgsAux f (as.set! i a) (i-1)
| const n _ _, as, i => some (n,as)
| _, as, _ => none

def Lean.Expr.getAppFnAndArgs (e : Expr) : Option (Name × Array Expr) :=
let dummy := mkSort levelZero
let nargs := e.getAppNumArgs
getAppFnAndArgsAux e (mkArray nargs dummy) (nargs-1)

namespace Tactic
namespace Ring

Expand All @@ -41,10 +30,9 @@ structure State :=

/-- The monad that `ring` works in. This is a reader monad containing a cache and
the list of atoms-up-to-defeq encountered thus far, used for atom sorting. -/
abbrev RingM := ReaderT Cache $ StateRefT State TacticM
abbrev RingM := ReaderT Cache $ StateRefT State MetaM

def run (e : Expr) {α} (m : RingM α): TacticM α := do
let ty ← inferType e
def RingM.run (ty : Expr) (m : RingM α) : MetaM α := do
let u ← getLevel ty
(m {α := ty, univ := u}).run' {}

Expand Down Expand Up @@ -115,7 +103,7 @@ def xadd' (a : HornerExpr) (x : Expr × ℕ) (n : Expr × ℕ) (b : HornerExpr)
def reflConv (e : HornerExpr) : RingM (HornerExpr × Expr) := do (e, ← mkEqRefl e)

/-- Pretty printer for `horner_expr`. -/
def pp : HornerExpr → TacticM Format
def pp : HornerExpr → MetaM Format
| (const e c) => do
let pe ← PrettyPrinter.ppExpr Name.anonymous [] e
return "[" ++ pe ++ ", " ++ toString c ++ "]"
Expand Down Expand Up @@ -343,27 +331,27 @@ partial def evalPow : HornerExpr → Expr × ℕ → RingM (HornerExpr × Expr)
| e, (_, 0) => do
let α1 ← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 1, none]
let p ← mkAppM ``pow_zero #[e]
return (const α1 1, p)
(const α1 1, p)
| e, (_, 1) => do
let p ← mkAppM ``pow_one #[e]
return (e, p)
| (const e coeff), (e₂, m) => do
(e, p)
| const e coeff, (e₂, m) => do
let (e', p) ← NormNum.eval $ ← mkAppM ``HPow.hPow #[e, e₂]
return (const e' (coeff ^ m), p)
(const e' (coeff ^ m), p)
| he@(xadd e a x n b), m =>
match b.e.numeral? with
| some 0 => do
let n' ← mkRawNatLit (n.2 * m.2)
let h₁ ← mkEqRefl n'
let (a', h₂) ← evalPow a m
let α0 ← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 0, none]
return (← xadd' a' x (n', n.2 * m.2) (const α0 0),
(← xadd' a' x (n', n.2 * m.2) (const α0 0),
← mkAppM ``horner_pow #[a, x.1, n.1, m.1, n', a', h₁, h₂])
| _ => do
let e₂ ← mkRawNatLit (m.2 - 1)
let (tl, hl) ← evalPow he (e₂, m.2-1)
let (t, p₂) ← evalMul tl he
return (t, ← mkAppM ``pow_succ_eq #[e, e₂, tl, t, hl, p₂])
(t, ← mkAppM ``pow_succ_eq #[e, e₂, tl, t, hl, p₂])


theorem horner_atom {α} [CommSemiring α] (x : α) : x = horner 1 x 1 0 := by
Expand All @@ -374,8 +362,7 @@ def evalAtom (e : Expr) : RingM (HornerExpr × Expr) := do
let i ← addAtom e
let zero ← const (← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 0, none]) 0
let one ← const (← mkAppOptM ``OfNat.ofNat #[(← read).α, mkRawNatLit 1, none]) 1
return (← xadd' one (e,i) (mkRawNatLit 1,1) zero, ← mkAppM ``horner_atom #[e])

(← xadd' one (e,i) (mkRawNatLit 1,1) zero, ← mkAppM ``horner_atom #[e])

theorem subst_into_add {α} [Add α] (l r tl tr t)
(prl : (l : α) = tl) (prr : r = tr) (prt : tl + tr = t) : l + r = t :=
Expand All @@ -390,20 +377,20 @@ theorem subst_into_pow {α} [Monoid α] (l r tl tr t)
by rw [prl, prr, prt]

partial def eval (e : Expr) : RingM (HornerExpr × Expr) :=
match e.getAppFnAndArgs with
| some (``HAdd.hAdd, #[_,_,_,_,e₁,e₂]) => do
match e.getAppFnArgs with
| (``HAdd.hAdd, #[_,_,_,_,e₁,e₂]) => do
let (e₁', p₁) ← eval e₁
let (e₂', p₂) ← eval e₂
let (e', p') ← evalAdd e₁' e₂'
let p ← mkAppM ``subst_into_add #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
(e',p)
| some (``HMul.hMul, #[_,_,_,_,e₁,e₂]) => do
| (``HMul.hMul, #[_,_,_,_,e₁,e₂]) => do
let (e₁', p₁) ← eval e₁
let (e₂', p₂) ← eval e₂
let (e', p') ← evalMul e₁' e₂'
let p ← mkAppM ``subst_into_mul #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
return (e', p)
| some (``HPow.hPow, #[_,_,_,P,e₁,e₂]) => do
| (``HPow.hPow, #[_,_,_,P,e₁,e₂]) => do
-- let (e₂', p₂) ← lift $ norm_num.derive e₂ <|> refl_conv e₂,
let (e₂', p₂) ← (e₂, ← mkEqRefl e₂)
match e₂'.numeral?, P.getAppFn with
Expand All @@ -412,24 +399,21 @@ partial def eval (e : Expr) : RingM (HornerExpr × Expr) :=
let (e', p') ← evalPow e₁' (e₂, k)
let p ← mkAppM ``subst_into_pow #[e₁, e₂, e₁', e₂', e', p₁, p₂, p']
return (e', p)
| _, _ => do ← evalAtom e
evalAtom e
| _, _ => evalAtom e
| _ =>
match e.numeral? with
| some n => (const e n).reflConv
| _ => (evalAtom e)

| _ => evalAtom e

elab "ring" : tactic => do
let g ← getMainTarget
match g.getAppFnAndArgs with
| some (`Eq, #[ty, e₁, e₂]) =>
let ((e₁', p₁), (e₂', p₂)) ← run e₁ $ Prod.mk <$> eval e₁ <*> eval e₂
if (← isDefEq e₁' e₂') then
match g.getAppFnArgs with
| (`Eq, #[ty, e₁, e₂]) =>
let ((e₁', p₁), (e₂', p₂)) ← RingM.run ty $ do (← eval e₁, ← eval e₂)
if ← isDefEq e₁' e₂' then
let p ← mkEqTrans p₁ (← mkEqSymm p₂)
ensureHasNoMVars p
assignExprMVar (← getMainGoal) p

replaceMainGoal []
else
throwError "failed \n{← e₁'.pp}\n{← e₂'.pp}"
Expand Down

0 comments on commit 05b74b0

Please sign in to comment.