From 7eb216d16d1aa4d0c19a42cddbb23e2abfd87e74 Mon Sep 17 00:00:00 2001 From: Bryan O'Sullivan Date: Sat, 19 Sep 2009 03:53:19 +0000 Subject: [PATCH] Move resampling code to the ST monad. --- Statistics/Function.hs | 12 ++++++------ Statistics/Resampling.hs | 17 ++++++++--------- statistics.cabal | 1 - 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/Statistics/Function.hs b/Statistics/Function.hs index 87ab5b9a..f7e8ae72 100644 --- a/Statistics/Function.hs +++ b/Statistics/Function.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE Rank2Types, TypeOperators #-} -- | -- Module : Statistics.Function -- Copyright : (c) 2009 Bryan O'Sullivan @@ -19,7 +19,7 @@ module Statistics.Function ) where import Control.Exception (assert) -import Control.Monad.ST (unsafeSTToIO) +import Control.Monad.ST (ST) import Data.Array.Vector.Algorithms.Combinators (apply) import Data.Array.Vector import qualified Data.Array.Vector.Algorithms.Intro as I @@ -49,12 +49,12 @@ minMax = fini . foldlU go (MM (1/0) (-1/0)) {-# INLINE minMax #-} -- | Create an array, using the given action to populate each element. -createU :: (UA e) => Int -> (Int -> IO e) -> IO (UArr e) +createU :: (UA e) => forall s. Int -> (Int -> ST s e) -> ST s (UArr e) createU size itemAt = assert (size >= 0) $ - unsafeSTToIO (newMU size) >>= loop 0 + newMU size >>= loop 0 where - loop k arr | k >= size = unsafeSTToIO (unsafeFreezeAllMU arr) + loop k arr | k >= size = unsafeFreezeAllMU arr | otherwise = do r <- itemAt k - unsafeSTToIO (writeMU arr k r) + writeMU arr k r loop (k+1) arr diff --git a/Statistics/Resampling.hs b/Statistics/Resampling.hs index 38956549..40b7d543 100644 --- a/Statistics/Resampling.hs +++ b/Statistics/Resampling.hs @@ -17,12 +17,12 @@ module Statistics.Resampling ) where import Control.Monad (forM_) -import Control.Monad.ST (unsafeSTToIO) +import Control.Monad.ST (ST) import Data.Array.Vector import Data.Array.Vector.Algorithms.Intro (sort) import Statistics.Function (createU) +import Statistics.RandomVariate (Gen, uniform) import Statistics.Types (Estimator, Sample) -import System.Random.Mersenne (MTGen, random) -- | A resample drawn randomly, with replacement, from a set of data -- points. Distinct from a normal array to make it harder for your @@ -33,20 +33,19 @@ newtype Resample = Resample { -- | Resample a data set repeatedly, with replacement, computing each -- estimate over the resampled data. -resample :: MTGen -> [Estimator] -> Int -> Sample -> IO [Resample] +resample :: Gen s -> [Estimator] -> Int -> Sample -> ST s [Resample] resample gen ests numResamples samples = do - results <- unsafeSTToIO . mapM (const (newMU numResamples)) $ ests + results <- mapM (const (newMU numResamples)) $ ests loop 0 (zip ests results) - unsafeSTToIO $ do - mapM_ sort results - mapM (fmap Resample . unsafeFreezeAllMU) results + mapM_ sort results + mapM (fmap Resample . unsafeFreezeAllMU) results where loop k ers | k >= numResamples = return () | otherwise = do re <- createU n $ \_ -> do - r <- random gen + r <- uniform gen return (indexU samples (abs r `mod` n)) - unsafeSTToIO . forM_ ers $ \(est,arr) -> + forM_ ers $ \(est,arr) -> writeMU arr k . est $ re loop (k+1) ers n = lengthU samples diff --git a/statistics.cabal b/statistics.cabal index 1c2855d8..caa3414c 100644 --- a/statistics.cabal +++ b/statistics.cabal @@ -55,7 +55,6 @@ library build-depends: base < 5, erf, - mersenne-random, time, uvector >= 0.1.0.4, uvector-algorithms >= 0.2