Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 3e068ec

Browse files
committed
refactor(data/matrix/basic): work around leanprover/lean4#2042 (#18696)
This adjust definitions such that everything is well-behaved in the case that things are unfolded. For each such definition, a lemma is added that replaces the equation lemma. Before this PR, we used ```lean def transpose (M : matrix m n α) : matrix n m α | x y := M y x ``` which has the nice behavior (in Lean 3 only) of `rw transpose` only unfolding the definition when it is of the applied form `transpose M i j`. If `dunfold transpose` is used then it becomes the undesirable `λ x y, M y x` in both Lean versions. After this PR, we use ```lean def transpose (M : matrix m n α) : matrix n m α := of $ λ x y, M y x -- TODO: set as an equation lemma for `transpose`, see mathlib4#3024 @[simp] lemma transpose_apply (M : matrix m n α) (i j) : transpose M i j = M j i := rfl ``` This no longer has the nice `rw` behavior, but we can't have that in Lean4 anyway (leanprover/lean4#2042). It also makes `dunfold` insert the `of`, which is better for type-safety. This affects * `matrix.transpose` * `matrix.row` * `matrix.col` * `matrix.diagonal` * `matrix.vec_mul_vec` * `matrix.block_diagonal` * `matrix.block_diagonal'` * `matrix.hadamard` * `matrix.kronecker_map` * `pequiv.to_matrix` * `matrix.circulant` * `matrix.mv_polynomial_X` * `algebra.trace_matrix` * `algebra.embeddings_matrix` While this just adds `_apply` noise in Lean 3, it is necessary when porting to Lean 4 as there the equation lemma is not generated in the way that we want. This is hopefully exhaustive; it was found by looking for lines ending in `matrix .*` followed by a `|` line
1 parent ce11c3c commit 3e068ec

File tree

16 files changed

+145
-80
lines changed

16 files changed

+145
-80
lines changed

src/algebra/lie/classical.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ begin
328328
ext i j,
329329
rcases i with ⟨⟨i₁ | i₂⟩ | i₃⟩;
330330
rcases j with ⟨⟨j₁ | j₂⟩ | j₃⟩;
331-
simp only [indefinite_diagonal, matrix.diagonal, equiv.sum_assoc_apply_inl_inl,
331+
simp only [indefinite_diagonal, matrix.diagonal_apply, equiv.sum_assoc_apply_inl_inl,
332332
matrix.reindex_lie_equiv_apply, matrix.submatrix_apply, equiv.symm_symm, matrix.reindex_apply,
333333
sum.elim_inl, if_true, eq_self_iff_true, matrix.one_apply_eq, matrix.from_blocks_apply₁₁,
334334
dmatrix.zero_apply, equiv.sum_assoc_apply_inl_inr, if_false, matrix.from_blocks_apply₁₂,

src/combinatorics/simple_graph/adj_matrix.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,12 @@ variables (α)
148148

149149
/-- `adj_matrix G α` is the matrix `A` such that `A i j = (1 : α)` if `i` and `j` are
150150
adjacent in the simple graph `G`, and otherwise `A i j = 0`. -/
151-
def adj_matrix [has_zero α] [has_one α] : matrix V V α
152-
| i j := if (G.adj i j) then 1 else 0
151+
def adj_matrix [has_zero α] [has_one α] : matrix V V α :=
152+
of $ λ i j, if (G.adj i j) then (1 : α) else 0
153153

154154
variable {α}
155155

156+
-- TODO: set as an equation lemma for `adj_matrix`, see mathlib4#3024
156157
@[simp]
157158
lemma adj_matrix_apply (v w : V) [has_zero α] [has_one α] :
158159
G.adj_matrix α v w = if (G.adj v w) then 1 else 0 := rfl

src/data/matrix/basic.lean

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,14 @@ The two sides of the equivalence are definitionally equal types. We want to use
7676
to distinguish the types because `matrix` has different instances to pi types (such as `pi.has_mul`,
7777
which performs elementwise multiplication, vs `matrix.has_mul`).
7878
79-
If you are defining a matrix, in terms of its entries, either use `of (λ i j, _)`, or use pattern
80-
matching in a definition as `| i j := _` (which can only be unfolded when fully-applied). The
81-
purpose of this approach is to ensure that terms of the form `(λ i j, _) * (λ i j, _)` do not
79+
If you are defining a matrix, in terms of its entries, use `of (λ i j, _)`. The
80+
purpose of this approach is to ensure that terms of th
81+
e form `(λ i j, _) * (λ i j, _)` do not
8282
appear, as the type of `*` can be misleading.
83+
84+
Porting note: In Lean 3, it is also safe to use pattern matching in a definition as `| i j := _`,
85+
which can only be unfolded when fully-applied. leanprover/lean4#2042 means this does not
86+
(currently) work in Lean 4.
8387
-/
8488
def of : (m → n → α) ≃ matrix m n α := equiv.refl _
8589
@[simp] lemma of_apply (f : m → n → α) (i j) : of f i j = f i j := rfl
@@ -118,8 +122,12 @@ lemma map_injective {f : α → β} (hf : function.injective f) :
118122
λ M N h, ext $ λ i j, hf $ ext_iff.mpr h i j
119123

120124
/-- The transpose of a matrix. -/
121-
def transpose (M : matrix m n α) : matrix n m α
122-
| x y := M y x
125+
def transpose (M : matrix m n α) : matrix n m α :=
126+
of $ λ x y, M y x
127+
128+
-- TODO: set as an equation lemma for `transpose`, see mathlib4#3024
129+
@[simp] lemma transpose_apply (M : matrix m n α) (i j) :
130+
transpose M i j = M j i := rfl
123131

124132
localized "postfix (name := matrix.transpose) `ᵀ`:1500 := matrix.transpose" in matrix
125133

@@ -130,12 +138,19 @@ M.transpose.map star
130138
localized "postfix (name := matrix.conj_transpose) `ᴴ`:1500 := matrix.conj_transpose" in matrix
131139

132140
/-- `matrix.col u` is the column matrix whose entries are given by `u`. -/
133-
def col (w : m → α) : matrix m unit α
134-
| x y := w x
141+
def col (w : m → α) : matrix m unit α :=
142+
of $ λ x y, w x
143+
144+
-- TODO: set as an equation lemma for `col`, see mathlib4#3024
145+
@[simp] lemma col_apply (w : m → α) (i j) :
146+
col w i j = w i := rfl
135147

136148
/-- `matrix.row u` is the row matrix whose entries are given by `u`. -/
137-
def row (v : n → α) : matrix unit n α
138-
| x y := v y
149+
def row (v : n → α) : matrix unit n α :=
150+
of $ λ x y, v y
151+
152+
-- TODO: set as an equation lemma for `row`, see mathlib4#3024
153+
@[simp] lemma row_apply (v : n → α) (i j) : row v i j = v j := rfl
139154

140155
instance [inhabited α] : inhabited (matrix m n α) := pi.inhabited _
141156
instance [has_add α] : has_add (matrix m n α) := pi.has_add
@@ -239,8 +254,12 @@ Note that bundled versions exist as:
239254
* `matrix.diagonal_ring_hom`
240255
* `matrix.diagonal_alg_hom`
241256
-/
242-
def diagonal [has_zero α] (d : n → α) : matrix n n α
243-
| i j := if i = j then d i else 0
257+
def diagonal [has_zero α] (d : n → α) : matrix n n α :=
258+
of $ λ i j, if i = j then d i else 0
259+
260+
-- TODO: set as an equation lemma for `diagonal`, see mathlib4#3024
261+
lemma diagonal_apply [has_zero α] (d : n → α) (i j) : diagonal d i j = if i = j then d i else 0 :=
262+
rfl
244263

245264
@[simp] theorem diagonal_apply_eq [has_zero α] (d : n → α) (i : n) : (diagonal d) i i = d i :=
246265
by simp [diagonal]
@@ -302,7 +321,7 @@ variables {n α R}
302321

303322
@[simp] lemma diagonal_map [has_zero α] [has_zero β] {f : α → β} (h : f 0 = 0) {d : n → α} :
304323
(diagonal d).map f = diagonal (λ m, f (d m)) :=
305-
by { ext, simp only [diagonal, map_apply], split_ifs; simp [h], }
324+
by { ext, simp only [diagonal_apply, map_apply], split_ifs; simp [h], }
306325

307326
@[simp] lemma diagonal_conj_transpose [add_monoid α] [star_add_monoid α] (v : n → α) :
308327
(diagonal v)ᴴ = diagonal (star v) :=
@@ -1113,8 +1132,13 @@ namespace matrix
11131132

11141133
/-- For two vectors `w` and `v`, `vec_mul_vec w v i j` is defined to be `w i * v j`.
11151134
Put another way, `vec_mul_vec w v` is exactly `col w ⬝ row v`. -/
1116-
def vec_mul_vec [has_mul α] (w : m → α) (v : n → α) : matrix m n α
1117-
| x y := w x * v y
1135+
def vec_mul_vec [has_mul α] (w : m → α) (v : n → α) : matrix m n α :=
1136+
of $ λ x y, w x * v y
1137+
1138+
-- TODO: set as an equation lemma for `vec_mul_vec`, see mathlib4#3024
1139+
lemma vec_mul_vec_apply [has_mul α] (w : m → α) (v : n → α) (i j) :
1140+
vec_mul_vec w v i j = w i * v j :=
1141+
rfl
11181142

11191143
lemma vec_mul_vec_eq [has_mul α] [add_comm_monoid α] (w : m → α) (v : n → α) :
11201144
vec_mul_vec w v = (col w) ⬝ (row v) :=
@@ -1336,13 +1360,6 @@ section transpose
13361360

13371361
open_locale matrix
13381362

1339-
/--
1340-
Tell `simp` what the entries are in a transposed matrix.
1341-
1342-
Compare with `mul_apply`, `diagonal_apply_eq`, etc.
1343-
-/
1344-
@[simp] lemma transpose_apply (M : matrix m n α) (i j) : M.transpose j i = M i j := rfl
1345-
13461363
@[simp] lemma transpose_transpose (M : matrix m n α) :
13471364
Mᵀᵀ = M :=
13481365
by ext; refl
@@ -1353,7 +1370,7 @@ by ext i j; refl
13531370
@[simp] lemma transpose_one [decidable_eq n] [has_zero α] [has_one α] : (1 : matrix n n α)ᵀ = 1 :=
13541371
begin
13551372
ext i j,
1356-
unfold has_one.one transpose,
1373+
rw [transpose_apply, ←diagonal_one],
13571374
by_cases i = j,
13581375
{ simp only [h, diagonal_apply_eq] },
13591376
{ simp only [diagonal_apply_ne _ h, diagonal_apply_ne' _ h] }
@@ -1435,6 +1452,8 @@ def transpose_ring_equiv [add_comm_monoid α] [comm_semigroup α] [fintype m] :
14351452
inv_fun := λ M, M.unopᵀ,
14361453
map_mul' := λ M N, (congr_arg mul_opposite.op (transpose_mul M N)).trans
14371454
(mul_opposite.op_mul _ _),
1455+
left_inv := λ M, transpose_transpose M,
1456+
right_inv := λ M, mul_opposite.unop_injective $ transpose_transpose M.unop,
14381457
..(transpose_add_equiv m m α).trans mul_opposite.op_add_equiv }
14391458

14401459
variables {m α}
@@ -1895,9 +1914,6 @@ by { ext, refl }
18951914
@[simp] lemma row_smul [has_smul R α] (x : R) (v : m → α) : row (x • v) = x • row v :=
18961915
by { ext, refl }
18971916

1898-
@[simp] lemma col_apply (v : m → α) (i j) : matrix.col v i j = v i := rfl
1899-
@[simp] lemma row_apply (v : m → α) (i j) : matrix.row v i j = v j := rfl
1900-
19011917
@[simp]
19021918
lemma transpose_col (v : m → α) : (matrix.col v)ᵀ = matrix.row v := by { ext, refl }
19031919
@[simp]

src/data/matrix/block.lean

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,13 @@ the diagonal and zero elsewhere.
278278
279279
See also `matrix.block_diagonal'` if the matrices may not have the same size everywhere.
280280
-/
281-
def block_diagonal (M : o → matrix m n α) : matrix (m × o) (n × o) α
282-
| ⟨i, k⟩ ⟨j, k'⟩ := if k = k' then M k i j else 0
281+
def block_diagonal (M : o → matrix m n α) : matrix (m × o) (n × o) α :=
282+
of $ (λ ⟨i, k⟩ ⟨j, k'⟩, if k = k' then M k i j else 0 : m × o → n × o → α)
283+
284+
-- TODO: set as an equation lemma for `block_diagonal`, see mathlib4#3024
285+
lemma block_diagonal_apply' (M : o → matrix m n α) (i k j k') :
286+
block_diagonal M ⟨i, k⟩ ⟨j, k'⟩ = if k = k' then M k i j else 0 :=
287+
rfl
283288

