From e2b33abb01af7fd385dd8df535e9d962b14e681d Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Mon, 31 Oct 2022 12:47:31 +0000 Subject: [PATCH] refactor(data/fintype/basic): simplify the defeq of `sum.fintype` (#17236) Using `finset.disj_sum` instead of `finset.union` removes a handful of proof obligations, and makes some results defeq. This removes the generalization on `univ_sum_type` that includes handling `fintype.subsingleton`, as it probably isn't useful; but the new version holds without decidable equality. This removes `finset.prod_on_sum` which is a duplicate of `fintype.prod_sum_type` --- src/algebra/big_operators/basic.lean | 37 ++++++---------------------- src/algebra/big_operators/fin.lean | 2 +- src/analysis/convex/combination.lean | 10 +++----- src/data/finset/sum.lean | 10 ++++++++ src/data/fintype/basic.lean | 29 ++++++---------------- src/data/fintype/card.lean | 6 ++--- 6 files changed, 34 insertions(+), 60 deletions(-) diff --git a/src/algebra/big_operators/basic.lean b/src/algebra/big_operators/basic.lean index 451151089d152..b21f23f8c341a 100644 --- a/src/algebra/big_operators/basic.lean +++ b/src/algebra/big_operators/basic.lean @@ -349,38 +349,17 @@ lemma prod_sdiff [decidable_eq α] (h : s₁ ⊆ s₂) : by rw [←prod_union sdiff_disjoint, sdiff_union_of_subset h] @[simp, to_additive] -lemma prod_sum_elim [decidable_eq (α ⊕ γ)] - (s : finset α) (t : finset γ) (f : α → β) (g : γ → β) : - ∏ x in s.map function.embedding.inl ∪ t.map function.embedding.inr, sum.elim f g x = - (∏ x in s, f x) * (∏ x in t, g x) := +lemma prod_disj_sum (s : finset α) (t : finset γ) (f : α ⊕ γ → β) : + ∏ x in s.disj_sum t, f x = (∏ x in s, f (sum.inl x)) * (∏ x in t, f (sum.inr x)) := begin - rw [prod_union, prod_map, prod_map], - { simp only [sum.elim_inl, function.embedding.inl_apply, function.embedding.inr_apply, - sum.elim_inr] }, - { simp only [disjoint_left, finset.mem_map, finset.mem_map], - rintros _ ⟨i, hi, rfl⟩ ⟨j, hj, H⟩, - cases H } + rw [←map_inl_disj_union_map_inr, prod_disj_union, prod_map, prod_map], + refl, end -@[simp, to_additive] -lemma prod_on_sum [fintype α] [fintype γ] (f : α ⊕ γ → β) : - ∏ (x : α ⊕ γ), f x = - (∏ (x : α), f (sum.inl x)) * (∏ (x : γ), f (sum.inr x)) := -begin - haveI := classical.dec_eq (α ⊕ γ), - convert prod_sum_elim univ univ (λ x, f (sum.inl x)) (λ x, f (sum.inr x)), - { ext a, - split, - { intro x, - cases a, - { simp only [mem_union, mem_map, mem_univ, function.embedding.inl_apply, or_false, - exists_true_left, exists_apply_eq_apply, function.embedding.inr_apply, exists_false], }, - { simp only [mem_union, mem_map, mem_univ, function.embedding.inl_apply, false_or, - exists_true_left, exists_false, function.embedding.inr_apply, - exists_apply_eq_apply], }, }, - { simp only [mem_univ, implies_true_iff], }, }, - { simp only [sum.elim_comp_inl_inr], }, -end +@[to_additive] +lemma prod_sum_elim (s : finset α) (t : finset γ) (f : α → β) (g : γ → β) : + ∏ x in s.disj_sum t, sum.elim f g x = (∏ x in s, f x) * (∏ x in t, g x) := +by simp @[to_additive] lemma prod_bUnion [decidable_eq α] {s : finset γ} {t : γ → finset α} diff --git a/src/algebra/big_operators/fin.lean b/src/algebra/big_operators/fin.lean index 13040f2bf89b1..52af1c42c6cfb 100644 --- a/src/algebra/big_operators/fin.lean +++ b/src/algebra/big_operators/fin.lean @@ -146,7 +146,7 @@ begin rw fintype.prod_equiv fin_sum_fin_equiv.symm f (λ i, f (fin_sum_fin_equiv.to_fun i)), swap, { intro x, simp only [equiv.to_fun_as_coe, equiv.apply_symm_apply], }, - apply prod_on_sum, + apply fintype.prod_sum_type, end @[to_additive] diff --git a/src/analysis/convex/combination.lean b/src/analysis/convex/combination.lean index cc139723d51c8..22e67ec1acad2 100644 --- a/src/analysis/convex/combination.lean +++ b/src/analysis/convex/combination.lean @@ -73,9 +73,7 @@ lemma finset.center_mass_segment' (s : finset ι) (t : finset ι') (ws : ι → R) (zs : ι → E) (wt : ι' → R) (zt : ι' → E) (hws : ∑ i in s, ws i = 1) (hwt : ∑ i in t, wt i = 1) (a b : R) (hab : a + b = 1) : a • s.center_mass ws zs + b • t.center_mass wt zt = - (s.map embedding.inl ∪ t.map embedding.inr).center_mass - (sum.elim (λ i, a * ws i) (λ j, b * wt j)) - (sum.elim zs zt) := + (s.disj_sum t).center_mass (sum.elim (λ i, a * ws i) (λ j, b * wt j)) (sum.elim zs zt) := begin rw [s.center_mass_eq_of_sum_1 _ hws, t.center_mass_eq_of_sum_1 _ hwt, smul_sum, smul_sum, ← finset.sum_sum_elim, finset.center_mass_eq_of_sum_1], @@ -287,13 +285,13 @@ begin rw [finset.center_mass_segment' _ _ _ _ _ _ hwx₁ hwy₁ _ _ hab], refine ⟨_, _, _, _, _, _, _, rfl⟩, { rintros i hi, - rw [finset.mem_union, finset.mem_map, finset.mem_map] at hi, + rw [finset.mem_disj_sum] at hi, rcases hi with ⟨j, hj, rfl⟩|⟨j, hj, rfl⟩; simp only [sum.elim_inl, sum.elim_inr]; apply_rules [mul_nonneg, hwx₀, hwy₀] }, - { simp [finset.sum_sum_elim, finset.mul_sum.symm, *] }, + { simp [finset.sum_sum_elim, finset.mul_sum.symm, *], }, { intros i hi, - rw [finset.mem_union, finset.mem_map, finset.mem_map] at hi, + rw [finset.mem_disj_sum] at hi, rcases hi with ⟨j, hj, rfl⟩|⟨j, hj, rfl⟩; apply_rules [hzx, hzy] } }, { rintros _ ⟨ι, t, w, z, hw₀, hw₁, hz, rfl⟩, exact t.center_mass_mem_convex_hull hw₀ (hw₁.symm ▸ zero_lt_one) hz } diff --git a/src/data/finset/sum.lean b/src/data/finset/sum.lean index 8933dd3cdc3e6..acae13b4c23e8 100644 --- a/src/data/finset/sum.lean +++ b/src/data/finset/sum.lean @@ -34,6 +34,16 @@ val_inj.1 $ multiset.disj_sum_zero _ @[simp] lemma card_disj_sum : (s.disj_sum t).card = s.card + t.card := multiset.card_disj_sum _ _ +/-- Note that this is not stated with `disjoint` so that it can be used with `finset.disj_union`. -/ +lemma disjoint_map_inl_map_inr {α β : Type*} (s : finset α) (t : finset β) (a : α ⊕ β) : + a ∈ (s.map embedding.inl : finset (α ⊕ β)) → a ∉ (t.map embedding.inr : finset (α ⊕ β)) := +by { simp_rw mem_map, rintro ⟨a, _, rfl⟩ ⟨b, _, ⟨⟩⟩ } + +@[simp] +lemma map_inl_disj_union_map_inr : + (s.map embedding.inl).disj_union (t.map embedding.inr) (disjoint_map_inl_map_inr _ _) = + s.disj_sum t := rfl + variables {s t} {s₁ s₂ : finset α} {t₁ t₂ : finset β} {a : α} {b : β} {x : α ⊕ β} lemma mem_disj_sum : x ∈ s.disj_sum t ↔ (∃ a, a ∈ s ∧ inl a = x) ∨ ∃ b, b ∈ t ∧ inr b = x := diff --git a/src/data/fintype/basic.lean b/src/data/fintype/basic.lean index 9554897591a95..75388787661ff 100644 --- a/src/data/fintype/basic.lean +++ b/src/data/fintype/basic.lean @@ -10,6 +10,7 @@ import data.finset.pi import data.finset.powerset import data.finset.prod import data.finset.sigma +import data.finset.sum import data.finite.defs import data.list.nodup_equiv_fin import data.sym.basic @@ -975,19 +976,13 @@ instance (α : Type*) [fintype α] : fintype (lex α) := ‹fintype α› @[simp] lemma fintype.card_lex (α : Type*) [fintype α] : fintype.card (lex α) = fintype.card α := rfl -lemma univ_sum_type {α β : Type*} [fintype α] [fintype β] [fintype (α ⊕ β)] [decidable_eq (α ⊕ β)] : - (univ : finset (α ⊕ β)) = map function.embedding.inl univ ∪ map function.embedding.inr univ := -begin - rw [eq_comm, eq_univ_iff_forall], simp only [mem_union, mem_map, exists_prop, mem_univ, true_and], - rintro (x|y), exacts [or.inl ⟨x, rfl⟩, or.inr ⟨y, rfl⟩] -end - instance (α : Type u) (β : Type v) [fintype α] [fintype β] : fintype (α ⊕ β) := -@fintype.of_equiv _ _ (@sigma.fintype _ - (λ b, cond b (ulift α) (ulift.{(max u v) v} β)) _ - (λ b, by cases b; apply ulift.fintype)) - ((equiv.sum_equiv_sigma_bool _ _).symm.trans - (equiv.sum_congr equiv.ulift equiv.ulift)) +{ elems := univ.disj_sum univ, + complete := by rintro (_ | _); simp } + +@[simp] lemma finset.univ_disj_sum_univ {α β : Type*} [fintype α] [fintype β] : + univ.disj_sum univ = (univ : finset (α ⊕ β)) := +rfl /-- Given that `α ⊕ β` is a fintype, `α` is also a fintype. This is non-computable as it uses that `sum.inl` is an injection, but there's no clear inverse if `α` is empty. -/ @@ -1001,15 +996,7 @@ fintype.of_injective (sum.inr : β → α ⊕ β) sum.inr_injective @[simp] theorem fintype.card_sum [fintype α] [fintype β] : fintype.card (α ⊕ β) = fintype.card α + fintype.card β := -begin - classical, - rw [←finset.card_univ, univ_sum_type, finset.card_union_eq], - { simp [finset.card_univ] }, - { intros x hx, - rsuffices ⟨⟨a, rfl⟩, ⟨b, hb⟩⟩ : (∃ (a : α), sum.inl a = x) ∧ ∃ (b : β), sum.inr b = x, - { simpa using hb }, - simpa using hx } -end +card_disj_sum _ _ /-- If the subtype of all-but-one elements is a `fintype` then the type itself is a `fintype`. -/ def fintype_of_fintype_ne (a : α) (h : fintype {b // b ≠ a}) : fintype α := diff --git a/src/data/fintype/card.lean b/src/data/fintype/card.lean index 8d5532a2a5ee6..336770c457b41 100644 --- a/src/data/fintype/card.lean +++ b/src/data/fintype/card.lean @@ -241,11 +241,11 @@ variables {α₁ : Type*} {α₂ : Type*} {M : Type*} [fintype α₁] [fintype @[to_additive] lemma fintype.prod_sum_elim (f : α₁ → M) (g : α₂ → M) : (∏ x, sum.elim f g x) = (∏ a₁, f a₁) * (∏ a₂, g a₂) := -by { classical, rw [univ_sum_type, prod_sum_elim] } +prod_disj_sum _ _ _ -@[to_additive] +@[simp, to_additive] lemma fintype.prod_sum_type (f : α₁ ⊕ α₂ → M) : (∏ x, f x) = (∏ a₁, f (sum.inl a₁)) * (∏ a₂, f (sum.inr a₂)) := -by simp only [← fintype.prod_sum_elim, sum.elim_comp_inl_inr] +prod_disj_sum _ _ _ end