Skip to content

Commit

Permalink
refactor(data/fintype/basic): simplify the defeq of sum.fintype (#1…
Browse files Browse the repository at this point in the history
…7236)

Using `finset.disj_sum` instead of `finset.union` removes a handful of proof obligations, and makes some results defeq.

This removes the generalization on `univ_sum_type` that includes handling `fintype.subsingleton`, as it probably isn't useful; but the new version holds without decidable equality.

This removes `finset.prod_on_sum` which is a duplicate of `fintype.prod_sum_type`
  • Loading branch information
eric-wieser committed Oct 31, 2022
1 parent 88bad67 commit e2b33ab
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 60 deletions.
37 changes: 8 additions & 29 deletions src/algebra/big_operators/basic.lean
Expand Up @@ -349,38 +349,17 @@ lemma prod_sdiff [decidable_eq α] (h : s₁ ⊆ s₂) :
by rw [←prod_union sdiff_disjoint, sdiff_union_of_subset h]

@[simp, to_additive]
lemma prod_sum_elim [decidable_eq (α ⊕ γ)]
(s : finset α) (t : finset γ) (f : α → β) (g : γ → β) :
∏ x in s.map function.embedding.inl ∪ t.map function.embedding.inr, sum.elim f g x =
(∏ x in s, f x) * (∏ x in t, g x) :=
lemma prod_disj_sum (s : finset α) (t : finset γ) (f : α ⊕ γ → β) :
∏ x in s.disj_sum t, f x = (∏ x in s, f (sum.inl x)) * (∏ x in t, f (sum.inr x)) :=
begin
rw [prod_union, prod_map, prod_map],
{ simp only [sum.elim_inl, function.embedding.inl_apply, function.embedding.inr_apply,
sum.elim_inr] },
{ simp only [disjoint_left, finset.mem_map, finset.mem_map],
rintros _ ⟨i, hi, rfl⟩ ⟨j, hj, H⟩,
cases H }
rw [←map_inl_disj_union_map_inr, prod_disj_union, prod_map, prod_map],
refl,
end

@[simp, to_additive]
lemma prod_on_sum [fintype α] [fintype γ] (f : α ⊕ γ → β) :
∏ (x : α ⊕ γ), f x =
(∏ (x : α), f (sum.inl x)) * (∏ (x : γ), f (sum.inr x)) :=
begin
haveI := classical.dec_eq (α ⊕ γ),
convert prod_sum_elim univ univ (λ x, f (sum.inl x)) (λ x, f (sum.inr x)),
{ ext a,
split,
{ intro x,
cases a,
{ simp only [mem_union, mem_map, mem_univ, function.embedding.inl_apply, or_false,
exists_true_left, exists_apply_eq_apply, function.embedding.inr_apply, exists_false], },
{ simp only [mem_union, mem_map, mem_univ, function.embedding.inl_apply, false_or,
exists_true_left, exists_false, function.embedding.inr_apply,
exists_apply_eq_apply], }, },
{ simp only [mem_univ, implies_true_iff], }, },
{ simp only [sum.elim_comp_inl_inr], },
end
@[to_additive]
lemma prod_sum_elim (s : finset α) (t : finset γ) (f : α → β) (g : γ → β) :
∏ x in s.disj_sum t, sum.elim f g x = (∏ x in s, f x) * (∏ x in t, g x) :=
by simp

@[to_additive]
lemma prod_bUnion [decidable_eq α] {s : finset γ} {t : γ → finset α}
Expand Down
2 changes: 1 addition & 1 deletion src/algebra/big_operators/fin.lean
Expand Up @@ -146,7 +146,7 @@ begin
rw fintype.prod_equiv fin_sum_fin_equiv.symm f (λ i, f (fin_sum_fin_equiv.to_fun i)), swap,
{ intro x,
simp only [equiv.to_fun_as_coe, equiv.apply_symm_apply], },
apply prod_on_sum,
apply fintype.prod_sum_type,
end

