From 12515db655a0c4cc759ef5e368fc6d8e592ef983 Mon Sep 17 00:00:00 2001 From: Devon Tuma Date: Sun, 24 Oct 2021 22:52:49 +0000 Subject: [PATCH] feat(data/list): product of list.update_nth in terms of inverses (#9800) Expression for the product of `l.update_nth n x` in terms of inverses instead of `take` and `drop`, assuming a group instead of a monoid. --- src/data/list/basic.lean | 42 +++++++++++++++++++++++--------------- src/data/list/zip.lean | 22 ++++++++++++++++++++ src/data/vector/basic.lean | 24 ++++++++++++++++++++++ src/data/vector/zip.lean | 6 ++++++ 4 files changed, 77 insertions(+), 17 deletions(-) diff --git a/src/data/list/basic.lean b/src/data/list/basic.lean index 5321eb2ffa187..8419e641718b7 100644 --- a/src/data/list/basic.lean +++ b/src/data/list/basic.lean @@ -2463,15 +2463,14 @@ begin exact is_unit.mul (u h (mem_cons_self h t)) (prod_is_unit (λ m mt, u m (mem_cons_of_mem h mt))) end --- `to_additive` chokes on the next few lemmas, so we do them by hand below -@[simp] +@[simp, to_additive] lemma prod_take_mul_prod_drop : ∀ (L : list α) (i : ℕ), (L.take i).prod * (L.drop i).prod = L.prod | [] i := by simp | L 0 := by simp | (h :: t) (n+1) := by { dsimp, rw [prod_cons, prod_cons, mul_assoc, prod_take_mul_prod_drop], } -@[simp] +@[simp, to_additive] lemma prod_take_succ : ∀ (L : list α) (i : ℕ) (p), (L.take (i + 1)).prod = (L.take i).prod * L.nth_le i p | [] i p := by cases p @@ -2482,6 +2481,7 @@ lemma prod_take_succ : lemma length_pos_of_prod_ne_one (L : list α) (h : L.prod ≠ 1) : 0 < L.length := by { cases L, { simp at h, cases h, }, { simp, }, } +@[to_additive] lemma prod_update_nth : ∀ (L : list α) (n : ℕ) (a : α), (L.update_nth n a).prod = (L.take n).prod * (if n < L.length then a else 1) * (L.drop (n + 1)).prod @@ -2505,6 +2505,14 @@ lemma prod_inv_reverse : ∀ (L : list α), L.prod⁻¹ = (L.map (λ x, x⁻¹)) lemma prod_reverse_noncomm : ∀ (L : list α), L.reverse.prod = (L.map (λ x, x⁻¹)).prod⁻¹ := by simp [prod_inv_reverse] +/-- Counterpart to `list.prod_take_succ` when we have an inverse operation -/ +@[simp, to_additive /-"Counterpart to `list.sum_take_succ` when we have an negation operation"-/] +lemma prod_drop_succ : + ∀ (L : list α) (i : ℕ) (p), (L.drop (i + 1)).prod = (L.nth_le i p)⁻¹ * (L.drop i).prod +| [] i p := false.elim (nat.not_lt_zero _ p) +| (x :: xs) 0 p := by simp +| (x :: xs) (i + 1) p := prod_drop_succ xs i _ + end group section comm_group @@ -2516,21 +2524,21 @@ lemma prod_inv : ∀ (L : list α), L.prod⁻¹ = (L.map (λ x, x⁻¹)).prod | [] := by simp | (x :: xs) := by simp [mul_comm, prod_inv xs] -end comm_group - -@[simp] -lemma sum_take_add_sum_drop [add_monoid α] : - ∀ (L : list α) (i : ℕ), (L.take i).sum + (L.drop i).sum = L.sum -| [] i := by simp -| L 0 := by simp -| (h :: t) (n+1) := by { dsimp, rw [sum_cons, sum_cons, add_assoc, sum_take_add_sum_drop], } +/-- Alternative version of `list.prod_update_nth` when the list is over a group -/ +@[to_additive /-"Alternative version of `list.sum_update_nth` when the list is over a group"-/] +lemma prod_update_nth' (L : list α) (n : ℕ) (a : α) : + (L.update_nth n a).prod = + L.prod * (if hn : n < L.length then (L.nth_le n hn)⁻¹ * a else 1) := +begin + refine (prod_update_nth L n a).trans _, + split_ifs with hn hn, + { rw [mul_comm _ a, mul_assoc a, prod_drop_succ L n hn, mul_comm _ (drop n L).prod, + ← mul_assoc (take n L).prod, prod_take_mul_prod_drop, mul_comm a, mul_assoc] }, + { simp only [take_all_of_le (le_of_not_lt hn), prod_nil, mul_one, + drop_eq_nil_of_le ((le_of_not_lt hn).trans n.le_succ)] } +end -@[simp] -lemma sum_take_succ [add_monoid α] : - ∀ (L : list α) (i : ℕ) (p), (L.take (i + 1)).sum = (L.take i).sum + L.nth_le i p -| [] i p := by cases p -| (h :: t) 0 _ := by simp -| (h :: t) (n+1) _ := by { dsimp, rw [sum_cons, sum_cons, sum_take_succ, add_assoc], } +end comm_group lemma eq_of_sum_take_eq [add_left_cancel_monoid α] {L L' : list α} (h : L.length = L'.length) (h' : ∀ i ≤ L.length, (L.take i).sum = (L'.take i).sum) : L = L' := diff --git a/src/data/list/zip.lean b/src/data/list/zip.lean index 13782b9af77d1..1f7603843ab2a 100644 --- a/src/data/list/zip.lean +++ b/src/data/list/zip.lean @@ -381,4 +381,26 @@ end end distrib +section comm_monoid + +variables [comm_monoid α] + +@[to_additive] +lemma prod_mul_prod_eq_prod_zip_with_mul_prod_drop : ∀ (L L' : list α), L.prod * L'.prod = + (zip_with (*) L L').prod * (L.drop L'.length).prod * (L'.drop L.length).prod +| [] ys := by simp +| xs [] := by simp +| (x :: xs) (y :: ys) := begin + simp only [drop, length, zip_with_cons_cons, prod_cons], + rw [mul_assoc x, mul_comm xs.prod, mul_assoc y, mul_comm ys.prod, + prod_mul_prod_eq_prod_zip_with_mul_prod_drop xs ys, mul_assoc, mul_assoc, mul_assoc, mul_assoc] +end + +@[to_additive] +lemma prod_mul_prod_eq_prod_zip_with_of_length_eq (L L' : list α) (h : L.length = L'.length) : + L.prod * L'.prod = (zip_with (*) L L').prod := +(prod_mul_prod_eq_prod_zip_with_mul_prod_drop L L').trans (by simp [h]) + +end comm_monoid + end list diff --git a/src/data/vector/basic.lean b/src/data/vector/basic.lean index 02b1d93927468..1bef6b9ece070 100644 --- a/src/data/vector/basic.lean +++ b/src/data/vector/basic.lean @@ -441,6 +441,10 @@ section update_nth def update_nth (v : vector α n) (i : fin n) (a : α) : vector α n := ⟨v.1.update_nth i.1 a, by rw [list.update_nth_length, v.2]⟩ +@[simp] lemma to_list_update_nth (v : vector α n) (i : fin n) (a : α) : + (v.update_nth i a).to_list = v.to_list.update_nth i a := +rfl + @[simp] lemma nth_update_nth_same (v : vector α n) (i : fin n) (a : α) : (v.update_nth i a).nth i = a := by cases v; cases i; simp [vector.update_nth, vector.nth_eq_nth_le] @@ -454,6 +458,26 @@ lemma nth_update_nth_eq_if {v : vector α n} {i j : fin n} (a : α) : (v.update_nth i a).nth j = if i = j then a else v.nth j := by split_ifs; try {simp *}; try {rw nth_update_nth_of_ne}; assumption +@[to_additive] +lemma prod_update_nth [monoid α] (v : vector α n) (i : fin n) (a : α) : + (v.update_nth i a).to_list.prod = + (v.take i).to_list.prod * a * (v.drop (i + 1)).to_list.prod := +begin + refine (list.prod_update_nth v.to_list i a).trans _, + have : ↑i < v.to_list.length := lt_of_lt_of_le i.2 (le_of_eq v.2.symm), + simp [this], +end + +@[to_additive] +lemma prod_update_nth' [comm_group α] (v : vector α n) (i : fin n) (a : α) : + (v.update_nth i a).to_list.prod = + v.to_list.prod * (v.nth i)⁻¹ * a := +begin + refine (list.prod_update_nth' v.to_list i a).trans _, + have : ↑i < v.to_list.length := lt_of_lt_of_le i.2 (le_of_eq v.2.symm), + simp [this, nth_eq_nth_le, mul_assoc], +end + end update_nth end vector diff --git a/src/data/vector/zip.lean b/src/data/vector/zip.lean index 7b4bb6dbae8a5..e47bc9ee582ad 100644 --- a/src/data/vector/zip.lean +++ b/src/data/vector/zip.lean @@ -40,6 +40,12 @@ lemma zip_with_tail (x : vector α n) (y : vector β n) : (vector.zip_with f x y).tail = vector.zip_with f x.tail y.tail := by { ext, simp [nth_tail], } +@[to_additive] +lemma prod_mul_prod_eq_prod_zip_with [comm_monoid α] (x y : vector α n) : + x.to_list.prod * y.to_list.prod = (vector.zip_with (*) x y).to_list.prod := +list.prod_mul_prod_eq_prod_zip_with_of_length_eq x.to_list y.to_list + ((to_list_length x).trans (to_list_length y).symm) + end zip_with end vector