Skip to content

Commit

Permalink
Force the components of returned pairs
Browse files Browse the repository at this point in the history
Some functions, like partition, return a pair of values. Before this
change these functions would do almost no work and return immediately,
due to suspending most of the work in closures. This could cause space
leaks.

Closes #14.
  • Loading branch information
tibbe committed Aug 25, 2012
1 parent 0f616f4 commit 8a661a5
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 203 deletions.
92 changes: 56 additions & 36 deletions Data/IntMap/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ import Control.Applicative (Applicative(pure,(<*>)),(<$>))
import Control.Monad ( liftM )
import Control.DeepSeq (NFData(rnf))

import Data.StrictPair

#if __GLASGOW_HASKELL__
import Text.Read
import Data.Data (Data(..), mkNoRepType)
Expand Down Expand Up @@ -1402,16 +1404,18 @@ partition p m
-- > partitionWithKey (\ k _ -> k > 7) (fromList [(5,"a"), (3,"b")]) == (empty, fromList [(3, "b"), (5, "a")])

partitionWithKey :: (Key -> a -> Bool) -> IntMap a -> (IntMap a,IntMap a)
partitionWithKey predicate t
= case t of
Bin p m l r
-> let (l1,l2) = partitionWithKey predicate l
(r1,r2) = partitionWithKey predicate r
in (bin p m l1 r1, bin p m l2 r2)
Tip k x
| predicate k x -> (t,Nil)
| otherwise -> (Nil,t)
Nil -> (Nil,Nil)
partitionWithKey predicate0 t0 = toPair $ go predicate0 t0
where
go predicate t
= case t of
Bin p m l r
-> let (l1 :*: l2) = go predicate l
(r1 :*: r2) = go predicate r
in bin p m l1 r1 :*: bin p m l2 r2
Tip k x
| predicate k x -> (t :*: Nil)
| otherwise -> (Nil :*: t)
Nil -> (Nil :*: Nil)

-- | /O(n)/. Map values and collect the 'Just' results.
--
Expand Down Expand Up @@ -1457,15 +1461,17 @@ mapEither f m
-- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])

mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c)
mapEitherWithKey f (Bin p m l r)
= (bin p m l1 r1, bin p m l2 r2)
mapEitherWithKey f0 t0 = toPair $ go f0 t0
where
(l1,l2) = mapEitherWithKey f l
(r1,r2) = mapEitherWithKey f r
mapEitherWithKey f (Tip k x) = case f k x of
Left y -> (Tip k y, Nil)
Right z -> (Nil, Tip k z)
mapEitherWithKey _ Nil = (Nil, Nil)
go f (Bin p m l r)
= bin p m l1 r1 :*: bin p m l2 r2
where
(l1 :*: l2) = go f l
(r1 :*: r2) = go f r
go f (Tip k x) = case f k x of
Left y -> (Tip k y :*: Nil)
Right z -> (Nil :*: Tip k z)
go _ Nil = (Nil :*: Nil)

-- | /O(min(n,W))/. The expression (@'split' k map@) is a pair @(map1,map2)@
-- where all keys in @map1@ are lower than @k@ and all keys in
Expand All @@ -1479,18 +1485,23 @@ mapEitherWithKey _ Nil = (Nil, Nil)

split :: Key -> IntMap a -> (IntMap a, IntMap a)
split k t =
case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers.
then case go k l of (lt, gt) -> (union r lt, gt)
else case go k r of (lt, gt) -> (lt, union gt l)
_ -> go k t
case t of
Bin _ m l r
| m < 0 -> if k >= 0 -- handle negative numbers.
then case go k l of (lt :*: gt) -> let lt' = union r lt
in lt' `seq` (lt', gt)
else case go k r of (lt :*: gt) -> let gt' = union gt l
in gt' `seq` (lt, gt')
_ -> case go k t of
(lt :*: gt) -> (lt, gt)
where
go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nil) else (Nil, t')
| zero k' m = case go k' l of (lt, gt) -> (lt, union gt r)
| otherwise = case go k' r of (lt, gt) -> (union l lt, gt)
go k' t'@(Tip ky _) | k' > ky = (t', Nil)
| k' < ky = (Nil, t')
| otherwise = (Nil, Nil)
go _ Nil = (Nil, Nil)
go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then t' :*: Nil else Nil :*: t'
| zero k' m = case go k' l of (lt :*: gt) -> lt :*: union gt r
| otherwise = case go k' r of (lt :*: gt) -> union l lt :*: gt
go k' t'@(Tip ky _) | k' > ky = (t' :*: Nil)
| k' < ky = (Nil :*: t')
| otherwise = (Nil :*: Nil)
go _ Nil = (Nil :*: Nil)

