Skip to content

Commit

Permalink
feat(data/matrix/basic): dependently-typed block diagonal (#7068)
Browse files Browse the repository at this point in the history
This allows constructing block diagonal matrices whose blocks are different sizes. A notable example of such a matrix is the one from the Jordan Normal Form.

This duplicates all of the API for `block_diagonal` from this file. Most of the proofs copy across cleanly, but the proof for `block_diagonal_mul'` required lots of hand-holding that `simp` could solve by itself for the non-dependent case.
  • Loading branch information
eric-wieser committed Apr 7, 2021
1 parent 8459d0a commit c3c7c34
Showing 1 changed file with 122 additions and 9 deletions.
131 changes: 122 additions & 9 deletions src/data/matrix/basic.lean
Expand Up @@ -25,6 +25,8 @@ def matrix (m : Type u) (n : Type u') [fintype m] [fintype n] (α : Type v) : Ty
m → n → α

variables {l m n o : Type*} [fintype l] [fintype m] [fintype n] [fintype o]
variables {m' : o → Type*} [∀ i, fintype (m' i)]
variables {n' : o → Type*} [∀ i, fintype (n' i)]
variables {α : Type v}

namespace matrix
Expand Down Expand Up @@ -1123,9 +1125,12 @@ section has_zero

variables [has_zero α]

/-- `matrix.block_diagonal M` turns `M : o → matrix m n α'` into a
`m × o`-by`n × o` block matrix which has the entries of `M` along the diagonal
and zero elsewhere. -/
/-- `matrix.block_diagonal M` turns a homogenously-indexed collection of matrices
`M : o → matrix m n α'` into a `m × o`-by-`n × o` block matrix which has the entries of `M` along
the diagonal and zero elsewhere.
See also `matrix.block_diagonal'` if the matrices may not have the same size everywhere.
-/
def block_diagonal : matrix (m × o) (n × o) α
| ⟨i, k⟩ ⟨j, k'⟩ := if k = k' then M k i j else 0

Expand All @@ -1143,7 +1148,7 @@ lemma block_diagonal_apply_ne (i j) {k k'} (h : k ≠ k') :
if_neg h

@[simp] lemma block_diagonal_transpose :
(block_diagonal M)ᵀ = (block_diagonal (λ k, (M k)ᵀ)) :=
(block_diagonal M)ᵀ = block_diagonal (λ k, (M k)ᵀ) :=
begin
ext,
simp only [transpose_apply, block_diagonal_apply, eq_comm],
Expand All @@ -1157,16 +1162,16 @@ end
by { ext, simp [block_diagonal_apply] }

@[simp] lemma block_diagonal_diagonal [decidable_eq m] (d : o → m → α) :
(block_diagonal (λ k, diagonal (d k))) = diagonal (λ ik, d ik.2 ik.1) :=
block_diagonal (λ k, diagonal (d k)) = diagonal (λ ik, d ik.2 ik.1) :=
begin
ext ⟨i, k⟩ ⟨j, k'⟩,
simp only [block_diagonal_apply, diagonal],
split_ifs; finish
end

@[simp] lemma block_diagonal_one [decidable_eq m] [has_one α] :
(block_diagonal (1 : o → matrix m m α)) = 1 :=
show (block_diagonal (λ (_ : o), diagonal (λ (_ : m), (1 : α)))) = diagonal (λ _, 1),
block_diagonal (1 : o → matrix m m α) = 1 :=
show block_diagonal (λ (_ : o), diagonal (λ (_ : m), (1 : α))) = diagonal (λ _, 1),
by rw [block_diagonal_diagonal]

end has_zero
Expand All @@ -1191,8 +1196,8 @@ end
block_diagonal (M - N) = block_diagonal M - block_diagonal N :=
by simp [sub_eq_add_neg]

