Skip to content

Commit

Permalink
feat(algebra/to_additive): do not additivize operations on constant t…
Browse files Browse the repository at this point in the history
…ypes (#7792)

* Fixes #4210
* Adds a heuristic to `@[to_additive]` that decides which multiplicative identifiers to replace with their additive counterparts. 
* See [Zulip](https://leanprover.zulipchat.com/#narrow/stream/113488-general/topic/to_additive.20and.20fixed.20types) thread or documentation for the precise heuristic.
* We tag some types with `@[to_additive]`, so that they are handled correctly by the heurstic. These types `pempty`, `empty`, `unit` and `punit`.
* We make the following change to enable to above bullet point: you are allowed to translate a declaration to itself, only if you write its name again as argument of the attribute (if you don't specify an argument we want to raise an error, since that likely is a mistake).
* Because of this heuristic, all declarations with the `@[to_additive]` attribute should have a type with a multiplicative structure on it as its first argument. The first argument should not be an arbitrary indexing type. This means that in `finset.prod` and `finprod` we reorder the first two (implicit) arguments, so that the first argument is the codomain of the function.
* This will eliminate many (but not all) type mismatches generated by `@[to_additive]`.
* This heuristic doesn't catch all cases: for example, the declaration could have two type arguments with multiplicative structure, and the second one is `ℕ`, but the first one is a variable.



Co-authored-by: Floris van Doorn <fpv@andrew.cmu.edu>
  • Loading branch information
fpvandoorn and Floris van Doorn committed Jun 29, 2021
1 parent e6ec901 commit 2c749b1
Show file tree
Hide file tree
Showing 14 changed files with 229 additions and 121 deletions.
91 changes: 23 additions & 68 deletions src/algebra/big_operators/basic.lean
Expand Up @@ -27,10 +27,16 @@ Let `s` be a `finset α`, and `f : α → β` a function.
* `∑ x, f x` is notation for `finset.sum finset.univ f`
(assuming `α` is a `fintype` and `β` is an `add_comm_monoid`)
## Implementation Notes
The first arguments in all definitions and lemmas is the codomain of the function of the big
operator. This is necessary for the heuristic in `@[to_additive]`.
See the documentation of `to_additive.attr` for more information.
-/

universes u v w
variables {α : Type u} {β : Type v} {γ : Type w}
variables {β : Type u} {α : Type v} {γ : Type w}

namespace finset

Expand All @@ -42,7 +48,7 @@ as `x` ranges over the elements of the finite set `s`.
of the finite set `s`."]
protected def prod [comm_monoid β] (s : finset α) (f : α → β) : β := (s.1.map f).prod

@[simp, to_additive] lemma prod_mk [comm_monoid β] (s : multiset α) (hs) (f : α → β) :
@[simp, to_additive] lemma prod_mk [comm_monoid β] (s : multiset α) (hs : s.nodup) (f : α → β) :
(⟨s, hs⟩ : finset α).prod f = (s.map f).prod :=
rfl

Expand Down Expand Up @@ -88,7 +94,7 @@ variables {s s₁ s₂ : finset α} {a : α} {f g : α → β}

@[to_additive]
theorem prod_eq_fold [comm_monoid β] (s : finset α) (f : α → β) :
(∏ x in s, f x) = s.fold (*) 1 f :=
∏ x in s, f x = s.fold (*) 1 f :=
rfl

@[simp] lemma sum_multiset_singleton (s : finset α) :
Expand Down Expand Up @@ -157,7 +163,7 @@ section comm_monoid
variables [comm_monoid β]

@[simp, to_additive]
lemma prod_empty {α : Type u} {f : α → β} : (∏ x in (∅:finset α), f x) = 1 := rfl
lemma prod_empty {f : α → β} : (∏ x in (∅:finset α), f x) = 1 := rfl

@[simp, to_additive]
lemma prod_insert [decidable_eq α] : a ∉ s → (∏ x in (insert a s), f x) = f a * ∏ x in s, f x :=
Expand Down Expand Up @@ -197,9 +203,9 @@ by rw [prod_insert (not_mem_singleton.2 h), prod_singleton]

@[simp, priority 1100] lemma prod_const_one : (∏ x in s, (1 : β)) = 1 :=
by simp only [finset.prod, multiset.map_const, multiset.prod_repeat, one_pow]
@[simp, priority 1100] lemma sum_const_zero {β} {s : finset α} [add_comm_monoid β] :
@[simp, priority 1100] lemma sum_const_zero α} {s : finset α} [add_comm_monoid β] :
(∑ x in s, (0 : β)) = 0 :=
@prod_const_one _ (multiplicative β) _ _
@prod_const_one (multiplicative β) _ _ _
attribute [to_additive] prod_const_one

@[simp, to_additive]
Expand Down Expand Up @@ -740,14 +746,6 @@ begin
exact prod_congr rfl hfg
end

lemma sum_range_succ_comm {β} [add_comm_monoid β] (f : ℕ → β) (n : ℕ) :
∑ x in range (n + 1), f x = f n + ∑ x in range n, f x :=
by rw [range_succ, sum_insert not_mem_range_self]

lemma sum_range_succ {β} [add_comm_monoid β] (f : ℕ → β) (n : ℕ) :
∑ x in range (n + 1), f x = ∑ x in range n, f x + f n :=
by simp only [add_comm, sum_range_succ_comm]

@[to_additive]
lemma prod_range_succ_comm (f : ℕ → β) (n : ℕ) :
∏ x in range (n + 1), f x = f n * ∏ x in range n, f x :=
Expand All @@ -758,11 +756,13 @@ lemma prod_range_succ (f : ℕ → β) (n : ℕ) :
∏ x in range (n + 1), f x = (∏ x in range n, f x) * f n :=
by simp only [mul_comm, prod_range_succ_comm]

@[to_additive]
lemma prod_range_succ' (f : ℕ → β) :
∀ n : ℕ, (∏ k in range (n + 1), f k) = (∏ k in range n, f (k+1)) * f 0
| 0 := prod_range_succ _ _
| (n + 1) := by rw [prod_range_succ _ n, mul_right_comm, ← prod_range_succ', prod_range_succ]

@[to_additive]
lemma eventually_constant_prod {u : ℕ → β} {N : ℕ} (hu : ∀ n ≥ N, u n = 1) {n : ℕ} (hn : N ≤ n) :
∏ k in range (n + 1), u k = ∏ k in range (N + 1), u k :=
begin
Expand All @@ -774,13 +774,7 @@ begin
simp [hu]
end

lemma eventually_constant_sum {β} [add_comm_monoid β] {u : ℕ → β} {N : ℕ}
(hu : ∀ n ≥ N, u n = 0) {n : ℕ} (hn : N ≤ n) :
∑ k in range (n + 1), u k = ∑ k in range (N + 1), u k :=
@eventually_constant_prod (multiplicative β) _ _ _ hu _ hn

attribute [to_additive] eventually_constant_prod

@[to_additive]
lemma prod_range_add (f : ℕ → β) (n m : ℕ) :
∏ x in range (n + m), f x =
(∏ x in range n, f x) * (∏ x in range m, f (n + x)) :=
Expand All @@ -795,15 +789,10 @@ lemma prod_range_zero (f : ℕ → β) :
∏ k in range 0, f k = 1 :=
by rw [range_zero, prod_empty]

@[to_additive sum_range_one]
lemma prod_range_one (f : ℕ → β) :
∏ k in range 1, f k = f 0 :=
by { rw [range_one], apply @prod_singleton ℕ β 0 f }

lemma sum_range_one {δ : Type*} [add_comm_monoid δ] (f : ℕ → δ) :
∑ k in range 1, f k = f 0 :=
@prod_range_one (multiplicative δ) _ f

attribute [to_additive finset.sum_range_one] prod_range_one
by { rw [range_one], apply @prod_singleton β ℕ 0 f }

open multiset

Expand Down Expand Up @@ -938,7 +927,7 @@ lemma prod_pow (s : finset α) (n : ℕ) (f : α → β) :
by haveI := classical.dec_eq α; exact
finset.induction_on s (by simp) (by simp [mul_pow] {contextual := tt})

-- `to_additive` fails on this lemma, so we prove it manually below
@[to_additive]
lemma prod_flip {n : ℕ} (f : ℕ → β) :
∏ r in range (n + 1), f (n - r) = ∏ k in range (n + 1), f k :=
begin
Expand Down Expand Up @@ -1068,6 +1057,8 @@ by { rw [update_eq_piecewise, prod_piecewise], simp [h] }

/-- If a product of a `finset` of size at most 1 has a given value, so
do the terms in that product. -/
@[to_additive eq_of_card_le_one_of_sum_eq "If a sum of a `finset` of size at most 1 has a given
value, so do the terms in that sum."]
lemma eq_of_card_le_one_of_prod_eq {s : finset α} (hc : s.card ≤ 1) {f : α → β} {b : β}
(h : ∏ x in s, f x = b) : ∀ x ∈ s, f x = b :=
begin
Expand All @@ -1084,26 +1075,6 @@ begin
exact h }
end

/-- If a sum of a `finset` of size at most 1 has a given value, so do
the terms in that sum. -/
lemma eq_of_card_le_one_of_sum_eq [add_comm_monoid γ] {s : finset α} (hc : s.card ≤ 1)
{f : α → γ} {b : γ} (h : ∑ x in s, f x = b) : ∀ x ∈ s, f x = b :=
begin
intros x hx,
by_cases hc0 : s.card = 0,
{ exact false.elim (card_ne_zero_of_mem hx hc0) },
{ have h1 : s.card = 1 := le_antisymm hc (nat.one_le_of_lt (nat.pos_of_ne_zero hc0)),
rw card_eq_one at h1,
cases h1 with x2 hx2,
rw [hx2, mem_singleton] at hx,
simp_rw hx2 at h,
rw hx,
rw sum_singleton at h,
exact h }
end

attribute [to_additive eq_of_card_le_one_of_sum_eq] eq_of_card_le_one_of_prod_eq

/-- If a function applied at a point is 1, a product is unchanged by
removing that point, if present, from a `finset`. -/
@[to_additive "If a function applied at a point is 0, a sum is unchanged by
Expand Down Expand Up @@ -1158,12 +1129,12 @@ attribute [to_additive] prod_update_of_mem

lemma sum_nsmul [add_comm_monoid β] (s : finset α) (n : ℕ) (f : α → β) :
(∑ x in s, n • (f x)) = n • ((∑ x in s, f x)) :=
@prod_pow _ (multiplicative β) _ _ _ _
@prod_pow (multiplicative β) _ _ _ _ _
attribute [to_additive sum_nsmul] prod_pow

@[simp] lemma sum_const [add_comm_monoid β] (b : β) :
(∑ x in s, b) = s.card • b :=
@prod_const _ (multiplicative β) _ _ _
@prod_const (multiplicative β) _ _ _ _
attribute [to_additive] prod_const

lemma card_eq_sum_ones (s : finset α) : s.card = ∑ _ in s, 1 :=
Expand Down Expand Up @@ -1193,15 +1164,10 @@ lemma sum_int_cast [add_comm_group β] [has_one β] (s : finset α) (f : α →

lemma sum_comp [add_comm_monoid β] [decidable_eq γ] {s : finset α} (f : γ → β) (g : α → γ) :
∑ a in s, f (g a) = ∑ b in s.image g, (s.filter (λ a, g a = b)).card • (f b) :=
@prod_comp _ (multiplicative β) _ _ _ _ _ _
@prod_comp (multiplicative β) _ _ _ _ _ _ _
attribute [to_additive "The sum of the composition of functions `f` and `g`, is the sum
over `b ∈ s.image g` of `f b` times of the cardinality of the fibre of `b`"] prod_comp

lemma sum_range_succ' [add_comm_monoid β] (f : ℕ → β) :
∀ n : ℕ, (∑ i in range (n + 1), f i) = (∑ i in range n, f (i + 1)) + f 0 :=
@prod_range_succ' (multiplicative β) _ _
attribute [to_additive] prod_range_succ'

lemma eq_sum_range_sub [add_comm_group β] (f : ℕ → β) (n : ℕ) :
f n = f 0 + ∑ i in range n, (f (i+1) - f i) :=
by { rw finset.sum_range_sub, abel }
Expand All @@ -1213,17 +1179,6 @@ begin
simp [finset.sum_range_succ', add_comm]
end

lemma sum_range_add {β} [add_comm_monoid β] (f : ℕ → β) (n : ℕ) (m : ℕ) :
(∑ x in range (n + m), f x) =
(∑ x in range n, f x) + (∑ x in range m, f (n + x)) :=
@prod_range_add (multiplicative β) _ _ _ _
attribute [to_additive] prod_range_add

lemma sum_flip [add_comm_monoid β] {n : ℕ} (f : ℕ → β) :
(∑ i in range (n + 1), f (n - i)) = (∑ i in range (n + 1), f i) :=
@prod_flip (multiplicative β) _ _ _
attribute [to_additive] prod_flip

section opposite

open opposite
Expand Down
10 changes: 7 additions & 3 deletions src/algebra/big_operators/finprod.lean
Expand Up @@ -59,6 +59,10 @@ Another application is the construction of a partition of unity from a collectio
function. In this case the finite set depends on the point and it's convenient to have a definition
that does not mention the set explicitly.
The first arguments in all definitions and lemmas is the codomain of the function of the big
operator. This is necessary for the heuristic in `@[to_additive]`.
See the documentation of `to_additive.attr` for more information.
## Todo
We did not add `is_finite (X : Type) : Prop`, because it is simply `nonempty (fintype X)`.
Expand All @@ -77,7 +81,7 @@ open function set
-/

section sort
variables {α β ι : Sort*} {M N : Type*} [comm_monoid M] [comm_monoid N]
variables {M N : Type*} {α β ι : Sort*} [comm_monoid M] [comm_monoid N]

open_locale big_operators

Expand All @@ -89,7 +93,7 @@ open_locale classical

/-- Sum of `f x` as `x` ranges over the elements of the support of `f`, if it's finite. Zero
otherwise. -/
@[irreducible] noncomputable def finsum {M} [add_comm_monoid M] (f : α → M) : M :=
@[irreducible] noncomputable def finsum {M α} [add_comm_monoid M] (f : α → M) : M :=
if h : finite (support (f ∘ plift.down)) then ∑ i in h.to_finset, f i.down else 0

/-- Product of `f x` as `x` ranges over the elements of the multiplicative support of `f`, if it's
Expand Down Expand Up @@ -143,7 +147,7 @@ begin
end

@[simp, to_additive] lemma finprod_true (f : true → M) : ∏ᶠ i, f i = f trivial :=
@finprod_unique true M _ ⟨⟨trivial⟩, λ _, rfl⟩ f
@finprod_unique M true _ ⟨⟨trivial⟩, λ _, rfl⟩ f

@[to_additive] lemma finprod_eq_dif {p : Prop} [decidable p] (f : p → M) :
∏ᶠ i, f i = if h : p then f h else 1 :=
Expand Down
7 changes: 5 additions & 2 deletions src/algebra/category/Group/basic.lean
Expand Up @@ -33,7 +33,8 @@ namespace Group
@[to_additive]
instance : bundled_hom.parent_projection group.to_monoid := ⟨⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] Group AddGroup
attribute [derive [has_coe_to_sort, large_category, concrete_category]] Group
attribute [to_additive] Group.has_coe_to_sort Group.large_category Group.concrete_category

/-- Construct a bundled `Group` from the underlying type and typeclass. -/
@[to_additive] def of (X : Type u) [group X] : Group := bundled.of X
Expand Down Expand Up @@ -87,7 +88,9 @@ namespace CommGroup
@[to_additive]
instance : bundled_hom.parent_projection comm_group.to_group := ⟨⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] CommGroup AddCommGroup
attribute [derive [has_coe_to_sort, large_category, concrete_category]] CommGroup
attribute [to_additive] CommGroup.has_coe_to_sort CommGroup.large_category
CommGroup.concrete_category

/-- Construct a bundled `CommGroup` from the underlying type and typeclass. -/
@[to_additive] def of (G : Type u) [comm_group G] : CommGroup := bundled.of G
Expand Down
6 changes: 4 additions & 2 deletions src/algebra/category/Mon/basic.lean
Expand Up @@ -45,7 +45,8 @@ instance bundled_hom : bundled_hom assoc_monoid_hom :=
λ M N P [monoid M] [monoid N] [monoid P], by exactI @monoid_hom.comp M N P _ _ _,
λ M N [monoid M] [monoid N], by exactI @monoid_hom.coe_inj M N _ _⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] Mon AddMon
attribute [derive [has_coe_to_sort, large_category, concrete_category]] Mon
attribute [to_additive] Mon.has_coe_to_sort Mon.large_category Mon.concrete_category

/-- Construct a bundled `Mon` from the underlying type and typeclass. -/
@[to_additive]
Expand Down Expand Up @@ -79,7 +80,8 @@ namespace CommMon
@[to_additive]
instance : bundled_hom.parent_projection comm_monoid.to_monoid := ⟨⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] CommMon AddCommMon
attribute [derive [has_coe_to_sort, large_category, concrete_category]] CommMon
attribute [to_additive] CommMon.has_coe_to_sort CommMon.large_category CommMon.concrete_category

/-- Construct a bundled `CommMon` from the underlying type and typeclass. -/
@[to_additive]
Expand Down
7 changes: 5 additions & 2 deletions src/algebra/category/Semigroup/basic.lean
Expand Up @@ -42,7 +42,8 @@ namespace Magma
instance bundled_hom : bundled_hom @mul_hom :=
⟨@mul_hom.to_fun, @mul_hom.id, @mul_hom.comp, @mul_hom.coe_inj⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] Magma AddMagma
attribute [derive [has_coe_to_sort, large_category, concrete_category]] Magma
attribute [to_additive] Magma.has_coe_to_sort Magma.large_category Magma.concrete_category

/-- Construct a bundled `Magma` from the underlying type and typeclass. -/
@[to_additive]
Expand Down Expand Up @@ -73,7 +74,9 @@ namespace Semigroup
@[to_additive]
instance : bundled_hom.parent_projection semigroup.to_has_mul := ⟨⟩

attribute [derive [has_coe_to_sort, large_category, concrete_category]] Semigroup AddSemigroup
attribute [derive [has_coe_to_sort, large_category, concrete_category]] Semigroup
attribute [to_additive] Semigroup.has_coe_to_sort Semigroup.large_category
Semigroup.concrete_category

/-- Construct a bundled `Semigroup` from the underlying type and typeclass. -/
@[to_additive]
Expand Down

0 comments on commit 2c749b1

Please sign in to comment.