From 49773660fa36d2c00274a17f03b72a25b7862fa7 Mon Sep 17 00:00:00 2001 From: Keyan Date: Wed, 3 Apr 2019 11:41:37 +0200 Subject: [PATCH] bug fix --- discrete/ProbabilityEstimator.go | 9 +++++++++ discrete/sparse/ConditionalMutualInformation.go | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/discrete/ProbabilityEstimator.go b/discrete/ProbabilityEstimator.go index 6cf68ce..d83265e 100644 --- a/discrete/ProbabilityEstimator.go +++ b/discrete/ProbabilityEstimator.go @@ -1,6 +1,8 @@ package discrete import ( + "math" + "github.com/kzahedi/goent/sm" ) @@ -137,8 +139,15 @@ func Empirical3DSparse(d [][]int) sm.SparseMatrix { } l := float64(rows) + var sum float64 for _, index := range p.Indices { p.Mul(index, 1.0/l) + v, _ := p.Get(index) + sum += v + } + + if math.Abs(sum-1.0) > 0.00001 { + panic("P does not sum up to one") } return p diff --git a/discrete/sparse/ConditionalMutualInformation.go b/discrete/sparse/ConditionalMutualInformation.go index a493c94..4657323 100644 --- a/discrete/sparse/ConditionalMutualInformation.go +++ b/discrete/sparse/ConditionalMutualInformation.go @@ -34,19 +34,19 @@ func ConditionalMutualInformation(pxyz sm.SparseMatrix, ln lnFunc) float64 { for _, index := range pxyCz.Indices { zi := sm.SparseMatrixIndex{index[2]} v, _ := pz.Get(zi) - pxyCz.Mul(zi, 1.0/v) + pxyCz.Mul(index, 1.0/v) } for _, index := range pxCz.Indices { zi := sm.SparseMatrixIndex{index[1]} v, _ := pz.Get(zi) - pxCz.Mul(zi, 1.0/v) + pxCz.Mul(index, 1.0/v) } for _, index := range pyCz.Indices { zi := sm.SparseMatrixIndex{index[1]} v, _ := pz.Get(zi) - pyCz.Mul(zi, 1.0/v) + pyCz.Mul(index, 1.0/v) } r := 0.0