Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - feat(measure_theory/probability_mass_function): Measures of sets under pmf monad operations #11613

Closed
wants to merge 5 commits into from
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions src/measure_theory/probability_mass_function/monad.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ lemma mem_support_pure_iff: a' ∈ (pure a).support ↔ a' = a := by simp

instance [inhabited α] : inhabited (pmf α) := ⟨pure default⟩

section measure

variable (s : set α)

lemma to_outer_measure_pure_apply : (pure a).to_outer_measure s = if a ∈ s then 1 else 0 :=
begin
refine (to_outer_measure_apply' (pure a) s).trans _,
split_ifs with ha ha,
{ refine ennreal.coe_eq_one.2 ((tsum_congr (λ b, _)).trans (tsum_ite_eq a 1)),
exact ite_eq_left_iff.2 (λ hb, symm (ite_eq_right_iff.2 (λ h, (hb $ h.symm ▸ ha).elim))) },
{ refine ennreal.coe_eq_zero.2 ((tsum_congr (λ b, _)).trans (tsum_zero)),
exact ite_eq_right_iff.2 (λ hb, ite_eq_right_iff.2 (λ h, (ha $ h ▸ hb).elim)) }
end

/-- The measure of a set under `pure a` is `1` for sets containing `a` and `0` otherwise -/
lemma to_measure_pure_apply [measurable_space α] (hs : measurable_set s) :
(pure a).to_measure s = if a ∈ s then 1 else 0 :=
(to_measure_apply_eq_to_outer_measure_apply (pure a) s hs).trans (to_outer_measure_pure_apply a s)

end measure

end pure

section bind
Expand All @@ -63,7 +84,7 @@ def bind (p : pmf α) (f : α → pmf β) : pmf β :=
(ennreal.coe_tsum p.summable_coe).symm]
end⟩

variables (p : pmf α) (f : α → pmf β)
variables (p : pmf α) (f : α → pmf β) (g : β → pmf γ)

@[simp] lemma bind_apply (b : β) : p.bind f b = ∑'a, p a * f a b := rfl

Expand All @@ -73,22 +94,20 @@ set.ext (λ b, by simp [mem_support_iff, tsum_eq_zero_iff (bind.summable p f b),
lemma mem_support_bind_iff (b : β) : b ∈ (p.bind f).support ↔ ∃ a ∈ p.support, b ∈ (f a).support :=
by simp

lemma coe_bind_apply (p : pmf α) (f : α → pmf β) (b : β) :
(p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
lemma coe_bind_apply (b : β) : (p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
eq.trans (ennreal.coe_tsum $ bind.summable p f b) $ by simp

@[simp] lemma pure_bind (a : α) (f : α → pmf β) : (pure a).bind f = f a :=
@[simp] lemma pure_bind (a : α) : (pure a).bind f = f a :=
have ∀ b a', ite (a' = a) 1 0 * f a' b = ite (a' = a) (f a b) 0, from
assume b a', by split_ifs; simp; subst h; simp,
by ext b; simp [this]

@[simp] lemma bind_pure (p : pmf α) : p.bind pure = p :=
@[simp] lemma bind_pure : p.bind pure = p :=
have ∀ a a', (p a * ite (a' = a) 1 0) = ite (a = a') (p a') 0, from
assume a a', begin split_ifs; try { subst a }; try { subst a' }; simp * at * end,
by ext b; simp [this]

@[simp] lemma bind_bind (p : pmf α) (f : α → pmf β) (g : β → pmf γ) :
(p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
@[simp] lemma bind_bind : (p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
begin
ext1 b,
simp only [ennreal.coe_eq_coe.symm, coe_bind_apply, ennreal.tsum_mul_left.symm,
Expand All @@ -107,6 +126,39 @@ begin
simp [mul_assoc, mul_left_comm, mul_comm]
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_apply :
(p.bind f).to_outer_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_outer_measure s :=
calc (p.bind f).to_outer_measure s
= ∑' (b : β), if b ∈ s then (↑(∑' (a : α), p a * f a b) : ℝ≥0∞) else 0
: by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then f a b else 0))
: tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then f a b else 0))
: tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then f a b else 0)
: ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then f a b else 0)
: tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(f a b) else (0 : ℝ≥0∞)
: tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * (f a).to_outer_measure s
: tsum_congr (λ a, by rw [to_outer_measure_apply, set.indicator])
dtumad marked this conversation as resolved.
Show resolved Hide resolved

/-- The measure of a set under `p.bind f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a` -/
lemma to_measure_bind_apply [measurable_space β] (hs : measurable_set s) :
(p.bind f).to_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind f) s hs).trans
((to_outer_measure_bind_apply p f s).trans (tsum_congr (λ a, congr_arg (λ x, p a * x)
(to_measure_apply_eq_to_outer_measure_apply (f a) s hs).symm)))

end measure

end bind

instance : monad pmf :=
Expand Down Expand Up @@ -230,6 +282,49 @@ begin
split_ifs with h1 h2 h2; ring,
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_on_support_apply :
(p.bind_on_support f).to_outer_measure s =
∑' (a : α), (p a : ℝ≥0) * if h : p a = 0 then 0 else (f a h).to_outer_measure s :=
let g : α → β → ℝ≥0 := λ a b, if h : p a = 0 then 0 else f a h b in
calc (p.bind_on_support f).to_outer_measure s
= ∑' (b : β), if b ∈ s then ↑(∑' (a : α), p a * g a b) else 0
: by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then g a b else 0))
: tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then g a b else 0))
: tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind_on_support.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then g a b else 0)
: ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then g a b else 0)
: tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(g a b) else (0 : ℝ≥0∞)
: tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * if h : p a = 0 then 0 else (f a h).to_outer_measure s
: tsum_congr (λ a, congr_arg (has_mul.mul ↑(p a)) begin
dtumad marked this conversation as resolved.
Show resolved Hide resolved
split_ifs with h h,
{ exact ennreal.tsum_eq_zero.mpr (λ x,
(by simp [g, h] : (0 : ℝ≥0∞) = ↑(g a x)) ▸ (if_t_t (x ∈ s) 0)) },
{ simp [to_outer_measure_apply, g, h, set.indicator_apply] }
end)

/-- The measure of a set under `p.bind_on_support f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a _`.
The additional if statement is needed since `f` is only a partial function -/
lemma to_measure_bind_on_support_apply [measurable_space β] (hs : measurable_set s) :
(p.bind_on_support f).to_measure s =
∑' (a : α), (p a : ℝ≥0∞) * if h : p a = 0 then 0 else (f a h).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind_on_support f) s hs).trans
((to_outer_measure_bind_on_support_apply f s).trans
(tsum_congr $ λ a, congr_arg (has_mul.mul ↑(p a)) (congr_arg (dite (p a = 0) (λ _, 0))
$ funext (λ h, symm $ to_measure_apply_eq_to_outer_measure_apply (f a h) s hs))))

end measure

end bind_on_support

end pmf