-- | /O(min(n,W))/. Performs a 'split' but also returns whether the pivot
-- key was found in the original map.
Expand All @@ -1503,14 +1514,23 @@ split k t =

splitLookup :: Key -> IntMap a -> (IntMap a, Maybe a, IntMap a)
splitLookup k t =
case t of Bin _ m l r | m < 0 -> if k >= 0 -- handle negative numbers.
then case go k l of (lt, fnd, gt) -> (union r lt, fnd, gt)
else case go k r of (lt, fnd, gt) -> (lt, fnd, union gt l)
_ -> go k t
case t of
Bin _ m l r
| m < 0 -> if k >= 0 -- handle negative numbers.
then case go k l of
(lt, fnd, gt) -> let lt' = union r lt
in lt' `seq` (lt', fnd, gt)
else case go k r of
(lt, fnd, gt) -> let gt' = union gt l
in gt' `seq` (lt, fnd, gt')
_ -> go k t
where
go k' t'@(Bin p m l r) | nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t')
| zero k' m = case go k' l of (lt, fnd, gt) -> (lt, fnd, union gt r)
| otherwise = case go k' r of (lt, fnd, gt) -> (union l lt, fnd, gt)
go k' t'@(Bin p m l r)
| nomatch k' p m = if k' > p then (t', Nothing, Nil) else (Nil, Nothing, t')
| zero k' m = case go k' l of
(lt, fnd, gt) -> let gt' = union gt r in gt' `seq` (lt, fnd, gt')
| otherwise = case go k' r of
(lt, fnd, gt) -> let lt' = union l lt in lt' `seq` (lt', fnd, gt)
go k' t'@(Tip ky y) | k' > ky = (t', Nothing, Nil)
| k' < ky = (Nil, Nothing, t')
| otherwise = (Nil, Just y, Nil)
Expand Down
98 changes: 54 additions & 44 deletions Data/IntMap/Strict.hs
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,18 @@ insertWithKey f k x t = k `seq` x `seq`
-- > insertLookup 7 "x" (fromList [(5,"a"), (3,"b")]) == (Nothing, fromList [(3, "b"), (5, "a"), (7, "x")])

insertLookupWithKey :: (Key -> a -> a -> a) -> Key -> a -> IntMap a -> (Maybe a, IntMap a)
insertLookupWithKey f k x t = k `seq` x `seq`
case t of
Bin p m l r
| nomatch k p m -> Nothing `strictPair` join k (Tip k x) p t
| zero k m -> let (found,l') = insertLookupWithKey f k x l in (found `strictPair` Bin p m l' r)
| otherwise -> let (found,r') = insertLookupWithKey f k x r in (found `strictPair` Bin p m l r')
Tip ky y
| k==ky -> (Just y `strictPair` (Tip k $! f k x y))
| otherwise -> (Nothing `strictPair` join k (Tip k x) ky t)
Nil -> Nothing `strictPair` Tip k x
insertLookupWithKey f0 k0 x0 t0 = k0 `seq` x0 `seq` toPair $ go f0 k0 x0 t0
where
go f k x t =
case t of
Bin p m l r
| nomatch k p m -> Nothing :*: join k (Tip k x) p t
| zero k m -> let (found :*: l') = go f k x l in (found :*: Bin p m l' r)
| otherwise -> let (found :*: r') = go f k x r in (found :*: Bin p m l r')
Tip ky y
| k==ky -> (Just y :*: (Tip k $! f k x y))
| otherwise -> (Nothing :*: join k (Tip k x) ky t)
Nil -> Nothing :*: Tip k x


