Skip to content

Commit 5fbb5bb

Browse files
authored
refactor(Tactic/NormNum): change to isNat function (#49)
1 parent bd3ce2f commit 5fbb5bb

File tree

3 files changed

+95
-87
lines changed

3 files changed

+95
-87
lines changed

Mathlib/Algebra/Group/Defs.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ import Mathlib.Tactic.Spread
1818
class Zero (α : Type u) where
1919
zero : α
2020

21-
instance [Zero α] : OfNat α (nat_lit 0) where
21+
instance instOfNatZero [Zero α] : OfNat α (nat_lit 0) where
2222
ofNat := Zero.zero
2323

2424
class One (α : Type u) where
2525
one : α
2626

27-
instance [One α] : OfNat α (nat_lit 1) where
27+
instance instOfNatOne [One α] : OfNat α (nat_lit 1) where
2828
ofNat := One.one
2929

3030
class Inv (α : Type u) where

Mathlib/Algebra/Ring/Basic.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ instance (R : Type u) [CommRing R] : CommSemiring R where
7777
namespace Nat
7878

7979
instance : Numeric Nat := ⟨id⟩
80-
@[simp] theorem ofNat_eq_Nat (n : Nat): Numeric.ofNat n = n := rfl
80+
@[simp] theorem ofNat_eq_Nat (n : Nat) : Numeric.ofNat n = n := rfl
8181

8282
instance : CommSemiring Nat where
8383
mul_comm := Nat.mul_comm
@@ -89,7 +89,7 @@ instance : CommSemiring Nat where
8989
ofNat_zero := rfl
9090
mul_one := Nat.mul_one
9191
one_mul := Nat.one_mul
92-
npow (n x) := HPow.hPow x n
92+
npow (n x) := x ^ n
9393
npow_zero' := Nat.pow_zero
9494
npow_succ' n x := by simp [Nat.pow_succ, Nat.mul_comm]
9595
one := 1
@@ -99,7 +99,7 @@ instance : CommSemiring Nat where
9999
add_assoc := Nat.add_assoc
100100
add_zero := Nat.add_zero
101101
zero_add := Nat.zero_add
102-
nsmul := HMul.hMul
102+
nsmul := (·*·)
103103
nsmul_zero' := Nat.zero_mul
104104
nsmul_succ' n x := by simp [Nat.add_comm, (Nat.succ_mul n x)]
105105
zero_mul := Nat.zero_mul

Mathlib/Tactic/NormNum.lean

Lines changed: 90 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -35,94 +35,105 @@ def mkOfNatLit (u : Level) (α sα n : Expr) : Expr :=
3535

3636
namespace NormNum
3737

38-
theorem ofNat_nat (n : ℕ) : n = @OfNat.ofNat _ n (@Numeric.OfNat _ _ _) := rfl
39-
set_option pp.all true
38+
def isNat [Semiring α] (a : α) (n : ℕ) := a = OfNat.ofNat n
4039

41-
theorem ofNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
42-
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' + b' = c → a + b = OfNat.ofNat c
40+
class LawfulOfNat (α) [Semiring α] (n) [OfNat α n] : Prop where
41+
isNat_ofNat : isNat (OfNat.ofNat n : α) n
42+
43+
instance (α) [Semiring α] : LawfulOfNat α n := ⟨rfl⟩
44+
instance (α) [Semiring α] : LawfulOfNat α (nat_lit 0) := ⟨Semiring.ofNat_zero.symm⟩
45+
instance (α) [Semiring α] : LawfulOfNat α (nat_lit 1) := ⟨Semiring.ofNat_one.symm⟩
46+
instance : LawfulOfNat Nat n := ⟨rfl⟩
47+
instance : LawfulOfNat Int n := ⟨rfl⟩
48+
49+
theorem isNat_rawNat (n : ℕ) : isNat n n := rfl
50+
51+
class LawfulZero (α) [Semiring α] [Zero α] : Prop where
52+
isNat_zero : isNat (Zero.zero : α) (nat_lit 0)
53+
54+
instance (α) [Semiring α] : LawfulZero α := ⟨Semiring.ofNat_zero.symm⟩
55+
56+
class LawfulOne (α) [Semiring α] [One α] : Prop where
57+
isNat_one : isNat (One.one : α) (nat_lit 1)
58+
59+
instance (α) [Semiring α] : LawfulOne α := ⟨Semiring.ofNat_one.symm⟩
60+
61+
theorem isNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
62+
isNat a a' → isNat b b' → Nat.add a' b' = c → isNat (a + b) c
4363
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_add _ _).symm
4464

