Skip to content

Commit

Permalink
feat(data/matrix/notation): simp lemmas for constant-indexed elements (
Browse files Browse the repository at this point in the history
…#4491)

If you use the `![]` vector notation to define a vector, then `simp`
can extract elements `0` and `1` from that vector, but not element `2`
or subsequent elements.  (This shows up in particular in geometry if
defining a triangle with a concrete vector of vertices and then
subsequently doing things that need to extract a particular vertex.)
Add `bit0` and `bit1` `simp` lemmas to allow any element indexed by a
numeral to be extracted (even when the numeral is larger than the
length of the list, such numerals in `fin n` being interpreted mod
`n`).

This ends up quite long; `bit0` and `bit1` semantics mean extracting
alternate elements of the vector in a way that can wrap around to the
start of the vector when the numeral is `n` or larger, so the lemmas
need to deal with constructing such a vector of alternate elements.
As I've implemented it, some definitions also need an extra hypothesis
as an argument to control definitional equalities for the vector
lengths, to avoid problems with non-defeq types when stating
subsequent lemmas.  However, even the example added to
`test/matrix.lean` of extracting element `123456789` of a 5-element
vector is processed quite quickly, so this seems efficient enough.

Note also that there are two `@[simp]` lemmas whose proofs are just
`by simp`, but that are in fact needed for `simp` to complete
extracting some elements and that the `simp` linter does not (at least
when used with `#lint` for this single file) complain about.  I'm not
sure what's going on there, since I didn't think a lemma that `simp`
can prove should normally need to be marked as `@[simp]`.
  • Loading branch information
jsm28 committed Oct 15, 2020
1 parent 85d4b57 commit b01ca81
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/data/fin.lean
Expand Up @@ -829,6 +829,14 @@ lemma comp_tail {α : Type*} {β : Type*} (g : α → β) (q : fin n.succ → α
g ∘ (tail q) = tail (g ∘ q) :=
by { ext j, simp [tail] }

/-- `fin.append ho u v` appends two vectors of lengths `m` and `n` to produce
one of length `o = m + n`. `ho` provides control of definitional equality
for the vector length. -/
def append {α : Type*} {o : ℕ} (ho : o = m + n) (u : fin m → α) (v : fin n → α) : fin o → α :=
λ i, if h : (i : ℕ) < m
then u ⟨i, h⟩
else v ⟨(i : ℕ) - m, (nat.sub_lt_left_iff_lt_add (le_of_not_lt h)).2 (ho ▸ i.property)⟩

end tuple

section tuple_right
Expand Down
116 changes: 116 additions & 0 deletions src/data/matrix/notation.lean
Expand Up @@ -130,6 +130,122 @@ cons_val_succ x u 0
vec_cons x u i = x :=
by { fin_cases i, refl }

/-! ### Numeral (`bit0` and `bit1`) indices
The following definitions and `simp` lemmas are to allow any
numeral-indexed element of a vector given with matrix notation to
be extracted by `simp` (even when the numeral is larger than the
number of elements in the vector, which is taken modulo that number
of elements by virtue of the semantics of `bit0` and `bit1` and of
addition on `fin n`).
-/

@[simp] lemma empty_append (v : fin n → α) : fin.append (zero_add _).symm ![] v = v :=
by { ext, simp [fin.append] }

@[simp] lemma cons_append (ho : o + 1 = m + 1 + n) (x : α) (u : fin m → α) (v : fin n → α) :
fin.append ho (vec_cons x u) v =
vec_cons x (fin.append (by rwa [add_assoc, add_comm 1, ←add_assoc,
add_right_cancel_iff] at ho) u v) :=
begin
ext i,
simp_rw [fin.append],
split_ifs with h,
{ rcases i with ⟨⟨⟩ | i, hi⟩,
{ simp },
{ simp only [nat.succ_eq_add_one, add_lt_add_iff_right, fin.coe_mk] at h,
simp [h] } },
{ rcases i with ⟨⟨⟩ | i, hi⟩,
{ simpa using h },
{ rw [not_lt, fin.coe_mk, nat.succ_eq_add_one, add_le_add_iff_right] at h,
simp [h] } }
end

/-- `vec_alt0 v` gives a vector with half the length of `v`, with
only alternate elements (even-numbered). -/
def vec_alt0 (hm : m = n + n) (v : fin m → α) (k : fin n) : α :=
v ⟨(k : ℕ) + k, hm.symm ▸ add_lt_add k.property k.property⟩

/-- `vec_alt1 v` gives a vector with half the length of `v`, with
only alternate elements (odd-numbered). -/
def vec_alt1 (hm : m = n + n) (v : fin m → α) (k : fin n) : α :=
v ⟨(k : ℕ) + k + 1, hm.symm ▸ nat.add_succ_lt_add k.property k.property⟩

lemma vec_alt0_append (v : fin n → α) : vec_alt0 rfl (fin.append rfl v v) = v ∘ bit0 :=
begin
ext i,
simp_rw [function.comp, bit0, vec_alt0, fin.append],
split_ifs with h; congr,
{ rw fin.coe_mk at h,
simp only [fin.ext_iff, fin.coe_add, fin.coe_mk],
exact (nat.mod_eq_of_lt h).symm },
{ rw [fin.coe_mk, not_lt] at h,
simp only [fin.ext_iff, fin.coe_add, fin.coe_mk, nat.mod_eq_sub_mod h],
refine (nat.mod_eq_of_lt _).symm,
rw nat.sub_lt_left_iff_lt_add h,
exact add_lt_add i.property i.property }
end

lemma vec_alt1_append (v : fin (n + 1) → α) : vec_alt1 rfl (fin.append rfl v v) = v ∘ bit1 :=
begin
ext i,
simp_rw [function.comp, vec_alt1, fin.append],
cases n,
{ simp, congr },
{ split_ifs with h; simp_rw [bit1, bit0]; congr,
{ rw fin.coe_mk at h,
simp only [fin.ext_iff, fin.coe_add, fin.coe_mk],
rw nat.mod_eq_of_lt (nat.lt_of_succ_lt h),
exact (nat.mod_eq_of_lt h).symm },
{ rw [fin.coe_mk, not_lt] at h,
simp only [fin.ext_iff, fin.coe_add, fin.coe_mk, nat.mod_add_mod, fin.coe_one,
nat.mod_eq_sub_mod h],
refine (nat.mod_eq_of_lt _).symm,
rw nat.sub_lt_left_iff_lt_add h,
exact nat.add_succ_lt_add i.property i.property } }
end

@[simp] lemma cons_vec_bit0_eq_alt0 (x : α) (u : fin n → α) (i : fin (n + 1)) :
vec_cons x u (bit0 i) = vec_alt0 rfl (fin.append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt0_append

@[simp] lemma cons_vec_bit1_eq_alt1 (x : α) (u : fin n → α) (i : fin (n + 1)) :
vec_cons x u (bit1 i) = vec_alt1 rfl (fin.append rfl (vec_cons x u) (vec_cons x u)) i :=
by rw vec_alt1_append

@[simp] lemma cons_vec_alt0 (h : m + 1 + 1 = (n + 1) + (n + 1)) (x y : α) (u : fin m → α) :
vec_alt0 h (vec_cons x (vec_cons y u)) = vec_cons x (vec_alt0
(by rwa [add_assoc n, add_comm 1, ←add_assoc, ←add_assoc, add_right_cancel_iff,
add_right_cancel_iff] at h) u) :=
begin
ext i,
simp_rw [vec_alt0],
rcases i with ⟨⟨⟩ | i, hi⟩,
{ refl },
{ simp [vec_alt0, nat.succ_add] }
end

-- Although proved by simp, extracting element 8 of a five-element
-- vector does not work by simp unless this lemma is present.
@[simp] lemma empty_vec_alt0 (α) {h} : vec_alt0 h (![] : fin 0 → α) = ![] :=
by simp

@[simp] lemma cons_vec_alt1 (h : m + 1 + 1 = (n + 1) + (n + 1)) (x y : α) (u : fin m → α) :
vec_alt1 h (vec_cons x (vec_cons y u)) = vec_cons y (vec_alt1
(by rwa [add_assoc n, add_comm 1, ←add_assoc, ←add_assoc, add_right_cancel_iff,
add_right_cancel_iff] at h) u) :=
begin
ext i,
simp_rw [vec_alt1],
rcases i with ⟨⟨⟩ | i, hi⟩,
{ refl },
{ simp [vec_alt1, nat.succ_add] }
end

-- Although proved by simp, extracting element 9 of a five-element
-- vector does not work by simp unless this lemma is present.
@[simp] lemma empty_vec_alt1 (α) {h} : vec_alt1 h (![] : fin 0 → α) = ![] :=
by simp

end val

section dot_product
Expand Down
28 changes: 28 additions & 0 deletions test/matrix.lean
Expand Up @@ -22,4 +22,32 @@ by simp
example {a b c d : α} : minor ![![a, b], ![c, d]] ![1, 0] ![0] = ![![c], ![a]] :=
by { ext, simp }

example {a b c : α} : ![a, b, c] 0 = a := by simp
example {a b c : α} : ![a, b, c] 1 = b := by simp
example {a b c : α} : ![a, b, c] 2 = c := by simp

example {a b c d : α} : ![a, b, c, d] 0 = a := by simp
example {a b c d : α} : ![a, b, c, d] 1 = b := by simp
example {a b c d : α} : ![a, b, c, d] 2 = c := by simp
example {a b c d : α} : ![a, b, c, d] 3 = d := by simp
example {a b c d : α} : ![a, b, c, d] 42 = c := by simp

example {a b c d e : α} : ![a, b, c, d, e] 0 = a := by simp
example {a b c d e : α} : ![a, b, c, d, e] 1 = b := by simp
example {a b c d e : α} : ![a, b, c, d, e] 2 = c := by simp
example {a b c d e : α} : ![a, b, c, d, e] 3 = d := by simp
example {a b c d e : α} : ![a, b, c, d, e] 4 = e := by simp
example {a b c d e : α} : ![a, b, c, d, e] 5 = a := by simp
example {a b c d e : α} : ![a, b, c, d, e] 6 = b := by simp
example {a b c d e : α} : ![a, b, c, d, e] 7 = c := by simp
example {a b c d e : α} : ![a, b, c, d, e] 8 = d := by simp
example {a b c d e : α} : ![a, b, c, d, e] 9 = e := by simp
example {a b c d e : α} : ![a, b, c, d, e] 123 = d := by simp
example {a b c d e : α} : ![a, b, c, d, e] 123456789 = e := by simp

example {a b c d e f g h : α} : ![a, b, c, d, e, f, g, h] 5 = f := by simp
example {a b c d e f g h : α} : ![a, b, c, d, e, f, g, h] 7 = h := by simp
example {a b c d e f g h : α} : ![a, b, c, d, e, f, g, h] 37 = f := by simp
example {a b c d e f g h : α} : ![a, b, c, d, e, f, g, h] 99 = d := by simp

end matrix

0 comments on commit b01ca81

Please sign in to comment.