Skip to content

Commit

Permalink
feat(data/matrix/basic): lemmas about mul_vec and single (#13835)
Browse files Browse the repository at this point in the history
We seem to be proving variants of the same statement over and over again; this introduces a new lemma that we can use to prove all these variants trivially in term mode. The new lemmas are:

* `matrix.mul_vec_single`
* `matrix.single_vec_mul`
* `matrix.diagonal_mul_vec_single`
* `matrix.single_vec_mul_diagonal`

A lot of the proofs got shorter by avoiding `ext` which invokes a more powerful lemma than we actually need.
  • Loading branch information
eric-wieser committed Jun 1, 2022
1 parent f359d55 commit 892f889
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 30 deletions.
28 changes: 28 additions & 0 deletions src/data/matrix/basic.lean
Expand Up @@ -1142,6 +1142,34 @@ lemma mul_vec_smul [fintype n] [monoid R] [non_unital_non_assoc_semiring S] [dis
M.mul_vec (b • v) = b • M.mul_vec v :=
by { ext i, simp only [mul_vec, dot_product, finset.smul_sum, pi.smul_apply, mul_smul_comm] }

@[simp] lemma mul_vec_single [fintype n] [decidable_eq n] [non_unital_non_assoc_semiring R]
(M : matrix m n R) (j : n) (x : R) :
M.mul_vec (pi.single j x) = (λ i, M i j * x) :=
funext $ λ i, dot_product_single _ _ _

@[simp] lemma single_vec_mul [fintype m] [decidable_eq m] [non_unital_non_assoc_semiring R]
(M : matrix m n R) (i : m) (x : R) :
vec_mul (pi.single i x) M = (λ j, x * M i j) :=
funext $ λ i, single_dot_product _ _ _

@[simp] lemma diagonal_mul_vec_single [fintype n] [decidable_eq n] [non_unital_non_assoc_semiring R]
(v : n → R) (j : n) (x : R) :
(diagonal v).mul_vec (pi.single j x) = pi.single j (v j * x) :=
begin
ext i,
rw mul_vec_diagonal,
exact pi.apply_single (λ i x, v i * x) (λ i, mul_zero _) j x i,
end

@[simp] lemma single_vec_mul_diagonal [fintype n] [decidable_eq n] [non_unital_non_assoc_semiring R]
(v : n → R) (j : n) (x : R) :
vec_mul (pi.single j x) (diagonal v) = pi.single j (x * v j) :=
begin
ext i,
rw vec_mul_diagonal,
exact pi.apply_single (λ i x, x * v i) (λ i, zero_mul _) j x i,
end

end non_unital_non_assoc_semiring

section non_unital_semiring
Expand Down
13 changes: 3 additions & 10 deletions src/linear_algebra/matrix/diagonal.lean
Expand Up @@ -32,23 +32,16 @@ variables {n : Type*} [fintype n] [decidable_eq n] {R : Type v} [comm_ring R]

lemma proj_diagonal (i : n) (w : n → R) :
(proj i).comp (to_lin' (diagonal w)) = (w i) • proj i :=
by ext j; simp [mul_vec_diagonal]
linear_map.ext $ λ j, mul_vec_diagonal _ _ _

lemma diagonal_comp_std_basis (w : n → R) (i : n) :
(diagonal w).to_lin'.comp (linear_map.std_basis R (λ_:n, R) i) =
(w i) • linear_map.std_basis R (λ_:n, R) i :=
begin
ext j,
simp_rw [linear_map.comp_apply, to_lin'_apply, mul_vec_diagonal, linear_map.smul_apply,
pi.smul_apply, algebra.id.smul_eq_mul],
by_cases i = j,
{ subst h },
{ rw [std_basis_ne R (λ_:n, R) _ _ (ne.symm h), _root_.mul_zero, _root_.mul_zero] }
end
linear_map.ext $ λ x, (diagonal_mul_vec_single w _ _).trans (pi.single_smul' i (w i) _)

lemma diagonal_to_lin' (w : n → R) :
(diagonal w).to_lin' = linear_map.pi (λi, w i • linear_map.proj i) :=
by ext v j; simp [mul_vec_diagonal]
linear_map.ext $ λ v, funext $ λ i, mul_vec_diagonal _ _ _

end comm_ring

Expand Down
27 changes: 8 additions & 19 deletions src/linear_algebra/matrix/to_lin.lean
Expand Up @@ -125,12 +125,12 @@ linear_map.to_matrix_right'.symm
@[simp] lemma matrix.to_linear_map_right'_mul [fintype l] [decidable_eq l] (M : matrix l m R)
(N : matrix m n R) : matrix.to_linear_map_right' (M ⬝ N) =
(matrix.to_linear_map_right' N).comp (matrix.to_linear_map_right' M) :=
by { ext, simp, }
linear_map.ext $ λ x, (vec_mul_vec_mul _ M N).symm

lemma matrix.to_linear_map_right'_mul_apply [fintype l] [decidable_eq l] (M : matrix l m R)
(N : matrix m n R) (x) : matrix.to_linear_map_right' (M ⬝ N) x =
(matrix.to_linear_map_right' N (matrix.to_linear_map_right' M x)) :=
by rw [matrix.to_linear_map_right'_mul, linear_map.comp_apply]
(vec_mul_vec_mul _ M N).symm

@[simp] lemma matrix.to_linear_map_right'_one :
matrix.to_linear_map_right' (1 : matrix m m R) = id :=
Expand Down Expand Up @@ -174,15 +174,7 @@ variables [fintype n] [decidable_eq n]

@[simp] lemma matrix.mul_vec_std_basis (M : matrix m n R) (i j) :
M.mul_vec (std_basis R (λ _, R) j 1) i = M i j :=
begin
have : (∑ j', M i j' * if j = j' then 1 else 0) = M i j,
{ simp_rw [mul_boole, finset.sum_ite_eq, finset.mem_univ, if_true] },
convert this,
ext,
split_ifs with h; simp only [std_basis_apply],
{ rw [h, function.update_same] },
{ rw [function.update_noteq (ne.symm h), pi.zero_apply] }
end
(congr_fun (matrix.mul_vec_single _ _ (1 : R)) i).trans $ mul_one _

/-- Linear maps `(n → R) →ₗ[R] (m → R)` are linearly equivalent to `matrix m n R`. -/
def linear_map.to_matrix' : ((n → R) →ₗ[R] (m → R)) ≃ₗ[R] matrix m n R :=
Expand Down Expand Up @@ -242,7 +234,7 @@ by { ext, rw [matrix.one_apply, linear_map.to_matrix'_apply, id_apply] }

@[simp] lemma matrix.to_lin'_mul [fintype m] [decidable_eq m] (M : matrix l m R)
(N : matrix m n R) : matrix.to_lin' (M ⬝ N) = (matrix.to_lin' M).comp (matrix.to_lin' N) :=
by { ext, simp, }
linear_map.ext $ λ x, (mul_vec_mul_vec _ _ _).symm

/-- Shortcut lemma for `matrix.to_lin'_mul` and `linear_map.comp_apply` -/
lemma matrix.to_lin'_mul_apply [fintype m] [decidable_eq m] (M : matrix l m R)
Expand Down Expand Up @@ -317,23 +309,20 @@ by simp [linear_map.to_matrix_alg_equiv']

@[simp] lemma matrix.to_lin_alg_equiv'_one :
matrix.to_lin_alg_equiv' (1 : matrix n n R) = id :=
by { ext, simp [matrix.one_apply, std_basis_apply] }
matrix.to_lin'_one

@[simp] lemma linear_map.to_matrix_alg_equiv'_id :
(linear_map.to_matrix_alg_equiv' (linear_map.id : (n → R) →ₗ[R] (n → R))) = 1 :=
by { ext, rw [matrix.one_apply, linear_map.to_matrix_alg_equiv'_apply, id_apply] }
linear_map.to_matrix'_id

@[simp] lemma matrix.to_lin_alg_equiv'_mul (M N : matrix n n R) :
matrix.to_lin_alg_equiv' (M ⬝ N) =
(matrix.to_lin_alg_equiv' M).comp (matrix.to_lin_alg_equiv' N) :=
by { ext, simp }
matrix.to_lin'_mul _ _

lemma linear_map.to_matrix_alg_equiv'_comp (f g : (n → R) →ₗ[R] (n → R)) :
(f.comp g).to_matrix_alg_equiv' = f.to_matrix_alg_equiv' ⬝ g.to_matrix_alg_equiv' :=
suffices (f.comp g) = (f.to_matrix_alg_equiv' ⬝ g.to_matrix_alg_equiv').to_lin_alg_equiv',
by rw [this, linear_map.to_matrix_alg_equiv'_to_lin_alg_equiv'],
by rw [matrix.to_lin_alg_equiv'_mul, matrix.to_lin_alg_equiv'_to_matrix_alg_equiv',
matrix.to_lin_alg_equiv'_to_matrix_alg_equiv']
linear_map.to_matrix'_comp _ _

lemma linear_map.to_matrix_alg_equiv'_mul
(f g : (n → R) →ₗ[R] (n → R)) :
Expand Down
2 changes: 1 addition & 1 deletion src/linear_algebra/std_basis.lean
Expand Up @@ -48,7 +48,7 @@ lemma std_basis_apply (i : ι) (b : φ i) : std_basis R φ i b = update 0 i b :=
rfl

lemma coe_std_basis (i : ι) : ⇑(std_basis R φ i) = pi.single i :=
funext $ std_basis_apply R φ i
rfl

@[simp] lemma std_basis_same (i : ι) (b : φ i) : std_basis R φ i b i = b :=
by rw [std_basis_apply, update_same]
Expand Down

0 comments on commit 892f889

Please sign in to comment.