Skip to content

Commit

Permalink
Reimplement <*>
Browse files Browse the repository at this point in the history
Use `coerce` for the `Functor` instance of `Elem`

Using `fmap = coerce` for `Elem` speeds up `<*>` by somewhere
around 20%.

Benchmark results:

OLD:

benchmarking <*>/ix1000/500000
time                 11.47 ms   (11.37 ms .. 11.59 ms)
                     0.999 R²   (0.998 R² .. 1.000 R²)
mean                 11.61 ms   (11.52 ms .. 11.73 ms)
std dev              279.9 μs   (209.5 μs .. 385.6 μs)

benchmarking <*>/nf100/2500/rep
time                 8.530 ms   (8.499 ms .. 8.568 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 8.511 ms   (8.498 ms .. 8.528 ms)
std dev              40.40 μs   (28.55 μs .. 63.84 μs)

benchmarking <*>/nf100/2500/ff
time                 27.13 ms   (26.16 ms .. 28.70 ms)
                     0.994 R²   (0.988 R² .. 1.000 R²)
mean                 26.49 ms   (26.29 ms .. 27.43 ms)
std dev              697.1 μs   (153.0 μs .. 1.443 ms)

benchmarking <*>/nf500/500/rep
time                 8.421 ms   (8.331 ms .. 8.491 ms)
                     0.991 R²   (0.967 R² .. 1.000 R²)
mean                 8.518 ms   (8.417 ms .. 9.003 ms)
std dev              529.9 μs   (40.37 μs .. 1.176 ms)
variance introduced by outliers: 32% (moderately inflated)

benchmarking <*>/nf500/500/ff
time                 33.71 ms   (33.58 ms .. 33.86 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 33.69 ms   (33.62 ms .. 33.76 ms)
std dev              150.0 μs   (119.0 μs .. 191.0 μs)

benchmarking <*>/nf2500/100/rep
time                 8.390 ms   (8.259 ms .. 8.456 ms)
                     0.997 R²   (0.992 R² .. 1.000 R²)
mean                 8.544 ms   (8.441 ms .. 8.798 ms)
std dev              402.6 μs   (21.25 μs .. 714.9 μs)
variance introduced by outliers: 23% (moderately inflated)

benchmarking <*>/nf2500/100/ff
time                 53.69 ms   (53.33 ms .. 54.08 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 53.59 ms   (53.38 ms .. 53.75 ms)
std dev              341.2 μs   (231.7 μs .. 473.9 μs)

NEW

benchmarking <*>/ix1000/500000
time                 2.688 μs   (2.607 μs .. 2.798 μs)
                     0.994 R²   (0.988 R² .. 1.000 R²)
mean                 2.632 μs   (2.607 μs .. 2.715 μs)
std dev              129.9 ns   (65.93 ns .. 242.8 ns)
variance introduced by outliers: 64% (severely inflated)

benchmarking <*>/nf100/2500/rep
time                 8.371 ms   (8.064 ms .. 8.535 ms)
                     0.983 R²   (0.947 R² .. 1.000 R²)
mean                 8.822 ms   (8.590 ms .. 9.463 ms)
std dev              991.2 μs   (381.3 μs .. 1.809 ms)
variance introduced by outliers: 61% (severely inflated)

benchmarking <*>/nf100/2500/ff
time                 22.84 ms   (22.74 ms .. 22.94 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 22.78 ms   (22.71 ms .. 22.86 ms)
std dev              183.3 μs   (116.3 μs .. 291.3 μs)

benchmarking <*>/nf500/500/rep
time                 8.320 ms   (8.102 ms .. 8.514 ms)
                     0.995 R²   (0.990 R² .. 0.999 R²)
mean                 8.902 ms   (8.675 ms .. 9.407 ms)
std dev              952.4 μs   (435.5 μs .. 1.672 ms)
variance introduced by outliers: 58% (severely inflated)

benchmarking <*>/nf500/500/ff
time                 24.50 ms   (24.41 ms .. 24.58 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 24.44 ms   (24.41 ms .. 24.48 ms)
std dev              75.08 μs   (50.16 μs .. 111.3 μs)

benchmarking <*>/nf2500/100/rep
time                 8.419 ms   (8.366 ms .. 8.458 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 8.571 ms   (8.525 ms .. 8.670 ms)
std dev              179.5 μs   (112.0 μs .. 278.1 μs)

benchmarking <*>/nf2500/100/ff
time                 24.14 ms   (24.07 ms .. 24.26 ms)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 24.11 ms   (24.07 ms .. 24.17 ms)
std dev              103.8 μs   (68.34 μs .. 142.0 μs)
  • Loading branch information
treeowl committed Dec 21, 2014
1 parent ae97ceb commit 38b1b81
Showing 1 changed file with 258 additions and 3 deletions.
261 changes: 258 additions & 3 deletions Data/Sequence.hs
Expand Up @@ -194,6 +194,7 @@ import Data.Functor.Identity (Identity(..))

infixr 5 `consTree`
infixl 5 `snocTree`
infixr 5 `appendTree0`

infixr 5 ><
infixr 5 <|, :<
Expand Down Expand Up @@ -258,10 +259,255 @@ instance Monad Seq where

instance Applicative Seq where
pure = singleton
fs <*> xs = foldl' add empty fs

Seq Empty <*> xs = xs `seq` empty
fs <*> Seq Empty = fs `seq` empty
fs <*> Seq (Single (Elem x)) = fmap ($ x) fs
fs <*> xs
| length fs < 4 = foldl' add empty fs
where add ys f = ys >< fmap f xs
fs <*> xs | length xs < 4 = apShort fs xs
fs <*> xs = apty fs xs

xs *> ys = replicateSeq (length xs) ys

-- <*> when the length of the first argument is at least two and
-- the length of the second is two or three.
apShort :: Seq (a -> b) -> Seq a -> Seq b
apShort (Seq fs) xs = Seq $ case toList xs of
[a,b] -> ap2FT fs (a,b)
[a,b,c] -> ap3FT fs (a,b,c)
_ -> error "apShort: not 2-6"

ap2FT :: FingerTree (Elem (a->b)) -> (a,a) -> FingerTree (Elem b)
ap2FT fs (x,y) = Deep (size fs * 2)
(Two (Elem $ firstf x) (Elem $ firstf y))
(mapMulFT 2 (\(Elem f) -> Node2 2 (Elem (f x)) (Elem (f y))) m)
(Two (Elem $ lastf x) (Elem $ lastf y))
where
(Elem firstf, m, Elem lastf) = trimTree fs

ap3FT :: FingerTree (Elem (a->b)) -> (a,a,a) -> FingerTree (Elem b)
ap3FT fs (x,y,z) = Deep (size fs * 3)
(Three (Elem $ firstf x) (Elem $ firstf y) (Elem $ firstf z))
(mapMulFT 3 (\(Elem f) -> Node3 3 (Elem (f x)) (Elem (f y)) (Elem (f z))) m)
(Three (Elem $ lastf x) (Elem $ lastf y) (Elem $ lastf z))
where
(Elem firstf, m, Elem lastf) = trimTree fs

-- <*> when the length of each argument is at least four.
apty :: Seq (a -> b) -> Seq a -> Seq b
apty (Seq fs) (Seq xs@Deep{}) = Seq $
runApState (fmap firstf) (fmap lastf) fmap fs' (ApState xs' xs' xs')
where
(Elem firstf, fs', Elem lastf) = trimTree fs
xs' = rigidify xs
apty _ _ = error "apty: expects a Deep constructor"

data ApState a = ApState (FingerTree a) (FingerTree a) (FingerTree a)

-- | 'runApState' uses three copies of the @xs@ tree to produce the @fs<*>xs@
-- tree. It pulls left digits off the left tree, right digits off the right tree,
-- and squashes down the other four digits. Once it gets to the bottom, it turns
-- the middle tree into a 2-3 tree, applies 'mapMulFT' to produce the main body,
-- and glues all the pieces together.
runApState
:: Sized c =>
(c -> d)
-> (c -> d)
-> ((a -> b) -> c -> d)
-> FingerTree (Elem (a -> b))
-> ApState c
-> FingerTree d
-- Not at the bottom yet
runApState firstf
lastf
map23
fs
(ApState
(Deep sl
prl
(Deep sml prml mml sfml)
sfl)
(Deep sm
prm
(Deep _smm prmm mmm sfmm)
sfm)
(Deep sr
prr
(Deep smr prmr mmr sfmr)
sfr))
= Deep (sl + sr + sm * size fs)
(fmap firstf prl)
(runApState (fmap firstf)
(fmap lastf)
(\f -> fmap (map23 f))
fs
nextState)
(fmap lastf sfr)
where nextState =
ApState
(Deep (sml + size sfl) prml mml (squashR sfml sfl))
(Deep sm (squashL prm prmm) mmm (squashR sfmm sfm))
(Deep (smr + size prr) (squashL prr prmr) mmr sfmr)

-- At the bottom
runApState firstf
lastf
map23
fs
(ApState
(Deep sl prl ml sfl)
(Deep sm prm mm sfm)
(Deep sr prr mr sfr))
= Deep (sl + sr + sm * size fs)
(fmap firstf prl)
((fmap (fmap firstf) ml `snocTree` fmap firstf (digitToNode sfl))
`appendTree0` middle `appendTree0`
(fmap lastf (digitToNode prr) `consTree` fmap (fmap lastf) mr))
(fmap lastf sfr)
where middle = case trimTree $ mapMulFT sm (\(Elem f) -> fmap (fmap (map23 f)) converted) fs of
(firstMapped, restMapped, lastMapped) ->
Deep (size firstMapped + size restMapped + size lastMapped)
(nodeToDigit firstMapped) restMapped (nodeToDigit lastMapped)
converted = case mm of
Empty -> Node2 sm lconv rconv
Single q -> Node3 sm lconv q rconv
Deep{} -> error "runApState: a tree is shallower than the middle tree"
lconv = digitToNode prm
rconv = digitToNode sfm

runApState _ _ _ _ _ = error "runApState: ApState must hold Deep finger trees of the same depth"

{-# SPECIALIZE
runApState
:: (Node c -> d)
-> (Node c -> d)
-> ((a -> b) -> Node c -> d)
-> FingerTree (Elem (a -> b))
-> ApState (Node c)
-> FingerTree d
#-}
{-# SPECIALIZE
runApState
:: (Elem c -> d)
-> (Elem c -> d)
-> ((a -> b) -> Elem c -> d)
-> FingerTree (Elem (a -> b))
-> ApState (Elem c)
-> FingerTree d
#-}

digitToNode :: Sized a => Digit a -> Node a
digitToNode (Two a b) = node2 a b
digitToNode (Three a b c) = node3 a b c
digitToNode _ = error "digitToNode: not representable as a node"

type Digit23 = Digit
type Digit12 = Digit

-- Squash the first argument down onto the left side of the second.
squashL :: Sized a => Digit23 a -> Digit12 (Node a) -> Digit23 (Node a)
squashL (Two a b) (One n) = Two (node2 a b) n
squashL (Two a b) (Two n1 n2) = Three (node2 a b) n1 n2
squashL (Three a b c) (One n) = Two (node3 a b c) n
squashL (Three a b c) (Two n1 n2) = Three (node3 a b c) n1 n2
squashL _ _ = error "squashL: wrong digit types"

-- Squash the second argument down onto the right side of the first
squashR :: Sized a => Digit12 (Node a) -> Digit23 a -> Digit23 (Node a)
squashR (One n) (Two a b) = Two n (node2 a b)
squashR (Two n1 n2) (Two a b) = Three n1 n2 (node2 a b)
squashR (One n) (Three a b c) = Two n (node3 a b c)
squashR (Two n1 n2) (Three a b c) = Three n1 n2 (node3 a b c)
squashR _ _ = error "squashR: wrong digit types"

-- | /O(m*n)/ (incremental) Takes an /O(m)/ function and a finger tree of size
-- /n/ and maps the function over the tree leaves. Unlike the usual 'fmap', the
-- function is applied to the "leaves" of the 'FingerTree' (i.e., given a
-- @FingerTree (Elem a)@, it applies the function to elements of type @Elem
-- a@), replacing the leaves with subtrees of at least the same height, e.g.,
-- @Node(Node(Elem y))@. The multiplier argument serves to make the annotations
-- match up properly.
mapMulFT :: Int -> (a -> b) -> FingerTree a -> FingerTree b
mapMulFT _ _ Empty = Empty
mapMulFT _mul f (Single a) = Single (f a)
mapMulFT mul f (Deep s pr m sf) = Deep (mul * s) (fmap f pr) (mapMulFT mul (mapMulNode mul f) m) (fmap f sf)

mapMulNode :: Int -> (a -> b) -> Node a -> Node b
mapMulNode mul f (Node2 s a b) = Node2 (mul * s) (f a) (f b)
mapMulNode mul f (Node3 s a b c) = Node3 (mul * s) (f a) (f b) (f c)


trimTree :: Sized a => FingerTree a -> (a, FingerTree a, a)
trimTree Empty = error "trim: empty tree"
trimTree Single{} = error "trim: singleton"
trimTree t = case splitTree 0 t of
Split _ hd r ->
case splitTree (size r - 1) r of
Split m tl _ -> (hd, m, tl)

-- | /O(log n)/ (incremental) Takes the extra flexibility out of a 'FingerTree'
-- to make it a genuine 2-3 finger tree. The result of 'rigidify' will have
-- only 'Two' and 'Three' digits at the top level and only 'One' and 'Two'
-- digits elsewhere. It gives an error if the tree has fewer than four
-- elements.
rigidify :: Sized a => FingerTree a -> FingerTree a
-- Note that 'rigidify' may call itself, but it will do so at most
-- once: each call to 'rigidify' will either fix the whole tree or fix one digit
-- and leave the other alone. The patterns below just fix up the top level of
-- the tree; 'rigidify' delegates the hard work to 'thin'.

-- The top of the tree is fine.
rigidify (Deep s pr@Two{} m sf@Three{}) = Deep s pr (thin m) sf
rigidify (Deep s pr@Three{} m sf@Three{}) = Deep s pr (thin m) sf
rigidify (Deep s pr@Two{} m sf@Two{}) = Deep s pr (thin m) sf
rigidify (Deep s pr@Three{} m sf@Two{}) = Deep s pr (thin m) sf

-- One of the Digits is a Four.
rigidify (Deep s (Four a b c d) m sf) =
rigidify $ Deep s (Two a b) (node2 c d `consTree` m) sf
rigidify (Deep s pr m (Four a b c d)) =
rigidify $ Deep s pr (m `snocTree` node2 a b) (Two c d)

-- One of the Digits is a One. If the middle is empty, we can only rigidify the
-- tree if the other Digit is a Three.
rigidify (Deep s (One a) Empty (Three b c d)) = Deep s (Two a b) Empty (Two c d)
rigidify (Deep s (One a) m sf) = rigidify $ case viewLTree m of
Just2 (Node2 _ b c) m' -> Deep s (Three a b c) m' sf
Just2 (Node3 _ b c d) m' -> Deep s (Two a b) (node2 c d `consTree` m') sf
Nothing2 -> error "rigidify: small tree"
rigidify (Deep s (Three a b c) Empty (One d)) = Deep s (Two a b) Empty (Two c d)
rigidify (Deep s pr m (One e)) = rigidify $ case viewRTree m of
Just2 m' (Node2 _ a b) -> Deep s pr m' (Three a b e)
Just2 m' (Node3 _ a b c) -> Deep s pr (m' `snocTree` node2 a b) (Two c e)
Nothing2 -> error "rigidify: small tree"
rigidify Empty = error "rigidify: empty tree"
rigidify Single{} = error "rigidify: singleton"

-- | /O(log n)/ (incremental) Rejigger a finger tree so the digits are all ones
-- and twos.
thin :: Sized a => FingerTree a -> FingerTree a
-- Note that 'thin' may call itself at most once before passing the job on to
-- 'thin12'. 'thin12' will produce a 'Deep' constructor immediately before
-- calling 'thin'.
thin Empty = Empty
thin (Single a) = Single a
thin t@(Deep s pr m sf) =
case pr of
One{} -> thin12 t
Two{} -> thin12 t
Three a b c -> thin $ Deep s (One a) (node2 b c `consTree` m) sf
Four a b c d -> thin $ Deep s (Two a b) (node2 c d `consTree` m) sf

thin12 :: Sized a => FingerTree a -> FingerTree a
thin12 (Deep s pr m sf@One{}) = Deep s pr (thin m) sf
thin12 (Deep s pr m sf@Two{}) = Deep s pr (thin m) sf
thin12 (Deep s pr m (Three a b c)) = Deep s pr (thin $ m `snocTree` node2 a b) (One c)
thin12 (Deep s pr m (Four a b c d)) = Deep s pr (thin $ m `snocTree` node2 a b) (Two c d)
thin12 _ = error "thin12 expects a Deep FingerTree."


instance MonadPlus Seq where
mzero = empty
mplus = (><)
Expand Down Expand Up @@ -559,7 +805,12 @@ instance Sized (Elem a) where
size _ = 1

instance Functor Elem where
#if __GLASGOW_HASKELL__ >= 708
-- This cuts the time for <*> by around a fifth.
fmap = coerce
#else
fmap f (Elem x) = Elem (f x)
#endif

instance Foldable Elem where
foldMap f (Elem x) = f x
Expand Down Expand Up @@ -732,7 +983,9 @@ Seq xs >< Seq ys = Seq (appendTree0 xs ys)

-- The appendTree/addDigits gunk below is machine generated

appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a)
{-# SPECIALIZE appendTree0 :: FingerTree (Elem a) -> FingerTree (Elem a) -> FingerTree (Elem a) #-}
{-# SPECIALIZE appendTree0 :: FingerTree (Node a) -> FingerTree (Node a) -> FingerTree (Node a) #-}
appendTree0 :: Sized a => FingerTree a -> FingerTree a -> FingerTree a
appendTree0 Empty xs =
xs
appendTree0 xs Empty =
Expand All @@ -744,7 +997,9 @@ appendTree0 xs (Single x) =
appendTree0 (Deep s1 pr1 m1 sf1) (Deep s2 pr2 m2 sf2) =
Deep (s1 + s2) pr1 (addDigits0 m1 sf1 pr2 m2) sf2

addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a))
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Elem a)) -> Digit (Elem a) -> Digit (Elem a) -> FingerTree (Node (Elem a)) -> FingerTree (Node (Elem a)) #-}
{-# SPECIALIZE addDigits0 :: FingerTree (Node (Node a)) -> Digit (Node a) -> Digit (Node a) -> FingerTree (Node (Node a)) -> FingerTree (Node (Node a)) #-}
addDigits0 :: Sized a => FingerTree (Node a) -> Digit a -> Digit a -> FingerTree (Node a) -> FingerTree (Node a)
addDigits0 m1 (One a) (One b) m2 =
appendTree1 m1 (node2 a b) m2
addDigits0 m1 (One a) (Two b c) m2 =
Expand Down

0 comments on commit 38b1b81

Please sign in to comment.