Skip to content

Commit

Permalink
refactor(linear_algebra/trace): unbundle matrix.trace (#13712)
Browse files Browse the repository at this point in the history
These extra type arguments are annoying to work with in many cases, especially when Lean doesn't have any information to infer the mostly-irrelevant `R` argument from. This came up while trying to work with `continuous.matrix_trace`, which is annoying to use for that reason.
The old bundled version is still available as `matrix.trace_linear_map`.

The cost of this change is that we have to copy across the usual set of obvious lemmas about additive maps; but we already do this for `diagonal`, `transpose` etc anyway.
  • Loading branch information
eric-wieser committed May 2, 2022
1 parent a627569 commit 320df45
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 77 deletions.
10 changes: 5 additions & 5 deletions src/algebra/lie/classical.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ variables [decidable_eq n] [decidable_eq p] [decidable_eq q] [decidable_eq l]
variables [comm_ring R]

@[simp] lemma matrix_trace_commutator_zero [fintype n] (X Y : matrix n n R) :
matrix.trace n R R ⁅X, Y⁆ = 0 :=
calc _ = matrix.trace n R R (X ⬝ Y) - matrix.trace n R R (Y ⬝ X) : linear_map.map_sub _ _ _
... = matrix.trace n R R (X ⬝ Y) - matrix.trace n R R (X ⬝ Y) :
matrix.trace ⁅X, Y⁆ = 0 :=
calc _ = matrix.trace (X ⬝ Y) - matrix.trace (Y ⬝ X) : trace_sub _ _
... = matrix.trace (X ⬝ Y) - matrix.trace (X ⬝ Y) :
congr_arg (λ x, _ - x) (matrix.trace_mul_comm Y X)
... = 0 : sub_self _

Expand All @@ -85,7 +85,7 @@ namespace special_linear
/-- The special linear Lie algebra: square matrices of trace zero. -/
def sl [fintype n] : lie_subalgebra R (matrix n n R) :=
{ lie_mem' := λ X Y _ _, linear_map.mem_ker.2 $ matrix_trace_commutator_zero _ _ _ _,
..linear_map.ker (matrix.trace n R R) }
..linear_map.ker (matrix.trace_linear_map n R R) }

lemma sl_bracket [fintype n] (A B : sl n R) : ⁅A, B⁆.val = A.val ⬝ B.val - B.val ⬝ A.val := rfl

Expand All @@ -97,7 +97,7 @@ variables {n} [fintype n] (i j : n)
basis of sl n R. -/
def Eb (h : j ≠ i) : sl n R :=
⟨matrix.std_basis_matrix i j (1 : R),
show matrix.std_basis_matrix i j (1 : R) ∈ linear_map.ker (matrix.trace n R R),
show matrix.std_basis_matrix i j (1 : R) ∈ linear_map.ker (matrix.trace_linear_map n R R),
from matrix.std_basis_matrix.trace_zero i j (1 : R) h⟩

@[simp] lemma Eb_val (h : j ≠ i) : (Eb R i j h).val = matrix.std_basis_matrix i j 1 := rfl
Expand Down
6 changes: 3 additions & 3 deletions src/combinatorics/simple_graph/adj_matrix.lean
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ by simp [mul_apply, neighbor_finset_eq_filter, sum_filter, adj_comm]

variable (α)

theorem trace_adj_matrix [non_assoc_semiring α] [semiring β] [module β α]:
matrix.trace _ β _ (G.adj_matrix α) = 0 :=
by simp
@[simp] theorem trace_adj_matrix [add_comm_monoid α] [has_one α] :
matrix.trace (G.adj_matrix α) = 0 :=
by simp [matrix.trace]

variable {α}

Expand Down
12 changes: 12 additions & 0 deletions src/data/matrix/basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,18 @@ lemma diag_map {f : α → β} {A : matrix n n α} : diag (A.map f) = f ∘ diag
@[simp] lemma diag_conj_transpose [add_monoid α] [star_add_monoid α] (A : matrix n n α) :
diag Aᴴ = star (diag A) := rfl