@[to_additive]
Expand Down
10 changes: 4 additions & 6 deletions src/analysis/convex/combination.lean
Expand Up @@ -73,9 +73,7 @@ lemma finset.center_mass_segment'
(s : finset ι) (t : finset ι') (ws : ι → R) (zs : ι → E) (wt : ι' → R) (zt : ι' → E)
(hws : ∑ i in s, ws i = 1) (hwt : ∑ i in t, wt i = 1) (a b : R) (hab : a + b = 1) :
a • s.center_mass ws zs + b • t.center_mass wt zt =
(s.map embedding.inl ∪ t.map embedding.inr).center_mass
(sum.elim (λ i, a * ws i) (λ j, b * wt j))
(sum.elim zs zt) :=
(s.disj_sum t).center_mass (sum.elim (λ i, a * ws i) (λ j, b * wt j)) (sum.elim zs zt) :=
begin
rw [s.center_mass_eq_of_sum_1 _ hws, t.center_mass_eq_of_sum_1 _ hwt,
smul_sum, smul_sum, ← finset.sum_sum_elim, finset.center_mass_eq_of_sum_1],
Expand Down Expand Up @@ -287,13 +285,13 @@ begin
rw [finset.center_mass_segment' _ _ _ _ _ _ hwx₁ hwy₁ _ _ hab],
refine ⟨_, _, _, _, _, _, _, rfl⟩,
{ rintros i hi,
rw [finset.mem_union, finset.mem_map, finset.mem_map] at hi,
rw [finset.mem_disj_sum] at hi,
rcases hi with ⟨j, hj, rfl⟩|⟨j, hj, rfl⟩;
simp only [sum.elim_inl, sum.elim_inr];
apply_rules [mul_nonneg, hwx₀, hwy₀] },
{ simp [finset.sum_sum_elim, finset.mul_sum.symm, *] },
{ simp [finset.sum_sum_elim, finset.mul_sum.symm, *], },
{ intros i hi,
rw [finset.mem_union, finset.mem_map, finset.mem_map] at hi,
rw [finset.mem_disj_sum] at hi,
rcases hi with ⟨j, hj, rfl⟩|⟨j, hj, rfl⟩; apply_rules [hzx, hzy] } },
{ rintros _ ⟨ι, t, w, z, hw₀, hw₁, hz, rfl⟩,
exact t.center_mass_mem_convex_hull hw₀ (hw₁.symm ▸ zero_lt_one) hz }
Expand Down
10 changes: 10 additions & 0 deletions src/data/finset/sum.lean
Expand Up @@ -34,6 +34,16 @@ val_inj.1 $ multiset.disj_sum_zero _

@[simp] lemma card_disj_sum : (s.disj_sum t).card = s.card + t.card := multiset.card_disj_sum _ _

/-- Note that this is not stated with `disjoint` so that it can be used with `finset.disj_union`. -/
lemma disjoint_map_inl_map_inr {α β : Type*} (s : finset α) (t : finset β) (a : α ⊕ β) :
a ∈ (s.map embedding.inl : finset (α ⊕ β)) → a ∉ (t.map embedding.inr : finset (α ⊕ β)) :=
by { simp_rw mem_map, rintro ⟨a, _, rfl⟩ ⟨b, _, ⟨⟩⟩ }

@[simp]
lemma map_inl_disj_union_map_inr :
(s.map embedding.inl).disj_union (t.map embedding.inr) (disjoint_map_inl_map_inr _ _) =
s.disj_sum t := rfl

variables {s t} {s₁ s₂ : finset α} {t₁ t₂ : finset β} {a : α} {b : β} {x : α ⊕ β}

lemma mem_disj_sum : x ∈ s.disj_sum t ↔ (∃ a, a ∈ s ∧ inl a = x) ∨ ∃ b, b ∈ t ∧ inr b = x :=
Expand Down
29 changes: 8 additions & 21 deletions src/data/fintype/basic.lean
Expand Up @@ -10,6 +10,7 @@ import data.finset.pi
import data.finset.powerset
import data.finset.prod
import data.finset.sigma
import data.finset.sum
import data.finite.defs
import data.list.nodup_equiv_fin
import data.sym.basic
Expand Down Expand Up @@ -975,19 +976,13 @@ instance (α : Type*) [fintype α] : fintype (lex α) := ‹fintype α›
@[simp] lemma fintype.card_lex (α : Type*) [fintype α] :
fintype.card (lex α) = fintype.card α := rfl

lemma univ_sum_type {α β : Type*} [fintype α] [fintype β] [fintype (α ⊕ β)] [decidable_eq (α ⊕ β)] :
(univ : finset (α ⊕ β)) = map function.embedding.inl univ ∪ map function.embedding.inr univ :=
begin
rw [eq_comm, eq_univ_iff_forall], simp only [mem_union, mem_map, exists_prop, mem_univ, true_and],
rintro (x|y), exacts [or.inl ⟨x, rfl⟩, or.inr ⟨y, rfl⟩]
end

instance (α : Type u) (β : Type v) [fintype α] [fintype β] : fintype (α ⊕ β) :=
@fintype.of_equiv _ _ (@sigma.fintype _
(λ b, cond b (ulift α) (ulift.{(max u v) v} β)) _
(λ b, by cases b; apply ulift.fintype))
((equiv.sum_equiv_sigma_bool _ _).symm.trans
(equiv.sum_congr equiv.ulift equiv.ulift))
{ elems := univ.disj_sum univ,
complete := by rintro (_ | _); simp }

@[simp] lemma finset.univ_disj_sum_univ {α β : Type*} [fintype α] [fintype β] :
univ.disj_sum univ = (univ : finset (α ⊕ β)) :=
rfl

/-- Given that `α ⊕ β` is a fintype, `α` is also a fintype. This is non-computable as it uses
that `sum.inl` is an injection, but there's no clear inverse if `α` is empty. -/
Expand All @@ -1001,15 +996,7 @@ fintype.of_injective (sum.inr : β → α ⊕ β) sum.inr_injective

@[simp] theorem fintype.card_sum [fintype α] [fintype β] :
fintype.card (α ⊕ β) = fintype.card α + fintype.card β :=
begin
classical,
rw [←finset.card_univ, univ_sum_type, finset.card_union_eq],
{ simp [finset.card_univ] },
{ intros x hx,
rsuffices ⟨⟨a, rfl⟩, ⟨b, hb⟩⟩ : (∃ (a : α), sum.inl a = x) ∧ ∃ (b : β), sum.inr b = x,
{ simpa using hb },
simpa using hx }
end
card_disj_sum _ _

/-- If the subtype of all-but-one elements is a `fintype` then the type itself is a `fintype`. -/
def fintype_of_fintype_ne (a : α) (h : fintype {b // b ≠ a}) : fintype α :=
Expand Down
6 changes: 3 additions & 3 deletions src/data/fintype/card.lean
Expand Up @@ -241,11 +241,11 @@ variables {α₁ : Type*} {α₂ : Type*} {M : Type*} [fintype α₁] [fintype
@[to_additive]
lemma fintype.prod_sum_elim (f : α₁ → M) (g : α₂ → M) :
(∏ x, sum.elim f g x) = (∏ a₁, f a₁) * (∏ a₂, g a₂) :=
by { classical, rw [univ_sum_type, prod_sum_elim] }
prod_disj_sum _ _ _

@[to_additive]
@[simp, to_additive]
lemma fintype.prod_sum_type (f : α₁ ⊕ α₂ → M) :
(∏ x, f x) = (∏ a₁, f (sum.inl a₁)) * (∏ a₂, f (sum.inr a₂)) :=
by simp only [← fintype.prod_sum_elim, sum.elim_comp_inl_inr]
prod_disj_sum _ _ _

end

0 comments on commit e2b33ab

Please sign in to comment.