Skip to content

Commit e4f93e7

Browse files
committed
feat(Probability): characterize conditional independence with conditional distributions (#30024)
Two random variables `f, g` are conditionally independent given a third `k` iff the conditional distribution of `f` given `k` and `g` is equal to the conditional distribution of `f` given `k`. ```lean CondIndepFun (mγ.comap k) hk.comap_le g f μ ↔ condDistrib f (fun ω ↦ (k ω, g ω)) μ =ᵐ[μ.map (fun ω ↦ (k ω, g ω))] (condDistrib f k μ).prodMkRight _ ``` From the LeanBandits project. Co-authored-by: Remy Degenne <remydegenne@gmail.com>
1 parent f5bc879 commit e4f93e7

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed

Mathlib/Probability/Independence/Conditional.lean

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,103 @@ lemma condIndepFun_iff_map_prod_eq_prod_comp_trim
790790
rfl
791791
· rw [Measure.compProd_eq_comp_prod]
792792

793+
/-- Two random variables `f, g` are conditionally independent given a third `k` iff the
794+
joint distribution of `k, f, g` factors into a product of their conditional distributions
795+
given `k`. -/
796+
theorem condIndepFun_iff_map_prod_eq_prod_condDistrib_prod_condDistrib
797+
{γ : Type*} {mγ : MeasurableSpace γ} {mβ : MeasurableSpace β} {mβ' : MeasurableSpace β'}
798+
[StandardBorelSpace β] [Nonempty β] [StandardBorelSpace β'] [Nonempty β']
799+
(hf : Measurable f) (hg : Measurable g) {k : Ω → γ} (hk : Measurable k) :
800+
CondIndepFun _ hk.comap_le f g μ ↔
801+
μ.map (fun ω ↦ (k ω, f ω, g ω)) =
802+
(Kernel.id ×ₖ (condDistrib f k μ ×ₖ condDistrib g k μ)) ∘ₘ μ.map k := by
803+
rw [condIndepFun_iff_map_prod_eq_prod_comp_trim hf hg]
804+
simp_rw [Measure.ext_prod₃_iff]
805+
have hk_meas {s : Set γ} (hs : MeasurableSet s) : MeasurableSet[mγ.comap k] (k ⁻¹' s) :=
806+
⟨s, hs, rfl⟩
807+
have h_left {s : Set γ} {t : Set β} {u : Set β'} (hs : MeasurableSet s) (ht : MeasurableSet t)
808+
(hu : MeasurableSet u) :
809+
(μ.map (fun ω ↦ (k ω, f ω, g ω))) (s ×ˢ t ×ˢ u) =
810+
(@Measure.map _ _ _ ((mγ.comap k).prod inferInstance)
811+
(fun ω ↦ (ω, f ω, g ω)) μ) ((k ⁻¹' s) ×ˢ t ×ˢ u) := by
812+
rw [Measure.map_apply (by fun_prop) (hs.prod (ht.prod hu)),
813+
Measure.map_apply _ ((hk_meas hs).prod (ht.prod hu))]
814+
· simp [Set.mk_preimage_prod]
815+
· exact (measurable_id.mono le_rfl hk.comap_le).prodMk (by fun_prop)
816+
have h_right {s : Set γ} {t : Set β} {u : Set β'} (hs : MeasurableSet s) (ht : MeasurableSet t)
817+
(hu : MeasurableSet u) :
818+
((Kernel.id ×ₖ (condDistrib f k μ ×ₖ condDistrib g k μ)) ∘ₘ μ.map k) (s ×ˢ t ×ˢ u) =
819+
((Kernel.id ×ₖ
820+
((condExpKernel μ (mγ.comap k)).map f ×ₖ (condExpKernel μ (mγ.comap k)).map g)) ∘ₘ
821+
μ.trim hk.comap_le) ((k ⁻¹' s) ×ˢ t ×ˢ u) := by
822+
rw [Measure.bind_apply ((hk_meas hs).prod (ht.prod hu)) (by fun_prop),
823+
Measure.bind_apply (hs.prod (ht.prod hu)) (by fun_prop), lintegral_map ?_ (by fun_prop),
824+
lintegral_trim]
825+
rotate_left
826+
· exact Kernel.measurable_coe _ ((hk_meas hs).prod (ht.prod hu))
827+
· exact Kernel.measurable_coe _ (hs.prod (ht.prod hu))
828+
refine lintegral_congr_ae ?_
829+
filter_upwards [condDistrib_apply_ae_eq_condExpKernel_map hf hk ht,
830+
condDistrib_apply_ae_eq_condExpKernel_map hg hk hu] with a haX haT
831+
simp only [Kernel.prod_apply_prod, Kernel.id_apply, Measure.dirac_apply' _ hs]
832+
rw [@Measure.dirac_apply' _ (mγ.comap k) _ _ (hk_meas hs)]
833+
congr
834+
refine ⟨fun h s t u hs ht hu ↦ ?_, fun h ↦ ?_⟩
835+
· convert h (hk_meas hs) ht hu
836+
· exact h_left hs ht hu
837+
· exact h_right hs ht hu
838+
· rintro - t u ⟨s, hs, rfl⟩ ht hu
839+
convert h hs ht hu
840+
· exact (h_left hs ht hu).symm
841+
· exact (h_right hs ht hu).symm
842+
843+
/-- Two random variables `f, g` are conditionally independent given a third `k` iff the
844+
conditional distribution of `f` given `k` and `g` is equal to the conditional distribution of `f`
845+
given `k`. -/
846+
theorem condIndepFun_iff_condDistrib_prod_ae_eq_prodMkLeft
847+
{γ : Type*} {mγ : MeasurableSpace γ} {mβ : MeasurableSpace β} {mβ' : MeasurableSpace β'}
848+
[StandardBorelSpace β] [Nonempty β] [StandardBorelSpace β'] [Nonempty β']
849+
(hf : Measurable f) (hg : Measurable g) {k : Ω → γ} (hk : Measurable k) :
850+
CondIndepFun (mγ.comap k) hk.comap_le g f μ ↔
851+
condDistrib f (fun ω ↦ (k ω, g ω)) μ =ᵐ[μ.map (fun ω ↦ (k ω, g ω))]
852+
(condDistrib f k μ).prodMkRight _ := by
853+
rw [condDistrib_ae_eq_iff_measure_eq_compProd (μ := μ) _ hf.aemeasurable,
854+
condIndepFun_iff_map_prod_eq_prod_condDistrib_prod_condDistrib hg hf hk,
855+
Measure.compProd_eq_comp_prod]
856+
let e : γ × β' × β ≃ᵐ (γ × β') × β := MeasurableEquiv.prodAssoc.symm
857+
have h_eq : ((Kernel.id ×ₖ condDistrib g k μ) ×ₖ condDistrib f k μ) ∘ₘ μ.map k =
858+
(Kernel.id ×ₖ (condDistrib f k μ).prodMkRight _) ∘ₘ μ.map (fun a ↦ (k a, g a)) := by
859+
calc ((Kernel.id ×ₖ condDistrib g k μ) ×ₖ condDistrib f k μ) ∘ₘ μ.map k
860+
_ = (Kernel.id ×ₖ (condDistrib f k μ).prodMkRight _) ∘ₘ (μ.map k ⊗ₘ condDistrib g k μ) := by
861+
rw [Measure.compProd_eq_comp_prod, Measure.comp_assoc]
862+
congr 2
863+
have h := Kernel.prod_prodMkRight_comp_deterministic_prod (condDistrib g k μ)
864+
(condDistrib f k μ) Kernel.id measurable_id
865+
rw [← Kernel.id] at h
866+
simpa using h.symm
867+
_ = (Kernel.id ×ₖ (condDistrib f k μ).prodMkRight _) ∘ₘ μ.map (fun a ↦ (k a, g a)) := by
868+
rw [compProd_map_condDistrib hg.aemeasurable]
869+
rw [← h_eq]
870+
have h1 : μ.map (fun x ↦ ((k x, g x), f x)) = (μ.map (fun a ↦ (k a , g a, f a))).map e := by
871+
rw [Measure.map_map (by fun_prop) (by fun_prop)]
872+
rfl
873+
have h1_symm : μ.map (fun a ↦ (k a , g a, f a)) =
874+
(μ.map (fun x ↦ ((k x, g x), f x))).map e.symm := by
875+
rw [h1, Measure.map_map (by fun_prop) (by fun_prop), MeasurableEquiv.symm_comp_self,
876+
Measure.map_id]
877+
have h2 : ((Kernel.id ×ₖ condDistrib g k μ) ×ₖ condDistrib f k μ) ∘ₘ μ.map k =
878+
((Kernel.id ×ₖ (condDistrib g k μ ×ₖ condDistrib f k μ)) ∘ₘ μ.map k).map e := by
879+
rw [← Measure.deterministic_comp_eq_map e.measurable, Measure.comp_assoc]
880+
congr 2
881+
unfold e
882+
rw [Kernel.deterministic_comp_eq_map, Kernel.prodAssoc_symm_prod]
883+
have h2_symm : (Kernel.id ×ₖ (condDistrib g k μ ×ₖ condDistrib f k μ)) ∘ₘ μ.map k =
884+
(((Kernel.id ×ₖ condDistrib g k μ) ×ₖ condDistrib f k μ) ∘ₘ μ.map k).map e.symm := by
885+
rw [h2, Measure.map_map (by fun_prop) (by fun_prop), MeasurableEquiv.symm_comp_self,
886+
Measure.map_id]
887+
rw [h1, h2]
888+
exact ⟨fun h ↦ by rw [h], fun h ↦ by rw [h1_symm, h1, h2_symm, h2, h]⟩
889+
793890
section iCondIndepFun
794891
variable {β : ι → Type*} {m : ∀ i, MeasurableSpace (β i)} {f : ∀ i, Ω → β i}
795892

Mathlib/Probability/Kernel/Composition/Lemmas.lean

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,56 @@ variable {α β γ δ : Type*} {mα : MeasurableSpace α} {mβ : MeasurableSpace
2121
{mγ : MeasurableSpace γ} {mδ : MeasurableSpace δ}
2222
{μ : Measure α} {ν : Measure β} {κ : Kernel α β}
2323

24+
namespace ProbabilityTheory.Kernel
25+
26+
/-- The composition of two product kernels `(ξ ×ₖ η') ∘ₖ (κ ×ₖ ζ)` is the product of the
27+
compositions `(ξ ∘ₖ (κ ×ₖ ζ)) ×ₖ (η' ∘ₖ (κ ×ₖ ζ))`, if `ζ` is deterministic (of the form
28+
`.deterministic f hf`) and `η'` does not depend on the output of `κ`.
29+
That is, `η'` has the form `η.prodMkLeft β` for a kernel `η`.
30+
31+
If `κ` was deterministic, this would be true even if `η.prodMkLeft β` was a more general
32+
kernel since `κ ×ₖ Kernel.deterministic f hf` would be deterministic and commute with copying.
33+
Here `κ` is not deterministic, but it is discarded in one branch of the copy. -/
34+
lemma prod_prodMkLeft_comp_prod_deterministic {β' ε : Type*}
35+
{mβ' : MeasurableSpace β'} {mε : MeasurableSpace ε}
36+
(κ : Kernel γ β) [IsSFiniteKernel κ] (η : Kernel ε β') [IsSFiniteKernel η]
37+
(ξ : Kernel (β × ε) δ) [IsSFiniteKernel ξ] {f : γ → ε} (hf : Measurable f) :
38+
(ξ ×ₖ η.prodMkLeft β) ∘ₖ (κ ×ₖ deterministic f hf)
39+
= (ξ ∘ₖ (κ ×ₖ deterministic f hf)) ×ₖ (η ∘ₖ deterministic f hf) := by
40+
ext ω s hs
41+
rw [prod_apply' _ _ _ hs, comp_apply' _ _ _ hs, lintegral_prod_deterministic,
42+
lintegral_comp, lintegral_prod_deterministic]
43+
· congr with b
44+
rw [prod_apply' _ _ _ hs, prodMkLeft_apply, comp_deterministic_eq_comap, comap_apply]
45+
· exact (measurable_measure_prodMk_left hs).lintegral_kernel
46+
· exact measurable_measure_prodMk_left hs
47+
· exact Kernel.measurable_coe _ hs
48+
49+
/-- The composition of two product kernels `(ξ ×ₖ η') ∘ₖ (ζ ×ₖ κ)` is the product of the
50+
compositions, `(ξ ∘ₖ (ζ ×ₖ κ)) ×ₖ (η' ∘ₖ (ζ ×ₖ κ))`, if `ζ` is deterministic (of the form
51+
`.deterministic f hf`) and `η'` does not depend on the output of `κ`.
52+
That is, `η'` has the form `η.prodMkRight β` for a kernel `η`.
53+
54+
If `κ` was deterministic, this would be true even if `η.prodMkRight β` was a more general
55+
kernel since `Kernel.deterministic f hf ×ₖ κ` would be deterministic and commute with copying.
56+
Here `κ` is not deterministic, but it is discarded in one branch of the copy. -/
57+
lemma prod_prodMkRight_comp_deterministic_prod {β' ε : Type*}
58+
{mβ' : MeasurableSpace β'} {mε : MeasurableSpace ε}
59+
(κ : Kernel γ β) [IsSFiniteKernel κ] (η : Kernel ε β') [IsSFiniteKernel η]
60+
(ξ : Kernel (ε × β) δ) [IsSFiniteKernel ξ] {f : γ → ε} (hf : Measurable f) :
61+
(ξ ×ₖ η.prodMkRight β) ∘ₖ (deterministic f hf ×ₖ κ)
62+
= (ξ ∘ₖ (deterministic f hf ×ₖ κ)) ×ₖ (η ∘ₖ deterministic f hf) := by
63+
ext ω s hs
64+
rw [prod_apply' _ _ _ hs, comp_apply' _ _ _ hs, lintegral_deterministic_prod,
65+
lintegral_comp, lintegral_deterministic_prod]
66+
· congr with b
67+
rw [prod_apply' _ _ _ hs, prodMkRight_apply, comp_deterministic_eq_comap, comap_apply]
68+
· exact (measurable_measure_prodMk_left hs).lintegral_kernel
69+
· exact measurable_measure_prodMk_left hs
70+
· exact Kernel.measurable_coe _ hs
71+
72+
end ProbabilityTheory.Kernel
73+
2474
namespace MeasureTheory.Measure
2575

2676
lemma compProd_eq_parallelComp_comp_copy_comp [SFinite μ] :

Mathlib/Probability/Kernel/Composition/Prod.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ lemma map_prod_swap (κ : Kernel α β) (η : Kernel α γ) [IsSFiniteKernel κ]
200200
refine (lintegral_lintegral_swap ?_).symm
201201
fun_prop
202202

203+
lemma prodComm_prod {κ : Kernel α β} [IsSFiniteKernel κ] {η : Kernel α γ} [IsSFiniteKernel η] :
204+
(κ ×ₖ η).map MeasurableEquiv.prodComm = η ×ₖ κ :=
205+
map_prod_swap κ η
206+
203207
@[simp]
204208
lemma swap_prod {κ : Kernel α β} [IsSFiniteKernel κ] {η : Kernel α γ} [IsSFiniteKernel η] :
205209
(swap β γ) ∘ₖ (κ ×ₖ η) = (η ×ₖ κ) := by
@@ -223,6 +227,12 @@ lemma prodAssoc_prod (κ : Kernel α β) [IsSFiniteKernel κ] (η : Kernel α γ
223227
rw [map_apply _ (by fun_prop), prod_apply, prod_apply, Measure.prodAssoc_prod, prod_apply,
224228
prod_apply]
225229

230+
lemma prodAssoc_symm_prod (κ : Kernel α β) [IsSFiniteKernel κ] (η : Kernel α γ) [IsSFiniteKernel η]
231+
(ξ : Kernel α δ) [IsSFiniteKernel ξ] :
232+
(κ ×ₖ (ξ ×ₖ η)).map MeasurableEquiv.prodAssoc.symm = (κ ×ₖ ξ) ×ₖ η := by
233+
rw [← prodAssoc_prod, ← Kernel.map_comp_right _ (by fun_prop) (by fun_prop)]
234+
simp
235+
226236
lemma prod_const_comp {δ} {mδ : MeasurableSpace δ} (κ : Kernel α β) [IsSFiniteKernel κ]
227237
(η : Kernel β γ) [IsSFiniteKernel η] (μ : Measure δ) [SFinite μ] :
228238
(η ×ₖ (const β μ)) ∘ₖ κ = (η ∘ₖ κ) ×ₖ (const α μ) := by

0 commit comments

Comments
 (0)