@[simp] lemma diag_list_sum [add_monoid α] (l : list (matrix n n α)) :
diag l.sum = (l.map diag).sum :=
map_list_sum (diag_add_monoid_hom n α) l

@[simp] lemma diag_multiset_sum [add_comm_monoid α] (s : multiset (matrix n n α)) :
diag s.sum = (s.map diag).sum :=
map_multiset_sum (diag_add_monoid_hom n α) s

@[simp] lemma diag_sum {ι} [add_comm_monoid α] (s : finset ι) (f : ι → matrix n n α) :
diag (∑ i in s, f i) = ∑ i in s, diag (f i) :=
map_sum (diag_add_monoid_hom n α) f s

end diag

section dot_product
Expand Down
4 changes: 3 additions & 1 deletion src/data/matrix/basis.lean
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ by { ext j, by_cases hij : i = j; try {rw hij}; simp [hij] }

variable [fintype n]

lemma trace_zero (h : j ≠ i) : trace n α α (std_basis_matrix i j c) = 0 := by simp [h]
@[simp] lemma trace_zero (h : j ≠ i) : trace (std_basis_matrix i j c) = 0 := by simp [trace, h]

@[simp] lemma trace_eq : trace (std_basis_matrix i i c) = c := by simp [trace]

@[simp] lemma mul_left_apply_same (b : n) (M : matrix n n α) :
(std_basis_matrix i j c ⬝ M) i b = c * M j b :=
Expand Down
4 changes: 2 additions & 2 deletions src/data/matrix/hadamard.lean
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ section trace
variables [fintype m] [fintype n]
variables (R) [semiring α] [semiring R] [module R α]

lemma sum_hadamard_eq : ∑ (i : m) (j : n), (A ⊙ B) i j = trace m R α (A ⬝ Bᵀ) :=
lemma sum_hadamard_eq : ∑ (i : m) (j : n), (A ⊙ B) i j = trace (A ⬝ Bᵀ) :=
rfl

lemma dot_product_vec_mul_hadamard [decidable_eq m] [decidable_eq n] (v : m → α) (w : n → α) :
dot_product (vec_mul v (A ⊙ B)) w = trace m R α (diagonal v ⬝ A ⬝ (B ⬝ diagonal w)ᵀ) :=
dot_product (vec_mul v (A ⊙ B)) w = trace (diagonal v ⬝ A ⬝ (B ⬝ diagonal w)ᵀ) :=
begin
rw [←sum_hadamard_eq, finset.sum_comm],
simp [dot_product, vec_mul, finset.sum_mul, mul_assoc],
Expand Down
15 changes: 8 additions & 7 deletions src/linear_algebra/matrix/charpoly/coeff.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ begin
end

theorem trace_eq_neg_charpoly_coeff [nonempty n] (M : matrix n n R) :
(trace n R R) M = -M.charpoly.coeff (fintype.card n - 1) :=
trace M = -M.charpoly.coeff (fintype.card n - 1) :=
begin
rw charpoly_coeff_eq_prod_coeff_of_le, swap, refl,
rw [fintype.card, prod_X_sub_C_coeff_card_pred univ (λ i : n, M i i)], simp,
rw [← fintype.card, fintype.card_pos_iff], apply_instance,
rw [fintype.card, prod_X_sub_C_coeff_card_pred univ (λ i : n, M i i) fintype.card_pos, neg_neg,
trace],
refl
end

-- I feel like this should use polynomial.alg_hom_eval₂_algebra_map
Expand Down Expand Up @@ -209,16 +210,16 @@ end
by { have h := finite_field.matrix.charpoly_pow_card M, rwa zmod.card at h, }

