Skip to content

Commit

Permalink
feat(measure_theory/probability_mass_function): Measures of sets unde…
Browse files Browse the repository at this point in the history
…r `pmf` monad operations (#11613)

This PR adds explicit formulas for the measures of sets under `pmf.pure`, `pmf.bind`, and `pmf.bind_on_support`.
  • Loading branch information
dtumad committed Feb 2, 2022
1 parent a687cbf commit c1d2860
Showing 1 changed file with 102 additions and 7 deletions.
109 changes: 102 additions & 7 deletions src/measure_theory/probability_mass_function/monad.lean
Expand Up @@ -40,6 +40,27 @@ lemma mem_support_pure_iff: a' ∈ (pure a).support ↔ a' = a := by simp

instance [inhabited α] : inhabited (pmf α) := ⟨pure default⟩

section measure

variable (s : set α)

lemma to_outer_measure_pure_apply : (pure a).to_outer_measure s = if a ∈ s then 1 else 0 :=
begin
refine (to_outer_measure_apply' (pure a) s).trans _,
split_ifs with ha ha,
{ refine ennreal.coe_eq_one.2 ((tsum_congr (λ b, _)).trans (tsum_ite_eq a 1)),
exact ite_eq_left_iff.2 (λ hb, symm (ite_eq_right_iff.2 (λ h, (hb $ h.symm ▸ ha).elim))) },
{ refine ennreal.coe_eq_zero.2 ((tsum_congr (λ b, _)).trans (tsum_zero)),
exact ite_eq_right_iff.2 (λ hb, ite_eq_right_iff.2 (λ h, (ha $ h ▸ hb).elim)) }
end

/-- The measure of a set under `pure a` is `1` for sets containing `a` and `0` otherwise -/
lemma to_measure_pure_apply [measurable_space α] (hs : measurable_set s) :
(pure a).to_measure s = if a ∈ s then 1 else 0 :=
(to_measure_apply_eq_to_outer_measure_apply (pure a) s hs).trans (to_outer_measure_pure_apply a s)

end measure

end pure

section bind
Expand All @@ -63,7 +84,7 @@ def bind (p : pmf α) (f : α → pmf β) : pmf β :=
(ennreal.coe_tsum p.summable_coe).symm]
end

variables (p : pmf α) (f : α → pmf β)
variables (p : pmf α) (f : α → pmf β) (g : β → pmf γ)

@[simp] lemma bind_apply (b : β) : p.bind f b = ∑'a, p a * f a b := rfl

Expand All @@ -73,22 +94,20 @@ set.ext (λ b, by simp [mem_support_iff, tsum_eq_zero_iff (bind.summable p f b),
lemma mem_support_bind_iff (b : β) : b ∈ (p.bind f).support ↔ ∃ a ∈ p.support, b ∈ (f a).support :=
by simp

lemma coe_bind_apply (p : pmf α) (f : α → pmf β) (b : β) :
(p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
lemma coe_bind_apply (b : β) : (p.bind f b : ℝ≥0∞) = ∑'a, p a * f a b :=
eq.trans (ennreal.coe_tsum $ bind.summable p f b) $ by simp

@[simp] lemma pure_bind (a : α) (f : α → pmf β) : (pure a).bind f = f a :=
@[simp] lemma pure_bind (a : α) : (pure a).bind f = f a :=
have ∀ b a', ite (a' = a) 1 0 * f a' b = ite (a' = a) (f a b) 0, from
assume b a', by split_ifs; simp; subst h; simp,
by ext b; simp [this]

@[simp] lemma bind_pure (p : pmf α) : p.bind pure = p :=
@[simp] lemma bind_pure : p.bind pure = p :=
have ∀ a a', (p a * ite (a' = a) 1 0) = ite (a = a') (p a') 0, from
assume a a', begin split_ifs; try { subst a }; try { subst a' }; simp * at * end,
by ext b; simp [this]

