diff --git a/vector/src/Data/Vector/Generic/Mutable.hs b/vector/src/Data/Vector/Generic/Mutable.hs index 7ccf21dc..45326f3c 100644 --- a/vector/src/Data/Vector/Generic/Mutable.hs +++ b/vector/src/Data/Vector/Generic/Mutable.hs @@ -79,6 +79,7 @@ module Data.Vector.Generic.Mutable ( ) where import Control.Monad ((<=<)) +import Control.Monad.ST import Data.Vector.Generic.Mutable.Base import qualified Data.Vector.Generic.Base as V @@ -91,7 +92,7 @@ import Data.Vector.Fusion.Bundle.Size import Data.Vector.Fusion.Util ( delay_inline ) import Data.Vector.Internal.Check -import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim ) +import Control.Monad.Primitive ( PrimMonad(..), stToPrim ) import Prelude ( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..) @@ -106,8 +107,7 @@ import Data.Bits ( Bits(shiftR) ) -- Internal functions -- ------------------ -unsafeAppend1 :: (PrimMonad m, MVector v a) - => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a) +unsafeAppend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a) {-# INLINE_INNER unsafeAppend1 #-} -- NOTE: The case distinction has to be on the outside because -- GHC creates a join point for the unsafeWrite even when everything @@ -122,8 +122,7 @@ unsafeAppend1 v i x checkIndex Internal i (length v') $ unsafeWrite v' i x return v' -unsafePrepend1 :: (PrimMonad m, MVector v a) - => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int) +unsafePrepend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a, Int) {-# INLINE_INNER unsafePrepend1 #-} unsafePrepend1 v i x | i /= 0 = do @@ -207,7 +206,7 @@ unstream :: (PrimMonad m, MVector v a) => Bundle u a -> m (v (PrimState m) a) -- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR) {-# INLINE_FUSED unstream #-} -unstream s = munstream (Bundle.lift s) +unstream s = stToPrim $ munstream (Bundle.lift s) -- | Create a new mutable vector and fill it with elements from the monadic -- stream. The vector will grow exponentially if the maximum size of the stream @@ -243,9 +242,8 @@ munstreamUnknown s $ unsafeSlice 0 n v' where {-# INLINE_INNER put #-} - put (v,i) x = do - v' <- unsafeAppend1 v i x - return (v',i+1) + put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x + return (v',i+1) -- | Create a new mutable vector and fill it with elements from the 'Bundle'. @@ -255,7 +253,7 @@ vunstream :: (PrimMonad m, V.Vector v a) => Bundle v a -> m (V.Mutable v (PrimState m) a) -- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR) {-# INLINE_FUSED vunstream #-} -vunstream s = vmunstream (Bundle.lift s) +vunstream s = stToPrim $ vmunstream (Bundle.lift s) -- | Create a new mutable vector and fill it with elements from the monadic -- stream. The vector will grow exponentially if the maximum size of the stream @@ -311,7 +309,7 @@ unstreamR :: (PrimMonad m, MVector v a) => Bundle u a -> m (v (PrimState m) a) -- NOTE: replace INLINE_FUSED by INLINE? (also in unstream) {-# INLINE_FUSED unstreamR #-} -unstreamR s = munstreamR (Bundle.lift s) +unstreamR s = stToPrim $ munstreamR (Bundle.lift s) -- | Create a new mutable vector and fill it with elements from the monadic -- stream from right to left. The vector will grow exponentially if the maximum @@ -350,7 +348,7 @@ munstreamRUnknown s $ unsafeSlice i (n-i) v' where {-# INLINE_INNER put #-} - put (v,i) x = unsafePrepend1 v i x + put (v,i) x = stToPrim $ unsafePrepend1 v i x -- Length -- ------ @@ -563,10 +561,9 @@ enlarge_delta :: MVector v a => v s a -> Int enlarge_delta v = max (length v) 1 -- | Grow a vector logarithmically. -enlarge :: (PrimMonad m, MVector v a) - => v (PrimState m) a -> m (v (PrimState m) a) +enlarge :: (MVector v a) => v s a -> ST s (v s a) {-# INLINE enlarge #-} -enlarge v = stToPrim $ do +enlarge v = do vnew <- unsafeGrow v by basicInitialize $ basicUnsafeSlice (length v) by vnew return vnew @@ -996,10 +993,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src) accum :: forall m v a b u. (HasCallStack, PrimMonad m, MVector v a) => (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m () {-# INLINE accum #-} -accum f !v s = Bundle.mapM_ upd s +accum f !v s = stToPrim $ Bundle.mapM_ upd s where {-# INLINE_INNER upd #-} - upd :: HasCallStack => (Int, b) -> m () + upd :: HasCallStack => (Int, b) -> ST (PrimState m) () upd (i,b) = do a <- checkIndex Bounds i n $ unsafeRead v i unsafeWrite v i (f a b) @@ -1008,10 +1005,10 @@ accum f !v s = Bundle.mapM_ upd s update :: forall m v a u. (HasCallStack, PrimMonad m, MVector v a) => v (PrimState m) a -> Bundle u (Int, a) -> m () {-# INLINE update #-} -update !v s = Bundle.mapM_ upd s +update !v s = stToPrim $ Bundle.mapM_ upd s where {-# INLINE_INNER upd #-} - upd :: HasCallStack => (Int, a) -> m () + upd :: HasCallStack => (Int, a) -> ST (PrimState m) () upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b !n = length v @@ -1019,7 +1016,7 @@ update !v s = Bundle.mapM_ upd s unsafeAccum :: (PrimMonad m, MVector v a) => (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m () {-# INLINE unsafeAccum #-} -unsafeAccum f !v s = Bundle.mapM_ upd s +unsafeAccum f !v s = stToPrim $ Bundle.mapM_ upd s where {-# INLINE_INNER upd #-} upd (i,b) = do @@ -1028,9 +1025,9 @@ unsafeAccum f !v s = Bundle.mapM_ upd s !n = length v unsafeUpdate :: (PrimMonad m, MVector v a) - => v (PrimState m) a -> Bundle u (Int, a) -> m () + => v (PrimState m) a -> Bundle u (Int, a) -> m () {-# INLINE unsafeUpdate #-} -unsafeUpdate !v s = Bundle.mapM_ upd s +unsafeUpdate !v s = stToPrim $ Bundle.mapM_ upd s where {-# INLINE_INNER upd #-} upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b @@ -1038,7 +1035,7 @@ unsafeUpdate !v s = Bundle.mapM_ upd s reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m () {-# INLINE reverse #-} -reverse !v = reverse_loop 0 (length v - 1) +reverse !v = stToPrim $ reverse_loop 0 (length v - 1) where reverse_loop i j | i < j = do unsafeSwap v i j @@ -1048,11 +1045,11 @@ reverse !v = reverse_loop 0 (length v - 1) unstablePartition :: forall m v a. (PrimMonad m, MVector v a) => (a -> Bool) -> v (PrimState m) a -> m Int {-# INLINE unstablePartition #-} -unstablePartition f !v = from_left 0 (length v) +unstablePartition f !v = stToPrim $ from_left 0 (length v) where -- NOTE: GHC 6.10.4 panics without the signatures on from_left and -- from_right - from_left :: Int -> Int -> m Int + from_left :: Int -> Int -> ST (PrimState m) Int from_left i j | i == j = return i | otherwise = do @@ -1061,7 +1058,7 @@ unstablePartition f !v = from_left 0 (length v) then from_left (i+1) j else from_right i (j-1) - from_right :: Int -> Int -> m Int + from_right :: Int -> Int -> ST (PrimState m) Int from_right i j | i == j = return i | otherwise = do @@ -1078,7 +1075,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a) => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a) {-# INLINE unstablePartitionBundle #-} unstablePartitionBundle f s - = case upperBound (Bundle.size s) of + = stToPrim + $ case upperBound (Bundle.size s) of Just n -> unstablePartitionMax f s n Nothing -> partitionUnknown f s @@ -1087,7 +1085,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a) -> m (v (PrimState m) a, v (PrimState m) a) {-# INLINE unstablePartitionMax #-} unstablePartitionMax f s n - = do + = stToPrim $ do v <- checkLength Internal n $ unsafeNew n let {-# INLINE_INNER put #-} put (i, j) x @@ -1105,15 +1103,15 @@ partitionBundle :: (PrimMonad m, MVector v a) => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a) {-# INLINE partitionBundle #-} partitionBundle f s - = case upperBound (Bundle.size s) of + = stToPrim + $ case upperBound (Bundle.size s) of Just n -> partitionMax f s n Nothing -> partitionUnknown f s partitionMax :: (PrimMonad m, MVector v a) => (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a) {-# INLINE partitionMax #-} -partitionMax f s n - = do +partitionMax f s n = stToPrim $ do v <- checkLength Internal n $ unsafeNew n let {-# INLINE_INNER put #-} @@ -1138,8 +1136,7 @@ partitionMax f s n partitionUnknown :: (PrimMonad m, MVector v a) => (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a) {-# INLINE partitionUnknown #-} -partitionUnknown f s - = do +partitionUnknown f s = stToPrim $ do v1 <- unsafeNew 0 v2 <- unsafeNew 0 (v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s @@ -1165,7 +1162,8 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c) => (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c) {-# INLINE partitionWithBundle #-} partitionWithBundle f s - = case upperBound (Bundle.size s) of + = stToPrim + $ case upperBound (Bundle.size s) of Just n -> partitionWithMax f s n Nothing -> partitionWithUnknown f s @@ -1173,7 +1171,7 @@ partitionWithMax :: (PrimMonad m, MVector v a, MVector v b, MVector v c) => (a -> Either b c) -> Bundle u a -> Int -> m (v (PrimState m) b, v (PrimState m) c) {-# INLINE partitionWithMax #-} partitionWithMax f s n - = do + = stToPrim $ do v1 <- unsafeNew n v2 <- unsafeNew n let {-# INLINE_INNER put #-} @@ -1194,7 +1192,7 @@ partitionWithUnknown :: forall m v u a b c. => (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c) {-# INLINE partitionWithUnknown #-} partitionWithUnknown f s - = do + = stToPrim $ do v1 <- unsafeNew 0 v2 <- unsafeNew 0 (v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s @@ -1204,14 +1202,14 @@ partitionWithUnknown f s where put :: (v (PrimState m) b, Int, v (PrimState m) c, Int) -> a - -> m (v (PrimState m) b, Int, v (PrimState m) c, Int) + -> ST (PrimState m) (v (PrimState m) b, Int, v (PrimState m) c, Int) {-# INLINE_INNER put #-} put (v1, i1, v2, i2) x = case f x of Left b -> do - v1' <- unsafeAppend1 v1 i1 b + v1' <- stToPrim $ unsafeAppend1 v1 i1 b return (v1', i1+1, v2, i2) Right c -> do - v2' <- unsafeAppend1 v2 i2 c + v2' <- stToPrim $ unsafeAppend1 v2 i2 c return (v1, i1, v2', i2+1) -- Modifying vectors