lemma finite_field.trace_pow_card {K : Type*} [field K] [fintype K]
(M : matrix n n K) : trace n K K (M ^ (fintype.card K)) = (trace n K K M) ^ (fintype.card K) :=
(M : matrix n n K) : trace (M ^ (fintype.card K)) = trace M ^ (fintype.card K) :=
begin
casesI is_empty_or_nonempty n,
{ simp [zero_pow fintype.card_pos], },
{ simp [zero_pow fintype.card_pos, matrix.trace], },
rw [matrix.trace_eq_neg_charpoly_coeff, matrix.trace_eq_neg_charpoly_coeff,
finite_field.matrix.charpoly_pow_card, finite_field.pow_card]
end

lemma zmod.trace_pow_card {p:ℕ} [fact p.prime] (M : matrix n n (zmod p)) :
trace n (zmod p) (zmod p) (M ^ p) = (trace n (zmod p) (zmod p) M)^p :=
lemma zmod.trace_pow_card {p : ℕ} [fact p.prime] (M : matrix n n (zmod p)) :
trace (M ^ p) = (trace M)^p :=
by { have h := finite_field.trace_pow_card M, rwa zmod.card at h, }

namespace matrix
Expand Down
128 changes: 87 additions & 41 deletions src/linear_algebra/matrix/trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import data.matrix.basic
/-!
# Trace of a matrix
This file defines the trace of a matrix, the linear map
sending a matrix to the sum of its diagonal entries.
This file defines the trace of a matrix, the map sending a matrix to the sum of its diagonal
entries.
See also `linear_algebra.trace` for the trace of an endomorphism.
Expand All @@ -19,79 +19,125 @@ matrix, trace, diagonal
-/

open_locale big_operators
open_locale matrix
open_locale big_operators matrix

namespace matrix

section trace
variables {ι m n p : Type*} {α R S : Type*}
variables [fintype m] [fintype n] [fintype p]

universes u v w
section add_comm_monoid
variables [add_comm_monoid R]

variables {m : Type*} (n : Type*) {p : Type*}
variables (R : Type*) (M : Type*) [semiring R] [add_comm_monoid M] [module R M]
/-- The trace of a square matrix. For more bundled versions, see:
* `matrix.trace_add_monoid_hom`
* `matrix.trace_linear_map`
-/
def trace (A : matrix n n R) : R := ∑ i, diag A i

variables (n) (R) (M)
variables (n R)
@[simp] lemma trace_zero : trace (0 : matrix n n R) = 0 :=
(finset.sum_const (0 : R)).trans $ smul_zero _
variables {n R}