@[simp] lemma block_diagonal_mul {p : Type*} [fintype p] [semiring α]
(N : o → matrix n p α) : block_diagonal (λ k, M k ⬝ N k) = block_diagonal M ⬝ block_diagonal N :=
@[simp] lemma block_diagonal_mul {p : Type*} [fintype p] [semiring α] (N : o → matrix n p α) :
block_diagonal (λ k, M k ⬝ N k) = block_diagonal M ⬝ block_diagonal N :=
begin
ext ⟨i, k⟩ ⟨j, k'⟩,
simp only [block_diagonal_apply, mul_apply, ← finset.univ_product_univ, finset.sum_product],
Expand All @@ -1205,6 +1210,114 @@ by { ext, simp only [block_diagonal_apply, pi.smul_apply, smul_apply], split_ifs

end block_diagonal

section block_diagonal'

variables (M N : Π i, matrix (m' i) (n' i) α) [decidable_eq o]

section has_zero

variables [has_zero α]

/-- `matrix.block_diagonal' M` turns `M : Π i, matrix (m i) (n i) α` into a
`Σ i, m i`-by-`Σ i, n i` block matrix which has the entries of `M` along the diagonal
and zero elsewhere.
This is the dependently-typed version of `matrix.block_diagonal`. -/
def block_diagonal' : matrix (Σ i, m' i) (Σ i, n' i) α
| ⟨k, i⟩ ⟨k', j⟩ := if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0

lemma block_diagonal'_eq_block_diagonal (M : o → matrix m n α) {k k'} (i j) :
block_diagonal M (i, k) (j, k') = block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ :=
rfl

lemma block_diagonal'_minor_eq_block_diagonal (M : o → matrix m n α) :
(block_diagonal' M).minor (prod.to_sigma ∘ prod.swap) (prod.to_sigma ∘ prod.swap) =
block_diagonal M :=
matrix.ext $ λ ⟨k, i⟩ ⟨k', j⟩, rfl

lemma block_diagonal'_apply (ik jk) :
block_diagonal' M ik jk = if h : ik.1 = jk.1 then
M ik.1 ik.2 (cast (congr_arg n' h.symm) jk.2) else 0 :=
by { cases ik, cases jk, refl }

@[simp]
lemma block_diagonal'_apply_eq (k i j) :
block_diagonal' M ⟨k, i⟩ ⟨k, j⟩ = M k i j :=
dif_pos rfl

lemma block_diagonal'_apply_ne {k k'} (i j) (h : k ≠ k') :
block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ = 0 :=
dif_neg h

@[simp] lemma block_diagonal'_transpose :
(block_diagonal' M)ᵀ = block_diagonal' (λ k, (M k)ᵀ) :=
begin
ext ⟨ii, ix⟩ ⟨ji, jx⟩,
simp only [transpose_apply, block_diagonal'_apply, eq_comm],
dsimp only,
split_ifs with h₁ h₂ h₂,
{ subst h₁, refl, },
{ exact (h₂ h₁.symm).elim },
{ exact (h₁ h₂.symm).elim },
{ refl }
end

@[simp] lemma block_diagonal'_zero :
block_diagonal' (0 : Π i, matrix (m' i) (n' i) α) = 0 :=
by { ext, simp [block_diagonal'_apply] }

@[simp] lemma block_diagonal'_diagonal [∀ i, decidable_eq (m' i)] (d : Π i, m' i → α) :
block_diagonal' (λ k, diagonal (d k)) = diagonal (λ ik, d ik.1 ik.2) :=
begin
ext ⟨i, k⟩ ⟨j, k'⟩,
simp only [block_diagonal'_apply, diagonal],
split_ifs; finish
end

@[simp] lemma block_diagonal'_one [∀ i, decidable_eq (m' i)] [has_one α] :
block_diagonal' (1 : Π i, matrix (m' i) (m' i) α) = 1 :=
show block_diagonal' (λ (i : o), diagonal (λ (_ : m' i), (1 : α))) = diagonal (λ _, 1),
by rw [block_diagonal'_diagonal]

end has_zero

@[simp] lemma block_diagonal'_add [add_monoid α] :
block_diagonal' (M + N) = block_diagonal' M + block_diagonal' N :=
begin
ext,
simp only [block_diagonal'_apply, add_apply],
split_ifs; simp
end

@[simp] lemma block_diagonal'_neg [add_group α] :
block_diagonal' (-M) = - block_diagonal' M :=
begin
ext,
simp only [block_diagonal'_apply, neg_apply],
split_ifs; simp
end

@[simp] lemma block_diagonal'_sub [add_group α] :
block_diagonal' (M - N) = block_diagonal' M - block_diagonal' N :=
by simp [sub_eq_add_neg]

@[simp] lemma block_diagonal'_mul {p : o → Type*} [Π i, fintype (p i)] [semiring α]
(N : Π i, matrix (n' i) (p i) α) :
block_diagonal' (λ k, M k ⬝ N k) = block_diagonal' M ⬝ block_diagonal' N :=
begin
ext ⟨k, i⟩ ⟨k', j⟩,
simp only [block_diagonal'_apply, mul_apply, ← finset.univ_sigma_univ, finset.sum_sigma],
rw fintype.sum_eq_single k,
{ split_ifs; simp },
{ intros j' hj', exact finset.sum_eq_zero (λ _ _, by rw [dif_neg hj'.symm, zero_mul]) },
end

@[simp] lemma block_diagonal'_smul {R : Type*} [semiring R] [add_comm_monoid α] [semimodule R α]
(x : R) : block_diagonal' (x • M) = x • block_diagonal' M :=
by { ext, simp only [block_diagonal'_apply, pi.smul_apply, smul_apply], split_ifs; simp }

end block_diagonal'

end matrix

namespace ring_hom
Expand Down

0 comments on commit c3c7c34

Please sign in to comment.