Permalink
Browse files

Merge strictness annotations.

Conflicts:
	Statistics/Sampler/Slice.hs
  • Loading branch information...
2 parents b960d6d + 0e52508 commit 8cdc985470af10328e9c93629b0be14d4568ca27 @kosmikus kosmikus committed Mar 9, 2012
View
@@ -14,3 +14,5 @@ CODAchain2.txt
CODAindex.txt
test/parser/classic-bugs/vol1/line/jags.cmd
sandpit/*.csv
+*.csv
+*.svg
@@ -1,11 +1,59 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE FlexibleContexts #-}
module Statistics.Distribution.Random.Exponential (
exponential
) where
+import Random.CRI
+import Statistics.Distribution.Random.Uniform
-import qualified System.Random.MWC as R
-import Control.Monad
-import Control.Monad.Primitive (PrimMonad, PrimState)
+-- q[k-1] = sum(log(2)^k / k!) k=1,..,n,
+q :: [Double]
+-- q = let factorial = foldr (*) 1 . enumFromTo 1
+-- qk :: Integer -> Double
+-- qk k = (log 2)^k / (fromIntegral (factorial k))
+-- qs :: [Double]
+-- qs = map qk [1..]
+-- in map (\ n -> sum (take n qs)) (take 16 [1..])
+q = [0.6931471805599453,
+ 0.9333736875190459,
+ 0.9888777961838675,
+ 0.9984959252914960,
+ 0.9998292811061389,
+ 0.9999833164100727,
+ 0.9999985691438767,
+ 0.9999998906925558,
+ 0.9999999924734159,
+ 0.9999999995283275,
+ 0.9999999999728814,
+ 0.9999999999985598,
+ 0.9999999999999289,
+ 0.9999999999999968,
+ 0.9999999999999999,
+ 1.0000000000000000]
+{-# INLINE q #-}
+
+exponential :: Source m g Double => g m -> m Double
+exponential rng =
+ let !q0 = q !! 0
+ mk_au :: Double -> Double -> (Double, Double)
+ mk_au !a !u
+ | u > 1.0 = (a, u - 1.0)
+ | otherwise = mk_au (a + q0) (u + u)
+
+ go a _ umin ![] = return (a + umin * q0)
+ go a u umin !(qi:qs) = do
+ !ustar <- uniform rng
+ let umin' = min ustar umin
+ if u > qi
+ then go a u umin' qs
+ else return (a + umin' * q0)
+ in do
+ !u' <- uniform rng
+ let !(a, u) = mk_au 0.0 (u' + u')
+ if u <= q0
+ then return (a + u)
+ else do
+ !us <- uniform rng
+ go a u us (tail q)
-exponential :: PrimMonad m => R.Gen (PrimState m) -> m Double
-exponential rng = liftM (negate . log) (R.uniform rng)
@@ -1,3 +1,4 @@
+{-# LANGUAGE FlexibleContexts #-}
module Statistics.Distribution.Random.Gamma (
gamma
) where
@@ -19,11 +20,13 @@ module Statistics.Distribution.Random.Gamma (
- Computing, 12, 223-246.
-}
-import qualified System.Random.MWC as R
-import qualified Statistics.Distribution.Random.Exponential as E
-import Control.Monad.Primitive (PrimMonad, PrimState)
+import qualified Statistics.Distribution.Random.Exponential as D
+import qualified Statistics.Distribution.Random.Normal as D
+import qualified Statistics.Distribution.Random.Uniform as D
import Data.List (foldl')
import Data.Number.LogFloat (expm1)
+import Data.Word
+import Random.CRI
sqrt32, exp_m1 :: Double
sqrt32 = sqrt 32
@@ -59,18 +62,18 @@ horner :: [Double] -> Double -> Double
horner q r = foldl' (\ a b -> (a + b) * r) 0 q
{-# INLINE horner #-}
-gamma :: (PrimMonad m) => Double -> Double -> R.Gen (PrimState m) -> m Double
+gamma :: (Source m g Double, Source m g Word32) => Double -> Double -> g m -> m Double
gamma shape scale rng
| shape == 0 = return 0
| shape < 1 = gammaGS shape scale rng
| otherwise = gammaGD shape scale rng
-gammaGS :: (PrimMonad m) => Double -> Double -> R.Gen (PrimState m) -> m Double
+gammaGS :: (Source m g Double, Source m g Word32) => Double -> Double -> g m -> m Double
gammaGS shape scale rng =
let e = 1 + exp_m1 * shape
go = do
- ru <- R.uniform rng
- re <- E.exponential rng
+ ru <- D.uniform rng
+ re <- D.exponential rng
let p = e * ru
x | p >= 1 = - log ((e - p) / shape)
| otherwise = exp (log p / shape)
@@ -79,7 +82,7 @@ gammaGS shape scale rng =
if accept then return (scale * x) else go
in go
-gammaGD :: (PrimMonad m) => Double -> Double -> R.Gen (PrimState m) -> m Double
+gammaGD :: (Source m g Double, Source m g Word32) => Double -> Double -> g m -> m Double
gammaGD shape scale rng =
let s2 = shape - 0.5
@@ -99,8 +102,8 @@ gammaGD shape scale rng =
where v = t / (s + s)
choose_t = do
- e <- E.exponential rng
- u' <- R.uniform rng
+ e <- D.exponential rng
+ u' <- D.uniform rng
let uu = u' + u' - 1
let tt = if uu < 0 then b - si * e else b + si * e
if tt >= -0.71874483771719
@@ -112,14 +115,14 @@ gammaGD shape scale rng =
else choose_t
in do
- t <- R.normal rng
+ t <- D.normal rng
let x = s + 0.5 * t
ret_val = scale * x * x
if t >= 0
then return ret_val
else do
- u <- R.uniform rng
+ u <- D.uniform rng
if d * u <= t * t * t
then return ret_val
else do
@@ -0,0 +1,66 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE BangPatterns #-}
+module Statistics.Distribution.Random.Normal (normal) where
+
+import Control.Monad
+import Data.Bits
+import Data.Word
+import qualified Data.Vector.Unboxed as I
+import Random.CRI
+import Statistics.Distribution.Random.Uniform
+
+-- Copied from mwc-random.
+
+data T = T {-# UNPACK #-} !Double {-# UNPACK #-} !Double
+
+-- | Generate a normally distributed random variate.
+--
+-- The implementation uses Doornik's modified ziggurat algorithm.
+-- Compared to the ziggurat algorithm usually used, this is slower,
+-- but generates more independent variates that pass stringent tests
+-- of randomness.
+normal :: (Source m g Double, Source m g Word32) => g m -> m Double
+normal gen = loop
+ where
+ loop = do
+ u <- (subtract 1 . (*2)) `liftM` uniform gen
+ ri <- uniform gen
+ let i = fromIntegral ((ri :: Word32) .&. 127)
+ bi = I.unsafeIndex blocks i
+ bj = I.unsafeIndex blocks (i+1)
+ if abs u < I.unsafeIndex ratios i
+ then return $! u * bi
+ else if i == 0
+ then normalTail (u < 0)
+ else do
+ let x = u * bi
+ xx = x * x
+ d = exp (-0.5 * (bi * bi - xx))
+ e = exp (-0.5 * (bj * bj - xx))
+ c <- uniform gen
+ if e + c * (d - e) < 1
+ then return x
+ else loop
+ blocks = let f = exp (-0.5 * r * r)
+ in (`I.snoc` 0) . I.cons (v/f) . I.cons r .
+ I.unfoldrN 126 go $! T r f
+ where
+ go (T b g) = let !u = T h (exp (-0.5 * h * h))
+ h = sqrt (-2 * log (v / b + g))
+ in Just (h, u)
+ v = 9.91256303526217e-3
+ {-# NOINLINE blocks #-}
+ r = 3.442619855899
+ ratios = I.zipWith (/) (I.tail blocks) blocks
+ {-# NOINLINE ratios #-}
+ normalTail neg = tailing
+ where tailing = do
+ x <- ((/r) . log) `liftM` uniform gen
+ y <- log `liftM` uniform gen
+ if y * (-2) < x * x
+ then tailing
+ else return $! if neg then x - r else r - x
+{-# INLINE normal #-}
+
+
@@ -0,0 +1,9 @@
+{-# LANGUAGE FlexibleContexts #-}
+module Statistics.Distribution.Random.Uniform (
+ uniform
+) where
+
+import Random.CRI
+
+uniform :: Source m g a => g m -> m a
+uniform = grandom
@@ -1,16 +1,16 @@
-{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE ScopedTypeVariables, FlexibleContexts, BangPatterns #-}
module Statistics.Sampler.Slice (
slice,
newSlicerState,
adaptOff,
SlicerState()
) where
-import qualified System.Random.MWC as R
-import qualified Statistics.Distribution.Random.Exponential as E
-import Statistics.Constants (m_epsilon)
+import Random.CRI
+import qualified Statistics.Distribution.Random.Exponential as D
+import qualified Statistics.Distribution.Random.Uniform as D
+import Numeric.MathFunctions.Constants (m_epsilon)
import Control.Monad (when)
-import Control.Monad.Primitive (PrimMonad, PrimState)
import Prelude hiding (max)
@@ -20,13 +20,13 @@ import Prelude hiding (max)
-- http://projecteuclid.org/euclid.aos/1056562461
data SlicerState = SlicerState {
- lower :: Double, -- ^ lower bound of distribution
- upper :: Double, -- ^ upper bound of distribution
- width :: Double, -- ^ width of step out size (approximate scale parameter)
- steps :: Int, -- ^ maximum number of step outs
- adapt :: Bool, -- ^ adapt phase underway
- sumdiff :: Double, -- ^ store sumdiff for adaption phase
- iter :: Int -- ^ number of iterations
+ lower :: !Double, -- ^ lower bound of distribution
+ upper :: !Double, -- ^ upper bound of distribution
+ width :: !Double, -- ^ width of step out size (approximate scale parameter)
+ steps :: !Int, -- ^ maximum number of step outs
+ adapt :: !Bool, -- ^ adapt phase underway
+ sumdiff :: !Double, -- ^ store sumdiff for adaption phase
+ iter :: !Int -- ^ number of iterations
} deriving Show
newSlicerState :: Double -- ^ lower bound
@@ -46,15 +46,15 @@ newSlicerState l u w = SlicerState {
adaptOff :: SlicerState -> SlicerState
adaptOff st = st { adapt = False }
-slice :: (PrimMonad m) =>
+slice :: (Source m g Double) =>
SlicerState
-> (Double -> Double) -- ^ x -> log(f x) where f is proportional to probibility density
-> Double -- ^ current value
- -> R.Gen (PrimState m) -- ^ a random number generator
+ -> g m -- ^ a random number generator
-> m (SlicerState, Double) -- ^ return slicer state and new sample value
slice st g x0 rng =
do
- let g0 = g x0
+ let !g0 = g x0
when (isInfinite g0) $
error $ "Infinite value found in slice sampler: " ++ show x0 ++ " -> " ++ show g0
@@ -63,26 +63,26 @@ slice st g x0 rng =
error $ "NaN found in slice sampler: " ++ show x0 ++ " -> " ++ show g0
-- 1. define slice
- e <- E.exponential rng
- let z = g0 - e
+ e <- D.exponential rng
+ let !z = g0 - e
-- 2. find interval
- u <- R.uniform rng
- let l = x0 - width st * u
- r = l + width st
+ u <- D.uniform rng
+ let !l = x0 - width st * u
+ !r = l + width st
- v :: Double <- R.uniform rng
- let j = floor (fromIntegral (steps st) * v)
- k = (steps st - 1) - j
+ v :: Double <- D.uniform rng
+ let !j = floor (fromIntegral (steps st) * v)
+ !k = (steps st - 1) - j
- let left = calc_left j l
+ let !left = calc_left j l
calc_left n l'
| l' < lower st = lower st
| n == 0 = l'
| z >= g l' = l'
| otherwise = calc_left (n-1) (l' - width st)
- let right = calc_right k r
+ let !right = calc_right k r
calc_right n r'
| r' > upper st = upper st
| n == 0 = r'
@@ -92,28 +92,28 @@ slice st g x0 rng =
-- 3. loop until accept (guaranteed)
let sample left' right' =
do
- u' <- R.uniform rng
- let x = left' + u' * (right' - left')
+ u' <- D.uniform rng
+ let !x = left' + u' * (right' - left')
if z - m_epsilon <= g x
then return x -- accept
else
if x < x0
then sample x right'
else sample left' x
- x1 <- sample left right
+ !x1 <- sample left right
- let st' | adapt st = adaptSlicer x0 x1 st
- | otherwise = st
+ let !st' | adapt st = adaptSlicer x0 x1 st
+ | otherwise = st
return (st', x1)
adaptSlicer :: Double -> Double -> SlicerState -> SlicerState
adaptSlicer old new oldst = newst
where
- iterf = fromIntegral (iter oldst)
- sumdiff' = sumdiff oldst + iterf * abs (new - old)
- newst = oldst {
+ !iterf = fromIntegral (iter oldst)
+ !sumdiff' = sumdiff oldst + iterf * abs (new - old)
+ !newst = oldst {
sumdiff = sumdiff',
iter = iter oldst + 1,
width = if iter oldst > 50
Oops, something went wrong.

0 comments on commit 8cdc985

Please sign in to comment.