/--
The trace of a square matrix.
-/
def trace [fintype n] : (matrix n n M) →ₗ[R] M :=
{ to_fun := λ A, ∑ i, diag A i,
map_add' := by { intros, apply finset.sum_add_distrib, },
map_smul' := by { intros, simp [finset.smul_sum], } }
@[simp] lemma trace_add (A B : matrix n n R) : trace (A + B) = trace A + trace B :=
finset.sum_add_distrib

@[simp] lemma trace_smul [monoid α] [distrib_mul_action α R] (r : α) (A : matrix n n R) :
trace (r • A) = r • trace A :=
finset.smul_sum.symm

@[simp] lemma trace_transpose (A : matrix n n R) : trace Aᵀ = trace A := rfl

variables (n α R)
/-- `matrix.trace` as an `add_monoid_hom` -/
@[simps]
def trace_add_monoid_hom : matrix n n R →+ R :=
{ to_fun := trace, map_zero' := trace_zero n R, map_add' := trace_add }

/-- `matrix.trace` as a `linear_map` -/
@[simps]
def trace_linear_map [semiring α] [module α R] : matrix n n R →ₗ[α] R :=
{ to_fun := trace, map_add' := trace_add, map_smul' := trace_smul }
variables {n α R}

@[simp] lemma trace_list_sum (l : list (matrix n n R)) : trace l.sum = (l.map trace).sum :=
map_list_sum (trace_add_monoid_hom n R) l

@[simp] lemma trace_multiset_sum (s : multiset (matrix n n R)) : trace s.sum = (s.map trace).sum :=
map_multiset_sum (trace_add_monoid_hom n R) s

variables {n} {R} {M} [fintype n] [fintype m] [fintype p]
@[simp] lemma trace_sum (s : finset ι) (f : ι → matrix n n R) :
trace (∑ i in s, f i) = ∑ i in s, trace (f i) :=
map_sum (trace_add_monoid_hom n R) f s

@[simp] lemma trace_diag (A : matrix n n M) : trace n R M A = ∑ i, diag A i := rfl
end add_comm_monoid

lemma trace_apply (A : matrix n n M) : trace n R M A = ∑ i, A i i := rfl
section add_comm_group
variables [add_comm_group R]

@[simp] lemma trace_one [decidable_eq n] :
trace n R R 1 = fintype.card n :=
have h : trace n R R 1 = ∑ i, diag 1 i := rfl,
by simp_rw [h, diag_one, pi.one_def, finset.sum_const, nsmul_one]; refl
@[simp] lemma trace_sub (A B : matrix n n R) : trace (A - B) = trace A - trace B :=
finset.sum_sub_distrib

@[simp] lemma trace_transpose (A : matrix n n M) : trace n R M Aᵀ = trace n R M A := rfl
@[simp] lemma trace_neg (A : matrix n n R) : trace (-A) = -trace A :=
finset.sum_neg_distrib

@[simp] lemma trace_transpose_mul (A : matrix m n R) (B : matrix n m R) :
trace n R R (Aᵀ ⬝ Bᵀ) = trace m R R (A ⬝ B) := finset.sum_comm
end add_comm_group

lemma trace_mul_comm {S : Type v} [comm_semiring S] (A : matrix m n S) (B : matrix n m S) :
trace m S S (A ⬝ B) = trace n S S (B ⬝ A) :=
section one
variables [decidable_eq n] [add_comm_monoid R] [has_one R]

@[simp] lemma trace_one : trace (1 : matrix n n R) = fintype.card n :=
by simp_rw [trace, diag_one, pi.one_def, finset.sum_const, nsmul_one, finset.card_univ]

end one

section mul

@[simp] lemma trace_transpose_mul [add_comm_monoid R] [has_mul R]
(A : matrix m n R) (B : matrix n m R) : trace (Aᵀ ⬝ Bᵀ) = trace (A ⬝ B) := finset.sum_comm

lemma trace_mul_comm [add_comm_monoid R] [comm_semigroup R] (A : matrix m n R) (B : matrix n m R) :
trace (A ⬝ B) = trace (B ⬝ A) :=
by rw [←trace_transpose, ←trace_transpose_mul, transpose_mul]

lemma trace_mul_cycle {S : Type v} [comm_semiring S]
(A : matrix m n S) (B : matrix n p S) (C : matrix p m S) :
trace _ S S (A ⬝ B ⬝ C) = trace p S S (C ⬝ A ⬝ B) :=
lemma trace_mul_cycle [non_unital_comm_semiring R]
(A : matrix m n R) (B : matrix n p R) (C : matrix p m R) :
trace (A ⬝ B ⬝ C) = trace (C ⬝ A ⬝ B) :=
by rw [trace_mul_comm, matrix.mul_assoc]

lemma trace_mul_cycle' {S : Type v} [comm_semiring S]
(A : matrix m n S) (B : matrix n p S) (C : matrix p m S) :
trace _ S S (A ⬝ (B ⬝ C)) = trace p S S (C ⬝ (A ⬝ B)) :=
lemma trace_mul_cycle' [non_unital_comm_semiring R]
(A : matrix m n R) (B : matrix n p R) (C : matrix p m R) :
trace (A ⬝ (B ⬝ C)) = trace (C ⬝ (A ⬝ B)) :=
by rw [←matrix.mul_assoc, trace_mul_comm]

@[simp] lemma trace_col_mul_row (a b : n → R) : trace n R R (col a ⬝ row b) = dot_product a b :=
by simp [dot_product]
@[simp] lemma trace_col_mul_row [non_unital_non_assoc_semiring R] (a b : n → R) :
trace (col a ⬝ row b) = dot_product a b :=
by simp [dot_product, trace]

end mul

section fin
variables [add_comm_monoid R]

/-! ### Special cases for `fin n`
While `simp [fin.sum_univ_succ]` can prove these, we include them for convenience and consistency
with `matrix.det_fin_two` etc.
-/

@[simp] lemma trace_fin_zero (A : matrix (fin 0) (fin 0) R) : trace _ R R A = 0 :=
@[simp] lemma trace_fin_zero (A : matrix (fin 0) (fin 0) R) : trace A = 0 :=
rfl

lemma trace_fin_one (A : matrix (fin 1) (fin 1) R) : trace _ R R A = A 0 0 :=
lemma trace_fin_one (A : matrix (fin 1) (fin 1) R) : trace A = A 0 0 :=
add_zero _

lemma trace_fin_two (A : matrix (fin 2) (fin 2) R) : trace _ R R A = A 0 0 + A 1 1 :=
lemma trace_fin_two (A : matrix (fin 2) (fin 2) R) : trace A = A 0 0 + A 1 1 :=
congr_arg ((+) _) (add_zero (A 1 1))

lemma trace_fin_three (A : matrix (fin 3) (fin 3) R) : trace _ R R A = A 0 0 + A 1 1 + A 2 2 :=
lemma trace_fin_three (A : matrix (fin 3) (fin 3) R) : trace A = A 0 0 + A 1 1 + A 2 2 :=
by { rw [← add_zero (A 2 2), add_assoc], refl }

end trace
end fin

end matrix
22 changes: 11 additions & 11 deletions src/linear_algebra/trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ variables (b : basis ι R M) (c : basis κ R M)
/-- The trace of an endomorphism given a basis. -/
def trace_aux :
(M →ₗ[R] M) →ₗ[R] R :=
(matrix.trace ι R R) ∘ₗ ↑(linear_map.to_matrix b b)
(matrix.trace_linear_map ι R R) ∘ₗ ↑(linear_map.to_matrix b b)

-- Can't be `simp` because it would cause a loop.
lemma trace_aux_def (b : basis ι R M) (f : M →ₗ[R] M) :
trace_aux R b f = matrix.trace ι R R (linear_map.to_matrix b b f) :=
trace_aux R b f = matrix.trace (linear_map.to_matrix b b f) :=
rfl

theorem trace_aux_eq : trace_aux R b = trace_aux R c :=
linear_map.ext $ λ f,
calc matrix.trace ι R R (linear_map.to_matrix b b f)
= matrix.trace ι R R (linear_map.to_matrix b b ((linear_map.id.comp f).comp linear_map.id)) :
calc matrix.trace (linear_map.to_matrix b b f)
= matrix.trace (linear_map.to_matrix b b ((linear_map.id.comp f).comp linear_map.id)) :
by rw [linear_map.id_comp, linear_map.comp_id]
... = matrix.trace ι R R (linear_map.to_matrix c b linear_map.id ⬝
... = matrix.trace (linear_map.to_matrix c b linear_map.id ⬝
linear_map.to_matrix c c f ⬝
linear_map.to_matrix b c linear_map.id) :
by rw [linear_map.to_matrix_comp _ c, linear_map.to_matrix_comp _ c]
... = matrix.trace κ R R (linear_map.to_matrix c c f ⬝
... = matrix.trace (linear_map.to_matrix c c f ⬝
linear_map.to_matrix b c linear_map.id ⬝
linear_map.to_matrix c b linear_map.id) :
by rw [matrix.mul_assoc, matrix.trace_mul_comm]
... = matrix.trace κ R R (linear_map.to_matrix c c ((f.comp linear_map.id).comp linear_map.id)) :
... = matrix.trace (linear_map.to_matrix c c ((f.comp linear_map.id).comp linear_map.id)) :
by rw [linear_map.to_matrix_comp _ b, linear_map.to_matrix_comp _ c]
... = matrix.trace κ R R (linear_map.to_matrix c c f) :
... = matrix.trace (linear_map.to_matrix c c f) :
by rw [linear_map.comp_id, linear_map.comp_id]

open_locale classical
Expand All @@ -81,13 +81,13 @@ variables (R) {M}
/-- Auxiliary lemma for `trace_eq_matrix_trace`. -/
theorem trace_eq_matrix_trace_of_finset {s : finset M} (b : basis s R M)
(f : M →ₗ[R] M) :
trace R M f = matrix.trace s R R (linear_map.to_matrix b b f) :=
trace R M f = matrix.trace (linear_map.to_matrix b b f) :=
have ∃ (s : finset M), nonempty (basis s R M),
from ⟨s, ⟨b⟩⟩,
by { rw [trace, dif_pos this, ← trace_aux_def], congr' 1, apply trace_aux_eq }

theorem trace_eq_matrix_trace (f : M →ₗ[R] M) :
trace R M f = matrix.trace ι R R (linear_map.to_matrix b b f) :=
trace R M f = matrix.trace (linear_map.to_matrix b b f) :=
by rw [trace_eq_matrix_trace_of_finset R b.reindex_finset_range,
← trace_aux_def, ← trace_aux_def, trace_aux_eq R b]

Expand Down Expand Up @@ -127,7 +127,7 @@ begin
simp only [function.comp_app, basis.tensor_product_apply, basis.coe_dual_basis, coe_comp],
rw [trace_eq_matrix_trace R b, to_matrix_dual_tensor_hom],
by_cases hij : i = j,
{ rw [hij], simp},
{ rw [hij], simp },
rw matrix.std_basis_matrix.trace_zero j i (1:R) hij,
simp [finsupp.single_eq_pi_single, hij],
end
Expand Down
8 changes: 4 additions & 4 deletions src/ring_theory/trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ variables {R}

-- Can't be a `simp` lemma because it depends on a choice of basis
lemma trace_eq_matrix_trace [decidable_eq ι] (b : basis ι R S) (s : S) :
trace R S s = matrix.trace _ R _ (algebra.left_mul_matrix b s) :=
trace R S s = matrix.trace (algebra.left_mul_matrix b s) :=
by rw [trace_apply, linear_map.trace_eq_matrix_trace _ b, to_matrix_lmul_eq]

/-- If `x` is in the base field `K`, then the trace is `[L : K] * x`. -/
lemma trace_algebra_map_of_basis (x : R) :
trace R S (algebra_map R S x) = fintype.card ι • x :=
begin
haveI := classical.dec_eq ι,
rw [trace_apply, linear_map.trace_eq_matrix_trace R b, trace_diag],
rw [trace_apply, linear_map.trace_eq_matrix_trace R b, matrix.trace],
convert finset.sum_const _,
ext i,
simp,
Expand All @@ -133,10 +133,10 @@ begin
haveI := classical.dec_eq ι,
haveI := classical.dec_eq κ,
rw [trace_eq_matrix_trace (b.smul c), trace_eq_matrix_trace b, trace_eq_matrix_trace c,
matrix.trace_apply, matrix.trace_apply, matrix.trace_apply,
matrix.trace, matrix.trace, matrix.trace,
← finset.univ_product_univ, finset.sum_product],
refine finset.sum_congr rfl (λ i _, _),
simp only [alg_hom.map_sum, smul_left_mul_matrix, finset.sum_apply,
simp only [alg_hom.map_sum, smul_left_mul_matrix, finset.sum_apply, matrix.diag,
-- The unifier is not smart enough to apply this one by itself:
finset.sum_apply i _ (λ y, left_mul_matrix b (left_mul_matrix c x y y))]
end
Expand Down
Loading

0 comments on commit 320df45

Please sign in to comment.