Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(Tactic/NormNum): change to isNat function #49

Merged
merged 1 commit into from Sep 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions Mathlib/Algebra/Group/Defs.lean
Expand Up @@ -18,13 +18,13 @@ import Mathlib.Tactic.Spread
class Zero (α : Type u) where
zero : α

instance [Zero α] : OfNat α (nat_lit 0) where
instance instOfNatZero [Zero α] : OfNat α (nat_lit 0) where
ofNat := Zero.zero

class One (α : Type u) where
one : α

instance [One α] : OfNat α (nat_lit 1) where
instance instOfNatOne [One α] : OfNat α (nat_lit 1) where
ofNat := One.one

class Inv (α : Type u) where
Expand Down
6 changes: 3 additions & 3 deletions Mathlib/Algebra/Ring/Basic.lean
Expand Up @@ -77,7 +77,7 @@ instance (R : Type u) [CommRing R] : CommSemiring R where
namespace Nat

instance : Numeric Nat := ⟨id⟩
@[simp] theorem ofNat_eq_Nat (n : Nat): Numeric.ofNat n = n := rfl
@[simp] theorem ofNat_eq_Nat (n : Nat) : Numeric.ofNat n = n := rfl

instance : CommSemiring Nat where
mul_comm := Nat.mul_comm
Expand All @@ -89,7 +89,7 @@ instance : CommSemiring Nat where
ofNat_zero := rfl
mul_one := Nat.mul_one
one_mul := Nat.one_mul
npow (n x) := HPow.hPow x n
npow (n x) := x ^ n
npow_zero' := Nat.pow_zero
npow_succ' n x := by simp [Nat.pow_succ, Nat.mul_comm]
one := 1
Expand All @@ -99,7 +99,7 @@ instance : CommSemiring Nat where
add_assoc := Nat.add_assoc
add_zero := Nat.add_zero
zero_add := Nat.zero_add
nsmul := HMul.hMul
nsmul := (·*·)
nsmul_zero' := Nat.zero_mul
nsmul_succ' n x := by simp [Nat.add_comm, (Nat.succ_mul n x)]
zero_mul := Nat.zero_mul
Expand Down
172 changes: 90 additions & 82 deletions Mathlib/Tactic/NormNum.lean
Expand Up @@ -35,94 +35,105 @@ def mkOfNatLit (u : Level) (α sα n : Expr) : Expr :=

namespace NormNum

theorem ofNat_nat (n : ℕ) : n = @OfNat.ofNat _ n (@Numeric.OfNat _ _ _) := rfl
set_option pp.all true
def isNat [Semiring α] (a : α) (n : ℕ) := a = OfNat.ofNat n

theorem ofNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' + b' = c → a + b = OfNat.ofNat c
class LawfulOfNat (α) [Semiring α] (n) [OfNat α n] : Prop where
isNat_ofNat : isNat (OfNat.ofNat n : α) n

instance (α) [Semiring α] : LawfulOfNat α n := ⟨rfl⟩
instance (α) [Semiring α] : LawfulOfNat α (nat_lit 0) := ⟨Semiring.ofNat_zero.symm⟩
instance (α) [Semiring α] : LawfulOfNat α (nat_lit 1) := ⟨Semiring.ofNat_one.symm⟩
instance : LawfulOfNat Nat n := ⟨rfl⟩
instance : LawfulOfNat Int n := ⟨rfl⟩

theorem isNat_rawNat (n : ℕ) : isNat n n := rfl

class LawfulZero (α) [Semiring α] [Zero α] : Prop where
isNat_zero : isNat (Zero.zero : α) (nat_lit 0)

instance (α) [Semiring α] : LawfulZero α := ⟨Semiring.ofNat_zero.symm⟩

class LawfulOne (α) [Semiring α] [One α] : Prop where
isNat_one : isNat (One.one : α) (nat_lit 1)

instance (α) [Semiring α] : LawfulOne α := ⟨Semiring.ofNat_one.symm⟩

theorem isNat_add {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
isNat a a' → isNat b b' → Nat.add a' b' = c → isNat (a + b) c
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_add _ _).symm

theorem ofNat_mul {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
a = OfNat.ofNat a' → b = OfNat.ofNat b' → a' * b' = c → a * b = OfNat.ofNat c
theorem isNat_mul {α} [Semiring α] : (a b : α) → (a' b' c : Nat) →
isNat a a' → isNat b b' → Nat.mul a' b' = c → isNat (a * b) c
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_mul _ _).symm

