Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - feat(data/fin): add_comm_monoid and simp lemmas #5010

Closed
wants to merge 7 commits into from
Closed
32 changes: 31 additions & 1 deletion src/data/fin.lean
Expand Up @@ -567,7 +567,7 @@ end
lemma nat_add_zero {n : ℕ} : fin.nat_add 0 = (fin.cast (zero_add n).symm).to_rel_embedding :=
by { ext, apply zero_add }

/-- `min n m` as an element of `fin (m + 1)`. -/
/-- `min n m` as an element of `fin (m + 1)` -/
def clamp (n m : ℕ) : fin (m + 1) := of_nat $ min n m

@[simp] lemma coe_clamp (n m : ℕ) : (clamp n m : ℕ) = min n m :=
Expand Down Expand Up @@ -1318,4 +1318,34 @@ by rw [← of_nat_eq_coe]; refl
(@fin.of_nat' _ I n : ℕ) = n % m :=
rfl

section monoid

@[simp] protected lemma add_zero (k : fin (n + 1)) : k + 0 = k :=
by simp [eq_iff_veq, add_def, mod_eq_of_lt (is_lt k)]

@[simp] protected lemma zero_add (k : fin (n + 1)) : (0 : fin (n + 1)) + k = k :=
by simp [eq_iff_veq, add_def, mod_eq_of_lt (is_lt k)]

@[simp] protected lemma mul_one (k : fin (n + 1)) : k * 1 = k :=
by { cases n, simp, simp [eq_iff_veq, mul_def, mod_eq_of_lt (is_lt k)] }

@[simp] protected lemma one_mul (k : fin (n + 1)) : (1 : fin (n + 1)) * k = k :=
by { cases n, simp, simp [eq_iff_veq, mul_def, mod_eq_of_lt (is_lt k)] }

@[simp] protected lemma mul_zero (k : fin (n + 1)) : k * 0 = 0 :=
by simp [eq_iff_veq, mul_def]

@[simp] protected lemma zero_mul (k : fin (n + 1)) : (0 : fin (n + 1)) * k = 0 :=
by simp [eq_iff_veq, mul_def]

instance add_comm_monoid (n : ℕ) : add_comm_monoid (fin (n + 1)) :=
{ add := (+),
add_assoc := by simp [eq_iff_veq, add_def, add_assoc],
zero := 0,
zero_add := fin.zero_add,
add_zero := fin.add_zero,
add_comm := by simp [eq_iff_veq, add_def, add_comm] }

end monoid

end fin
25 changes: 5 additions & 20 deletions src/data/zmod/basic.lean
Expand Up @@ -53,17 +53,6 @@ begin
rwa [← int.coe_nat_lt, nat_mod, to_nat_of_nonneg (int.mod_nonneg _ h)]
end⟩⟩

/-- Additive commutative semigroup structure on `fin (n+1)`. -/
def add_comm_semigroup (n : ℕ) : add_comm_semigroup (fin (n+1)) :=
{ add_assoc := λ ⟨a, ha⟩ ⟨b, hb⟩ ⟨c, hc⟩, fin.eq_of_veq
(show ((a + b) % (n+1) + c) ≡ (a + (b + c) % (n+1)) [MOD (n+1)],
from calc ((a + b) % (n+1) + c) ≡ a + b + c [MOD (n+1)] : modeq_add (nat.mod_mod _ _) rfl
... ≡ a + (b + c) [MOD (n+1)] : by rw add_assoc
... ≡ (a + (b + c) % (n+1)) [MOD (n+1)] : modeq_add rfl (nat.mod_mod _ _).symm),
add_comm := λ ⟨a, _⟩ ⟨b, _⟩,
fin.eq_of_veq (show (a + b) % (n+1) = (b + a) % (n+1), by rw add_comm),
..fin.has_add }

/-- Multiplicative commutative semigroup structure on `fin (n+1)`. -/
def comm_semigroup (n : ℕ) : comm_semigroup (fin (n+1)) :=
{ mul_assoc := λ ⟨a, ha⟩ ⟨b, hb⟩ ⟨c, hc⟩, fin.eq_of_veq
Expand All @@ -74,7 +63,7 @@ def comm_semigroup (n : ℕ) : comm_semigroup (fin (n+1)) :=
fin.eq_of_veq (show (a * b) % (n+1) = (b * a) % (n+1), by rw mul_comm),
..fin.has_mul }

local attribute [instance] fin.add_comm_semigroup fin.comm_semigroup
local attribute [instance] fin.comm_semigroup

private lemma one_mul_aux (n : ℕ) (a : fin (n+1)) : (1 : fin (n+1)) * a = a :=
begin
Expand All @@ -94,10 +83,7 @@ private lemma left_distrib_aux (n : ℕ) : ∀ a b c : fin (n+1), a * (b + c) =

/-- Commutative ring structure on `fin (n+1)`. -/
def comm_ring (n : ℕ) : comm_ring (fin (n+1)) :=
{ zero_add := λ ⟨a, ha⟩, fin.eq_of_veq (show (0 + a) % (n+1) = a,
by rw zero_add; exact nat.mod_eq_of_lt ha),
add_zero := λ ⟨a, ha⟩, fin.eq_of_veq (nat.mod_eq_of_lt ha),
add_left_neg :=
{ add_left_neg :=
λ ⟨a, ha⟩, fin.eq_of_veq (show (((-a : ℤ) % (n+1)).to_nat + a) % (n+1) = 0,
from int.coe_nat_inj
begin
Expand All @@ -106,14 +92,13 @@ def comm_ring (n : ℕ) : comm_ring (fin (n+1)) :=
rw [int.coe_nat_mod, int.coe_nat_add, to_nat_of_nonneg (int.mod_nonneg _ hn), add_comm],
simp,
end),
one_mul := one_mul_aux n,
mul_one := λ a, by rw mul_comm; exact one_mul_aux n a,
one_mul := fin.one_mul,
mul_one := fin.mul_one,
left_distrib := left_distrib_aux n,
right_distrib := λ a b c, by rw [mul_comm, left_distrib_aux, mul_comm _ b, mul_comm]; refl,
..fin.has_zero,
..fin.has_one,
..fin.has_neg (n+1),
..fin.add_comm_semigroup n,
..fin.add_comm_monoid n,
..fin.comm_semigroup n }

end fin
Expand Down