From 070c21b6b705d908cc9203fa0b93865563f16bb4 Mon Sep 17 00:00:00 2001 From: Eric Rodriguez Date: Mon, 25 Apr 2022 05:10:33 +0000 Subject: [PATCH] chore(data/matrix): generalisation linter (#13655) --- src/data/matrix/basic.lean | 147 ++++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 68 deletions(-) diff --git a/src/data/matrix/basic.lean b/src/data/matrix/basic.lean index c272c62249e75..283fe8f91ec52 100644 --- a/src/data/matrix/basic.lean +++ b/src/data/matrix/basic.lean @@ -261,7 +261,7 @@ variables {n α R} (diagonal d).map f = diagonal (λ m, f (d m)) := by { ext, simp only [diagonal, map_apply], split_ifs; simp [h], } -@[simp] lemma diagonal_conj_transpose [semiring α] [star_ring α] (v : n → α) : +@[simp] lemma diagonal_conj_transpose [add_monoid α] [star_add_monoid α] (v : n → α) : (diagonal v)ᴴ = diagonal (star v) := begin rw [conj_transpose, diagonal_transpose, diagonal_map (star_zero _)], @@ -300,7 +300,7 @@ section numeral @[simp] lemma bit0_apply [has_add α] (M : matrix m m α) (i : m) (j : m) : (bit0 M) i j = bit0 (M i j) := rfl -variables [add_monoid α] [has_one α] +variables [add_zero_class α] [has_one α] lemma bit1_apply (M : matrix n n α) (i : n) (j : n) : (bit1 M) i j = if i = j then bit1 (M i j) else bit0 (M i j) := @@ -337,7 +337,7 @@ lemma dot_product_assoc [fintype n] [non_unital_semiring α] (u : m → α) (w : (λ j, u ⬝ᵥ (λ i, v i j)) ⬝ᵥ w = u ⬝ᵥ (λ i, (v i) ⬝ᵥ w) := by simpa [dot_product, finset.mul_sum, finset.sum_mul, mul_assoc] using finset.sum_comm -lemma dot_product_comm [comm_semiring α] (v w : m → α) : +lemma dot_product_comm [add_comm_monoid α] [comm_semigroup α] (v w : m → α) : v ⬝ᵥ w = w ⬝ᵥ v := by simp_rw [dot_product, mul_comm] @@ -465,6 +465,21 @@ lemma sum_apply [add_comm_monoid α] (i : m) (j : n) (∑ c in s, g c) i j = ∑ c in s, g c i j := (congr_fun (s.sum_apply i g) j).trans (s.sum_apply j _) +section add_comm_monoid + +variables [add_comm_monoid α] [has_mul α] + +@[simp] lemma smul_mul [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α] + (a : R) (M : matrix m n α) (N : matrix n l α) : + (a • M) ⬝ N = a • M ⬝ N := +by { ext, apply smul_dot_product } + +@[simp] lemma mul_smul [fintype n] [monoid R] [distrib_mul_action R α] [smul_comm_class R α α] + (M : matrix m n α) (a : R) (N : matrix n l α) : M ⬝ (a • N) = a • M ⬝ N := +by { ext, apply dot_product_smul } + +end add_comm_monoid + section non_unital_non_assoc_semiring variables [non_unital_non_assoc_semiring α] @@ -506,6 +521,10 @@ theorem diagonal_mul_diagonal' [fintype n] [decidable_eq n] (d₁ d₂ : n → diagonal d₁ * diagonal d₂ = diagonal (λ i, d₁ i * d₂ i) := diagonal_mul_diagonal _ _ +lemma smul_eq_diagonal_mul [fintype m] [decidable_eq m] (M : matrix m n α) (a : α) : + a • M = diagonal (λ _, a) ⬝ M := +by { ext, simp } + /-- Left multiplication by a matrix, as an `add_monoid_hom` from matrices to matrices. -/ @[simps] def add_monoid_hom_mul_left [fintype m] (M : matrix l m α) : matrix m n α →+ matrix l n α := @@ -528,6 +547,16 @@ protected lemma mul_sum [fintype m] (s : finset β) (f : β → matrix m n α) (M : matrix l m α) : M ⬝ ∑ a in s, f a = ∑ a in s, M ⬝ f a := (add_monoid_hom_mul_left M : matrix m n α →+ _).map_sum f s +/-- This instance enables use with `smul_mul_assoc`. -/ +instance semiring.is_scalar_tower [fintype n] [monoid R] [distrib_mul_action R α] + [is_scalar_tower R α α] : is_scalar_tower R (matrix n n α) (matrix n n α) := +⟨λ r m n, matrix.smul_mul r m n⟩ + +/-- This instance enables use with `mul_smul_comm`. -/ +instance semiring.smul_comm_class [fintype n] [monoid R] [distrib_mul_action R α] + [smul_comm_class R α α] : smul_comm_class R (matrix n n α) (matrix n n α) := +⟨λ r m n, (matrix.mul_smul m r n).symm⟩ + end non_unital_non_assoc_semiring section non_assoc_semiring @@ -620,29 +649,6 @@ instance [fintype n] [decidable_eq n] [ring α] : ring (matrix n n α) := section semiring variables [semiring α] -lemma smul_eq_diagonal_mul [fintype m] [decidable_eq m] (M : matrix m n α) (a : α) : - a • M = diagonal (λ _, a) ⬝ M := -by { ext, simp } - -@[simp] lemma smul_mul [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α] - (a : R) (M : matrix m n α) (N : matrix n l α) : - (a • M) ⬝ N = a • M ⬝ N := -by { ext, apply smul_dot_product } - -/-- This instance enables use with `smul_mul_assoc`. -/ -instance semiring.is_scalar_tower [fintype n] [monoid R] [distrib_mul_action R α] - [is_scalar_tower R α α] : is_scalar_tower R (matrix n n α) (matrix n n α) := -⟨λ r m n, matrix.smul_mul r m n⟩ - -@[simp] lemma mul_smul [fintype n] [monoid R] [distrib_mul_action R α] [smul_comm_class R α α] - (M : matrix m n α) (a : R) (N : matrix n l α) : M ⬝ (a • N) = a • M ⬝ N := -by { ext, apply dot_product_smul } - -/-- This instance enables use with `mul_smul_comm`. -/ -instance semiring.smul_comm_class [fintype n] [monoid R] [distrib_mul_action R α] - [smul_comm_class R α α] : smul_comm_class R (matrix n n α) (matrix n n α) := -⟨λ r m n, (matrix.mul_smul m r n).symm⟩ - @[simp] lemma mul_mul_left [fintype n] (M : matrix m n α) (N : matrix n o α) (a : α) : (λ i j, a * M i j) ⬝ N = a • (M ⬝ N) := smul_mul a M N @@ -979,6 +985,10 @@ namespace matrix def vec_mul_vec [has_mul α] (w : m → α) (v : n → α) : matrix m n α | x y := w x * v y +lemma vec_mul_vec_eq [has_mul α] [add_comm_monoid α] (w : m → α) (v : n → α) : + vec_mul_vec w v = (col w) ⬝ (row v) := +by { ext i j, simp only [vec_mul_vec, mul_apply, fintype.univ_punit, finset.sum_singleton], refl } + section non_unital_non_assoc_semiring variables [non_unital_non_assoc_semiring α] @@ -1027,10 +1037,6 @@ by { ext, simp [mul_vec] } @[simp] lemma vec_mul_zero [fintype m] (v : m → α) : vec_mul v (0 : matrix m n α) = 0 := by { ext, simp [vec_mul] } -lemma vec_mul_vec_eq (w : m → α) (v : n → α) : - vec_mul_vec w v = (col w) ⬝ (row v) := -by { ext i j, simp [vec_mul_vec, mul_apply], refl } - lemma smul_mul_vec_assoc [fintype n] [monoid R] [distrib_mul_action R α] [is_scalar_tower R α α] (a : R) (A : matrix m n α) (b : n → α) : (a • A).mul_vec b = a • (A.mul_vec b) := @@ -1052,13 +1058,13 @@ lemma add_vec_mul [fintype m] (A : matrix m n α) (x y : m → α) : vec_mul (x + y) A = vec_mul x A + vec_mul y A := by { ext, apply add_dot_product } -lemma vec_mul_smul [fintype m] [comm_semiring R] [semiring S] [algebra R S] - (M : matrix m n S) (b : R) (v : m → S) : +lemma vec_mul_smul [fintype n] [monoid R] [non_unital_non_assoc_semiring S] [distrib_mul_action R S] + [is_scalar_tower R S S] (M : matrix n m S) (b : R) (v : n → S) : M.vec_mul (b • v) = b • M.vec_mul v := by { ext i, simp only [vec_mul, dot_product, finset.smul_sum, pi.smul_apply, smul_mul_assoc] } -lemma mul_vec_smul [fintype n] [comm_semiring R] [semiring S] [algebra R S] - (M : matrix m n S) (b : R) (v : n → S) : +lemma mul_vec_smul [fintype n] [monoid R] [non_unital_non_assoc_semiring S] [distrib_mul_action R S] + [smul_comm_class R S S] (M : matrix m n S) (b : R) (v : n → S) : M.mul_vec (b • v) = b • M.mul_vec v := by { ext i, simp only [mul_vec, dot_product, finset.smul_sum, pi.smul_apply, mul_smul_comm] } @@ -1158,8 +1164,8 @@ by { ext i j, simp } (M - N)ᵀ = Mᵀ - Nᵀ := by { ext i j, simp } -@[simp] lemma transpose_mul [comm_semiring α] [fintype n] (M : matrix m n α) (N : matrix n l α) : - (M ⬝ N)ᵀ = Nᵀ ⬝ Mᵀ := +@[simp] lemma transpose_mul [add_comm_monoid α] [comm_semigroup α] [fintype n] + (M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᵀ = Nᵀ ⬝ Mᵀ := begin ext i j, apply dot_product_comm @@ -1202,7 +1208,8 @@ lemma transpose_sum [add_comm_monoid α] {ι : Type*} (s : finset ι) (M : ι /-- `matrix.transpose` as a `ring_equiv` to the opposite ring -/ @[simps] -def transpose_ring_equiv [comm_semiring α] [fintype m] : matrix m m α ≃+* (matrix m m α)ᵐᵒᵖ := +def transpose_ring_equiv [add_comm_monoid α] [comm_semigroup α] [fintype m] : + matrix m m α ≃+* (matrix m m α)ᵐᵒᵖ := { to_fun := λ M, mul_opposite.op (Mᵀ), inv_fun := λ M, M.unopᵀ, map_mul' := λ M N, (congr_arg mul_opposite.op (transpose_mul M N)).trans @@ -1229,30 +1236,34 @@ open_locale matrix @[simp] lemma conj_transpose_conj_transpose [has_involutive_star α] (M : matrix m n α) : Mᴴᴴ = M := -by ext; simp +matrix.ext $ by simp -@[simp] lemma conj_transpose_zero [semiring α] [star_ring α] : (0 : matrix m n α)ᴴ = 0 := -by ext i j; simp +@[simp] lemma conj_transpose_zero [add_monoid α] [star_add_monoid α] : (0 : matrix m n α)ᴴ = 0 := +matrix.ext $ by simp @[simp] lemma conj_transpose_one [decidable_eq n] [semiring α] [star_ring α]: (1 : matrix n n α)ᴴ = 1 := by simp [conj_transpose] @[simp] lemma conj_transpose_add [add_monoid α] [star_add_monoid α] (M N : matrix m n α) : - (M + N)ᴴ = Mᴴ + Nᴴ := by ext i j; simp + (M + N)ᴴ = Mᴴ + Nᴴ := +matrix.ext $ by simp @[simp] lemma conj_transpose_sub [add_group α] [star_add_monoid α] (M N : matrix m n α) : - (M - N)ᴴ = Mᴴ - Nᴴ := by ext i j; simp + (M - N)ᴴ = Mᴴ - Nᴴ := +matrix.ext $ by simp -@[simp] lemma conj_transpose_smul [comm_monoid α] [star_semigroup α] (c : α) (M : matrix m n α) : +@[simp] lemma conj_transpose_smul [comm_semigroup α] [star_semigroup α] (c : α) (M : matrix m n α) : (c • M)ᴴ = (star c) • Mᴴ := -by ext i j; simp [mul_comm] +matrix.ext $ by simp -@[simp] lemma conj_transpose_mul [fintype n] [semiring α] [star_ring α] - (M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᴴ = Nᴴ ⬝ Mᴴ := by ext i j; simp [mul_apply] +@[simp] lemma conj_transpose_mul [fintype n] [non_unital_semiring α] [star_ring α] + (M : matrix m n α) (N : matrix n l α) : (M ⬝ N)ᴴ = Nᴴ ⬝ Mᴴ := +matrix.ext $ by simp [mul_apply] -@[simp] lemma conj_transpose_neg [non_unital_ring α] [star_ring α] (M : matrix m n α) : - (- M)ᴴ = - Mᴴ := by ext i j; simp +@[simp] lemma conj_transpose_neg [add_group α] [star_add_monoid α] (M : matrix m n α) : + (- M)ᴴ = - Mᴴ := +matrix.ext $ by simp /-- `matrix.conj_transpose` as an `add_equiv` -/ @[simps apply] @@ -1321,7 +1332,7 @@ instance [fintype n] [semiring α] [star_ring α] : star_ring (matrix n n α) := star_mul := conj_transpose_mul, } /-- A version of `star_mul` for `⬝` instead of `*`. -/ -lemma star_mul [fintype n] [semiring α] [star_ring α] (M N : matrix n n α) : +lemma star_mul [fintype n] [non_unital_semiring α] [star_ring α] (M N : matrix n n α) : star (M ⬝ N) = star N ⬝ star M := conj_transpose_mul _ _ end star @@ -1363,12 +1374,10 @@ lemma minor_neg [has_neg α] (A : matrix m n α) : lemma minor_sub [has_sub α] (A B : matrix m n α) : ((A - B).minor : (l → m) → (o → n) → matrix l o α) = A.minor - B.minor := rfl -@[simp] -lemma minor_zero [has_zero α] : +@[simp] lemma minor_zero [has_zero α] : ((0 : matrix m n α).minor : (l → m) → (o → n) → matrix l o α) = 0 := rfl -lemma minor_smul {R : Type*} [semiring R] [add_comm_monoid α] [module R α] (r : R) - (A : matrix m n α) : +lemma minor_smul {R : Type*} [has_scalar R α] (r : R) (A : matrix m n α) : ((r • A : matrix m n α).minor : (l → m) → (o → n) → matrix l o α) = r • A.minor := rfl lemma minor_map (f : α → β) (e₁ : l → m) (e₂ : o → n) (A : matrix m n α) : @@ -1391,7 +1400,7 @@ lemma minor_one [has_zero α] [has_one α] [decidable_eq m] [decidable_eq l] (e (1 : matrix m m α).minor e e = 1 := minor_diagonal _ e he -lemma minor_mul [fintype n] [fintype o] [semiring α] {p q : Type*} +lemma minor_mul [fintype n] [fintype o] [has_mul α] [add_comm_monoid α] {p q : Type*} (M : matrix m n α) (N : matrix n p α) (e₁ : l → m) (e₂ : o → n) (e₃ : q → p) (he₂ : function.bijective e₂) : (M ⬝ N).minor e₁ e₃ = (M.minor e₁ e₂) ⬝ (N.minor e₂ e₃) := @@ -1424,13 +1433,14 @@ lemma minor_one_equiv [has_zero α] [has_one α] [decidable_eq m] [decidable_eq minor_one e e.injective @[simp] -lemma minor_mul_equiv [fintype n] [fintype o] [semiring α] {p q : Type*} +lemma minor_mul_equiv [fintype n] [fintype o] [add_comm_monoid α] [has_mul α] {p q : Type*} (M : matrix m n α) (N : matrix n p α) (e₁ : l → m) (e₂ : o ≃ n) (e₃ : q → p) : (M.minor e₁ e₂) ⬝ (N.minor e₂ e₃) = (M ⬝ N).minor e₁ e₃ := (minor_mul M N e₁ e₂ e₃ e₂.bijective).symm -lemma mul_minor_one [fintype n] [fintype o] [semiring α] [decidable_eq o] (e₁ : n ≃ o) (e₂ : l → o) - (M : matrix m n α) : M ⬝ (1 : matrix o o α).minor e₁ e₂ = minor M id (e₁.symm ∘ e₂) := +lemma mul_minor_one [fintype n] [fintype o] [non_assoc_semiring α] [decidable_eq o] (e₁ : n ≃ o) + (e₂ : l → o) (M : matrix m n α) : + M ⬝ (1 : matrix o o α).minor e₁ e₂ = minor M id (e₁.symm ∘ e₂) := begin let A := M.minor id e₁.symm, have : M = A.minor id e₁, @@ -1440,8 +1450,9 @@ begin equiv.symm_comp_self], end -lemma one_minor_mul [fintype m] [fintype o] [semiring α] [decidable_eq o] (e₁ : l → o) (e₂ : m ≃ o) - (M : matrix m n α) : ((1 : matrix o o α).minor e₁ e₂).mul M = minor M (e₂.symm ∘ e₁) id := +lemma one_minor_mul [fintype m] [fintype o] [non_assoc_semiring α] [decidable_eq o] (e₁ : l → o) + (e₂ : m ≃ o) (M : matrix m n α) : + ((1 : matrix o o α).minor e₁ e₂).mul M = minor M (e₂.symm ∘ e₁) id := begin let A := M.minor e₂.symm id, have : M = A.minor e₂ id, @@ -1485,7 +1496,7 @@ lemma conj_transpose_reindex [has_star α] (eₘ : m ≃ l) (eₙ : n ≃ o) (M rfl @[simp] -lemma minor_mul_transpose_minor [fintype m] [fintype n] [semiring α] +lemma minor_mul_transpose_minor [fintype m] [fintype n] [add_comm_monoid α] [has_mul α] (e : m ≃ n) (M : matrix m n α) : (M.minor id e) ⬝ (Mᵀ).minor e id = M ⬝ Mᵀ := by rw [minor_mul_equiv, minor_id_id] @@ -1562,13 +1573,13 @@ lemma conj_transpose_col [has_star α] (v : m → α) : (col v)ᴴ = row (star v @[simp] lemma conj_transpose_row [has_star α] (v : m → α) : (row v)ᴴ = col (star v) := by { ext, refl } -lemma row_vec_mul [fintype m] [semiring α] (M : matrix m n α) (v : m → α) : +lemma row_vec_mul [fintype m] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : m → α) : matrix.row (matrix.vec_mul v M) = matrix.row v ⬝ M := by {ext, refl} -lemma col_vec_mul [fintype m] [semiring α] (M : matrix m n α) (v : m → α) : +lemma col_vec_mul [fintype m] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : m → α) : matrix.col (matrix.vec_mul v M) = (matrix.row v ⬝ M)ᵀ := by {ext, refl} -lemma col_mul_vec [fintype n] [semiring α] (M : matrix m n α) (v : n → α) : +lemma col_mul_vec [fintype n] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : n → α) : matrix.col (matrix.mul_vec M v) = M ⬝ matrix.col v := by {ext, refl} -lemma row_mul_vec [fintype n] [semiring α] (M : matrix m n α) (v : n → α) : +lemma row_mul_vec [fintype n] [non_unital_non_assoc_semiring α] (M : matrix m n α) (v : n → α) : matrix.row (matrix.mul_vec M v) = (M ⬝ matrix.col v)ᵀ := by {ext, refl} @[simp] @@ -1695,22 +1706,22 @@ end update end matrix namespace ring_hom -variables [fintype n] [semiring α] [semiring β] +variables [fintype n] [non_assoc_semiring α] [non_assoc_semiring β] lemma map_matrix_mul (M : matrix m n α) (N : matrix n o α) (i : m) (j : o) (f : α →+* β) : f (matrix.mul M N i j) = matrix.mul (λ i j, f (M i j)) (λ i j, f (N i j)) i j := by simp [matrix.mul_apply, ring_hom.map_sum] -lemma map_dot_product [semiring R] [semiring S] (f : R →+* S) (v w : n → R) : +lemma map_dot_product [non_assoc_semiring R] [non_assoc_semiring S] (f : R →+* S) (v w : n → R) : f (v ⬝ᵥ w) = (f ∘ v) ⬝ᵥ (f ∘ w) := by simp only [matrix.dot_product, f.map_sum, f.map_mul] -lemma map_vec_mul [semiring R] [semiring S] +lemma map_vec_mul [non_assoc_semiring R] [non_assoc_semiring S] (f : R →+* S) (M : matrix n m R) (v : n → R) (i : m) : f (M.vec_mul v i) = ((M.map f).vec_mul (f ∘ v) i) := by simp only [matrix.vec_mul, matrix.map_apply, ring_hom.map_dot_product] -lemma map_mul_vec [semiring R] [semiring S] +lemma map_mul_vec [non_assoc_semiring R] [non_assoc_semiring S] (f : R →+* S) (M : matrix m n R) (v : n → R) (i : m) : f (M.mul_vec v i) = ((M.map f).mul_vec (f ∘ v) i) := by simp only [matrix.mul_vec, matrix.map_apply, ring_hom.map_dot_product]