284289
lemma block_diagonal_apply (M : o → matrix m n α) (ik jk) :
285290
block_diagonal M ik jk = if ik.2 = jk.2 then M ik.2 ik.1 jk.1 else 0 :=
@@ -328,7 +333,7 @@ by { ext, simp [block_diagonal_apply] }
328333
block_diagonal (λ k, diagonal (d k)) = diagonal (λ ik, d ik.2 ik.1) :=
329334
begin
330335
ext ⟨i, k⟩ ⟨j, k'⟩,
331-
simp only [block_diagonal_apply, diagonal, prod.mk.inj_iff, ← ite_and],
336+
simp only [block_diagonal_apply, diagonal_apply, prod.mk.inj_iff, ← ite_and],
332337
congr' 1,
333338
rw and_comm,
334339
end
@@ -404,8 +409,12 @@ section block_diag
404409
/-- Extract a block from the diagonal of a block diagonal matrix.
405410
406411
This is the block form of `matrix.diag`, and the left-inverse of `matrix.block_diagonal`. -/
407-
def block_diag (M : matrix (m × o) (n × o) α) (k : o) : matrix m n α
408-
| i j := M (i, k) (j, k)
412+
def block_diag (M : matrix (m × o) (n × o) α) (k : o) : matrix m n α :=
413+
of $ λ i j, M (i, k) (j, k)
414+
415+
-- TODO: set as an equation lemma for `block_diag`, see mathlib4#3024
416+
lemma block_diag_apply (M : matrix (m × o) (n × o) α) (k : o) (i j) :
417+
block_diag M k i j = M (i, k) (j, k) := rfl
409418