@[simp] lemma bind_bind (p : pmf α) (f : α → pmf β) (g : β → pmf γ) :
(p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
@[simp] lemma bind_bind : (p.bind f).bind g = p.bind (λ a, (f a).bind g) :=
begin
ext1 b,
simp only [ennreal.coe_eq_coe.symm, coe_bind_apply, ennreal.tsum_mul_left.symm,
Expand All @@ -107,6 +126,39 @@ begin
simp [mul_assoc, mul_left_comm, mul_comm]
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_apply :
(p.bind f).to_outer_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_outer_measure s :=
calc (p.bind f).to_outer_measure s
= ∑' (b : β), if b ∈ s then (↑(∑' (a : α), p a * f a b) : ℝ≥0∞) else 0 :
by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then f a b else 0)) :
tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then f a b else 0)) :
tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then f a b else 0) :
ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then f a b else 0) :
tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(f a b) else (0 : ℝ≥0∞) :
tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * (f a).to_outer_measure s :
tsum_congr (λ a, by rw [to_outer_measure_apply, set.indicator])

/-- The measure of a set under `p.bind f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a` -/
lemma to_measure_bind_apply [measurable_space β] (hs : measurable_set s) :
(p.bind f).to_measure s = ∑' (a : α), (p a : ℝ≥0∞) * (f a).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind f) s hs).trans
((to_outer_measure_bind_apply p f s).trans (tsum_congr (λ a, congr_arg (λ x, p a * x)
(to_measure_apply_eq_to_outer_measure_apply (f a) s hs).symm)))

end measure

end bind

instance : monad pmf :=
Expand Down Expand Up @@ -230,6 +282,49 @@ begin
split_ifs with h1 h2 h2; ring,
end

section measure

variable (s : set β)

lemma to_outer_measure_bind_on_support_apply :
(p.bind_on_support f).to_outer_measure s =
∑' (a : α), (p a : ℝ≥0) * if h : p a = 0 then 0 else (f a h).to_outer_measure s :=
let g : α → β → ℝ≥0 := λ a b, if h : p a = 0 then 0 else f a h b in
calc (p.bind_on_support f).to_outer_measure s
= ∑' (b : β), if b ∈ s then ↑(∑' (a : α), p a * g a b) else 0 :
by simp [to_outer_measure_apply, set.indicator_apply]
... = ∑' (b : β), ↑(∑' (a : α), p a * (if b ∈ s then g a b else 0)) :
tsum_congr (λ b, by split_ifs; simp)
... = ∑' (b : β) (a : α), ↑(p a * (if b ∈ s then g a b else 0)) :
tsum_congr (λ b, ennreal.coe_tsum $
nnreal.summable_of_le (by split_ifs; simp) (bind_on_support.summable p f b))
... = ∑' (a : α) (b : β), ↑(p a) * ↑(if b ∈ s then g a b else 0) :
ennreal.tsum_comm.trans (tsum_congr $ λ a, tsum_congr $ λ b, ennreal.coe_mul)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), ↑(if b ∈ s then g a b else 0) :
tsum_congr (λ a, ennreal.tsum_mul_left)
... = ∑' (a : α), ↑(p a) * ∑' (b : β), if b ∈ s then ↑(g a b) else (0 : ℝ≥0∞) :
tsum_congr (λ a, congr_arg (λ x, ↑(p a) * x) $ tsum_congr (λ b, by split_ifs; refl))
... = ∑' (a : α), ↑(p a) * if h : p a = 0 then 0 else (f a h).to_outer_measure s :
tsum_congr (λ a, congr_arg (has_mul.mul ↑(p a)) begin
split_ifs with h h,
{ exact ennreal.tsum_eq_zero.mpr (λ x,
(by simp [g, h] : (0 : ℝ≥0∞) = ↑(g a x)) ▸ (if_t_t (x ∈ s) 0)) },
{ simp [to_outer_measure_apply, g, h, set.indicator_apply] }
end)

/-- The measure of a set under `p.bind_on_support f` is the sum over `a : α`
of the probability of `a` under `p` times the measure of the set under `f a _`.
The additional if statement is needed since `f` is only a partial function -/
lemma to_measure_bind_on_support_apply [measurable_space β] (hs : measurable_set s) :
(p.bind_on_support f).to_measure s =
∑' (a : α), (p a : ℝ≥0∞) * if h : p a = 0 then 0 else (f a h).to_measure s :=
(to_measure_apply_eq_to_outer_measure_apply (p.bind_on_support f) s hs).trans
((to_outer_measure_bind_on_support_apply f s).trans
(tsum_congr $ λ a, congr_arg (has_mul.mul ↑(p a)) (congr_arg (dite (p a = 0) (λ _, 0))
$ funext (λ h, symm $ to_measure_apply_eq_to_outer_measure_apply (f a h) s hs))))

end measure

end bind_on_support

end pmf

0 comments on commit c1d2860

Please sign in to comment.