Skip to content

Commit

Permalink
chore(data/matrix): generalisation linter (#13655)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericrbg committed Apr 25, 2022
1 parent df4066c commit 070c21b
Showing 1 changed file with 79 additions and 68 deletions.
147 changes: 79 additions & 68 deletions src/data/matrix/basic.lean
Expand Up @@ -261,7 +261,7 @@ variables {n α R}
(diagonal d).map f = diagonal (λ m, f (d m)) :=
by { ext, simp only [diagonal, map_apply], split_ifs; simp [h], }

@[simp] lemma diagonal_conj_transpose [semiring α] [star_ring α] (v : n → α) :
@[simp] lemma diagonal_conj_transpose [add_monoid α] [star_add_monoid α] (v : n → α) :
(diagonal v)ᴴ = diagonal (star v) :=
begin
rw [conj_transpose, diagonal_transpose, diagonal_map (star_zero _)],
Expand Down Expand Up @@ -300,7 +300,7 @@ section numeral
@[simp] lemma bit0_apply [has_add α] (M : matrix m m α) (i : m) (j : m) :
(bit0 M) i j = bit0 (M i j) := rfl

variables [add_monoid α] [has_one α]
variables [add_zero_class α] [has_one α]

lemma bit1_apply (M : matrix n n α) (i : n) (j : n) :
(bit1 M) i j = if i = j then bit1 (M i j) else bit0 (M i j) :=
Expand Down Expand Up @@ -337,7 +337,7 @@ lemma dot_product_assoc [fintype n] [non_unital_semiring α] (u : m → α) (w :
(λ j, u ⬝ᵥ (λ i, v i j)) ⬝ᵥ w = u ⬝ᵥ (λ i, (v i) ⬝ᵥ w) :=
by simpa [dot_product, finset.mul_sum, finset.sum_mul, mul_assoc] using finset.sum_comm

lemma dot_product_comm [comm_semiring α] (v w : m → α) :
lemma dot_product_comm [add_comm_monoid α] [comm_semigroup α] (v w : m → α) :
v ⬝ᵥ w = w ⬝ᵥ v :=
by simp_rw [dot_product, mul_comm]

Expand Down Expand Up @@ -465,6 +465,21 @@ lemma sum_apply [add_comm_monoid α] (i : m) (j : n)
(∑ c in s, g c) i j = ∑ c in s, g c i j :=
(congr_fun (s.sum_apply i g) j).trans (s.sum_apply j _)

section add_comm_monoid

variables [add_comm_monoid α] [has_mul α]

@[simp] lemma smul_mul [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α]
(a : R) (M : matrix m n α) (N : matrix n l α) :
(a • M) ⬝ N = a • M ⬝ N :=
by { ext, apply smul_dot_product }

@[simp] lemma mul_smul [fintype n] [monoid R] [distrib_mul_action R α] [smul_comm_class R α α]
(M : matrix m n α) (a : R) (N : matrix n l α) : M ⬝ (a • N) = a • M ⬝ N :=
by { ext, apply dot_product_smul }

end add_comm_monoid

section non_unital_non_assoc_semiring
variables [non_unital_non_assoc_semiring α]

Expand Down Expand Up @@ -506,6 +521,10 @@ theorem diagonal_mul_diagonal' [fintype n] [decidable_eq n] (d₁ d₂ : n →
diagonal d₁ * diagonal d₂ = diagonal (λ i, d₁ i * d₂ i) :=
diagonal_mul_diagonal _ _

lemma smul_eq_diagonal_mul [fintype m] [decidable_eq m] (M : matrix m n α) (a : α) :
a • M = diagonal (λ _, a) ⬝ M :=
by { ext, simp }

/-- Left multiplication by a matrix, as an `add_monoid_hom` from matrices to matrices. -/
@[simps] def add_monoid_hom_mul_left [fintype m] (M : matrix l m α) :
matrix m n α →+ matrix l n α :=
Expand All @@ -528,6 +547,16 @@ protected lemma mul_sum [fintype m] (s : finset β) (f : β → matrix m n α)
(M : matrix l m α) : M ⬝ ∑ a in s, f a = ∑ a in s, M ⬝ f a :=
(add_monoid_hom_mul_left M : matrix m n α →+ _).map_sum f s

/-- This instance enables use with `smul_mul_assoc`. -/
instance semiring.is_scalar_tower [fintype n] [monoid R] [distrib_mul_action R α]
[is_scalar_tower R α α] : is_scalar_tower R (matrix n n α) (matrix n n α) :=
⟨λ r m n, matrix.smul_mul r m n⟩

/-- This instance enables use with `mul_smul_comm`. -/
instance semiring.smul_comm_class [fintype n] [monoid R] [distrib_mul_action R α]
[smul_comm_class R α α] : smul_comm_class R (matrix n n α) (matrix n n α) :=
⟨λ r m n, (matrix.mul_smul m r n).symm⟩

end non_unital_non_assoc_semiring

section non_assoc_semiring
Expand Down Expand Up @@ -620,29 +649,6 @@ instance [fintype n] [decidable_eq n] [ring α] : ring (matrix n n α) :=
section semiring
variables [semiring α]

lemma smul_eq_diagonal_mul [fintype m] [decidable_eq m] (M : matrix m n α) (a : α) :
a • M = diagonal (λ _, a) ⬝ M :=
by { ext, simp }

@[simp] lemma smul_mul [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α]
(a : R) (M : matrix m n α) (N : matrix n l α) :
(a • M) ⬝ N = a • M ⬝ N :=
by { ext, apply smul_dot_product }

/-- This instance enables use with `smul_mul_assoc`. -/
instance semiring.is_scalar_tower [fintype n] [monoid R] [distrib_mul_action R α]
[is_scalar_tower R α α] : is_scalar_tower R (matrix n n α) (matrix n n α) :=
⟨λ r m n, matrix.smul_mul r m n⟩

@[simp] lemma mul_smul [fintype n] [monoid R] [distrib_mul_action R α] [smul_comm_class R α α]
(M : matrix m n α) (a : R) (N : matrix n l α) : M ⬝ (a • N) = a • M ⬝ N :=
by { ext, apply dot_product_smul }

/-- This instance enables use with `mul_smul_comm`. -/
instance semiring.smul_comm_class [fintype n] [monoid R] [distrib_mul_action R α]
[smul_comm_class R α α] : smul_comm_class R (matrix n n α) (matrix n n α) :=
⟨λ r m n, (matrix.mul_smul m r n).symm⟩

@[simp] lemma mul_mul_left [fintype n] (M : matrix m n α) (N : matrix n o α) (a : α) :
(λ i j, a * M i j) ⬝ N = a • (M ⬝ N) :=
smul_mul a M N
Expand Down Expand Up @@ -979,6 +985,10 @@ namespace matrix
def vec_mul_vec [has_mul α] (w : m → α) (v : n → α) : matrix m n α
| x y := w x * v y

lemma vec_mul_vec_eq [has_mul α] [add_comm_monoid α] (w : m → α) (v : n → α) :
vec_mul_vec w v = (col w) ⬝ (row v) :=
by { ext i j, simp only [vec_mul_vec, mul_apply, fintype.univ_punit, finset.sum_singleton], refl }

section non_unital_non_assoc_semiring
variables [non_unital_non_assoc_semiring α]

Expand Down Expand Up @@ -1027,10 +1037,6 @@ by { ext, simp [mul_vec] }
@[simp] lemma vec_mul_zero [fintype m] (v : m → α) : vec_mul v (0 : matrix m n α) = 0 :=
by { ext, simp [vec_mul] }

lemma vec_mul_vec_eq (w : m → α) (v : n → α) :
vec_mul_vec w v = (col w) ⬝ (row v) :=
by { ext i j, simp [vec_mul_vec, mul_apply], refl }

lemma smul_mul_vec_assoc [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α]
(a : R) (A : matrix m n α) (b : n → α) :
(a • A).mul_vec b = a • (A.mul_vec b) :=
Expand All @@ -1052,13 +1058,13 @@ lemma add_vec_mul [fintype m] (A : matrix m n α) (x y : m → α) :
vec_mul (x + y) A = vec_mul x A + vec_mul y A :=
by { ext, apply add_dot_product }

lemma vec_mul_smul [fintype m] [comm_semiring R] [semiring S] [algebra R S]
(M : matrix m n S) (b : R) (v : m → S) :
lemma vec_mul_smul [fintype n] [monoid R] [non_unital_non_assoc_semiring S] [distrib_mul_action R S]
[is_scalar_tower R S S] (M : matrix n m S) (b : R) (v : n → S) :
M.vec_mul (b • v) = b • M.vec_mul v :=
by { ext i, simp only [vec_mul, dot_product, finset.smul_sum, pi.smul_apply, smul_mul_assoc] }

lemma mul_vec_smul [fintype n] [comm_semiring R] [semiring S] [algebra R S]
(M : matrix m n S) (b : R) (v : n → S) :
lemma mul_vec_smul [fintype n] [monoid R] [non_unital_non_assoc_semiring S] [distrib_mul_action R S]
[smul_comm_class R S S] (M : matrix m n S) (b : R) (v : n → S) :
M.mul_vec (b • v) = b • M.mul_vec v :=
by { ext i, simp only [mul_vec, dot_product, finset.smul_sum, pi.smul_apply, mul_smul_comm] }

Expand Down Expand Up @@ -1158,8 +1164,8 @@ by { ext i j, simp }
(M - N)ᵀ = Mᵀ - Nᵀ :=
by { ext i j, simp }

@[simp] lemma transpose_mul [comm_semiring α] [fintype n] (M : matrix m n α) (N : matrix n l α) :
(M ⬝ N)ᵀ = Nᵀ ⬝ Mᵀ :=
@[simp] lemma transpose_mul [add_comm_monoid α] [comm_semigroup α] [fintype n]
(M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᵀ = Nᵀ ⬝ Mᵀ :=
begin
ext i j,
apply dot_product_comm
Expand Down Expand Up @@ -1202,7 +1208,8 @@ lemma transpose_sum [add_comm_monoid α] {ι : Type*} (s : finset ι) (M : ι

/-- `matrix.transpose` as a `ring_equiv` to the opposite ring -/
@[simps]
def transpose_ring_equiv [comm_semiring α] [fintype m] : matrix m m α ≃+* (matrix m m α)ᵐᵒᵖ :=
def transpose_ring_equiv [add_comm_monoid α] [comm_semigroup α] [fintype m] :
matrix m m α ≃+* (matrix m m α)ᵐᵒᵖ :=
{ to_fun := λ M, mul_opposite.op (Mᵀ),
inv_fun := λ M, M.unopᵀ,
map_mul' := λ M N, (congr_arg mul_opposite.op (transpose_mul M N)).trans
Expand All @@ -1229,30 +1236,34 @@ open_locale matrix

@[simp] lemma conj_transpose_conj_transpose [has_involutive_star α] (M : matrix m n α) :
Mᴴᴴ = M :=
by ext; simp
matrix.ext $ by simp

@[simp] lemma conj_transpose_zero [semiring α] [star_ring α] : (0 : matrix m n α)ᴴ = 0 :=
by ext i j; simp
@[simp] lemma conj_transpose_zero [add_monoid α] [star_add_monoid α] : (0 : matrix m n α)ᴴ = 0 :=
matrix.ext $ by simp

@[simp] lemma conj_transpose_one [decidable_eq n] [semiring α] [star_ring α]:
(1 : matrix n n α)ᴴ = 1 :=
by simp [conj_transpose]

@[simp] lemma conj_transpose_add [add_monoid α] [star_add_monoid α] (M N : matrix m n α) :
(M + N)ᴴ = Mᴴ + Nᴴ := by ext i j; simp
(M + N)ᴴ = Mᴴ + Nᴴ :=
matrix.ext $ by simp

@[simp] lemma conj_transpose_sub [add_group α] [star_add_monoid α] (M N : matrix m n α) :
(M - N)ᴴ = Mᴴ - Nᴴ := by ext i j; simp
(M - N)ᴴ = Mᴴ - Nᴴ :=
matrix.ext $ by simp

@[simp] lemma conj_transpose_smul [comm_monoid α] [star_semigroup α] (c : α) (M : matrix m n α) :
@[simp] lemma conj_transpose_smul [comm_semigroup α] [star_semigroup α] (c : α) (M : matrix m n α) :
(c • M)ᴴ = (star c) • Mᴴ :=
by ext i j; simp [mul_comm]
matrix.ext $ by simp

@[simp] lemma conj_transpose_mul [fintype n] [semiring α] [star_ring α]
(M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᴴ = Nᴴ ⬝ Mᴴ := by ext i j; simp [mul_apply]
@[simp] lemma conj_transpose_mul [fintype n] [non_unital_semiring α] [star_ring α]
(M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᴴ = Nᴴ ⬝ Mᴴ :=
matrix.ext $ by simp [mul_apply]

@[simp] lemma conj_transpose_neg [non_unital_ring α] [star_ring α] (M : matrix m n α) :
(- M)ᴴ = - Mᴴ := by ext i j; simp
@[simp] lemma conj_transpose_neg [add_group α] [star_add_monoid α] (M : matrix m n α) :
(- M)ᴴ = - Mᴴ :=
matrix.ext $ by simp

/-- `matrix.conj_transpose` as an `add_equiv` -/
@[simps apply]
Expand Down Expand Up @@ -1321,7 +1332,7 @@ instance [fintype n] [semiring α] [star_ring α] : star_ring (matrix n n α) :=
star_mul := conj_transpose_mul, }

/-- A version of `star_mul` for `⬝` instead of `*`. -/
lemma star_mul [fintype n] [semiring α] [star_ring α] (M N : matrix n n α) :
lemma star_mul [fintype n] [non_unital_semiring α] [star_ring α] (M N : matrix n n α) :
star (M ⬝ N) = star N ⬝ star M := conj_transpose_mul _ _

end star
Expand Down Expand Up @@ -1363,12 +1374,10 @@ lemma minor_neg [has_neg α] (A : matrix m n α) :
lemma minor_sub [has_sub α] (A B : matrix m n α) :
((A - B).minor : (l → m) → (o → n) → matrix l o α) = A.minor - B.minor := rfl

@[simp]
lemma minor_zero [has_zero α] :
@[simp] lemma minor_zero [has_zero α] :
((0 : matrix m n α).minor : (l → m) → (o → n) → matrix l o α) = 0 := rfl

lemma minor_smul {R : Type*} [semiring R] [add_comm_monoid α] [module R α] (r : R)
(A : matrix m n α) :
lemma minor_smul {R : Type*} [has_scalar R α] (r : R) (A : matrix m n α) :
((r • A : matrix m n α).minor : (l → m) → (o → n) → matrix l o α) = r • A.minor := rfl

lemma minor_map (f : α → β) (e₁ : l → m) (e₂ : o → n) (A : matrix m n α) :
Expand All @@ -1391,7 +1400,7 @@ lemma minor_one [has_zero α] [has_one α] [decidable_eq m] [decidable_eq l] (e
(1 : matrix m m α).minor e e = 1 :=
minor_diagonal _ e he

lemma minor_mul [fintype n] [fintype o] [semiring α] {p q : Type*}
lemma minor_mul [fintype n] [fintype o] [has_mul α] [add_comm_monoid α] {p q : Type*}
(M : matrix m n α) (N : matrix n p α)
(e₁ : l → m) (e₂ : o → n) (e₃ : q → p) (he₂ : function.bijective e₂) :
(M ⬝ N).minor e₁ e₃ = (M.minor e₁ e₂) ⬝ (N.minor e₂ e₃) :=
Expand Down Expand Up @@ -1424,13 +1433,14 @@ lemma minor_one_equiv [has_zero α] [has_one α] [decidable_eq m] [decidable_eq
minor_one e e.injective

@[simp]
lemma minor_mul_equiv [fintype n] [fintype o] [semiring α] {p q : Type*}
lemma minor_mul_equiv [fintype n] [fintype o] [add_comm_monoid α] [has_mul α] {p q : Type*}
(M : matrix m n α) (N : matrix n p α) (e₁ : l → m) (e₂ : o ≃ n) (e₃ : q → p) :
(M.minor e₁ e₂) ⬝ (N.minor e₂ e₃) = (M ⬝ N).minor e₁ e₃ :=
(minor_mul M N e₁ e₂ e₃ e₂.bijective).symm

lemma mul_minor_one [fintype n] [fintype o] [semiring α] [decidable_eq o] (e₁ : n ≃ o) (e₂ : l → o)
(M : matrix m n α) : M ⬝ (1 : matrix o o α).minor e₁ e₂ = minor M id (e₁.symm ∘ e₂) :=
lemma mul_minor_one [fintype n] [fintype o] [non_assoc_semiring α] [decidable_eq o] (e₁ : n ≃ o)
(e₂ : l → o) (M : matrix m n α) :
M ⬝ (1 : matrix o o α).minor e₁ e₂ = minor M id (e₁.symm ∘ e₂) :=
begin
let A := M.minor id e₁.symm,
have : M = A.minor id e₁,
Expand All @@ -1440,8 +1450,9 @@ begin
equiv.symm_comp_self],
end

lemma one_minor_mul [fintype m] [fintype o] [semiring α] [decidable_eq o] (e₁ : l → o) (e₂ : m ≃ o)
(M : matrix m n α) : ((1 : matrix o o α).minor e₁ e₂).mul M = minor M (e₂.symm ∘ e₁) id :=
lemma one_minor_mul [fintype m] [fintype o] [non_assoc_semiring α] [decidable_eq o] (e₁ : l → o)
(e₂ : m ≃ o) (M : matrix m n α) :
((1 : matrix o o α).minor e₁ e₂).mul M = minor M (e₂.symm ∘ e₁) id :=
begin
let A := M.minor e₂.symm id,
have : M = A.minor e₂ id,
Expand Down Expand Up @@ -1485,7 +1496,7 @@ lemma conj_transpose_reindex [has_star α] (eₘ : m ≃ l) (eₙ : n ≃ o) (M
rfl

@[simp]
lemma minor_mul_transpose_minor [fintype m] [fintype n] [semiring α]
lemma minor_mul_transpose_minor [fintype m] [fintype n] [add_comm_monoid α] [has_mul α]
(e : m ≃ n) (M : matrix m n α) :
(M.minor id e) ⬝ (Mᵀ).minor e id = M ⬝ Mᵀ :=
by rw [minor_mul_equiv, minor_id_id]
Expand Down Expand Up @@ -1562,13 +1573,13 @@ lemma conj_transpose_col [has_star α] (v : m → α) : (col v)ᴴ = row (star v
@[simp]
lemma conj_transpose_row [has_star α] (v : m → α) : (row v)ᴴ = col (star v) := by { ext, refl }

lemma row_vec_mul [fintype m] [semiring α] (M : matrix m n α) (v : m → α) :
lemma row_vec_mul [fintype m] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : m → α) :
matrix.row (matrix.vec_mul v M) = matrix.row v ⬝ M := by {ext, refl}
lemma col_vec_mul [fintype m] [semiring α] (M : matrix m n α) (v : m → α) :
lemma col_vec_mul [fintype m] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : m → α) :
matrix.col (matrix.vec_mul v M) = (matrix.row v ⬝ M)ᵀ := by {ext, refl}
lemma col_mul_vec [fintype n] [semiring α] (M : matrix m n α) (v : n → α) :
lemma col_mul_vec [fintype n] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : n → α) :
matrix.col (matrix.mul_vec M v) = M ⬝ matrix.col v := by {ext, refl}
lemma row_mul_vec [fintype n] [semiring α] (M : matrix m n α) (v : n → α) :
lemma row_mul_vec [fintype n] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : n → α) :
matrix.row (matrix.mul_vec M v) = (M ⬝ matrix.col v)ᵀ := by {ext, refl}

@[simp]
Expand Down Expand Up @@ -1695,22 +1706,22 @@ end update
end matrix

namespace ring_hom
variables [fintype n] [semiring α] [semiring β]
variables [fintype n] [non_assoc_semiring α] [non_assoc_semiring β]

lemma map_matrix_mul (M : matrix m n α) (N : matrix n o α) (i : m) (j : o) (f : α →+* β) :
f (matrix.mul M N i j) = matrix.mul (λ i j, f (M i j)) (λ i j, f (N i j)) i j :=
by simp [matrix.mul_apply, ring_hom.map_sum]

lemma map_dot_product [semiring R] [semiring S] (f : R →+* S) (v w : n → R) :
lemma map_dot_product [non_assoc_semiring R] [non_assoc_semiring S] (f : R →+* S) (v w : n → R) :
f (v ⬝ᵥ w) = (f ∘ v) ⬝ᵥ (f ∘ w) :=
by simp only [matrix.dot_product, f.map_sum, f.map_mul]

lemma map_vec_mul [semiring R] [semiring S]
lemma map_vec_mul [non_assoc_semiring R] [non_assoc_semiring S]
(f : R →+* S) (M : matrix n m R) (v : n → R) (i : m) :
f (M.vec_mul v i) = ((M.map f).vec_mul (f ∘ v) i) :=
by simp only [matrix.vec_mul, matrix.map_apply, ring_hom.map_dot_product]

lemma map_mul_vec [semiring R] [semiring S]
lemma map_mul_vec [non_assoc_semiring R] [non_assoc_semiring S]
(f : R →+* S) (M : matrix m n R) (v : n → R) (i : m) :
f (M.mul_vec v i) = ((M.map f).mul_vec (f ∘ v) i) :=
by simp only [matrix.mul_vec, matrix.map_apply, ring_hom.map_dot_product]
Expand Down

0 comments on commit 070c21b

Please sign in to comment.