Skip to content

Commit

Permalink
feat(measure_theory/probability_mass_function): Define uniform pmf on…
Browse files Browse the repository at this point in the history
… an inhabited fintype (#9920)

This PR defines uniform probability mass functions on nonempty finsets and inhabited fintypes.
  • Loading branch information
dtumad committed Nov 2, 2021
1 parent f6894c4 commit fc12ca8
Showing 1 changed file with 103 additions and 5 deletions.
108 changes: 103 additions & 5 deletions src/measure_theory/probability_mass_function.lean
Expand Up @@ -50,6 +50,11 @@ def support (p : pmf α) : set α := {a | p.1 a ≠ 0}
@[simp] lemma mem_support_iff (p : pmf α) (a : α) :
a ∈ p.support ↔ p a ≠ 0 := iff.rfl

lemma coe_le_one (p : pmf α) (a : α) : p a ≤ 1 :=
has_sum_le (by intro b; split_ifs; simp [h]; exact le_refl _) (has_sum_ite_eq a (p a)) p.2

section pure

/-- The pure `pmf` is the `pmf` where all the mass lies in one point.
The value of `pure a` is `1` at `a` and `0` elsewhere. -/
def pure (a : α) : pmf α := ⟨λ a', if a' = a then 1 else 0, has_sum_ite_eq _ _⟩
Expand All @@ -61,8 +66,9 @@ by simp

instance [inhabited α] : inhabited (pmf α) := ⟨pure (default α)⟩

lemma coe_le_one (p : pmf α) (a : α) : p a ≤ 1 :=
has_sum_le (by intro b; split_ifs; simp [h]; exact le_refl _) (has_sum_ite_eq a (p a)) p.2
end pure

section bind

protected lemma bind.summable (p : pmf α) (f : α → pmf β) (b : β) :
summable (λ a : α, p a * f a b) :=
Expand Down Expand Up @@ -120,6 +126,10 @@ begin
simp [mul_assoc, mul_left_comm, mul_comm]
end

end bind

section bind_on_support

protected lemma bind_on_support.summable (p : pmf α) (f : ∀ a ∈ p.support, pmf β) (b : β) :
summable (λ a : α, p a * if h : p a = 0 then 0 else f a h b) :=
begin
Expand Down Expand Up @@ -228,6 +238,10 @@ begin
split_ifs with h1 h2 h2; ring,
end

end bind_on_support

section map

/-- The functorial action of a function on a `pmf`. -/
def map (f : α → β) (p : pmf α) : pmf β := bind p (pure ∘ f)

Expand All @@ -241,9 +255,37 @@ by simp [map]
lemma pure_map (a : α) (f : α → β) : (pure a).map f = pure (f a) :=
by simp [map]

end map

/-- The monadic sequencing operation for `pmf`. -/
def seq (f : pmf (α → β)) (p : pmf α) : pmf β := f.bind (λ m, p.bind $ λ a, pure (m a))

section of_finite

/-- Given a finset `s` and a function `f : α → ℝ≥0` with sum `1` on `s`,
such that `f x = 0` for `x ∉ s`, we get a `pmf` -/
def of_finset (f : α → ℝ≥0) (s : finset α) (h : ∑ x in s, f x = 1)
(h' : ∀ x ∉ s, f x = 0) : pmf α :=
⟨f, h ▸ has_sum_sum_of_ne_finset_zero h'⟩

@[simp]
lemma of_finset_apply {f : α → ℝ≥0} {s : finset α} (h : ∑ x in s, f x = 1)
(h' : ∀ x ∉ s, f x = 0) (a : α) : of_finset f s h h' a = f a :=
rfl

lemma of_finset_apply_of_not_mem {f : α → ℝ≥0} {s : finset α} (h : ∑ x in s, f x = 1)
(h' : ∀ x ∉ s, f x = 0) {a : α} (ha : a ∉ s) : of_finset f s h h' a = 0 :=
h' a ha

/-- Given a finite type `α` and a function `f : α → ℝ≥0` with sum 1, we get a `pmf`. -/
def of_fintype [fintype α] (f : α → ℝ≥0) (h : ∑ x, f x = 1) : pmf α :=
of_finset f finset.univ h (λ x hx, absurd (finset.mem_univ x) hx)

@[simp]
lemma of_fintype_apply [fintype α] {f : α → ℝ≥0} (h : ∑ x, f x = 1)
(a : α) : of_fintype f h a = f a :=
rfl

/-- Given a non-empty multiset `s` we construct the `pmf` which sends `a` to the fraction of
elements in `s` that are `a`. -/
def of_multiset (s : multiset α) (hs : s ≠ 0) : pmf α :=
Expand All @@ -258,9 +300,52 @@ def of_multiset (s : multiset α) (hs : s ≠ 0) : pmf α :=
simp {contextual := tt},
end

/-- Given a finite type `α` and a function `f : α → ℝ≥0` with sum 1, we get a `pmf`. -/
def of_fintype [fintype α] (f : α → ℝ≥0) (h : ∑ x, f x = 1) : pmf α :=
⟨f, h ▸ has_sum_sum_of_ne_finset_zero (by simp)⟩
@[simp]
lemma of_multiset_apply {s : multiset α} (hs : s ≠ 0) (a : α) :
of_multiset s hs a = s.count a / s.card :=
rfl

lemma of_multiset_apply_of_not_mem {s : multiset α} (hs : s ≠ 0)
{a : α} (ha : a ∉ s) : of_multiset s hs a = 0 :=
div_eq_zero_iff.2 (or.inl $ nat.cast_eq_zero.2 $ multiset.count_eq_zero_of_not_mem ha)

end of_finite

section uniform

/-- Uniform distribution taking the same non-zero probability on the nonempty finset `s` -/
def uniform_of_finset (s : finset α) (hs : s.nonempty) : pmf α :=
of_finset (λ a, if a ∈ s then (s.card : ℝ≥0)⁻¹ else 0) s (Exists.rec_on hs (λ x hx,
calc ∑ (a : α) in s, ite (a ∈ s) (s.card : ℝ≥0)⁻¹ 0
= ∑ (a : α) in s, (s.card : ℝ≥0)⁻¹ : finset.sum_congr rfl (λ x hx, by simp [hx])
... = s.card • (s.card : ℝ≥0)⁻¹ : finset.sum_const _
... = (s.card : ℝ≥0) * (s.card : ℝ≥0)⁻¹ : by rw nsmul_eq_mul
... = 1 : div_self (nat.cast_ne_zero.2 $ finset.card_ne_zero_of_mem hx)
)) (λ x hx, by simp only [hx, if_false])

@[simp]
lemma uniform_of_finset_apply {s : finset α} (hs : s.nonempty) (a : α) :
uniform_of_finset s hs a = if a ∈ s then (s.card : ℝ≥0)⁻¹ else 0 :=
rfl

lemma uniform_of_finset_apply_of_mem {s : finset α} (hs : s.nonempty) {a : α} (ha : a ∈ s) :
uniform_of_finset s hs a = (s.card)⁻¹ :=
by simp [ha]

lemma uniform_of_finset_apply_of_not_mem {s : finset α} (hs : s.nonempty) {a : α} (ha : a ∉ s) :
uniform_of_finset s hs a = 0 :=
by simp [ha]

/-- The uniform pmf taking the same uniform value on all of the fintype `α` -/
def uniform_of_fintype (α : Type*) [fintype α] [nonempty α] : pmf α :=
uniform_of_finset (finset.univ) (finset.univ_nonempty)

@[simp]
lemma uniform_of_fintype_apply [fintype α] [nonempty α] (a : α) :
uniform_of_fintype α a = (fintype.card α)⁻¹ :=
by simpa only [uniform_of_fintype, finset.mem_univ, if_true, uniform_of_finset_apply]

end uniform

/-- Given a `f` with non-zero sum, we get a `pmf` by normalizing `f` by its `tsum` -/
def normalize (f : α → ℝ≥0) (hf0 : tsum f ≠ 0) : pmf α :=
Expand All @@ -271,6 +356,8 @@ def normalize (f : α → ℝ≥0) (hf0 : tsum f ≠ 0) : pmf α :=
lemma normalize_apply {f : α → ℝ≥0} (hf0 : tsum f ≠ 0) (a : α) :
(normalize f hf0) a = f a * (∑' x, f x)⁻¹ := rfl

section filter

/-- Create new `pmf` by filtering on a set with non-zero measure and normalizing -/
def filter (p : pmf α) (s : set α) (h : ∃ a ∈ s, p a ≠ 0) : pmf α :=
pmf.normalize (s.indicator p) $ nnreal.tsum_indicator_ne_zero p.2.summable h
Expand Down Expand Up @@ -299,8 +386,19 @@ lemma filter_apply_ne_zero_iff (p : pmf α) {s : set α} (h : ∃ a ∈ s, p a
(p.filter s h) a ≠ 0 ↔ a ∈ (p.support ∩ s) :=
by rw [← not_iff, filter_apply_eq_zero_iff, not_iff, not_not]

end filter

section bernoulli

/-- A `pmf` which assigns probability `p` to `tt` and `1 - p` to `ff`. -/
def bernoulli (p : ℝ≥0) (h : p ≤ 1) : pmf bool :=
of_fintype (λ b, cond b p (1 - p)) (nnreal.eq $ by simp [h])

@[simp]
lemma bernuolli_apply {p : ℝ≥0} (h : p ≤ 1) (b : bool) :
bernoulli p h b = cond b p (1 - p) :=
rfl

end bernoulli

end pmf

0 comments on commit fc12ca8

Please sign in to comment.