From 07f5efbbff5420f0ee68286f50b8138b75016482 Mon Sep 17 00:00:00 2001 From: David Feuer Date: Sun, 18 Mar 2018 01:50:20 -0400 Subject: [PATCH] Fold better; fix foldl1 bug * Index `Array`s and `SmallArray`s eagerly. * Fix `foldl1` for `Array`, which was completely wrong. * Make the structure of the `Foldable` instances for `Array` and `SmallArray` the same, taking what I thought was the best from each. Fixes #86 --- Data/Primitive/Array.hs | 117 +++++++++++++++++++----------- Data/Primitive/SmallArray.hs | 133 ++++++++++++++++++++++++----------- 2 files changed, 170 insertions(+), 80 deletions(-) diff --git a/Data/Primitive/Array.hs b/Data/Primitive/Array.hs index 32f4c76b..cfea04d2 100644 --- a/Data/Primitive/Array.hs +++ b/Data/Primitive/Array.hs @@ -115,6 +115,13 @@ indexArray :: Array a -> Int -> a {-# INLINE indexArray #-} indexArray arr (I# i#) = case indexArray# (array# arr) i# of (# x #) -> x +-- | Read a value from the immutable array at the given index, returning +-- the result in an unboxed unary tuple. This is currently used to implement +-- folds. +indexArray## :: Array a -> Int -> (# a #) +indexArray## arr (I# i) = indexArray# (array# arr) i +{-# INLINE indexArray## #-} + -- | Monadically read a value from the immutable array at the given index. -- This allows us to be strict in the array while remaining lazy in the read -- element which is very useful for collective operations. Suppose we want to @@ -346,60 +353,90 @@ instance Ord a => Ord (Array a) where | otherwise = compare (sizeofArray a1) (sizeofArray a2) instance Foldable Array where - foldr f z a = go 0 - where go i | i < sizeofArray a = f (indexArray a i) (go $ i+1) - | otherwise = z + -- Note: we perform the array lookups eagerly so we won't + -- create thunks to perform lookups even if GHC can't see + -- that the folding function is strict. + foldr f = \z !ary -> + let + !sz = sizeofArray ary + go i + | i == sz = z + | (# x #) <- indexArray## ary i + = f x (go (i+1)) + in go 0 {-# INLINE foldr #-} - foldl f z a = go (sizeofArray a - 1) - where go i | i < 0 = z - | otherwise = f (go $ i-1) (indexArray a i) + foldl f = \z !ary -> + let + go i + | i < 0 = z + | (# x #) <- indexArray## ary i + = f (go (i-1)) x + in go (sizeofArray ary - 1) {-# INLINE foldl #-} - foldr1 f a | sz < 0 = die "foldr1" "empty array" - | otherwise = go 0 - where sz = sizeofArray a - 1 - z = indexArray a sz - go i | i < sz = f (indexArray a i) (go $ i+1) - | otherwise = z + foldr1 f = \ !ary -> + let + !sz = sizeofArray ary - 1 + go i = + case indexArray## ary i of + (# x #) | i == sz -> x + | otherwise -> f x (go (i+1)) + in if sz < 0 + then die "foldr1" "empty array" + else go 0 {-# INLINE foldr1 #-} - foldl1 f a | sz == 0 = die "foldl1" "empty array" - | otherwise = go $ sz-1 - where sz = sizeofArray a - z = indexArray a 0 - go i | i < 1 = f (go $ i-1) (indexArray a i) - | otherwise = z + foldl1 f = \ !ary -> + let + !sz = sizeofArray ary - 1 + go i = + case indexArray## ary i of + (# x #) | i == 0 -> x + | otherwise -> f x (go (i - 1)) + in if sz < 0 + then die "foldl1" "empty array" + else go sz {-# INLINE foldl1 #-} #if MIN_VERSION_base(4,6,0) - foldr' f z a = go (sizeofArray a - 1) z - where go i !acc | i < 0 = acc - | otherwise = go (i-1) (f (indexArray a i) acc) + foldr' f = \z !ary -> + let + go i !acc + | i == -1 = acc + | (# x #) <- indexArray## ary i + = go (i-1) (f x acc) + in go (sizeofArray ary - 1) z {-# INLINE foldr' #-} - foldl' f z a = go 0 z - where go i !acc | i < sizeofArray a = go (i+1) (f acc $ indexArray a i) - | otherwise = acc + foldl' f = \z !ary -> + let + !sz = sizeofArray ary + go i !acc + | i == sz = acc + | (# x #) <- indexArray## ary i + = go (i+1) (f acc x) + in go 0 z {-# INLINE foldl' #-} #endif #if MIN_VERSION_base(4,8,0) - toList a = Exts.build $ \c z -> let - sz = sizeofArray a - go i | i < sz = c (indexArray a i) (go $ i+1) - | otherwise = z - in go 0 - {-# INLINE toList #-} null a = sizeofArray a == 0 {-# INLINE null #-} length = sizeofArray {-# INLINE length #-} - maximum a | sz == 0 = die "maximum" "empty array" - | otherwise = go 1 (indexArray a 0) - where sz = sizeofArray a - go i !e | i < sz = go (i+1) (max e $ indexArray a i) - | otherwise = e + maximum ary | sz == 0 = die "maximum" "empty array" + | (# frst #) <- indexArray## ary 0 + = go 1 frst + where + sz = sizeofArray ary + go i !e + | i == sz = e + | (# x #) <- indexArray## ary i + = go (i+1) (max e x) {-# INLINE maximum #-} - minimum a | sz == 0 = die "minimum" "empty array" - | otherwise = go 1 (indexArray a 0) - where sz = sizeofArray a - go i !e | i < sz = go (i+1) (min e $ indexArray a i) - | otherwise = e + minimum ary | sz == 0 = die "minimum" "empty array" + | (# frst #) <- indexArray## ary 0 + = go 1 frst + where sz = sizeofArray ary + go i !e + | i == sz = e + | (# x #) <- indexArray## ary i + = go (i+1) (min e x) {-# INLINE minimum #-} sum = foldl' (+) 0 {-# INLINE sum #-} diff --git a/Data/Primitive/SmallArray.hs b/Data/Primitive/SmallArray.hs index aa56de8e..6a9990e7 100644 --- a/Data/Primitive/SmallArray.hs +++ b/Data/Primitive/SmallArray.hs @@ -6,6 +6,7 @@ {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE BangPatterns #-} -- | -- Module : Data.Primitive.SmallArray @@ -225,6 +226,15 @@ indexSmallArray (SmallArray a) = indexArray a #endif {-# INLINE indexSmallArray #-} +#if HAVE_SMALL_ARRAY +-- | Read a value from the immutable array at the given index, returning +-- the result in an unboxed unary tuple. This is currently used to implement +-- folds. +indexSmallArray## :: SmallArray a -> Int -> (# a #) +indexSmallArray## (SmallArray ary) (I# i) = indexSmallArray# ary i +{-# INLINE indexSmallArray## #-} +#endif + -- | Create a copy of a slice of an immutable array. cloneSmallArray :: SmallArray a -- ^ source @@ -424,53 +434,96 @@ instance Ord a => Ord (SmallArray a) where where l = length sl `min` length sr instance Foldable SmallArray where - foldr f z sa = fix ? 0 $ \go i -> - if i < length sa - then f (indexSmallArray sa i) (go $ i+1) - else z + -- Note: we perform the array lookups eagerly so we won't + -- create thunks to perform lookups even if GHC can't see + -- that the folding function is strict. + foldr f = \z !ary -> + let + !sz = sizeofSmallArray ary + go i + | i == sz = z + | (# x #) <- indexSmallArray## ary i + = f x (go (i+1)) + in go 0 {-# INLINE foldr #-} - - foldr' f z sa = fix ? z ? length sa - 1 $ \go acc i -> - if i < 0 - then acc - else go (f (indexSmallArray sa i) acc) (i-1) - {-# INLINE foldr' #-} - - foldl f z sa = fix ? length sa - 1 $ \go i -> - if i < 0 - then z - else f (go $ i-1) $ indexSmallArray sa i + foldl f = \z !ary -> + let + go i + | i < 0 = z + | (# x #) <- indexSmallArray## ary i + = f (go (i-1)) x + in go (sizeofSmallArray ary - 1) {-# INLINE foldl #-} - - foldl' f z sa = fix ? z ? 0 $ \go acc i -> - if i < length sa - then go (f acc $ indexSmallArray sa i) (i+1) - else acc - {-# INLINE foldl' #-} - - foldr1 f sa - | sz == 0 = die "foldr1" "empty list" - | otherwise = fix ? 0 $ \go i -> - if i < sz-1 - then f (indexSmallArray sa i) (go $ i+1) - else indexSmallArray sa $ sz-1 - where sz = sizeofSmallArray sa + foldr1 f = \ !ary -> + let + !sz = sizeofSmallArray ary - 1 + go i = + case indexSmallArray## ary i of + (# x #) | i == sz -> x + | otherwise -> f x (go (i+1)) + in if sz < 0 + then die "foldr1" "Empty SmallArray" + else go 0 {-# INLINE foldr1 #-} - - foldl1 f sa - | sz == 0 = die "foldl1" "empty list" - | otherwise = fix ? sz-1 $ \go i -> - if i < 1 - then indexSmallArray sa 0 - else f (go $ i-1) (indexSmallArray sa i) - where sz = sizeofSmallArray sa + foldl1 f = \ !ary -> + let + !sz = sizeofSmallArray ary - 1 + go i = + case indexSmallArray## ary i of + (# x #) | i == 0 -> x + | otherwise -> f x (go (i - 1)) + in if sz < 0 + then die "foldl1" "Empty SmallArray" + else go sz {-# INLINE foldl1 #-} - - null sa = sizeofSmallArray sa == 0 +#if MIN_VERSION_base(4,6,0) + foldr' f = \z !ary -> + let + go i !acc + | i == -1 = acc + | (# x #) <- indexSmallArray## ary i + = go (i-1) (f x acc) + in go (sizeofSmallArray ary - 1) z + {-# INLINE foldr' #-} + foldl' f = \z !ary -> + let + !sz = sizeofSmallArray ary + go i !acc + | i == sz = acc + | (# x #) <- indexSmallArray## ary i + = go (i+1) (f acc x) + in go 0 z + {-# INLINE foldl' #-} +#endif +#if MIN_VERSION_base(4,8,0) + null a = sizeofSmallArray a == 0 {-# INLINE null #-} - length = sizeofSmallArray {-# INLINE length #-} + maximum ary | sz == 0 = die "maximum" "Empty SmallArray" + | (# frst #) <- indexSmallArray## ary 0 + = go 1 frst + where + sz = sizeofSmallArray ary + go i !e + | i == sz = e + | (# x #) <- indexSmallArray## ary i + = go (i+1) (max e x) + {-# INLINE maximum #-} + minimum ary | sz == 0 = die "minimum" "Empty SmallArray" + | (# frst #) <- indexSmallArray## ary 0 + = go 1 frst + where sz = sizeofSmallArray ary + go i !e + | i == sz = e + | (# x #) <- indexSmallArray## ary i + = go (i+1) (min e x) + {-# INLINE minimum #-} + sum = foldl' (+) 0 + {-# INLINE sum #-} + product = foldl' (*) 1 + {-# INLINE product #-} +#endif instance Traversable SmallArray where traverse f sa = fromListN l <$> traverse (f . indexSmallArray sa) [0..l-1]