Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 37 additions & 39 deletions vector/src/Data/Vector/Generic/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ module Data.Vector.Generic.Mutable (
) where

import Control.Monad ((<=<))
import Control.Monad.ST
import Data.Vector.Generic.Mutable.Base
import qualified Data.Vector.Generic.Base as V

Expand All @@ -91,7 +92,7 @@ import Data.Vector.Fusion.Bundle.Size
import Data.Vector.Fusion.Util ( delay_inline )
import Data.Vector.Internal.Check

import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )
import Control.Monad.Primitive ( PrimMonad(..), stToPrim )

import Prelude
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..)
Expand All @@ -106,8 +107,7 @@ import Data.Bits ( Bits(shiftR) )
-- Internal functions
-- ------------------

unsafeAppend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
unsafeAppend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a)
{-# INLINE_INNER unsafeAppend1 #-}
-- NOTE: The case distinction has to be on the outside because
-- GHC creates a join point for the unsafeWrite even when everything
Expand All @@ -122,8 +122,7 @@ unsafeAppend1 v i x
checkIndex Internal i (length v') $ unsafeWrite v' i x
return v'

unsafePrepend1 :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
unsafePrepend1 :: (MVector v a) => v s a -> Int -> a -> ST s (v s a, Int)
{-# INLINE_INNER unsafePrepend1 #-}
unsafePrepend1 v i x
| i /= 0 = do
Expand Down Expand Up @@ -207,7 +206,7 @@ unstream :: (PrimMonad m, MVector v a)
=> Bundle u a -> m (v (PrimState m) a)
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
{-# INLINE_FUSED unstream #-}
unstream s = munstream (Bundle.lift s)
unstream s = stToPrim $ munstream (Bundle.lift s)

-- | Create a new mutable vector and fill it with elements from the monadic
-- stream. The vector will grow exponentially if the maximum size of the stream
Expand Down Expand Up @@ -243,9 +242,8 @@ munstreamUnknown s
$ unsafeSlice 0 n v'
where
{-# INLINE_INNER put #-}
put (v,i) x = do
v' <- unsafeAppend1 v i x
return (v',i+1)
put (v,i) x = stToPrim $ do v' <- unsafeAppend1 v i x
return (v',i+1)


-- | Create a new mutable vector and fill it with elements from the 'Bundle'.
Expand All @@ -255,7 +253,7 @@ vunstream :: (PrimMonad m, V.Vector v a)
=> Bundle v a -> m (V.Mutable v (PrimState m) a)
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstreamR)
{-# INLINE_FUSED vunstream #-}
vunstream s = vmunstream (Bundle.lift s)
vunstream s = stToPrim $ vmunstream (Bundle.lift s)

-- | Create a new mutable vector and fill it with elements from the monadic
-- stream. The vector will grow exponentially if the maximum size of the stream
Expand Down Expand Up @@ -311,7 +309,7 @@ unstreamR :: (PrimMonad m, MVector v a)
=> Bundle u a -> m (v (PrimState m) a)
-- NOTE: replace INLINE_FUSED by INLINE? (also in unstream)
{-# INLINE_FUSED unstreamR #-}
unstreamR s = munstreamR (Bundle.lift s)
unstreamR s = stToPrim $ munstreamR (Bundle.lift s)

-- | Create a new mutable vector and fill it with elements from the monadic
-- stream from right to left. The vector will grow exponentially if the maximum
Expand Down Expand Up @@ -350,7 +348,7 @@ munstreamRUnknown s
$ unsafeSlice i (n-i) v'
where
{-# INLINE_INNER put #-}
put (v,i) x = unsafePrepend1 v i x
put (v,i) x = stToPrim $ unsafePrepend1 v i x

-- Length
-- ------
Expand Down Expand Up @@ -563,10 +561,9 @@ enlarge_delta :: MVector v a => v s a -> Int
enlarge_delta v = max (length v) 1

-- | Grow a vector logarithmically.
enlarge :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> m (v (PrimState m) a)
enlarge :: (MVector v a) => v s a -> ST s (v s a)
{-# INLINE enlarge #-}
enlarge v = stToPrim $ do
enlarge v = do
vnew <- unsafeGrow v by
basicInitialize $ basicUnsafeSlice (length v) by vnew
return vnew
Expand Down Expand Up @@ -996,10 +993,10 @@ unsafeMove dst src = check Unsafe "length mismatch" (length dst == length src)
accum :: forall m v a b u. (HasCallStack, PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
{-# INLINE accum #-}
accum f !v s = Bundle.mapM_ upd s
accum f !v s = stToPrim $ Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd :: HasCallStack => (Int, b) -> m ()
upd :: HasCallStack => (Int, b) -> ST (PrimState m) ()
upd (i,b) = do
a <- checkIndex Bounds i n $ unsafeRead v i
unsafeWrite v i (f a b)
Expand All @@ -1008,18 +1005,18 @@ accum f !v s = Bundle.mapM_ upd s
update :: forall m v a u. (HasCallStack, PrimMonad m, MVector v a)
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
{-# INLINE update #-}
update !v s = Bundle.mapM_ upd s
update !v s = stToPrim $ Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd :: HasCallStack => (Int, a) -> m ()
upd :: HasCallStack => (Int, a) -> ST (PrimState m) ()
upd (i,b) = checkIndex Bounds i n $ unsafeWrite v i b

!n = length v

unsafeAccum :: (PrimMonad m, MVector v a)
=> (a -> b -> a) -> v (PrimState m) a -> Bundle u (Int, b) -> m ()
{-# INLINE unsafeAccum #-}
unsafeAccum f !v s = Bundle.mapM_ upd s
unsafeAccum f !v s = stToPrim $ Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = do
Expand All @@ -1028,17 +1025,17 @@ unsafeAccum f !v s = Bundle.mapM_ upd s
!n = length v

unsafeUpdate :: (PrimMonad m, MVector v a)
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
=> v (PrimState m) a -> Bundle u (Int, a) -> m ()
{-# INLINE unsafeUpdate #-}
unsafeUpdate !v s = Bundle.mapM_ upd s
unsafeUpdate !v s = stToPrim $ Bundle.mapM_ upd s
where
{-# INLINE_INNER upd #-}
upd (i,b) = checkIndex Unsafe i n $ unsafeWrite v i b
!n = length v

reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
{-# INLINE reverse #-}
reverse !v = reverse_loop 0 (length v - 1)
reverse !v = stToPrim $ reverse_loop 0 (length v - 1)
where
reverse_loop i j | i < j = do
unsafeSwap v i j
Expand All @@ -1048,11 +1045,11 @@ reverse !v = reverse_loop 0 (length v - 1)
unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
=> (a -> Bool) -> v (PrimState m) a -> m Int
{-# INLINE unstablePartition #-}
unstablePartition f !v = from_left 0 (length v)
unstablePartition f !v = stToPrim $ from_left 0 (length v)
where
-- NOTE: GHC 6.10.4 panics without the signatures on from_left and
-- from_right
from_left :: Int -> Int -> m Int
from_left :: Int -> Int -> ST (PrimState m) Int
from_left i j
| i == j = return i
| otherwise = do
Expand All @@ -1061,7 +1058,7 @@ unstablePartition f !v = from_left 0 (length v)
then from_left (i+1) j
else from_right i (j-1)

from_right :: Int -> Int -> m Int
from_right :: Int -> Int -> ST (PrimState m) Int
from_right i j
| i == j = return i
| otherwise = do
Expand All @@ -1078,7 +1075,8 @@ unstablePartitionBundle :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE unstablePartitionBundle #-}
unstablePartitionBundle f s
= case upperBound (Bundle.size s) of
= stToPrim
$ case upperBound (Bundle.size s) of
Just n -> unstablePartitionMax f s n
Nothing -> partitionUnknown f s

Expand All @@ -1087,7 +1085,7 @@ unstablePartitionMax :: (PrimMonad m, MVector v a)
-> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE unstablePartitionMax #-}
unstablePartitionMax f s n
= do
= stToPrim $ do
v <- checkLength Internal n $ unsafeNew n
let {-# INLINE_INNER put #-}
put (i, j) x
Expand All @@ -1105,15 +1103,15 @@ partitionBundle :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionBundle #-}
partitionBundle f s
= case upperBound (Bundle.size s) of
= stToPrim
$ case upperBound (Bundle.size s) of
Just n -> partitionMax f s n
Nothing -> partitionUnknown f s

partitionMax :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionMax #-}
partitionMax f s n
= do
partitionMax f s n = stToPrim $ do
v <- checkLength Internal n $ unsafeNew n

let {-# INLINE_INNER put #-}
Expand All @@ -1138,8 +1136,7 @@ partitionMax f s n
partitionUnknown :: (PrimMonad m, MVector v a)
=> (a -> Bool) -> Bundle u a -> m (v (PrimState m) a, v (PrimState m) a)
{-# INLINE partitionUnknown #-}
partitionUnknown f s
= do
partitionUnknown f s = stToPrim $ do
v1 <- unsafeNew 0
v2 <- unsafeNew 0
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
Expand All @@ -1165,15 +1162,16 @@ partitionWithBundle :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
{-# INLINE partitionWithBundle #-}
partitionWithBundle f s
= case upperBound (Bundle.size s) of
= stToPrim
$ case upperBound (Bundle.size s) of
Just n -> partitionWithMax f s n
Nothing -> partitionWithUnknown f s

partitionWithMax :: (PrimMonad m, MVector v a, MVector v b, MVector v c)
=> (a -> Either b c) -> Bundle u a -> Int -> m (v (PrimState m) b, v (PrimState m) c)
{-# INLINE partitionWithMax #-}
partitionWithMax f s n
= do
= stToPrim $ do
v1 <- unsafeNew n
v2 <- unsafeNew n
let {-# INLINE_INNER put #-}
Expand All @@ -1194,7 +1192,7 @@ partitionWithUnknown :: forall m v u a b c.
=> (a -> Either b c) -> Bundle u a -> m (v (PrimState m) b, v (PrimState m) c)
{-# INLINE partitionWithUnknown #-}
partitionWithUnknown f s
= do
= stToPrim $ do
v1 <- unsafeNew 0
v2 <- unsafeNew 0
(v1', n1, v2', n2) <- Bundle.foldM' put (v1, 0, v2, 0) s
Expand All @@ -1204,14 +1202,14 @@ partitionWithUnknown f s
where
put :: (v (PrimState m) b, Int, v (PrimState m) c, Int)
-> a
-> m (v (PrimState m) b, Int, v (PrimState m) c, Int)
-> ST (PrimState m) (v (PrimState m) b, Int, v (PrimState m) c, Int)
{-# INLINE_INNER put #-}
put (v1, i1, v2, i2) x = case f x of
Left b -> do
v1' <- unsafeAppend1 v1 i1 b
v1' <- stToPrim $ unsafeAppend1 v1 i1 b
return (v1', i1+1, v2, i2)
Right c -> do
v2' <- unsafeAppend1 v2 i2 c
v2' <- stToPrim $ unsafeAppend1 v2 i2 c
return (v1, i1, v2', i2+1)

-- Modifying vectors
Expand Down
Loading