410419
lemma block_diag_map (M : matrix (m × o) (n × o) α) (f : α → β) :
411420
block_diag (M.map f) = λ k, (block_diag M k).map f :=
@@ -431,14 +440,14 @@ rfl
431440
block_diag (diagonal d) k = diagonal (λ i, d (i, k)) :=
432441
ext $ λ i j, begin
433442
obtain rfl | hij := decidable.eq_or_ne i j,
434-
{ rw [block_diag, diagonal_apply_eq, diagonal_apply_eq] },
435-
{ rw [block_diag, diagonal_apply_ne _ hij, diagonal_apply_ne _ (mt _ hij)],
443+
{ rw [block_diag_apply, diagonal_apply_eq, diagonal_apply_eq] },
444+
{ rw [block_diag_apply, diagonal_apply_ne _ hij, diagonal_apply_ne _ (mt _ hij)],
436445
exact prod.fst_eq_iff.mpr },
437446
end
438447

439448
@[simp] lemma block_diag_block_diagonal [decidable_eq o] (M : o → matrix m n α) :
440449
block_diag (block_diagonal M) = M :=
441-
funext $ λ k, ext $ λ i j, block_diagonal_apply_eq _ _ _ _
450+
funext $ λ k, ext $ λ i j, block_diagonal_apply_eq M i j _
442451

443452
@[simp] lemma block_diag_one [decidable_eq o] [decidable_eq m] [has_one α] :
444453
block_diag (1 : matrix (m × o) (m × o) α) = 1 :=
@@ -486,8 +495,15 @@ variables [has_zero α] [has_zero β]
486495
and zero elsewhere.
487496
488497
This is the dependently-typed version of `matrix.block_diagonal`. -/
489-
def block_diagonal' (M : Π i, matrix (m' i) (n' i) α) : matrix (Σ i, m' i) (Σ i, n' i) α
490-
| ⟨k, i⟩ ⟨k', j⟩ := if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0
498+
def block_diagonal' (M : Π i, matrix (m' i) (n' i) α) : matrix (Σ i, m' i) (Σ i, n' i) α :=
499+
of $ (λ ⟨k, i⟩ ⟨k', j⟩, if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0 :
500+
(Σ i, m' i) → (Σ i, n' i) → α)
501+
502+
-- TODO: set as an equation lemma for `block_diagonal'`, see mathlib4#3024
503+
lemma block_diagonal'_apply' (M : Π i, matrix (m' i) (n' i) α) (k i k' j) :
504+
block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ =
505+
if h : k = k' then M k i (cast (congr_arg n' h.symm) j) else 0 :=
506+
rfl
491507

