Skip to content

Commit

Permalink
add ZDD.uniformM
Browse files Browse the repository at this point in the history
  • Loading branch information
msakai committed Oct 21, 2021
1 parent 40bc620 commit 764a77d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
10 changes: 9 additions & 1 deletion decision-diagrams.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ library
, hashable
, hashtables
, intern >=0.9.1.2 && <1.0.0.0
, mwc-random
, primitive
, random
, reflection
, unordered-containers
default-language: Haskell2010
Expand All @@ -54,13 +57,18 @@ test-suite decision-diagrams-test
test
ghc-options: -threaded -rtsopts -with-rtsopts=-N
build-depends:
base >=4.7 && <5
QuickCheck
, base >=4.7 && <5
, containers
, decision-diagrams
, hashable
, hashtables
, intern >=0.9.1.2 && <1.0.0.0
, mwc-random
, primitive
, random
, reflection
, statistics
, tasty >=0.10.1
, tasty-hunit >=0.9 && <0.11
, tasty-quickcheck >=0.8 && <0.11
Expand Down
5 changes: 5 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ dependencies:
- hashable
- hashtables
- intern >=0.9.1.2 && <1.0.0.0
- mwc-random
- primitive
- random
- reflection
- unordered-containers

Expand All @@ -41,6 +44,8 @@ tests:
- -with-rtsopts=-N
dependencies:
- decision-diagrams
- QuickCheck
- statistics
- tasty >=0.10.1
- tasty-hunit >=0.9 && <0.11
- tasty-quickcheck >=0.8 && <0.11
Expand Down
77 changes: 77 additions & 0 deletions src/Data/DecisionDiagram/ZDD.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -71,6 +72,9 @@ module Data.DecisionDiagram.ZDD
, minimalHittingSetsKnuth
, minimalHittingSetsImai

-- * Random sampling
, uniformM

-- * Misc
, flatten
, fold
Expand All @@ -84,19 +88,31 @@ module Data.DecisionDiagram.ZDD
import Prelude hiding (null)

import Control.Monad
#if !MIN_VERSION_mwc_random(0,15,0)
import Control.Monad.Primitive
#endif
import Control.Monad.ST
import qualified Data.Foldable as Foldable
import Data.Hashable
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as HashMap
import qualified Data.HashTable.Class as H
import qualified Data.HashTable.ST.Cuckoo as C
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.List (sortBy)
import Data.Proxy
import Data.Ratio
import Data.Set (Set)
import qualified Data.Set as Set
import qualified GHC.Exts as Exts
import Numeric.Natural
#if MIN_VERSION_mwc_random(0,15,0)
import System.Random.Stateful (StatefulGen (..))
#else
import System.Random.MWC (Gen)
#endif
import System.Random.MWC.Distributions (bernoulli)

import Data.DecisionDiagram.BDD.Internal
import qualified Data.DecisionDiagram.BDD as BDD
Expand Down Expand Up @@ -498,3 +514,64 @@ fold' !ff !tt br (ZDD node) = runST $ do
f node

-- ------------------------------------------------------------------------

-- | Sample a set from uniform distribution over elements of the ZDD.
--
-- The function constructs a table internally and the table is shared across
-- multiple use of the resulting action (@m IntSet@).
-- Therefore, the code
--
-- @
-- let g = uniformM zdd gen
-- s1 <- g
-- s2 <- g
-- @
--
-- is more efficient than
--
-- @
-- s1 <- uniformM zdd gen
-- s2 <- uniformM zdd gen
-- @
-- .
#if MIN_VERSION_mwc_random(0,15,0)
uniformM :: forall a g m. (ItemOrder a, StatefulGen g m) => ZDD a -> g -> m IntSet
#else
uniformM :: forall a m. (ItemOrder a, PrimMonad m) => ZDD a -> Gen (PrimState m) -> m IntSet
#endif
uniformM (ZDD F) = error "Data.DecisionDiagram.ZDD.uniformM: empty ZDD"
uniformM (ZDD node) = func
where
func gen = f node []
where
f F _ = error "Data.DecisionDiagram.ZDD.uniformM: should not happen"
f T r = return $ IntSet.fromList r
f p@(Branch top p0 p1) r = do
b <- bernoulli (table HashMap.! p) gen
if b then
f p1 (top : r)
else
f p0 r

table :: HashMap Node Double
table = runST $ do
h <- C.newSized defaultTableSize
let f F = return (0 :: Integer)
f T = return 1
f p@(Branch _ p0 p1) = do
m <- H.lookup h p
case m of
Just (ret, _) -> return ret
Nothing -> do
n0 <- f p0
n1 <- f p1
let s = n0 + n1
r :: Double
r = realToFrac (n1 % (n0 + n1))
seq r $ H.insert h p (s, r)
return s
_ <- f node
xs <- H.toList h
return $ HashMap.fromList [(n, r) | (n, (_, r)) <- xs]

-- ------------------------------------------------------------------------
24 changes: 24 additions & 0 deletions test/TestZDD.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ import Control.Monad
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.List
import qualified Data.Map.Strict as Map
import Data.Proxy
import Data.Set (Set)
import qualified Data.Set as Set
import qualified GHC.Exts as Exts
import Statistics.Distribution
import Statistics.Distribution.ChiSquared (chiSquared)
import qualified System.Random.MWC as Rand
import qualified Test.QuickCheck.Monadic as QM
import Test.Tasty
import Test.Tasty.QuickCheck
import Test.Tasty.TH
Expand Down Expand Up @@ -397,6 +402,25 @@ prop_flatten =
forAll arbitrary $ \(a :: ZDD o) ->
ZDD.flatten a === IntSet.unions (ZDD.toListOfIntSets a)

prop_uniformM :: Property
prop_uniformM =
withDefaultOrder $ \(_ :: Proxy o) ->
forAll (arbitrary `suchThat` ((>= (2::Integer)) . ZDD.size)) $ \(a :: ZDD o) ->
QM.monadicIO $ do
gen <- QM.run Rand.create
let m :: Integer
m = ZDD.size a
n = 1000
samples <- QM.run $ replicateM n $ ZDD.uniformM a gen
let hist_actual = Map.fromListWith (+) [(s, 1) | s <- samples]
hist_expected = [(s, fromIntegral n / fromIntegral m) | s <- ZDD.toListOfIntSets a]
chi_sq = sum [(Map.findWithDefault 0 s hist_actual - cnt) ** 2 / cnt | (s, cnt) <- hist_expected]
threshold = complQuantile (chiSquared (fromIntegral m - 1)) 0.001
QM.monitor $ counterexample $ show hist_actual ++ " /= " ++ show (Map.fromList hist_expected)
QM.assert $ and [xs `ZDD.member` a | xs <- Map.keys hist_actual]
QM.monitor $ counterexample $ "χ² = " ++ show chi_sq ++ " >= " ++ show threshold
QM.assert $ chi_sq < threshold

-- ------------------------------------------------------------------------

zddTestGroup :: TestTree
Expand Down

0 comments on commit 764a77d

Please sign in to comment.