45-
theorem ofNat_mul {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
46-
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' * b' = c → a * b = OfNat.ofNat c
65+
theorem isNat_mul {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
66+
isNat a a' → isNat b b' → Nat.mul a' b' = c → isNat (a * b) c
4767
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_mul _ _).symm
4868

49-
theorem ofNat_pow {α} [Semiring α] : (a : α) → (n a' c : Nat) →
50-
a = OfNat.ofNat a' → a'^n = c → a ^ n = OfNat.ofNat c
51-
| _, _, _, _, rfl, rfl => (Semiring.ofNat_pow _ _).symm
52-
53-
partial def eval' (e : Expr) : MetaM (Expr × Expr) := do
54-
match e.getAppFnArgs with
55-
| (``HAdd.hAdd, #[_, _, α, _, a, b]) => evalBinOp ``NormNum.ofNat_add (·+·) α a b
56-
| (``HMul.hMul, #[_, _, α, _, a, b]) => evalBinOp ``NormNum.ofNat_mul (·*·) α a b
57-
| (``HPow.hPow, #[_, _, α, _, a, n]) => evalPow ``NormNum.ofNat_pow (·^·) α a n
58-
| (``OfNat.ofNat, #[α, ln, _]) =>
59-
match ← ln.natLit? with
60-
| some 0 =>
61-
let Level.succ u _ ← getLevel α | throwError "fail"
62-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
63-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
64-
let e ← mkOfNatLit u α nα (mkRawNatLit 0)
65-
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
66-
(e, p)
67-
| some 1 =>
68-
let Level.succ u _ ← getLevel α | throwError "fail"
69-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
70-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
71-
let e ← mkOfNatLit u α nα (mkRawNatLit 1)
72-
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
73-
(e, p)
74-
| some _ => pure (e, ← mkEqRefl e)
75-
| none => throwError "fail"
69+
theorem isNat_pow {α} [Semiring α] : (a : α) → (b a' b' c : Nat) →
70+
isNat a a' → isNat b b' → Nat.pow a' b' = c → isNat (a ^ b) c
71+
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_pow _ _).symm
72+
73+
def instSemiringNat : Semiring Nat := inferInstance
74+
75+
partial def evalIsNat (u : Level) (α sα e : Expr) : MetaM (Expr × Expr) := do
76+
let (n, p) ← match e.getAppFnArgs with
77+
| (``HAdd.hAdd, #[_, _, _, _, a, b]) => evalBinOp ``NormNum.isNat_add (·+·) a b
78+
| (``HMul.hMul, #[_, _, _, _, a, b]) => evalBinOp ``NormNum.isNat_mul (·*·) a b
79+
| (``HPow.hPow, #[_, _, _, _, a, b]) => evalPow ``NormNum.isNat_pow (·^·) a b
80+
| (``OfNat.ofNat, #[_, ln, inst]) =>
81+
let some n ← ln.natLit? | throwError "fail"
82+
let lawful ← synthInstance (mkApp4 (mkConst ``LawfulOfNat [u]) α sα ln inst)
83+
(ln, mkApp5 (mkConst ``LawfulOfNat.isNat_ofNat [u]) α sα ln inst lawful)
84+
| (``Zero.zero, #[_, inst]) =>
85+
let lawful ← synthInstance (mkApp3 (mkConst ``LawfulZero [u]) α sα inst)
86+
(mkNatLit 0, mkApp4 (mkConst ``LawfulZero.isNat_zero [u]) α sα inst lawful)
87+
| (``One.one, #[_, inst]) =>
88+
let lawful ← synthInstance (mkApp3 (mkConst ``LawfulOne [u]) α sα inst)
89+
(mkNatLit 1, mkApp4 (mkConst ``LawfulOne.isNat_one [u]) α sα inst lawful)
7690
| _ =>
77-
if e.isNatLit then
78-
(mkOfNatLit levelZero (mkConst ``Nat) (mkConst ``Nat.instNumericNat) e,
79-
mkApp (mkConst ``ofNat_nat) e)
91+
if e.isNatLit then (e, mkApp (mkConst ``isNat_rawNat) e)
8092
else throwError "fail"
93+
(n, mkApp2 (mkConst ``id [levelZero]) (mkApp4 (mkConst ``isNat [u]) α sα e n) p)
8194
where
82-
evalBinOp (name : Name) (f : Nat → Nat → Nat) (α a b : Expr) : MetaM (Expr × Expr) := do
83-
let Level.succ u _ ← getLevel α | throwError "fail"
84-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
85-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
86-
let (a', pa) ← eval' a
87-
let (b', pb) ← eval' b
88-
let la := Expr.getRevArg! a' 1
89-
let lb := Expr.getRevArg! b' 1
90-
let lc := mkRawNatLit (f la.natLit! lb.natLit!)
91-
let c := mkOfNatLit u α nα lc
92-
(c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
93-
evalPow (name : Name) (f : Nat → Nat → Nat) (α a n : Expr) : MetaM (Expr × Expr) := do
94-
let Level.succ u _ ← getLevel α | throwError "fail"
95-
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
96-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
97-
let (a', pa) ← eval' a
98-
let la := Expr.getRevArg! a' 1
99-
let some nn ← n.numeral? | throwError "fail"
100-
let lc := mkRawNatLit (f la.natLit! nn)
101-
let c := mkOfNatLit u α nα lc
102-
(c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))
95+
evalBinOp (name : Name) (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Expr × Expr) := do
96+
let (la, pa) ← evalIsNat u α sα a
97+
let (lb, pb) ← evalIsNat u α sα b
98+
let a' := la.natLit!
99+
let b' := lb.natLit!
100+
let c' := f a' b'
101+
let lc := mkRawNatLit c'
102+
(lc, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
103+
evalPow (name : Name) (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Expr × Expr) := do
104+
let (la, pa) ← evalIsNat u α sα a
105+
let (lb, pb) ← evalIsNat levelZero (mkConst ``Nat) (mkConst ``instSemiringNat) b
106+
let a' := la.natLit!
107+
let b' := lb.natLit!
108+
let c' := f a' b'
109+
let lc := mkRawNatLit c'
110+
(lc, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
111+
112+
theorem eval_of_isNat {α} [Semiring α] (n) [OfNat α n] [LawfulOfNat α n] :
113+
(a : α) → isNat a n → a = OfNat.ofNat n
114+
| _, rfl => LawfulOfNat.isNat_ofNat.symm
103115

104116
def eval (e : Expr) : MetaM (Expr × Expr) := do
105-
let (e', p) ← eval' e
106-
e'.withApp fun f args => do
107-
if f.isConstOf ``OfNat.ofNat then
108-
let #[α,ln,_] ← args | throwError "fail"
109-
let some n ← ln.natLit? | throwError "fail"
110-
if n = 0 then
111-
let Level.succ u _ ← getLevel α | throwError "fail"
112-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
113-
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 0))
114-
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 0) nα
115-
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
116-
(e'', p')
117-
else if n = 1 then
118-
let Level.succ u _ ← getLevel α | throwError "fail"
119-
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
120-
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 1))
121-
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 1) nα
122-
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
123-
(e'', p')
124-
else (e', p)
125-
else (e', p)
117+
let α ← inferType e
118+
let Level.succ u _ ← getLevel α | throwError "fail"
119+
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
120+
let (ln, p) ← evalIsNat u α sα e
121+
let ofNatInst ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α ln)
122+
let lawfulInst ← synthInstance (mkApp4 (mkConst ``LawfulOfNat [u]) α sα ln ofNatInst)
123+
(mkApp3 (mkConst ``OfNat.ofNat [u]) α ln ofNatInst,
124+
mkApp7 (mkConst ``eval_of_isNat [u]) α sα ln ofNatInst lawfulInst e p)
125+
126+
theorem eval_eq_of_isNat {α} [Semiring α] :
127+
(a b : α) → (n : ℕ) → isNat a n → isNat b n → a = b
128+
| _, _, _, rfl, rfl => rfl
129+
130+
def evalEq (α a b : Expr) : MetaM Expr := do
131+
let Level.succ u _ ← getLevel α | throwError "fail"
132+
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
133+
let (ln, pa) ← evalIsNat u α sα a
134+
let (ln', pb) ← evalIsNat u α sα b
135+
guard (ln.natLit! == ln'.natLit!)
136+
mkApp7 (mkConst ``eval_eq_of_isNat [u]) α sα a b ln pa pb
126137

127138
end NormNum
128139
end Meta
@@ -134,10 +145,7 @@ open Meta Elab Tactic
134145
@[tactic normNum] def Tactic.evalNormNum : Tactic := fun stx =>
135146
liftMetaTactic fun g => do
136147
let some (α, lhs, rhs) ← matchEq? (← getMVarType g) | throwError "fail"
137-
let (lhs₂, lp) ← NormNum.eval' lhs
138-
let (rhs₂, rp) ← NormNum.eval' rhs
139-
unless ← isDefEq lhs₂ rhs₂ do throwError "fail"
140-
let p ← mkEqTrans lp (← mkEqSymm rp)
148+
let p ← NormNum.evalEq α lhs rhs
141149
assignExprMVar g p
142150
pure []
143151

0 commit comments

Comments
 (0)