Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize findIndex, findIndexEnd and map #347

Merged
merged 1 commit into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,18 @@ append = mappend
-- | /O(n)/ 'map' @f xs@ is the ByteString obtained by applying @f@ to each
-- element of @xs@.
map :: (Word8 -> Word8) -> ByteString -> ByteString
map f (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
create len $ map_ 0 a
map f (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \srcPtr ->
create len $ \dstPtr -> m srcPtr dstPtr
where
map_ :: Int -> Ptr Word8 -> Ptr Word8 -> IO ()
map_ !n !p1 !p2
| n >= len = return ()
| otherwise = do
x <- peekByteOff p1 n
pokeByteOff p2 n (f x)
map_ (n+1) p1 p2
m !p1 !p2 = map_ 0
where
map_ :: Int -> IO ()
map_ !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff p1 n
pokeByteOff p2 n (f x)
map_ (n+1)
Bodigrim marked this conversation as resolved.
Show resolved Hide resolved
{-# INLINE map #-}
Boarders marked this conversation as resolved.
Show resolved Hide resolved

-- | /O(n)/ 'reverse' @xs@ efficiently returns the elements of @xs@ in reverse order.
Expand Down Expand Up @@ -1342,13 +1344,15 @@ count w (BS x m) = accursedUnutterablePerformIO $ unsafeWithForeignPtr x $ \p ->
-- returns the index of the first element in the ByteString
-- satisfying the predicate.
findIndex :: (Word8 -> Bool) -> ByteString -> Maybe Int
findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \f -> go f 0
findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x g
where
go !ptr !n | n >= l = return Nothing
| otherwise = do w <- peek ptr
if k w
then return (Just n)
else go (ptr `plusPtr` 1) (n+1)
g !ptr = go 0
where
go !n | n >= l = return Nothing
| otherwise = do w <- peek $ ptr `plusPtr` n
if k w
then return (Just n)
else go (n+1)
{-# INLINE [1] findIndex #-}

-- | /O(n)/ The 'findIndexEnd' function takes a predicate and a 'ByteString' and
Expand All @@ -1357,13 +1361,15 @@ findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \f -> g
--
-- @since 0.10.12.0
findIndexEnd :: (Word8 -> Bool) -> ByteString -> Maybe Int
findIndexEnd k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \ f -> go f (l-1)
findIndexEnd k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x g
where
go !ptr !n | n < 0 = return Nothing
| otherwise = do w <- peekByteOff ptr n
if k w
then return (Just n)
else go ptr (n-1)
g !ptr = go (l-1)
where
go !n | n < 0 = return Nothing
| otherwise = do w <- peekByteOff ptr n
if k w
then return (Just n)
else go (n-1)
{-# INLINE findIndexEnd #-}

-- | /O(n)/ The 'findIndices' function extends 'findIndex', by returning the
Expand Down
21 changes: 18 additions & 3 deletions bench/BenchAll.hs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ zeroes = L.replicate 10000 0
zeroOneRepeating :: L.ByteString
zeroOneRepeating = L.take 10000 (L.cycle (L.pack [0,1]))


largeTraversalInput :: S.ByteString
largeTraversalInput = S.concat (replicate 10 byteStringData)

smallTraversalInput :: S.ByteString
smallTraversalInput = S8.pack "The quick brown fox"

main :: IO ()
main = do
mapM_ putStrLn sanityCheckInfo
Expand Down Expand Up @@ -424,8 +431,16 @@ main = do
, bench "groupBy (>=)" $ nf (L.groupBy (>=)) zeroes
, bench "groupBy (>)" $ nf (L.groupBy (>)) zeroes
]
, bgroup "findIndex"
[ bench "findIndices" $ nf (sum . S.findIndices even) byteStringData
, bench "find" $ nf (S.find (>= 9998)) byteStringData
, bgroup "findIndex_"
[ bench "findIndices" $ nf (sum . S.findIndices (\x -> x == 129 || x == 72)) byteStringData
, bench "find" $ nf (S.find (>= 198)) byteStringData
Bodigrim marked this conversation as resolved.
Show resolved Hide resolved
]
, bgroup "findIndexEnd"
[ bench "findIndexEnd" $ nf (S.findIndexEnd (<= 57)) byteStringData
, bench "elemIndexInd" $ nf (S.elemIndexEnd 42) byteStringData
]
, bgroup "traversals"
[ bench "map (+1)" $ nf (S.map (+ 1)) largeTraversalInput
, bench "map (+1)" $ nf (S.map (+ 1)) smallTraversalInput
]
]