{--------------------------------------------------------------------
Expand Down Expand Up @@ -475,18 +477,20 @@ updateWithKey f k t = k `seq`
-- > updateLookupWithKey f 3 (fromList [(5,"a"), (3,"b")]) == (Just "b", singleton 5 "a")

updateLookupWithKey :: (Key -> a -> Maybe a) -> Key -> IntMap a -> (Maybe a,IntMap a)
updateLookupWithKey f k t = k `seq`
case t of
Bin p m l r
| nomatch k p m -> (Nothing, t)
| zero k m -> let (found,l') = updateLookupWithKey f k l in (found `strictPair` bin p m l' r)
| otherwise -> let (found,r') = updateLookupWithKey f k r in (found `strictPair` bin p m l r')
Tip ky y
| k==ky -> case f k y of
Just y' -> y' `seq` (Just y `strictPair` Tip ky y')
Nothing -> (Just y, Nil)
| otherwise -> (Nothing,t)
Nil -> (Nothing,Nil)
updateLookupWithKey f0 k0 t0 = k0 `seq` toPair $ go f0 k0 t0
where
go f k t =
case t of
Bin p m l r
| nomatch k p m -> (Nothing :*: t)
| zero k m -> let (found :*: l') = go f k l in (found :*: bin p m l' r)
| otherwise -> let (found :*: r') = go f k r in (found :*: bin p m l r')
Tip ky y
| k==ky -> case f k y of
Just y' -> y' `seq` (Just y :*: Tip ky y')
Nothing -> (Just y :*: Nil)
| otherwise -> (Nothing :*: t)
Nil -> (Nothing :*: Nil)



Expand Down Expand Up @@ -743,24 +747,28 @@ mapAccumWithKey f a t
-- the accumulating argument and the both elements of the
-- result of the function.
mapAccumL :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c)
mapAccumL f a t
= case t of
Bin p m l r -> let (a1,l') = mapAccumL f a l
(a2,r') = mapAccumL f a1 r
in (a2 `strictPair` Bin p m l' r')
Tip k x -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x')
Nil -> (a `strictPair` Nil)
mapAccumL f0 a0 t0 = toPair $ go f0 a0 t0
where
go f a t
= case t of
Bin p m l r -> let (a1 :*: l') = go f a l
(a2 :*: r') = go f a1 r
in (a2 :*: Bin p m l' r')
Tip k x -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x')
Nil -> (a :*: Nil)

-- | /O(n)/. The function @'mapAccumR'@ threads an accumulating
-- argument through the map in descending order of keys.
mapAccumRWithKey :: (a -> Key -> b -> (a,c)) -> a -> IntMap b -> (a,IntMap c)
mapAccumRWithKey f a t
= case t of
Bin p m l r -> let (a1,r') = mapAccumRWithKey f a r
(a2,l') = mapAccumRWithKey f a1 l
in (a2 `strictPair` Bin p m l' r')
Tip k x -> let (a',x') = f a k x in x' `seq` (a' `strictPair` Tip k x')
Nil -> (a `strictPair` Nil)
mapAccumRWithKey f0 a0 t0 = toPair $ go f0 a0 t0
where
go f a t
= case t of
Bin p m l r -> let (a1 :*: r') = go f a r
(a2 :*: l') = go f a1 l
in (a2 :*: Bin p m l' r')
Tip k x -> let (a',x') = f a k x in x' `seq` (a' :*: Tip k x')
Nil -> (a :*: Nil)

-- | /O(n*log n)/.
-- @'mapKeysWith' c f s@ is the map obtained by applying @f@ to each key of @s@.
Expand Down Expand Up @@ -822,15 +830,17 @@ mapEither f m
-- > == (empty, fromList [(1,"x"), (3,"b"), (5,"a"), (7,"z")])

mapEitherWithKey :: (Key -> a -> Either b c) -> IntMap a -> (IntMap b, IntMap c)
mapEitherWithKey f (Bin p m l r)
= bin p m l1 r1 `strictPair` bin p m l2 r2
mapEitherWithKey f0 t0 = toPair $ go f0 t0
where
(l1,l2) = mapEitherWithKey f l
(r1,r2) = mapEitherWithKey f r
mapEitherWithKey f (Tip k x) = case f k x of
Left y -> y `seq` (Tip k y, Nil)
Right z -> z `seq` (Nil, Tip k z)
mapEitherWithKey _ Nil = (Nil, Nil)
go f (Bin p m l r)
= bin p m l1 r1 :*: bin p m l2 r2
where
(l1 :*: l2) = go f l
(r1 :*: r2) = go f r
go f (Tip k x) = case f k x of
Left y -> y `seq` (Tip k y :*: Nil)
Right z -> z `seq` (Nil :*: Tip k z)
go _ Nil = (Nil :*: Nil)

{--------------------------------------------------------------------
Conversions
Expand Down
Loading

0 comments on commit 8a661a5

Please sign in to comment.