theorem ofNat_pow {α} [Semiring α] : (a : α) → (n a' c : Nat) →
a = OfNat.ofNat a' → a'^n = c → a ^ n = OfNat.ofNat c
| _, _, _, _, rfl, rfl => (Semiring.ofNat_pow _ _).symm

partial def eval' (e : Expr) : MetaM (Expr × Expr) := do
match e.getAppFnArgs with
| (``HAdd.hAdd, #[_, _, α, _, a, b]) => evalBinOp ``NormNum.ofNat_add (·+·) α a b
| (``HMul.hMul, #[_, _, α, _, a, b]) => evalBinOp ``NormNum.ofNat_mul (·*·) α a b
| (``HPow.hPow, #[_, _, α, _, a, n]) => evalPow ``NormNum.ofNat_pow (·^·) α a n
| (``OfNat.ofNat, #[α, ln, _]) =>
match ← ln.natLit? with
| some 0 =>
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let e ← mkOfNatLit u α nα (mkRawNatLit 0)
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
(e, p)
| some 1 =>
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let e ← mkOfNatLit u α nα (mkRawNatLit 1)
let p ← mkEqSymm (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
(e, p)
| some _ => pure (e, ← mkEqRefl e)
| none => throwError "fail"
theorem isNat_pow {α} [Semiring α] : (a : α) → (b a' b' c : Nat) →
isNat a a' → isNat b b' → Nat.pow a' b' = c → isNat (a ^ b) c
| _, _, _, _, _, rfl, rfl, rfl => (Semiring.ofNat_pow _ _).symm

def instSemiringNat : Semiring Nat := inferInstance

partial def evalIsNat (u : Level) (α sα e : Expr) : MetaM (Expr × Expr) := do
let (n, p) ← match e.getAppFnArgs with
| (``HAdd.hAdd, #[_, _, _, _, a, b]) => evalBinOp ``NormNum.isNat_add (·+·) a b
| (``HMul.hMul, #[_, _, _, _, a, b]) => evalBinOp ``NormNum.isNat_mul (·*·) a b
| (``HPow.hPow, #[_, _, _, _, a, b]) => evalPow ``NormNum.isNat_pow (·^·) a b
| (``OfNat.ofNat, #[_, ln, inst]) =>
let some n ← ln.natLit? | throwError "fail"
let lawful ← synthInstance (mkApp4 (mkConst ``LawfulOfNat [u]) α sα ln inst)
(ln, mkApp5 (mkConst ``LawfulOfNat.isNat_ofNat [u]) α sα ln inst lawful)
| (``Zero.zero, #[_, inst]) =>
let lawful ← synthInstance (mkApp3 (mkConst ``LawfulZero [u]) α sα inst)
(mkNatLit 0, mkApp4 (mkConst ``LawfulZero.isNat_zero [u]) α sα inst lawful)
| (``One.one, #[_, inst]) =>
let lawful ← synthInstance (mkApp3 (mkConst ``LawfulOne [u]) α sα inst)
(mkNatLit 1, mkApp4 (mkConst ``LawfulOne.isNat_one [u]) α sα inst lawful)
| _ =>
if e.isNatLit then
(mkOfNatLit levelZero (mkConst ``Nat) (mkConst ``Nat.instNumericNat) e,
mkApp (mkConst ``ofNat_nat) e)
if e.isNatLit then (e, mkApp (mkConst ``isNat_rawNat) e)
else throwError "fail"
(n, mkApp2 (mkConst ``id [levelZero]) (mkApp4 (mkConst ``isNat [u]) α sα e n) p)
where
evalBinOp (name : Name) (f : Nat → Nat → Nat) (α a b : Expr) : MetaM (Expr × Expr) := do
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← eval' a
let (b', pb) ← eval' b
let la := Expr.getRevArg! a' 1
let lb := Expr.getRevArg! b' 1
let lc := mkRawNatLit (f la.natLit! lb.natLit!)
let c := mkOfNatLit u α nα lc
(c, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
evalPow (name : Name) (f : Nat → Nat → Nat) (α a n : Expr) : MetaM (Expr × Expr) := do
let Level.succ u _ ← getLevel α | throwError "fail"
let nα ← synthInstance (mkApp (mkConst ``Numeric [u]) α)
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (a', pa) ← eval' a
let la := Expr.getRevArg! a' 1
let some nn ← n.numeral? | throwError "fail"
let lc := mkRawNatLit (f la.natLit! nn)
let c := mkOfNatLit u α nα lc
(c, mkApp8 (mkConst name [u]) α sα a n la lc pa (← mkEqRefl lc))
evalBinOp (name : Name) (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Expr × Expr) := do
let (la, pa) ← evalIsNat u α sα a
let (lb, pb) ← evalIsNat u α sα b
let a' := la.natLit!
let b' := lb.natLit!
let c' := f a' b'
let lc := mkRawNatLit c'
(lc, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))
evalPow (name : Name) (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Expr × Expr) := do
let (la, pa) ← evalIsNat u α sα a
let (lb, pb) ← evalIsNat levelZero (mkConst ``Nat) (mkConst ``instSemiringNat) b
let a' := la.natLit!
let b' := lb.natLit!
let c' := f a' b'
let lc := mkRawNatLit c'
(lc, mkApp10 (mkConst name [u]) α sα a b la lb lc pa pb (← mkEqRefl lc))

theorem eval_of_isNat {α} [Semiring α] (n) [OfNat α n] [LawfulOfNat α n] :
(a : α) → isNat a n → a = OfNat.ofNat n
| _, rfl => LawfulOfNat.isNat_ofNat.symm

def eval (e : Expr) : MetaM (Expr × Expr) := do
let (e', p) ← eval' e
e'.withApp fun f args => do
if f.isConstOf ``OfNat.ofNat then
let #[α,ln,_] ← args | throwError "fail"
let some n ← ln.natLit? | throwError "fail"
if n = 0 then
let Level.succ u _ ← getLevel α | throwError "fail"
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 0))
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 0) nα
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_zero [u]) α sα)
(e'', p')
else if n = 1 then
let Level.succ u _ ← getLevel α | throwError "fail"
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let nα ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α (mkRawNatLit 1))
let e'' ← mkApp3 (mkConst ``OfNat.ofNat [u]) α (mkRawNatLit 1) nα
let p' ← mkEqTrans p (mkApp2 (mkConst ``Semiring.ofNat_one [u]) α sα)
(e'', p')
else (e', p)
else (e', p)
let α ← inferType e
let Level.succ u _ ← getLevel α | throwError "fail"
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (ln, p) ← evalIsNat u α sα e
let ofNatInst ← synthInstance (mkApp2 (mkConst ``OfNat [u]) α ln)
let lawfulInst ← synthInstance (mkApp4 (mkConst ``LawfulOfNat [u]) α sα ln ofNatInst)
(mkApp3 (mkConst ``OfNat.ofNat [u]) α ln ofNatInst,
mkApp7 (mkConst ``eval_of_isNat [u]) α sα ln ofNatInst lawfulInst e p)

theorem eval_eq_of_isNat {α} [Semiring α] :
(a b : α) → (n : ℕ) → isNat a n → isNat b n → a = b
| _, _, _, rfl, rfl => rfl

def evalEq (α a b : Expr) : MetaM Expr := do
let Level.succ u _ ← getLevel α | throwError "fail"
let sα ← synthInstance (mkApp (mkConst ``Semiring [u]) α)
let (ln, pa) ← evalIsNat u α sα a
let (ln', pb) ← evalIsNat u α sα b
guard (ln.natLit! == ln'.natLit!)
mkApp7 (mkConst ``eval_eq_of_isNat [u]) α sα a b ln pa pb

end NormNum
end Meta
Expand All @@ -134,10 +145,7 @@ open Meta Elab Tactic
@[tactic normNum] def Tactic.evalNormNum : Tactic := fun stx =>
liftMetaTactic fun g => do
let some (α, lhs, rhs) ← matchEq? (← getMVarType g) | throwError "fail"
let (lhs₂, lp) ← NormNum.eval' lhs
let (rhs₂, rp) ← NormNum.eval' rhs
unless ← isDefEq lhs₂ rhs₂ do throwError "fail"
let p ← mkEqTrans lp (← mkEqSymm rp)
let p ← NormNum.evalEq α lhs rhs
assignExprMVar g p
pure []

Expand Down