Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

[Merged by Bors] - feat(measure_theory/probability_mass_function): Measures of sets under pmf monad operations #11613

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions src/measure_theory/probability_mass_function/monad.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ lemma mem_support_pure_iff: a' ∈ (pure a).support ↔ a' = a := by simp

instance [inhabited α] : inhabited (pmf α) := ⟨pure default⟩

section measure

variable (s : set α)

lemma to_outer_measure_pure_apply : (pure a).to_outer_measure s = if a ∈ s then 1 else 0 :=
begin
refine (to_outer_measure_apply' (pure a) s).trans _,
split_ifs with ha ha,
{ refine ennreal.coe_eq_one.2 ((tsum_congr (λ b, _)).trans (tsum_ite_eq a 1)),
exact ite_eq_left_iff.2 (λ hb, symm (ite_eq_right_iff.2 (λ h, (hb $ h.symm ▸ ha).elim))) },
{ refine ennreal.coe_eq_zero.2 ((tsum_congr (λ b, _)).trans (tsum_zero)),
exact ite_eq_right_iff.2 (λ hb, ite_eq_right_iff.2 (λ h, (ha $ h ▸ hb).elim)) }
end

/-- The measure of a set under `pure a` is `1` for sets containing `a` and `0` otherwise -/
lemma to_measure_pure_apply [measurable_space α] (hs : measurable_set s) :
(pure a).to_measure s = if a ∈ s then 1 else 0 :=
(to_measure_apply_eq_to_outer_measure_apply (pure a) s hs).trans (to_outer_measure_pure_apply a s)

end measure

end pure

section bind
Expand All @@ -63,7 +84,7 @@ def bind (p : pmf α) (f : α → pmf β) : pmf β :=
(ennreal.coe_tsum p.summable_coe).symm]
end⟩

variables (p : pmf α) (f : α → pmf β)
variables (p : pmf α) (f : α → pmf β) (g : β → pmf γ)

@[simp] lemma bind_apply (b : β) : p.bind f b = ∑'a, p a * f a b := rfl

Expand All @@ -73,22 +94,20 @@ set.ext (λ b, by simp [mem_support_iff, tsum_eq_zero_iff (bind.summable p f b),
lemma mem_support_bind_iff (b : β) : b ∈ (p.bind f).support ↔ ∃ a ∈ p.support, b ∈ (f a).support :=
by simp

lemma coe_bind_apply (p : pmf α) (f : α → pmf β) (b : β) :
(p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
lemma coe_bind_apply (b : β) : (p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
eq.trans (ennreal.coe_tsum $ bind.summable p f b) $ by simp

@[simp] lemma pure_bind (a : α) (f : α → pmf β) : (pure a).bind f = f a :=
@[simp] lemma pure_bind (a : α) : (pure a).bind f = f a :=
have ∀ b a', ite (a' = a) 1 0 * f a' b = ite (a' = a) (f a b) 0, from
assume b a', by split_ifs; simp; subst h; simp,
by ext b; simp [this]

@[simp] lemma bind_pure (p : pmf α) : p.bind pure = p :=
@[simp] lemma bind_pure : p.bind pure = p :=
have ∀ a a', (p a * ite (a' = a) 1 0) = ite (a = a') (p a') 0, from
assume a a', begin split_ifs; try { subst a }; try { subst a' }; simp * at * end,
by ext b; simp [this]

@[simp] lemma bind_bind (p : pmf α) (f : α → pmf β) (g : β → pmf γ) :
(p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
@[simp] lemma bind_bind : (p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
begin
ext1 b,
simp only [ennreal.coe_eq_coe.symm, coe_bind_apply, ennreal.tsum_mul_left.symm,
Expand All @@ -107,6 +126,39 @@ begin
simp [mul_assoc, mul_left_comm, mul_comm]
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_apply :
(p.bind f).to_outer_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_outer_measure s :=
calc (p.bind f).to_outer_measure s
= ∑' (b : β), if b ∈ s then (↑(∑' (a : α), p a * f a b) : ℝ≥0∞) else 0 :
by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then f a b else 0)) :
tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then f a b else 0)) :
tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then f a b else 0) :
ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then f a b else 0) :
tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(f a b) else (0 : ℝ≥0∞) :
tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * (f a).to_outer_measure s :
tsum_congr (λ a, by rw [to_outer_measure_apply, set.indicator])

/-- The measure of a set under `p.bind f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a` -/
lemma to_measure_bind_apply [measurable_space β] (hs : measurable_set s) :
(p.bind f).to_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind f) s hs).trans
((to_outer_measure_bind_apply p f s).trans (tsum_congr (λ a, congr_arg (λ x, p a * x)
(to_measure_apply_eq_to_outer_measure_apply (f a) s hs).symm)))

end measure

end bind

instance : monad pmf :=
Expand Down Expand Up @@ -230,6 +282,49 @@ begin
split_ifs with h1 h2 h2; ring,
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_on_support_apply :
(p.bind_on_support f).to_outer_measure s =
∑' (a : α), (p a : ℝ≥0) * if h : p a = 0 then 0 else (f a h).to_outer_measure s :=
let g : α → β → ℝ≥0 := λ a b, if h : p a = 0 then 0 else f a h b in
calc (p.bind_on_support f).to_outer_measure s
= ∑' (b : β), if b ∈ s then ↑(∑' (a : α), p a * g a b) else 0 :
by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then g a b else 0)) :
tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then g a b else 0)) :
tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind_on_support.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then g a b else 0) :
ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then g a b else 0) :
tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(g a b) else (0 : ℝ≥0∞) :
tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * if h : p a = 0 then 0 else (f a h).to_outer_measure s :
tsum_congr (λ a, congr_arg (has_mul.mul ↑(p a)) begin
split_ifs with h h,
{ exact ennreal.tsum_eq_zero.mpr (λ x,
(by simp [g, h] : (0 : ℝ≥0∞) = ↑(g a x)) ▸ (if_t_t (x ∈ s) 0)) },
{ simp [to_outer_measure_apply, g, h, set.indicator_apply] }
end)

/-- The measure of a set under `p.bind_on_support f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a _`.
The additional if statement is needed since `f` is only a partial function -/
lemma to_measure_bind_on_support_apply [measurable_space β] (hs : measurable_set s) :
(p.bind_on_support f).to_measure s =
∑' (a : α), (p a : ℝ≥0∞) * if h : p a = 0 then 0 else (f a h).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind_on_support f) s hs).trans
((to_outer_measure_bind_on_support_apply f s).trans
(tsum_congr $ λ a, congr_arg (has_mul.mul ↑(p a)) (congr_arg (dite (p a = 0) (λ _, 0))
$ funext (λ h, symm $ to_measure_apply_eq_to_outer_measure_apply (f a h) s hs))))

end measure

end bind_on_support

end pmf