Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert folds to take two arguments #345

Merged
merged 1 commit into from
May 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .hlint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
within:
- Data.ByteString.Builder.Internal
- Data.ByteString.Builder.Prim
- ignore:
name: Reduce duplication
within:
- Data.ByteString
- ignore:
name: Redundant lambda
within:
- Data.ByteString.Builder.Internal
- Data.ByteString
159 changes: 93 additions & 66 deletions Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,8 @@ transpose = P.map pack . List.transpose . P.map unpack
-- ByteString using the binary operator, from left to right.
--
foldl :: (a -> Word8 -> a) -> a -> ByteString -> a
foldl f z (BS fp len) = go (end `plusPtr` len)
where
foldl f z = \(BS fp len) ->
let
end = unsafeForeignPtrToPtr fp `plusPtr` (-1)
-- not tail recursive; traverses array right to left
go !p | p == end = z
Expand All @@ -492,29 +492,44 @@ foldl f z (BS fp len) = go (end `plusPtr` len)
touchForeignPtr fp
return x'
in f (go (p `plusPtr` (-1))) x

in
go (end `plusPtr` len)
{-# INLINE foldl #-}

{-
Note [fold inlining]:

GHC will only inline a function marked INLINE
if it is fully saturated (meaning the number of
arguments provided at the call site is at least
equal to the number of lhs arguments).

-}
-- | 'foldl'' is like 'foldl', but strict in the accumulator.
--
foldl' :: (a -> Word8 -> a) -> a -> ByteString -> a
foldl' f v (BS fp len) =
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
where
foldl' f v = \(BS fp len) ->
-- see fold inlining
let
g ptr = go v ptr
where
end = ptr `plusPtr` len
-- tail recursive; traverses array left to right
go !z !p | p == end = return z
| otherwise = do x <- peek p
go (f z x) (p `plusPtr` 1)
in
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
{-# INLINE foldl' #-}

-- | 'foldr', applied to a binary operator, a starting value
-- (typically the right-identity of the operator), and a ByteString,
-- reduces the ByteString using the binary operator, from right to left.
foldr :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr k z (BS fp len) = go ptr
where
foldr k z = \(BS fp len) ->
-- see fold inlining
let
ptr = unsafeForeignPtrToPtr fp
end = ptr `plusPtr` len
-- not tail recursive; traverses array left to right
Expand All @@ -524,20 +539,25 @@ foldr k z (BS fp len) = go ptr
touchForeignPtr fp
return x'
in k x (go (p `plusPtr` 1))
in
go ptr
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr' k v (BS fp len) =
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g
where
foldr' k v = \(BS fp len) ->
-- see fold inlining
let
g ptr = go v (end `plusPtr` len)
where
end = ptr `plusPtr` (-1)
-- tail recursive; traverses array right to left
go !z !p | p == end = return z
| otherwise = do x <- peek p
go (k x z) (p `plusPtr` (-1))
in
accursedUnutterablePerformIO $ unsafeWithForeignPtr fp g

{-# INLINE foldr' #-}

-- | 'foldl1' is a variant of 'foldl' that has no starting value
Expand Down Expand Up @@ -688,40 +708,42 @@ minimum xs@(BS x l)
-- passing an accumulating parameter from left to right, and returning a
-- final value of this accumulator together with the new list.
mapAccumL :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString)
mapAccumL f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
mapAccumL f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
-- see fold inlining
gp <- mallocByteString len
let
go src dst = mapAccumL_ acc 0
where
mapAccumL_ !s !n
| n >= len = return s
| otherwise = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumL_ s' (n+1)
acc' <- unsafeWithForeignPtr gp (go a)
return (acc', BS gp len)
where
go src dst = mapAccumL_ acc 0
where
mapAccumL_ !s !n
| n >= len = return s
| otherwise = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumL_ s' (n+1)
{-# INLINE mapAccumL #-}

-- | The 'mapAccumR' function behaves like a combination of 'map' and
-- 'foldr'; it applies a function to each element of a ByteString,
-- passing an accumulating parameter from right to left, and returning a
-- final value of this accumulator together with the new ByteString.
mapAccumR :: (acc -> Word8 -> (acc, Word8)) -> acc -> ByteString -> (acc, ByteString)
mapAccumR f acc (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
mapAccumR f acc = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a -> do
-- see fold inlining
gp <- mallocByteString len
let
go src dst = mapAccumR_ acc (len-1)
where
mapAccumR_ !s (-1) = return s
mapAccumR_ !s !n = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumR_ s' (n-1)
acc' <- unsafeWithForeignPtr gp (go a)
return (acc', BS gp len)
where
go src dst = mapAccumR_ acc (len-1)
where
mapAccumR_ !s (-1) = return s
mapAccumR_ !s !n = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumR_ s' (n-1)
{-# INLINE mapAccumR #-}

-- ---------------------------------------------------------------------
Expand All @@ -746,20 +768,21 @@ scanl
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanl f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
scanl f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
-- see fold inlining
create (len+1) $ \q -> do
poke q v
let
go src dst = scanl_ v 0
where
scanl_ !z !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff src n
let z' = f z x
pokeByteOff dst n z'
scanl_ z' (n+1)
go a (q `plusPtr` 1)
where
go src dst = scanl_ v 0
where
scanl_ !z !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff src n
let z' = f z x
pokeByteOff dst n z'
scanl_ z' (n+1)
{-# INLINE scanl #-}

-- n.b. haskell's List scan returns a list one bigger than the
Expand Down Expand Up @@ -795,20 +818,21 @@ scanr
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr f v (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
scanr f v = \(BS fp len) -> unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
-- see fold inlining
create (len+1) $ \q -> do
poke (q `plusPtr` len) v
let
go p q = scanr_ v (len-1)
where
scanr_ !z !n
| n < 0 = return ()
| otherwise = do
x <- peekByteOff p n
let z' = f x z
pokeByteOff q n z'
scanr_ z' (n-1)
go a q
where
go p q = scanr_ v (len-1)
where
scanr_ !z !n
| n < 0 = return ()
| otherwise = do
x <- peekByteOff p n
let z' = f x z
pokeByteOff q n z'
scanr_ z' (n-1)
{-# INLINE scanr #-}

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
Expand Down Expand Up @@ -1412,21 +1436,24 @@ notElem c ps = not (c `elem` ps)
-- returns a ByteString containing those characters that satisfy the
-- predicate.
filter :: (Word8 -> Bool) -> ByteString -> ByteString
filter k ps@(BS x l)
| null ps = ps
| otherwise = unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do
filter k = \ps@(BS x l) ->
-- see fold inlining.
if null ps
then ps
else
unsafePerformIO $ createAndTrim l $ \p -> withForeignPtr x $ \f -> do
let
go' pf pt = go pf pt
where
end = pf `plusPtr` l
go !f !t | f == end = return t
| otherwise = do
w <- peek f
if k w
then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1)
else go (f `plusPtr` 1) t
t <- go' f p
return $! t `minusPtr` p -- actual length
where
go' pf pt = go pf pt
where
end = pf `plusPtr` l
go !f !t | f == end = return t
| otherwise = do
w <- peek f
if k w
then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1)
else go (f `plusPtr` 1) t
{-# INLINE filter #-}

{-
Expand Down