Skip to content

Commit 99dfc6c

Browse files
feat(Algebra/BigOperators): simprocify prod_univ_one/two/three/... (#23425)
Add a simproc `prod_univ_many` that rewrites `∏ (i : Fin n), f i` as `f 0 * f 1 * ... * f (n - 1)`, generalizing `prod_univ_one`, `prod_univ_two`, ..., `prod_univ_eight`. Co-authored-by: Eric Wieser <wieser.eric@gmail.com>
1 parent a7411a8 commit 99dfc6c

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

Mathlib/Data/Fin/Tuple/Reflection.lean

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,86 @@ example [CommMonoid α] (a : Fin 3 → α) : ∏ i, a i = a 0 * a 1 * a 2 :=
164164
example [AddCommMonoid α] (a : Fin 3 → α) : ∑ i, a i = a 0 + a 1 + a 2 :=
165165
(sum_eq _).symm
166166

167+
section Meta
168+
open Lean Meta Qq
169+
170+
/-- Produce a term of the form `f 0 * f 1 * ... * f (n - 1)` and an application of `FinVec.prod_eq`
171+
that shows it is equal to `∏ i, f i`. -/
172+
def mkProdEqQ {u : Level} {α : Q(Type u)} (inst : Q(CommMonoid $α)) (n : ℕ) (f : Q(Fin $n → $α)) :
173+
MetaM <| (val : Q($α)) × Q(∏ i, $f i = $val) := do
174+
match n with
175+
| 0 => return ⟨q((1 : $α)), q(Fin.prod_univ_zero $f)⟩
176+
| m + 1 =>
177+
let nezero : Q(NeZero ($m + 1)) := q(⟨Nat.succ_ne_zero _⟩)
178+
let val ← makeRHS (m + 1) f nezero (m + 1)
179+
let _ : $val =Q FinVec.prod $f := ⟨⟩
180+
return ⟨q($val), q(FinVec.prod_eq $f |>.symm)⟩
181+
where
182+
/-- Creates the expression `f 0 * f 1 * ... * f (n - 1)`. -/
183+
makeRHS (n : ℕ) (f : Q(Fin $n → $α)) (nezero : Q(NeZero $n)) (k : ℕ) : MetaM Q($α) := do
184+
match k with
185+
| 0 => failure
186+
| 1 => pure q($f 0)
187+
| m + 1 =>
188+
let pre ← makeRHS n f nezero m
189+
let mRaw : Q(ℕ) := mkRawNatLit m
190+
pure q($pre * $f (OfNat.ofNat $mRaw))
191+
192+
/-- Produce a term of the form `f 0 + f 1 + ... + f (n - 1)` and an application of `FinVec.sum_eq`
193+
that shows it is equal to `∑ i, f i`. -/
194+
def mkSumEqQ {u : Level} {α : Q(Type u)} (inst : Q(AddCommMonoid $α)) (n : ℕ) (f : Q(Fin $n → $α)) :
195+
MetaM <| (val : Q($α)) × Q(∑ i, $f i = $val) := do
196+
match n with
197+
| 0 => return ⟨q((0 : $α)), q(Fin.sum_univ_zero $f)⟩
198+
| m + 1 =>
199+
let nezero : Q(NeZero ($m + 1)) := q(⟨Nat.succ_ne_zero _⟩)
200+
let val ← makeRHS (m + 1) f nezero (m + 1)
201+
let _ : $val =Q FinVec.sum $f := ⟨⟩
202+
return ⟨q($val), q(FinVec.sum_eq $f |>.symm)⟩
203+
where
204+
/-- Creates the expression `f 0 + f 1 + ... + f (n - 1)`. -/
205+
makeRHS (n : ℕ) (f : Q(Fin $n → $α)) (nezero : Q(NeZero $n)) (k : ℕ) : MetaM Q($α) := do
206+
match k with
207+
| 0 => failure
208+
| 1 => pure q($f 0)
209+
| m + 1 =>
210+
let pre ← makeRHS n f nezero m
211+
let mRaw : Q(ℕ) := mkRawNatLit m
212+
pure q($pre + $f (OfNat.ofNat $mRaw))
213+
214+
end Meta
215+
167216
end FinVec
217+
218+
namespace Fin
219+
open Qq Lean FinVec
220+
221+
/-- Rewrites `∏ i : Fin n, f i` as `f 0 * f 1 * ... * f (n - 1)` when `n` is a numeral. -/
222+
simproc_decl prod_univ_ofNat (∏ _ : Fin _, _) := .ofQ fun u _ e => do
223+
match u, e with
224+
| .succ _, ~q(@Finset.prod (Fin $n) _ $inst (@Finset.univ _ $instF) $f) => do
225+
match (generalizing := false) n.nat? with
226+
| .none =>
227+
return .continue
228+
| .some nVal =>
229+
let ⟨res, pf⟩ ← mkProdEqQ inst nVal f
230+
let ⟨_⟩ ← assertDefEqQ q($instF) q(Fin.fintype _)
231+
have _ : $n =Q $nVal := ⟨⟩
232+
return .visit <| .mk q($res) <| some q($pf)
233+
| _, _ => return .continue
234+
235+
/-- Rewrites `∑ i : Fin n, f i` as `f 0 + f 1 + ... + f (n - 1)` when `n` is a numeral. -/
236+
simproc_decl sum_univ_ofNat (∑ _ : Fin _, _) := .ofQ fun u _ e => do
237+
match u, e with
238+
| .succ _, ~q(@Finset.sum (Fin $n) _ $inst (@Finset.univ _ $instF) $f) => do
239+
match n.nat? with
240+
| .none =>
241+
return .continue
242+
| .some nVal =>
243+
let ⟨res, pf⟩ ← mkSumEqQ inst nVal f
244+
let ⟨_⟩ ← assertDefEqQ q($instF) q(Fin.fintype _)
245+
have _ : $n =Q $nVal := ⟨⟩
246+
return .visit <| .mk q($res) <| some q($pf)
247+
| _, _ => return .continue
248+
249+
end Fin
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import Mathlib.Data.Fin.Tuple.Reflection
2+
3+
@[to_additive]
4+
lemma prod_test (R : Type) [CommMonoid R] (f : Fin 10 → R) :
5+
∏ i, f i = f 0 * f 1 * f 2 * f 3 * f 4 * f 5 * f 6 * f 7 * f 8 * f 9 := by
6+
simp only [Fin.prod_univ_ofNat]
7+
8+
/--
9+
info: sum_test (R : Type) [AddCommMonoid R] (f : Fin 10 → R) :
10+
∑ i, f i = f 0 + f 1 + f 2 + f 3 + f 4 + f 5 + f 6 + f 7 + f 8 + f 9
11+
-/
12+
#guard_msgs in
13+
#check sum_test
14+
15+
example (R : Type) [AddCommMonoid R] (f : Fin 10 → R) :
16+
∑ i, f i = f 0 + f 1 + f 2 + f 3 + f 4 + f 5 + f 6 + f 7 + f 8 + f 9 := by
17+
simp only [Fin.sum_univ_ofNat]

0 commit comments

Comments
 (0)