Skip to content

Commit

Permalink
feat(data/matrix/block): matrix.block_diagonal is a ring homomorphi…
Browse files Browse the repository at this point in the history
…sm (#13489)

This is one of the steps on the path to showing that the matrix exponential of a block diagonal matrix is a block diagonal matrix of the exponents of the blocks.

As well as adding the new bundled homomorphisms, this generalizes the typeclasses in this file and tidies up the order of arguments.

Finally, this protects some `map_*` lemmas to prevent clashes with the global lemmas of the same name.
  • Loading branch information
eric-wieser committed Apr 19, 2022
1 parent eb22ba4 commit fb44330
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 60 deletions.
8 changes: 5 additions & 3 deletions src/data/matrix/basic.lean
Expand Up @@ -135,15 +135,17 @@ instance [monoid R] [add_monoid α] [distrib_mul_action R α] :
instance [semiring R] [add_comm_monoid α] [module R α] :
module R (matrix m n α) := pi.module _ _ _

@[simp] lemma map_zero [has_zero α] [has_zero β] (f : α → β) (h : f 0 = 0) :
@[simp] protected lemma map_zero [has_zero α] [has_zero β] (f : α → β) (h : f 0 = 0) :
(0 : matrix m n α).map f = 0 :=
by { ext, simp [h], }

lemma map_add [has_add α] [has_add β] (f : α → β) (hf : ∀ a₁ a₂, f (a₁ + a₂) = f a₁ + f a₂)
protected lemma map_add [has_add α] [has_add β] (f : α → β)
(hf : ∀ a₁ a₂, f (a₁ + a₂) = f a₁ + f a₂)
(M N : matrix m n α) : (M + N).map f = M.map f + N.map f :=
ext $ λ _ _, hf _ _

lemma map_sub [has_sub α] [has_sub β] (f : α → β) (hf : ∀ a₁ a₂, f (a₁ - a₂) = f a₁ - f a₂)
protected lemma map_sub [has_sub α] [has_sub β] (f : α → β)
(hf : ∀ a₁ a₂, f (a₁ - a₂) = f a₁ - f a₂)
(M N : matrix m n α) : (M - N).map f = M.map f - N.map f :=
ext $ λ _ _, hf _ _

Expand Down
154 changes: 97 additions & 57 deletions src/data/matrix/block.lean
Expand Up @@ -13,11 +13,13 @@ import data.matrix.basic
* `matrix.from_blocks`: build a block matrix out of 4 blocks
* `matrix.to_blocks₁₁`, `matrix.to_blocks₁₂`, `matrix.to_blocks₂₁`, `matrix.to_blocks₂₂`:
extract each of the four blocks from `matrix.from_blocks`.
* `matrix.block_diagonal`: block diagonal of equally sized blocks
* `matrix.block_diagonal'`: block diagonal of unequally sized blocks
* `matrix.block_diagonal`: block diagonal of equally sized blocks. On square blocks, this is a
ring homomorphisms, `matrix.block_diagonal_ring_hom`.
* `matrix.block_diagonal'`: block diagonal of unequally sized blocks. On square blocks, this is a
ring homomorphisms, `matrix.block_diagonal'_ring_hom`.
-/

variables {l m n o : Type*} {m' : o → Type*} {n' : o → Type*}
variables {l m n o p q : Type*} {m' n' p' : o → Type*}
variables {R : Type*} {S : Type*} {α : Type*} {β : Type*}

open_locale matrix
Expand Down Expand Up @@ -157,16 +159,14 @@ def to_square_block_prop (M : matrix m m α) (p : m → Prop) :
@[simp] lemma to_square_block_prop_def (M : matrix m m α) (p : m → Prop) :
to_square_block_prop M p = λ i j, M ↑i ↑j := rfl

variables [semiring α]

lemma from_blocks_smul
(x : α) (A : matrix n l α) (B : matrix n m α) (C : matrix o l α) (D : matrix o m α) :
lemma from_blocks_smul [has_scalar R α]
(x : R) (A : matrix n l α) (B : matrix n m α) (C : matrix o l α) (D : matrix o m α) :
x • (from_blocks A B C D) = from_blocks (x • A) (x • B) (x • C) (x • D) :=
begin
ext i j, rcases i; rcases j; simp [from_blocks],
end

lemma from_blocks_add
lemma from_blocks_add [has_add α]
(A : matrix n l α) (B : matrix n m α) (C : matrix o l α) (D : matrix o m α)
(A' : matrix n l α) (B' : matrix n m α) (C' : matrix o l α) (D' : matrix o m α) :
(from_blocks A B C D) + (from_blocks A' B' C' D') =
Expand All @@ -176,7 +176,7 @@ begin
ext i j, rcases i; rcases j; refl,
end

lemma from_blocks_multiply {p q : Type*} [fintype l] [fintype m]
lemma from_blocks_multiply [fintype l] [fintype m] [non_unital_non_assoc_semiring α]
(A : matrix n l α) (B : matrix n m α) (C : matrix o l α) (D : matrix o m α)
(A' : matrix l p α) (B' : matrix l q α) (C' : matrix m p α) (D' : matrix m q α) :
(from_blocks A B C D) ⬝ (from_blocks A' B' C' D') =
Expand All @@ -190,23 +190,22 @@ end

variables [decidable_eq l] [decidable_eq m]

@[simp] lemma from_blocks_diagonal (d₁ : l → α) (d₂ : m → α) :
@[simp] lemma from_blocks_diagonal [has_zero α] (d₁ : l → α) (d₂ : m → α) :
from_blocks (diagonal d₁) 0 0 (diagonal d₂) = diagonal (sum.elim d₁ d₂) :=
begin
ext i j, rcases i; rcases j; simp [diagonal],
end

@[simp] lemma from_blocks_one : from_blocks (1 : matrix l l α) 0 0 (1 : matrix m m α) = 1 :=
@[simp] lemma from_blocks_one [has_zero α] [has_one α] :
from_blocks (1 : matrix l l α) 0 0 (1 : matrix m m α) = 1 :=
by { ext i j, rcases i; rcases j; simp [one_apply] }

end block_matrices

section block_diagonal

variables (M N : o → matrix m n α) [decidable_eq o]
variables [decidable_eq o]

section has_zero

variables [has_zero α] [has_zero β]

/-- `matrix.block_diagonal M` turns a homogenously-indexed collection of matrices
Expand All @@ -215,31 +214,31 @@ the diagonal and zero elsewhere.
See also `matrix.block_diagonal'` if the matrices may not have the same size everywhere.
-/
def block_diagonal : matrix (m × o) (n × o) α
def block_diagonal (M : o → matrix m n α) : matrix (m × o) (n × o) α
| ⟨i, k⟩ ⟨j, k'⟩ := if k = k' then M k i j else 0

lemma block_diagonal_apply (ik jk) :
lemma block_diagonal_apply (M : o → matrix m n α) (ik jk) :
block_diagonal M ik jk = if ik.2 = jk.2 then M ik.2 ik.1 jk.1 else 0 :=
by { cases ik, cases jk, refl }

@[simp]
lemma block_diagonal_apply_eq (i j k) :
lemma block_diagonal_apply_eq (M : o → matrix m n α) (i j k) :
block_diagonal M (i, k) (j, k) = M k i j :=
if_pos rfl

lemma block_diagonal_apply_ne (i j) {k k'} (h : k ≠ k') :
lemma block_diagonal_apply_ne (M : o → matrix m n α) (i j) {k k'} (h : k ≠ k') :
block_diagonal M (i, k) (j, k') = 0 :=
if_neg h

lemma block_diagonal_map (f : α → β) (hf : f 0 = 0) :
lemma block_diagonal_map (M : o → matrix m n α) (f : α → β) (hf : f 0 = 0) :
(block_diagonal M).map f = block_diagonal (λ k, (M k).map f) :=
begin
ext,
simp only [map_apply, block_diagonal_apply, eq_comm],
rw [apply_ite f, hf],
end

@[simp] lemma block_diagonal_transpose :
@[simp] lemma block_diagonal_transpose (M : o → matrix m n α) :
(block_diagonal M)ᵀ = block_diagonal (λ k, (M k)ᵀ) :=
begin
ext,
Expand All @@ -250,7 +249,7 @@ begin
end

@[simp] lemma block_diagonal_conj_transpose
{α : Type*} [semiring α] [star_ring α] (M : o → matrix m n α) :
{α : Type*} [add_monoid α] [star_add_monoid α] (M : o → matrix m n α) :
(block_diagonal M)ᴴ = block_diagonal (λ k, (M k)ᴴ) :=
begin
simp only [conj_transpose, block_diagonal_transpose],
Expand All @@ -277,44 +276,62 @@ by rw [block_diagonal_diagonal]

end has_zero

@[simp] lemma block_diagonal_add [add_monoid α] :
@[simp] lemma block_diagonal_add [add_zero_class α] (M N : o → matrix m n α) :
block_diagonal (M + N) = block_diagonal M + block_diagonal N :=
begin
ext,
simp only [block_diagonal_apply, pi.add_apply],
split_ifs; simp
end

@[simp] lemma block_diagonal_neg [add_group α] :
block_diagonal (-M) = - block_diagonal M :=
begin
ext,
simp only [block_diagonal_apply, pi.neg_apply],
split_ifs; simp
section
variables (o m n α)
/-- `matrix.block_diagonal` as an `add_monoid_hom`. -/
@[simps] def block_diagonal_add_monoid_hom [add_zero_class α] :
(o → matrix m n α) →+ matrix (m × o) (n × o) α :=
{ to_fun := block_diagonal,
map_zero' := block_diagonal_zero,
map_add' := block_diagonal_add }
end

@[simp] lemma block_diagonal_sub [add_group α] :
@[simp] lemma block_diagonal_neg [add_group α] (M : o → matrix m n α) :
block_diagonal (-M) = - block_diagonal M :=
map_neg (block_diagonal_add_monoid_hom m n o α) M

@[simp] lemma block_diagonal_sub [add_group α] (M N : o → matrix m n α) :
block_diagonal (M - N) = block_diagonal M - block_diagonal N :=
by simp [sub_eq_add_neg]
map_sub (block_diagonal_add_monoid_hom m n o α) M N

@[simp] lemma block_diagonal_mul {p : Type*} [fintype n] [fintype o] [semiring α]
(N : o → matrix n p α) :
@[simp] lemma block_diagonal_mul [fintype n] [fintype o] [non_unital_non_assoc_semiring α]
(M : o → matrix m n α) (N : o → matrix n p α) :
block_diagonal (λ k, M k ⬝ N k) = block_diagonal M ⬝ block_diagonal N :=
begin
ext ⟨i, k⟩ ⟨j, k'⟩,
simp only [block_diagonal_apply, mul_apply, ← finset.univ_product_univ, finset.sum_product],
split_ifs with h; simp [h]
end

@[simp] lemma block_diagonal_smul {R : Type*} [semiring R] [add_comm_monoid α] [module R α]
(x : R) : block_diagonal (x • M) = x • block_diagonal M :=
section
variables (α m o)
/-- `matrix.block_diagonal` as a `ring_hom`. -/
@[simps]
def block_diagonal_ring_hom [decidable_eq m] [fintype o] [fintype m] [non_assoc_semiring α] :
(o → matrix m m α) →+* matrix (m × o) (m × o) α :=
{ to_fun := block_diagonal,
map_one' := block_diagonal_one,
map_mul' := block_diagonal_mul,
..block_diagonal_add_monoid_hom m m o α }
end

@[simp] lemma block_diagonal_smul {R : Type*} [monoid R] [add_monoid α] [distrib_mul_action R α]
(x : R) (M : o → matrix m n α) : block_diagonal (x • M) = x • block_diagonal M :=
by { ext, simp only [block_diagonal_apply, pi.smul_apply], split_ifs; simp }

end block_diagonal

section block_diagonal'

variables (M N : Π i, matrix (m' i) (n' i) α) [decidable_eq o]
variables [decidable_eq o]

section has_zero

Expand All @@ -325,7 +342,7 @@ variables [has_zero α] [has_zero β]
and zero elsewhere.
This is the dependently-typed version of `matrix.block_diagonal`. -/
def block_diagonal' : matrix (Σ i, m' i) (Σ i, n' i) α
def block_diagonal' (M : Π i, matrix (m' i) (n' i) α) : matrix (Σ i, m' i) (Σ i, n' i) α
| ⟨k, i⟩ ⟨k', j⟩ := if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0

lemma block_diagonal'_eq_block_diagonal (M : o → matrix m n α) {k k'} (i j) :
Expand All @@ -337,37 +354,37 @@ lemma block_diagonal'_minor_eq_block_diagonal (M : o → matrix m n α) :
block_diagonal M :=
matrix.ext $ λ ⟨k, i⟩ ⟨k', j⟩, rfl

lemma block_diagonal'_apply (ik jk) :
lemma block_diagonal'_apply (M : Π i, matrix (m' i) (n' i) α) (ik jk) :
block_diagonal' M ik jk = if h : ik.1 = jk.1 then
M ik.1 ik.2 (cast (congr_arg n' h.symm) jk.2) else 0 :=
by { cases ik, cases jk, refl }

@[simp]
lemma block_diagonal'_apply_eq (k i j) :
lemma block_diagonal'_apply_eq (M : Π i, matrix (m' i) (n' i) α) (k i j) :
block_diagonal' M ⟨k, i⟩ ⟨k, j⟩ = M k i j :=
dif_pos rfl

lemma block_diagonal'_apply_ne {k k'} (i j) (h : k ≠ k') :
lemma block_diagonal'_apply_ne (M : Π i, matrix (m' i) (n' i) α) {k k'} (i j) (h : k ≠ k') :
block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ = 0 :=
dif_neg h

lemma block_diagonal'_map (f : α → β) (hf : f 0 = 0) :
lemma block_diagonal'_map (M : Π i, matrix (m' i) (n' i) α) (f : α → β) (hf : f 0 = 0) :
(block_diagonal' M).map f = block_diagonal' (λ k, (M k).map f) :=
begin
ext,
simp only [map_apply, block_diagonal'_apply, eq_comm],
rw [apply_dite f, hf],
end

@[simp] lemma block_diagonal'_transpose :
@[simp] lemma block_diagonal'_transpose (M : Π i, matrix (m' i) (n' i) α) :
(block_diagonal' M)ᵀ = block_diagonal' (λ k, (M k)ᵀ) :=
begin
ext ⟨ii, ix⟩ ⟨ji, jx⟩,
simp only [transpose_apply, block_diagonal'_apply],
split_ifs; cc
end

@[simp] lemma block_diagonal'_conj_transpose {α} [semiring α] [star_ring α]
@[simp] lemma block_diagonal'_conj_transpose {α} [add_monoid α] [star_add_monoid α]
(M : Π i, matrix (m' i) (n' i) α) :
(block_diagonal' M)ᴴ = block_diagonal' (λ k, (M k)ᴴ) :=
begin
Expand All @@ -379,12 +396,14 @@ end
block_diagonal' (0 : Π i, matrix (m' i) (n' i) α) = 0 :=
by { ext, simp [block_diagonal'_apply] }

@[simp] lemma block_diagonal'_diagonal [ i, decidable_eq (m' i)] (d : Π i, m' i → α) :
@[simp] lemma block_diagonal'_diagonal [Π i, decidable_eq (m' i)] (d : Π i, m' i → α) :
block_diagonal' (λ k, diagonal (d k)) = diagonal (λ ik, d ik.1 ik.2) :=
begin
ext ⟨i, k⟩ ⟨j, k'⟩,
simp only [block_diagonal'_apply, diagonal],
split_ifs; cc
obtain rfl | hij := decidable.eq_or_ne i j,
{ simp, },
{ simp [hij] },
end

@[simp] lemma block_diagonal'_one [∀ i, decidable_eq (m' i)] [has_one α] :
Expand All @@ -394,29 +413,37 @@ by rw [block_diagonal'_diagonal]

end has_zero

@[simp] lemma block_diagonal'_add [add_monoid α] :
@[simp] lemma block_diagonal'_add [add_zero_class α] (M N : Π i, matrix (m' i) (n' i) α) :
block_diagonal' (M + N) = block_diagonal' M + block_diagonal' N :=
begin
ext,
simp only [block_diagonal'_apply, pi.add_apply],
split_ifs; simp
end

@[simp] lemma block_diagonal'_neg [add_group α] :
block_diagonal' (-M) = - block_diagonal' M :=
begin
ext,
simp only [block_diagonal'_apply, pi.neg_apply],
split_ifs; simp

section
variables (m' n' α)
/-- `matrix.block_diagonal'` as an `add_monoid_hom`. -/
@[simps] def block_diagonal'_add_monoid_hom [add_zero_class α] :
(Π i, matrix (m' i) (n' i) α) →+ matrix (Σ i, m' i) (Σ i, n' i) α :=
{ to_fun := block_diagonal',
map_zero' := block_diagonal'_zero,
map_add' := block_diagonal'_add }
end

@[simp] lemma block_diagonal'_sub [add_group α] :
@[simp] lemma block_diagonal'_neg [add_group α] (M : Π i, matrix (m' i) (n' i) α) :
block_diagonal' (-M) = - block_diagonal' M :=
map_neg (block_diagonal'_add_monoid_hom m' n' α) M

@[simp] lemma block_diagonal'_sub [add_group α] (M N : Π i, matrix (m' i) (n' i) α) :
block_diagonal' (M - N) = block_diagonal' M - block_diagonal' N :=
by simp [sub_eq_add_neg]
map_sub (block_diagonal'_add_monoid_hom m' n' α) M N

@[simp] lemma block_diagonal'_mul {p : o → Type*} [semiring α] [Π i, fintype (n' i)] [fintype o]
(N : Π i, matrix (n' i) (p i) α) :
block_diagonal' (λ k, M k ⬝ N k) = block_diagonal' M ⬝ block_diagonal' N :=
@[simp] lemma block_diagonal'_mul [non_unital_non_assoc_semiring α]
[Π i, fintype (n' i)] [fintype o]
(M : Π i, matrix (m' i) (n' i) α) (N : Π i, matrix (n' i) (p' i) α) :
block_diagonal' (λ k, M k ⬝ N k) = block_diagonal' M ⬝ block_diagonal' N :=
begin
ext ⟨k, i⟩ ⟨k', j⟩,
simp only [block_diagonal'_apply, mul_apply, ← finset.univ_sigma_univ, finset.sum_sigma],
Expand All @@ -425,8 +452,21 @@ begin
{ intros j' hj', exact finset.sum_eq_zero (λ _ _, by rw [dif_neg hj'.symm, zero_mul]) },
end

section
variables (α m')
/-- `matrix.block_diagonal'` as a `ring_hom`. -/
@[simps]
def block_diagonal'_ring_hom [Π i, decidable_eq (m' i)] [fintype o] [Π i, fintype (m' i)]
[non_assoc_semiring α] :
(Π i, matrix (m' i) (m' i) α) →+* matrix (Σ i, m' i) (Σ i, m' i) α :=
{ to_fun := block_diagonal',
map_one' := block_diagonal'_one,
map_mul' := block_diagonal'_mul,
..block_diagonal'_add_monoid_hom m' m' α }
end

@[simp] lemma block_diagonal'_smul {R : Type*} [semiring R] [add_comm_monoid α] [module R α]
(x : R) : block_diagonal' (x • M) = x • block_diagonal' M :=
(x : R) (M : Π i, matrix (m' i) (n' i) α) : block_diagonal' (x • M) = x • block_diagonal' M :=
by { ext, simp only [block_diagonal'_apply, pi.smul_apply], split_ifs; simp }

end block_diagonal'
Expand Down

0 comments on commit fb44330

Please sign in to comment.