Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Straighten folds and scans. #364

Merged
merged 15 commits into from
Aug 19, 2021
62 changes: 56 additions & 6 deletions Data/ByteString/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ module Data.ByteString.Lazy (
foldl1,
foldl1',
foldr,
foldr',
foldr1,
foldr1',

-- ** Special folds
concat,
Expand All @@ -106,9 +108,9 @@ module Data.ByteString.Lazy (
-- * Building ByteStrings
-- ** Scans
scanl,
-- scanl1,
-- scanr,
-- scanr1,
scanl1,
scanr,
scanr1,

-- ** Accumulating maps
mapAccumL,
Expand Down Expand Up @@ -460,6 +462,14 @@ foldr :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr k = foldrChunks (flip (S.foldr k))
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr' f a = go
where
go Empty = a
go (Chunk c cs) = S.foldr' f (foldr' f a cs) c
{-# INLINE foldr' #-}

-- | 'foldl1' is a variant of 'foldl' that has no starting value
-- argument, and thus must be applied to non-empty 'ByteString's.
foldl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
Expand All @@ -479,6 +489,13 @@ foldr1 f (Chunk c0 cs0) = go c0 cs0
where go c Empty = S.foldr1 f c
go c (Chunk c' cs) = S.foldr f (go c' cs) c

-- | 'foldr1'' is like 'foldr1', but strict in the accumulator.
foldr1' :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
foldr1' _ Empty = errorEmptyList "foldr1'"
foldr1' f (Chunk c0 cs0) = go c0 cs0
where go c Empty = S.foldr1' f c
go c (Chunk c' cs) = S.foldr' f (go c' cs) c

-- ---------------------------------------------------------------------
-- Special folds

Expand Down Expand Up @@ -617,11 +634,44 @@ scanl
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanl f z = snd . foldl k (z,singleton z)
where
k (c,acc) a = let n = f c a in (n, acc `snoc` n)
scanl function = fmap (uncurry (flip snoc)) . mapAccumL (\x y -> (function x y, x))
{-# INLINE scanl #-}

-- | 'scanl1' is a variant of 'scanl' that has no starting value argument.
--
-- > scanl1 f [x1, x2, ...] == [x1, x1 `f` x2, ...]
scanl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString
scanl1 function byteStream = case uncons byteStream of
Nothing -> Empty
Just (firstByte, remainingBytes) -> scanl function firstByte remainingBytes

-- | 'scanr' is similar to 'foldr', but returns a list of successive
-- reduced values from the right.
--
-- > scanr f z [..., x{n-1}, xn] == [..., x{n-1} `f` (xn `f` z), xn `f` z, z]
--
-- Note that
--
-- > head (scanr f z xs) == foldr f z xs
-- > last (scanr f z xs) == z
--
scanr
:: (Word8 -> Word8 -> Word8)
-- ^ element -> accumulator -> new accumulator
-> Word8
-- ^ starting value of accumulator
-> ByteString
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr function = fmap (uncurry cons) . mapAccumR (\x y -> (function y x, x))

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
scanr1 :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString
scanr1 function byteStream = case unsnoc byteStream of
Nothing -> Empty
Just (initialBytes, lastByte) -> scanr function lastByte initialBytes

-- ---------------------------------------------------------------------
-- Unfolds and replicates

Expand Down
52 changes: 48 additions & 4 deletions Data/ByteString/Lazy/Char8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ module Data.ByteString.Lazy.Char8 (
foldl1,
foldl1',
foldr,
foldr',
foldr1,
foldr1',

-- ** Special folds
concat,
Expand All @@ -84,9 +86,9 @@ module Data.ByteString.Lazy.Char8 (
-- * Building ByteStrings
-- ** Scans
scanl,
-- scanl1,
-- scanr,
-- scanr1,
scanl1,
scanr,
scanr1,

-- ** Accumulating maps
mapAccumL,
Expand Down Expand Up @@ -238,7 +240,7 @@ import Foreign.Storable (peek)
import Prelude hiding
(reverse,head,tail,last,init,null,length,map,lines,foldl,foldr,unlines
,concat,any,take,drop,splitAt,takeWhile,dropWhile,span,break,elem,filter
,unwords,words,maximum,minimum,all,concatMap,scanl,scanl1,foldl1,foldr1
,unwords,words,maximum,minimum,all,concatMap,scanl,scanl1,scanr,scanr1,foldl1,foldr1
,readFile,writeFile,appendFile,replicate,getContents,getLine,putStr,putStrLn
,zip,zipWith,unzip,notElem,repeat,iterate,interact,cycle)

Expand Down Expand Up @@ -347,6 +349,10 @@ foldr :: (Char -> a -> a) -> a -> ByteString -> a
foldr f = L.foldr (f . w2c)
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Char -> a -> a) -> a -> ByteString -> a
foldr' f = L.foldr' (f . w2c)

-- | 'foldl1' is a variant of 'foldl' that has no starting value
-- argument, and thus must be applied to non-empty 'ByteString's.
foldl1 :: (Char -> Char -> Char) -> ByteString -> Char
Expand All @@ -363,6 +369,10 @@ foldr1 :: (Char -> Char -> Char) -> ByteString -> Char
foldr1 f ps = w2c (L.foldr1 (\x y -> c2w (f (w2c x) (w2c y))) ps)
{-# INLINE foldr1 #-}

-- | 'foldr1'' is like 'foldr1', but strict in the accumulator.
foldr1' :: (Char -> Char -> Char) -> ByteString -> Char
foldr1' f ps = w2c (L.foldr1' (\x y -> c2w (f (w2c x) (w2c y))) ps)

-- | Map a function over a 'ByteString' and concatenate the results
concatMap :: (Char -> ByteString) -> ByteString -> ByteString
concatMap f = L.concatMap (f . w2c)
Expand Down Expand Up @@ -404,6 +414,40 @@ minimum = w2c . L.minimum
scanl :: (Char -> Char -> Char) -> Char -> ByteString -> ByteString
scanl f z = L.scanl (\a b -> c2w (f (w2c a) (w2c b))) (c2w z)

-- | 'scanl1' is a variant of 'scanl' that has no starting value argument.
--
-- > scanl1 f [x1, x2, ...] == [x1, x1 `f` x2, ...]
scanl1 :: (Char -> Char -> Char) -> ByteString -> ByteString
scanl1 f = L.scanl1 f'
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | 'scanr' is similar to 'foldr', but returns a list of successive
-- reduced values from the right.
--
-- > scanr f z [..., x{n-1}, xn] == [..., x{n-1} `f` (xn `f` z), xn `f` z, z]
--
-- Note that
--
-- > head (scanr f z xs) == foldr f z xs
-- > last (scanr f z xs) == z
--
scanr
:: (Char -> Char -> Char)
-- ^ element -> accumulator -> new accumulator
-> Char
-- ^ starting value of accumulator
-> ByteString
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr f = L.scanr f' . c2w
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
scanr1 :: (Char -> Char -> Char) -> ByteString -> ByteString
scanr1 f = L.scanr1 f'
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | The 'mapAccumL' function behaves like a combination of 'map' and
-- 'foldl'; it applies a function to each element of a ByteString,
-- passing an accumulating parameter from left to right, and returning a
Expand Down
56 changes: 40 additions & 16 deletions bench/BenchAll.hs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ sortInputs = map (`S.take` S.pack [122, 121 .. 32]) [10..25]
foldInputs :: [S.ByteString]
foldInputs = map (\k -> S.pack $ if k <= 6 then take (2 ^ k) [32..95] else concat (replicate (2 ^ (k - 6)) [32..95])) [0..16]

foldInputsLazy :: [L.ByteString]
foldInputsLazy = map (\k -> L.pack $ if k <= 6 then take (2 ^ k) [32..95] else concat (replicate (2 ^ (k - 6)) [32..95])) [0..16]

zeroes :: L.ByteString
zeroes = L.replicate 10000 0

Expand Down Expand Up @@ -401,22 +404,43 @@ main = do
, bench "one huge word" $ nf S8.words byteStringData
]
, bgroup "folds"
[ bgroup "foldl'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputs
, bgroup "foldr'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputs
, bgroup "unfoldrN" $ map (\s -> bench (show $ S.length s) $
nf (S.unfoldrN (S.length s) (\a -> Just (a, a + 1))) 0) foldInputs
, bgroup "mapAccumL" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputs
, bgroup "mapAccumR" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputs
, bgroup "scanl" $ map (\s -> bench (show $ S.length s) $
nf (S.scanl (+) 0) s) foldInputs
, bgroup "scanr" $ map (\s -> bench (show $ S.length s) $
nf (S.scanr (+) 0) s) foldInputs
, bgroup "filter" $ map (\s -> bench (show $ S.length s) $
nf (S.filter odd) s) foldInputs
[ bgroup "strict"
[ bgroup "foldl'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputs
, bgroup "foldr'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputs
, bgroup "foldr1'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr1' (\x acc -> fromIntegral x + acc)) s) foldInputs
, bgroup "unfoldrN" $ map (\s -> bench (show $ S.length s) $
nf (S.unfoldrN (S.length s) (\a -> Just (a, a + 1))) 0) foldInputs
, bgroup "mapAccumL" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputs
, bgroup "mapAccumR" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputs
, bgroup "scanl" $ map (\s -> bench (show $ S.length s) $
nf (S.scanl (+) 0) s) foldInputs
, bgroup "scanr" $ map (\s -> bench (show $ S.length s) $
nf (S.scanr (+) 0) s) foldInputs
, bgroup "filter" $ map (\s -> bench (show $ S.length s) $
nf (S.filter odd) s) foldInputs
]
, bgroup "lazy"
[ bgroup "foldl'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputsLazy
, bgroup "foldr'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputsLazy
, bgroup "foldr1'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldr1' (\x acc -> fromIntegral x + acc)) s) foldInputsLazy
, bgroup "mapAccumL" $ map (\s -> bench (show $ L.length s) $
nf (L.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputsLazy
, bgroup "mapAccumR" $ map (\s -> bench (show $ L.length s) $
nf (L.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputsLazy
, bgroup "scanl" $ map (\s -> bench (show $ L.length s) $
nf (L.scanl (+) 0) s) foldInputsLazy
, bgroup "scanr" $ map (\s -> bench (show $ L.length s) $
nf (L.scanr (+) 0) s) foldInputsLazy
]

]
, bgroup "findIndexOrLength"
[ bench "takeWhile" $ nf (L.takeWhile even) zeroes
Expand Down
40 changes: 39 additions & 1 deletion tests/Properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import Control.Concurrent
import Control.Exception
import System.Posix.Internals (c_unlink)

import qualified Data.List as List
import Data.Char
import Data.Word
import Data.Maybe
import Data.Int (Int64)
import Data.Monoid
import Data.Semigroup
import GHC.Exts (Int(..), newPinnedByteArray#, unsafeFreezeByteArray#)
import GHC.ST (ST(..), runST)
Expand Down Expand Up @@ -463,6 +463,12 @@ short_tests =
, testProperty "pinned" prop_short_pinned
]

------------------------------------------------------------------------
-- Strictness checks.

explosiveTail :: L.ByteString -> L.ByteString
explosiveTail = (`L.append` error "Tail of this byte string is undefined!")

------------------------------------------------------------------------
-- The entry point

Expand All @@ -475,6 +481,7 @@ main = defaultMain $ testGroup "All"
, testGroup "Misc" misc_tests
, testGroup "IO" io_tests
, testGroup "Short" short_tests
, testGroup "Strictness" strictness_checks
]

io_tests =
Expand Down Expand Up @@ -535,5 +542,36 @@ misc_tests =
, testProperty "readIntegerUnsafe" prop_readIntegerUnsafe
]

strictness_checks =
[ testGroup "Lazy Word8"
[ testProperty "foldr is lazy" $ \ xs ->
List.genericTake (L.length xs) (L.foldr (:) [ ] (explosiveTail xs)) === L.unpack xs
, testProperty "foldr' is strict" $ expectFailure $ \ xs ys ->
List.genericTake (L.length xs) (L.foldr' (:) [ ] (explosiveTail (xs <> ys))) === L.unpack xs
, testProperty "foldr1 is lazy" $ \ xs -> L.length xs > 0 ==>
L.foldr1 const (explosiveTail (xs <> L.singleton 1)) === L.head xs
, testProperty "foldr1' is strict" $ expectFailure $ \ xs ys -> L.length xs > 0 ==>
L.foldr1' const (explosiveTail (xs <> L.singleton 1 <> ys)) === L.head xs
, testProperty "scanl is lazy" $ \ xs ->
L.take (L.length xs + 1) (L.scanl (+) 0 (explosiveTail (xs <> L.singleton 1))) === (L.pack . fmap (L.foldr (+) 0) . L.inits) xs
, testProperty "scanl1 is lazy" $ \ xs -> L.length xs > 0 ==>
L.take (L.length xs) (L.scanl1 (+) (explosiveTail (xs <> L.singleton 1))) === (L.pack . fmap (L.foldr1 (+)) . tail . L.inits) xs
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering why scanr is less lazy than scanl. The thing is that its output starts from an accumulator, and Data.ByteString.mapAccumR is too strict in this respect.

go src dst = mapAccumR_ acc (len-1)
where
mapAccumR_ !s (-1) = return s
mapAccumR_ !s !n = do
x <- peekByteOff src n
let (s', y) = f s x
pokeByteOff dst n y
mapAccumR_ s' (n-1)
acc' <- unsafeWithForeignPtr gp (go a)
return (acc', BS gp len)

I think this is fine: there are no particular expectations about strictness of scanr (there is no scanr' in Prelude).

Copy link
Contributor Author

@kindaro kindaro Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow. Is there a specific proposition you are reasoning towards? Or a question I may answer? For example:

PropositionData.ByteString.Lazy.scanr cannot be lazy.

Proof

As you noted, the output of a scanr starts from the end, so this is the sort of laziness we can have:

λ take 2 . reverse $ Prelude.scanr (+) 0 [undefined, 1]
[0,1]

So, first the spine of the input list is evaluated to the end, then elements are evaluated from the end backwards. (Whether the accumulator is evaluated before or after the first element depends on the order of evaluation of +.) Similarly, the byte stream's spine would have to be evaluated first. But the spine of the byte stream is strict in the leaf:

-- | A space-efficient representation of a 'Word8' vector, supporting many
-- efficient operations.
--
-- A lazy 'ByteString' contains 8-bit bytes, or by using the operations
-- from "Data.ByteString.Lazy.Char8" it can be interpreted as containing
-- 8-bit characters.
--
data ByteString = Empty | Chunk {-# UNPACK #-} !S.ByteString ByteString
deriving (Typeable, TH.Lift)
-- See 'invariant' function later in this module for internal invariants.

The leaf itself is a byte array and therefore also strict throughout. So, once we force the spine, every byte is also forced. There is no lazy scanr for byte streams. ∎

Something like this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just a remark for myself and @sjakobi and anyone else who is puzzled why we have laziness properties for scanl, but not for scanr.

It's not like you cannot make Data.ByteString.Lazy.scanr a bit lazier. E. g., for the proposed implementation

> Data.ByteString.Lazy.head $ Data.ByteString.Lazy.scanr const 42 ("foo" <> undefined)
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:79:14 in base:GHC.Err
  undefined, called at <interactive>:11:75 in interactive:Ghci1

However, if we are ready to sacrifice performance, one can define

scanr f z bs = cons hd tl
  where
    (_, tl) = mapAccumR (\x y -> (f y x, x)) z bs
    (hd, _) = List.mapAccumR (\x y -> (f y x, x)) z (unpack bs)

for which

> Data.ByteString.Lazy.head $ Data.ByteString.Lazy.scanr const 42 ("foo" <> undefined)
102

You can define an even lazier (and slower) version, capable to return first few chunks of bytestring, as long as f is very lazy (e. g., f = const).

My point is that this is a rare use case, which does not justify performance sacrifices, especially given that there is no general expectation how lazy scanr should be. I'm fine with your implementation, no action required.

, testGroup "Lazy Char"
[ testProperty "foldr is lazy" $ \ xs ->
List.genericTake (D.length xs) (D.foldr (:) [ ] (explosiveTail xs)) === D.unpack xs
, testProperty "foldr' is strict" $ expectFailure $ \ xs ys ->
List.genericTake (D.length xs) (D.foldr' (:) [ ] (explosiveTail (xs <> ys))) === D.unpack xs
, testProperty "foldr1 is lazy" $ \ xs -> D.length xs > 0 ==>
D.foldr1 const (explosiveTail (xs <> D.singleton 'x')) === D.head xs
, testProperty "foldr1' is strict" $ expectFailure $ \ xs ys -> D.length xs > 0 ==>
D.foldr1' const (explosiveTail (xs <> D.singleton 'x' <> ys)) === D.head xs
, testProperty "scanl is lazy" $ \ xs -> let char1 +. char2 = toEnum (fromEnum char1 + fromEnum char2) in
D.take (D.length xs + 1) (D.scanl (+.) '\NUL' (explosiveTail (xs <> D.singleton '\SOH'))) === (D.pack . fmap (D.foldr (+.) '\NUL') . D.inits) xs
, testProperty "scanl1 is lazy" $ \ xs -> D.length xs > 0 ==> let char1 +. char2 = toEnum (fromEnum char1 + fromEnum char2) in
D.take (D.length xs) (D.scanl1 (+.) (explosiveTail (xs <> D.singleton '\SOH'))) === (D.pack . fmap (D.foldr1 (+.)) . tail . D.inits) xs
]
]

removeFile :: String -> IO ()
removeFile fn = void $ withCString fn c_unlink
7 changes: 2 additions & 5 deletions tests/Properties/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,8 @@ tests =
\f (toElem -> c) x -> B.foldl' ((toElem .) . f) c x === foldl' ((toElem .) . f) c (B.unpack x)
, testProperty "foldr" $
\f (toElem -> c) x -> B.foldr ((toElem .) . f) c x === foldr ((toElem .) . f) c (B.unpack x)
#ifndef BYTESTRING_LAZY
, testProperty "foldr'" $
\f (toElem -> c) x -> B.foldr' ((toElem .) . f) c x === foldr' ((toElem .) . f) c (B.unpack x)
#endif

, testProperty "foldl cons" $
\x -> B.foldl (flip B.cons) B.empty x === B.reverse x
Expand All @@ -432,10 +430,8 @@ tests =
\f x -> not (B.null x) ==> B.foldl1' ((toElem .) . f) x === List.foldl1' ((toElem .) . f) (B.unpack x)
, testProperty "foldr1" $
\f x -> not (B.null x) ==> B.foldr1 ((toElem .) . f) x === foldr1 ((toElem .) . f) (B.unpack x)
#ifndef BYTESTRING_LAZY
, testProperty "foldr1'" $ -- there is not Data.List.foldr1'
\f x -> not (B.null x) ==> B.foldr1' ((toElem .) . f) x === foldr1 ((toElem .) . f) (B.unpack x)
#endif

, testProperty "foldl1 const" $
\x -> not (B.null x) ==> B.foldl1 const x === B.head x
Expand All @@ -455,7 +451,6 @@ tests =
, testProperty "scanl foldl" $
\f (toElem -> c) x -> not (B.null x) ==> B.last (B.scanl ((toElem .) . f) c x) === B.foldl ((toElem .) . f) c x

#ifndef BYTESTRING_LAZY
, testProperty "scanr" $
\f (toElem -> c) x -> B.unpack (B.scanr ((toElem .) . f) c x) === scanr ((toElem .) . f) c (B.unpack x)
, testProperty "scanl1" $
Expand All @@ -466,6 +461,8 @@ tests =
\f x -> B.unpack (B.scanr1 ((toElem .) . f) x) === scanr1 ((toElem .) . f) (B.unpack x)
, testProperty "scanr1 empty" $
\f -> B.scanr1 f B.empty === B.empty

#ifndef BYTESTRING_LAZY
, testProperty "sort" $
\x -> B.unpack (B.sort x) === List.sort (B.unpack x)
#endif
Expand Down