492508
lemma block_diagonal'_eq_block_diagonal (M : o → matrix m n α) {k k'} (i j) :
493509
block_diagonal M (i, k) (j, k') = block_diagonal' M ⟨k, i⟩ ⟨k', j⟩ :=
@@ -625,8 +641,12 @@ section block_diag'
625641
/-- Extract a block from the diagonal of a block diagonal matrix.
626642
627643
This is the block form of `matrix.diag`, and the left-inverse of `matrix.block_diagonal'`. -/
628-
def block_diag' (M : matrix (Σ i, m' i) (Σ i, n' i) α) (k : o) : matrix (m' k) (n' k) α
629-
| i j := M ⟨k, i⟩ ⟨k, j⟩
644+
def block_diag' (M : matrix (Σ i, m' i) (Σ i, n' i) α) (k : o) : matrix (m' k) (n' k) α :=
645+
of $ λ i j, M ⟨k, i⟩ ⟨k, j⟩
646+
647+
-- TODO: set as an equation lemma for `block_diag'`, see mathlib4#3024
648+
lemma block_diag'_apply (M : matrix (Σ i, m' i) (Σ i, n' i) α) (k : o) (i j) :
649+
block_diag' M k i j = M ⟨k, i⟩ ⟨k, j⟩ := rfl
630650

631651
lemma block_diag'_map (M : matrix (Σ i, m' i) (Σ i, n' i) α) (f : α → β) :
632652
block_diag' (M.map f) = λ k, (block_diag' M k).map f :=
@@ -653,14 +673,14 @@ rfl
653673
block_diag' (diagonal d) k = diagonal (λ i, d ⟨k, i⟩) :=
654674
ext $ λ i j, begin
655675
obtain rfl | hij := decidable.eq_or_ne i j,
656-
{ rw [block_diag', diagonal_apply_eq, diagonal_apply_eq] },
657-
{ rw [block_diag', diagonal_apply_ne _ hij, diagonal_apply_ne _ (mt (λ h, _) hij)],
676+
{ rw [block_diag'_apply, diagonal_apply_eq, diagonal_apply_eq] },
677+
{ rw [block_diag'_apply, diagonal_apply_ne _ hij, diagonal_apply_ne _ (mt (λ h, _) hij)],
658678
cases h, refl },
659679
end
660680

661681
@[simp] lemma block_diag'_block_diagonal' [decidable_eq o] (M : Π i, matrix (m' i) (n' i) α) :
662682
block_diag' (block_diagonal' M) = M :=
663-
funext $ λ k, ext $ λ i j, block_diagonal'_apply_eq _ _ _ _
683+
funext $ λ k, ext $ λ i j, block_diagonal'_apply_eq M _ _ _
664684

665685
@[simp] lemma block_diag'_one [decidable_eq o] [Π i, decidable_eq (m' i)] [has_one α] :
666686
block_diag' (1 : matrix (Σ i, m' i) (Σ i, m' i) α) = 1 :=

src/data/matrix/hadamard.lean

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ open_locale matrix big_operators
3737

3838
/-- `matrix.hadamard` defines the Hadamard product,
3939
which is the pointwise product of two matrices of the same size.-/
40-
@[simp]
41-
def hadamard [has_mul α] (A : matrix m n α) (B : matrix m n α) : matrix m n α
42-
| i j := A i j * B i j
40+
def hadamard [has_mul α] (A : matrix m n α) (B : matrix m n α) : matrix m n α :=
41+
of $ λ i j, A i j * B i j
4342

43+
-- TODO: set as an equation lemma for `hadamard`, see mathlib4#3024
44+
@[simp]
45+
lemma hadamard_apply [has_mul α] (A : matrix m n α) (B : matrix m n α) (i j) :
46+
hadamard A B i j = A i j * B i j := rfl
4447
localized "infix (name := matrix.hadamard) ` ⊙ `:100 := matrix.hadamard" in matrix
4548

4649
section basic_properties

src/data/matrix/kronecker.lean

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,14 @@ variables {l m n p : Type*} {q r : Type*} {l' m' n' p' : Type*}
5252
section kronecker_map
5353

5454
/-- Produce a matrix with `f` applied to every pair of elements from `A` and `B`. -/
55-
@[simp] def kronecker_map (f : α → β → γ) (A : matrix l m α) (B : matrix n p β) :
56-
matrix (l × n) (m × p) γ
57-
| i j := f (A i.1 j.1) (B i.2 j.2)
55+
def kronecker_map (f : α → β → γ) (A : matrix l m α) (B : matrix n p β) :
56+
matrix (l × n) (m × p) γ :=
57+
of $ λ (i : l × n) (j : m × p), f (A i.1 j.1) (B i.2 j.2)
58+
59+
-- TODO: set as an equation lemma for `kronecker_map`, see mathlib4#3024
60+
@[simp]
61+
lemma kronecker_map_apply (f : α → β → γ) (A : matrix l m α) (B : matrix n p β) (i j) :
62+
kronecker_map f A B i j = f (A i.1 j.1) (B i.2 j.2) := rfl
5863

5964
lemma kronecker_map_transpose (f : α → β → γ)
6065
(A : matrix l m α) (B : matrix n p β) :
@@ -199,7 +204,7 @@ lemma kronecker_map_bilinear_mul_mul [comm_semiring R]
199204
begin
200205
ext ⟨i, i'⟩ ⟨j, j'⟩,
201206
simp only [kronecker_map_bilinear_apply_apply, mul_apply, ← finset.univ_product_univ,
202-
finset.sum_product, kronecker_map],
207+
finset.sum_product, kronecker_map_apply],
203208
simp_rw [f.map_sum, linear_map.sum_apply, linear_map.map_sum, h_comm],
204209
end
205210

@@ -212,7 +217,7 @@ lemma trace_kronecker_map_bilinear [comm_semiring R]
212217
(f : α →ₗ[R] β →ₗ[R] γ) (A : matrix m m α) (B : matrix n n β) :
213218
trace (kronecker_map_bilinear f A B) = f (trace A) (trace B) :=
214219
by simp_rw [matrix.trace, matrix.diag, kronecker_map_bilinear_apply_apply,
215-
linear_map.map_sum₂, map_sum, ←finset.univ_product_univ, finset.sum_product, kronecker_map]
220+
linear_map.map_sum₂, map_sum, ←finset.univ_product_univ, finset.sum_product, kronecker_map_apply]
216221

217222
/-- `determinant` of `matrix.kronecker_map_bilinear`.
218223

src/data/matrix/notation.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ by { ext i, refine fin.cases _ _ i; simp [vec_mul_vec] }
285285

286286
@[simp] lemma vec_mul_vec_cons (v : m' → α) (x : α) (w : fin n → α) :
287287
vec_mul_vec v (vec_cons x w) = λ i, v i • vec_cons x w :=
288-
by { ext i j, rw [vec_mul_vec, pi.smul_apply, smul_eq_mul] }
288+
by { ext i j, rw [vec_mul_vec_apply, pi.smul_apply, smul_eq_mul] }
289289

290290
end vec_mul_vec
291291

src/data/matrix/pequiv.lean

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@ open_locale matrix
4444

4545
/-- `to_matrix` returns a matrix containing ones and zeros. `f.to_matrix i j` is `1` if
4646
`f i = some j` and `0` otherwise -/
47-
def to_matrix [decidable_eq n] [has_zero α] [has_one α] (f : m ≃. n) : matrix m n α
48-
| i j := if j ∈ f i then 1 else 0
47+
def to_matrix [decidable_eq n] [has_zero α] [has_one α] (f : m ≃. n) : matrix m n α :=
48+
of $ λ i j, if j ∈ f i then (1 : α) else 0
49+
50+
-- TODO: set as an equation lemma for `to_matrix`, see mathlib4#3024
51+
@[simp]
52+
lemma to_matrix_apply [decidable_eq n] [has_zero α] [has_one α] (f : m ≃. n) (i j) :
53+
to_matrix f i j = if j ∈ f i then (1 : α) else 0 := rfl
4954

5055
lemma mul_matrix_apply [fintype m] [decidable_eq m] [semiring α] (f : l ≃. m) (M : matrix m n α)
5156
(i j) : (f.to_matrix ⬝ M) i j = option.cases_on (f i) 0 (λ fi, M fi j) :=
@@ -59,11 +64,11 @@ end
5964

6065
lemma to_matrix_symm [decidable_eq m] [decidable_eq n] [has_zero α] [has_one α] (f : m ≃. n) :
6166
(f.symm.to_matrix : matrix n m α) = f.to_matrixᵀ :=
62-
by ext; simp only [transpose, mem_iff_mem f, to_matrix]; congr
67+
by ext; simp only [transpose, mem_iff_mem f, to_matrix_apply]; congr
6368

6469
@[simp] lemma to_matrix_refl [decidable_eq n] [has_zero α] [has_one α] :
6570
((pequiv.refl n).to_matrix : matrix n n α) = 1 :=
66-
by ext; simp [to_matrix, one_apply]; congr
71+
by ext; simp [to_matrix_apply, one_apply]; congr
6772

6873
lemma matrix_mul_apply [fintype m] [semiring α] [decidable_eq n] (M : matrix l m α) (f : m ≃. n)
6974
(i j) : (M ⬝ f.to_matrix) i j = option.cases_on (f.symm j) 0 (λ fj, M i fj) :=
@@ -104,7 +109,7 @@ begin
104109
classical,
105110
assume f g,
106111
refine not_imp_not.1 _,
107-
simp only [matrix.ext_iff.symm, to_matrix, pequiv.ext_iff,
112+
simp only [matrix.ext_iff.symm, to_matrix_apply, pequiv.ext_iff,
108113
not_forall, exists_imp_distrib],
109114
assume i hi,
110115
use i,

0 commit comments

Comments
 (0)