From c3c7c3495ce4fc62748d860cc7619c3cac38822b Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Wed, 7 Apr 2021 15:23:46 +0000 Subject: [PATCH] feat(data/matrix/basic): dependently-typed block diagonal (#7068) This allows constructing block diagonal matrices whose blocks are different sizes. A notable example of such a matrix is the one from the Jordan Normal Form. This duplicates all of the API for `block_diagonal` from this file. Most of the proofs copy across cleanly, but the proof for `block_diagonal_mul'` required lots of hand-holding that `simp` could solve by itself for the non-dependent case. --- src/data/matrix/basic.lean | 131 ++++++++++++++++++++++++++++++++++--- 1 file changed, 122 insertions(+), 9 deletions(-) diff --git a/src/data/matrix/basic.lean b/src/data/matrix/basic.lean index 48d050229e9f3..009f0ea0dfe24 100644 --- a/src/data/matrix/basic.lean +++ b/src/data/matrix/basic.lean @@ -25,6 +25,8 @@ def matrix (m : Type u) (n : Type u') [fintype m] [fintype n] (α : Type v) : Ty m → n → α variables {l m n o : Type*} [fintype l] [fintype m] [fintype n] [fintype o] +variables {m' : o → Type*} [∀ i, fintype (m' i)] +variables {n' : o → Type*} [∀ i, fintype (n' i)] variables {α : Type v} namespace matrix @@ -1123,9 +1125,12 @@ section has_zero variables [has_zero α] -/-- `matrix.block_diagonal M` turns `M : o → matrix m n α'` into a -`m × o`-by`n × o` block matrix which has the entries of `M` along the diagonal -and zero elsewhere. -/ +/-- `matrix.block_diagonal M` turns a homogenously-indexed collection of matrices +`M : o → matrix m n α'` into a `m × o`-by-`n × o` block matrix which has the entries of `M` along +the diagonal and zero elsewhere. + +See also `matrix.block_diagonal'` if the matrices may not have the same size everywhere. +-/ def block_diagonal : matrix (m × o) (n × o) α | ⟨i, k⟩ ⟨j, k'⟩ := if k = k' then M k i j else 0 @@ -1143,7 +1148,7 @@ lemma block_diagonal_apply_ne (i j) {k k'} (h : k ≠ k') : if_neg h @[simp] lemma block_diagonal_transpose : - (block_diagonal M)ᵀ = (block_diagonal (λ k, (M k)ᵀ)) := + (block_diagonal M)ᵀ = block_diagonal (λ k, (M k)ᵀ) := begin ext, simp only [transpose_apply, block_diagonal_apply, eq_comm], @@ -1157,7 +1162,7 @@ end by { ext, simp [block_diagonal_apply] } @[simp] lemma block_diagonal_diagonal [decidable_eq m] (d : o → m → α) : - (block_diagonal (λ k, diagonal (d k))) = diagonal (λ ik, d ik.2 ik.1) := + block_diagonal (λ k, diagonal (d k)) = diagonal (λ ik, d ik.2 ik.1) := begin ext ⟨i, k⟩ ⟨j, k'⟩, simp only [block_diagonal_apply, diagonal], @@ -1165,8 +1170,8 @@ begin end @[simp] lemma block_diagonal_one [decidable_eq m] [has_one α] : - (block_diagonal (1 : o → matrix m m α)) = 1 := -show (block_diagonal (λ (_ : o), diagonal (λ (_ : m), (1 : α)))) = diagonal (λ _, 1), + block_diagonal (1 : o → matrix m m α) = 1 := +show block_diagonal (λ (_ : o), diagonal (λ (_ : m), (1 : α))) = diagonal (λ _, 1), by rw [block_diagonal_diagonal] end has_zero @@ -1191,8 +1196,8 @@ end block_diagonal (M - N) = block_diagonal M - block_diagonal N := by simp [sub_eq_add_neg] -@[simp] lemma block_diagonal_mul {p : Type*} [fintype p] [semiring α] - (N : o → matrix n p α) : block_diagonal (λ k, M k ⬝ N k) = block_diagonal M ⬝ block_diagonal N := +@[simp] lemma block_diagonal_mul {p : Type*} [fintype p] [semiring α] (N : o → matrix n p α) : + block_diagonal (λ k, M k ⬝ N k) = block_diagonal M ⬝ block_diagonal N := begin ext ⟨i, k⟩ ⟨j, k'⟩, simp only [block_diagonal_apply, mul_apply, ← finset.univ_product_univ, finset.sum_product], @@ -1205,6 +1210,114 @@ by { ext, simp only [block_diagonal_apply, pi.smul_apply, smul_apply], split_ifs end block_diagonal +section block_diagonal' + +variables (M N : Π i, matrix (m' i) (n' i) α) [decidable_eq o] + +section has_zero + +variables [has_zero α] + +/-- `matrix.block_diagonal' M` turns `M : Π i, matrix (m i) (n i) α` into a +`Σ i, m i`-by-`Σ i, n i` block matrix which has the entries of `M` along the diagonal +and zero elsewhere. + +This is the dependently-typed version of `matrix.block_diagonal`. -/ +def block_diagonal' : matrix (Σ i, m' i) (Σ i, n' i) α +| ⟨k, i⟩ ⟨k', j⟩ := if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0 + +lemma block_diagonal'_eq_block_diagonal (M : o → matrix m n α) {k k'} (i j) : + block_diagonal M (i, k) (j, k') = block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ := +rfl + +lemma block_diagonal'_minor_eq_block_diagonal (M : o → matrix m n α) : + (block_diagonal' M).minor (prod.to_sigma ∘ prod.swap) (prod.to_sigma ∘ prod.swap) = + block_diagonal M := +matrix.ext $ λ ⟨k, i⟩ ⟨k', j⟩, rfl + +lemma block_diagonal'_apply (ik jk) : + block_diagonal' M ik jk = if h : ik.1 = jk.1 then + M ik.1 ik.2 (cast (congr_arg n' h.symm) jk.2) else 0 := +by { cases ik, cases jk, refl } + +@[simp] +lemma block_diagonal'_apply_eq (k i j) : + block_diagonal' M ⟨k, i⟩ ⟨k, j⟩ = M k i j := +dif_pos rfl + +lemma block_diagonal'_apply_ne {k k'} (i j) (h : k ≠ k') : + block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ = 0 := +dif_neg h + +@[simp] lemma block_diagonal'_transpose : + (block_diagonal' M)ᵀ = block_diagonal' (λ k, (M k)ᵀ) := +begin + ext ⟨ii, ix⟩ ⟨ji, jx⟩, + simp only [transpose_apply, block_diagonal'_apply, eq_comm], + dsimp only, + split_ifs with h₁ h₂ h₂, + { subst h₁, refl, }, + { exact (h₂ h₁.symm).elim }, + { exact (h₁ h₂.symm).elim }, + { refl } +end + +@[simp] lemma block_diagonal'_zero : + block_diagonal' (0 : Π i, matrix (m' i) (n' i) α) = 0 := +by { ext, simp [block_diagonal'_apply] } + +@[simp] lemma block_diagonal'_diagonal [∀ i, decidable_eq (m' i)] (d : Π i, m' i → α) : + block_diagonal' (λ k, diagonal (d k)) = diagonal (λ ik, d ik.1 ik.2) := +begin + ext ⟨i, k⟩ ⟨j, k'⟩, + simp only [block_diagonal'_apply, diagonal], + split_ifs; finish +end + +@[simp] lemma block_diagonal'_one [∀ i, decidable_eq (m' i)] [has_one α] : + block_diagonal' (1 : Π i, matrix (m' i) (m' i) α) = 1 := +show block_diagonal' (λ (i : o), diagonal (λ (_ : m' i), (1 : α))) = diagonal (λ _, 1), +by rw [block_diagonal'_diagonal] + +end has_zero + +@[simp] lemma block_diagonal'_add [add_monoid α] : + block_diagonal' (M + N) = block_diagonal' M + block_diagonal' N := +begin + ext, + simp only [block_diagonal'_apply, add_apply], + split_ifs; simp +end + +@[simp] lemma block_diagonal'_neg [add_group α] : + block_diagonal' (-M) = - block_diagonal' M := +begin + ext, + simp only [block_diagonal'_apply, neg_apply], + split_ifs; simp +end + +@[simp] lemma block_diagonal'_sub [add_group α] : + block_diagonal' (M - N) = block_diagonal' M - block_diagonal' N := +by simp [sub_eq_add_neg] + +@[simp] lemma block_diagonal'_mul {p : o → Type*} [Π i, fintype (p i)] [semiring α] + (N : Π i, matrix (n' i) (p i) α) : + block_diagonal' (λ k, M k ⬝ N k) = block_diagonal' M ⬝ block_diagonal' N := +begin + ext ⟨k, i⟩ ⟨k', j⟩, + simp only [block_diagonal'_apply, mul_apply, ← finset.univ_sigma_univ, finset.sum_sigma], + rw fintype.sum_eq_single k, + { split_ifs; simp }, + { intros j' hj', exact finset.sum_eq_zero (λ _ _, by rw [dif_neg hj'.symm, zero_mul]) }, +end + +@[simp] lemma block_diagonal'_smul {R : Type*} [semiring R] [add_comm_monoid α] [semimodule R α] + (x : R) : block_diagonal' (x • M) = x • block_diagonal' M := +by { ext, simp only [block_diagonal'_apply, pi.smul_apply, smul_apply], split_ifs; simp } + +end block_diagonal' + end matrix namespace ring_hom