Skip to content

Commit e076b33

Browse files
committed
Improve performance of folds.
We use wordsize-dependent implementation for GHC, for both 32-bit and 64-bit architectures. It is based on fast constant-time implementation of indexOfTheOnlyBit, which computes index of the only bit set in a word, suggested by Edward Kmett. Using that we can enumerate indexes of 1 bits, in the order from LSB to MSB. That results in fast foldl implementations. Foldr implementations bit-reverse the word and then iterate from the LSB to MSB using accumulator. That is faster then either not using accumulator or iterating from MSB to LSB.
1 parent a7d29bd commit e076b33

File tree

1 file changed

+148
-45
lines changed

1 file changed

+148
-45
lines changed

Data/IntSet.hs

Lines changed: 148 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ import Data.Data (Data(..), mkNoRepType)
148148

149149
#if __GLASGOW_HASKELL__
150150
import GHC.Exts ( Word(..), Int(..) )
151-
import GHC.Prim ( uncheckedShiftL#, uncheckedShiftRL# )
151+
import GHC.Prim ( uncheckedShiftL#, uncheckedShiftRL#, indexInt8OffAddr# )
152152
#else
153153
import Data.Word
154154
#endif
@@ -158,6 +158,7 @@ import Data.Word
158158
-- We do not use BangPatterns, because they are not in any standard and we
159159
-- want the compilers to be compiled by as many compilers as possible.
160160
#define STRICT_1_OF_2(fn) fn arg _ | arg `seq` False = undefined
161+
#define STRICT_2_OF_2(fn) fn _ arg | arg `seq` False = undefined
161162
#define STRICT_1_OF_3(fn) fn arg _ _ | arg `seq` False = undefined
162163
#define STRICT_2_OF_3(fn) fn _ arg _ | arg `seq` False = undefined
163164

@@ -523,10 +524,10 @@ filter predicate t
523524
Bin p m l r
524525
-> bin p m (filter predicate l) (filter predicate r)
525526
Tip kx bm
526-
-> tip kx (foldr'Bits 0 (bitPred kx) 0 bm)
527+
-> tip kx (foldl'Bits 0 (bitPred kx) 0 bm)
527528
Nil -> Nil
528-
where bitPred kx i bm | predicate (kx + i) = bm .|. bitmapOfSuffix i
529-
| otherwise = bm
529+
where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi
530+
| otherwise = bm
530531
{-# INLINE bitPred #-}
531532

532533
-- | /O(n)/. partition the set according to some predicate.
@@ -538,12 +539,12 @@ partition predicate t
538539
(r1,r2) = partition predicate r
539540
in (bin p m l1 r1, bin p m l2 r2)
540541
Tip kx bm
541-
-> let (bm1,bm2) = foldr'Bits 0 (bitPart kx) (0,0) bm
542-
in (tip kx bm1, tip kx bm2)
542+
-> let bm1 = foldl'Bits 0 (bitPred kx) 0 bm
543+
in (tip kx bm1, tip kx (bm `xor` bm1))
543544
Nil -> (Nil,Nil)
544-
where bitPart kx i (bm1,bm2) | predicate (kx + i) = (bm1 .|. bitmapOfSuffix i, bm2)
545-
| otherwise = (bm1, bm2 .|. bitmapOfSuffix i)
546-
{-# INLINE bitPart #-}
545+
where bitPred kx bm bi | predicate (kx + bi) = bm .|. bitmapOfSuffix bi
546+
| otherwise = bm
547+
{-# INLINE bitPred #-}
547548

548549

549550
-- | /O(min(n,W))/. The expression (@'split' x set@) is a pair @(set1,set2)@
@@ -696,7 +697,7 @@ fold = foldr
696697
-- > toAscList set = foldr (:) [] set
697698
foldr :: (Int -> b -> b) -> b -> IntSet -> b
698699
foldr f z t =
699-
case t of Bin 0 m l r | m < 0 -> go (go z l) r -- put negative numbers before
700+
case t of Bin _ m l r | m < 0 -> go (go z l) r -- put negative numbers before
700701
_ -> go z t
701702
where
702703
go z' Nil = z'
@@ -709,7 +710,7 @@ foldr f z t =
709710
-- function is strict in the starting value.
710711
foldr' :: (Int -> b -> b) -> b -> IntSet -> b
711712
foldr' f z t =
712-
case t of Bin 0 m l r | m < 0 -> go (go z l) r -- put negative numbers before
713+
case t of Bin _ m l r | m < 0 -> go (go z l) r -- put negative numbers before
713714
_ -> go z t
714715
where
715716
STRICT_1_OF_2(go)
@@ -726,7 +727,7 @@ foldr' f z t =
726727
-- > toDescList set = foldl (flip (:)) [] set
727728
foldl :: (a -> Int -> a) -> a -> IntSet -> a
728729
foldl f z t =
729-
case t of Bin 0 m l r | m < 0 -> go (go z r) l -- put negative numbers before
730+
case t of Bin _ m l r | m < 0 -> go (go z r) l -- put negative numbers before
730731
_ -> go z t
731732
where
732733
STRICT_1_OF_2(go)
@@ -740,7 +741,7 @@ foldl f z t =
740741
-- function is strict in the starting value.
741742
foldl' :: (a -> Int -> a) -> a -> IntSet -> a
742743
foldl' f z t =
743-
case t of Bin 0 m l r | m < 0 -> go (go z r) l -- put negative numbers before
744+
case t of Bin _ m l r | m < 0 -> go (go z r) l -- put negative numbers before
744745
_ -> go z t
745746
where
746747
STRICT_1_OF_2(go)
@@ -1134,75 +1135,177 @@ highestBitMask x0
11341135
{-# INLINE highestBitMask #-}
11351136

11361137
{----------------------------------------------------------------------
1137-
Finds the index of the lowest resp. highest bit set in a word. The following
1138-
code works fine for bit sizes up to 64. A possibly faster but
1139-
wordsize-dependant implementation based on multiplication and DeBrujn indeces
1140-
is proposed by Edward Kmett
1141-
<http://haskell.org/pipermail/libraries/2011-September/016749.html>
1142-
Some architectures, notably x86, also offer machine instructions for this
1143-
operation (bsr and bsl).
1138+
To get best performance, we provide fast implementations of
1139+
lowestBitSet, highestBitSet and fold[lr][l]Bits for GHC.
1140+
If the intel bsf and bsr instructions ever become GHC primops,
1141+
this code should be reimplemented using these.
1142+
1143+
Performance of this code is crucial for folds, toList, filter, partition.
1144+
1145+
The signatures of methods in question are placed after this comment.
1146+
----------------------------------------------------------------------}
1147+
1148+
lowestBitSet :: Nat -> Int
1149+
highestBitSet :: Nat -> Int
1150+
foldlBits :: Int -> (a -> Int -> a) -> a -> Nat -> a
1151+
foldl'Bits :: Int -> (a -> Int -> a) -> a -> Nat -> a
1152+
foldrBits :: Int -> (Int -> a -> a) -> a -> Nat -> a
1153+
foldr'Bits :: Int -> (Int -> a -> a) -> a -> Nat -> a
1154+
1155+
{-# INLINE lowestBitSet #-}
1156+
{-# INLINE highestBitSet #-}
1157+
{-# INLINE foldlBits #-}
1158+
{-# INLINE foldl'Bits #-}
1159+
{-# INLINE foldrBits #-}
1160+
{-# INLINE foldr'Bits #-}
1161+
1162+
#if defined(__GLASGOW_HASKELL__)
1163+
#include "MachDeps.h"
1164+
#endif
1165+
1166+
#if defined(__GLASGOW_HASKELL__) && (WORD_SIZE_IN_BITS==32 || WORD_SIZE_IN_BITS==64)
1167+
{----------------------------------------------------------------------
1168+
For lowestBitSet we use wordsize-dependant implementation based on
1169+
multiplication and DeBrujn indeces, which was proposed by Edward Kmett
1170+
<http://haskell.org/pipermail/libraries/2011-September/016749.html>
1171+
1172+
The core of this implementation is fast indexOfTheOnlyBit,
1173+
which is given a Nat with exactly one bit set, and returns
1174+
its index.
1175+
1176+
Lot of effort was put in these implementations, please benchmark carefully
1177+
before changing this code.
1178+
----------------------------------------------------------------------}
1179+
1180+
indexOfTheOnlyBit :: Nat -> Int
1181+
{-# INLINE indexOfTheOnlyBit #-}
1182+
indexOfTheOnlyBit bit =
1183+
I# (lsbArray `indexInt8OffAddr#` unboxInt (intFromNat ((bit * magic) `shiftRL` offset)))
1184+
where unboxInt (I# i) = i
1185+
#if WORD_SIZE_IN_BITS==32
1186+
magic = 0x077CB531
1187+
offset = 27
1188+
!lsbArray = "\0\1\28\2\29\14\24\3\30\22\20\15\25\17\4\8\31\27\13\23\21\19\16\7\26\12\18\6\11\5\10\9"#
1189+
#else
1190+
magic = 0x07EDD5E59A4E28C2
1191+
offset = 58
1192+
!lsbArray = "\63\0\58\1\59\47\53\2\60\39\48\27\54\33\42\3\61\51\37\40\49\18\28\20\55\30\34\11\43\14\22\4\62\57\46\52\38\26\32\41\50\36\17\19\29\10\13\21\56\45\25\31\35\16\9\12\44\24\15\8\23\7\6\5"#
1193+
#endif
1194+
-- The lsbArray gets inlined to every call site of indexOfTheOnlyBit.
1195+
-- That cannot be easily avoided, as GHC forbids top-level Addr# literal.
1196+
-- One could go around that by supplying getLsbArray :: () -> Addr# marked
1197+
-- as NOINLINE. But the code size of calling it and processing the result
1198+
-- is 48B on 32-bit and 56B on 64-bit architectures -- so the 32B and 64B array
1199+
-- is actually improvement on 32-bit and only a 8B size increase on 64-bit.
1200+
1201+
lowestBitMask :: Nat -> Nat
1202+
lowestBitMask x = x .&. negate x
1203+
{-# INLINE lowestBitMask #-}
1204+
1205+
-- Reverse the order of bits in the Nat.
1206+
revNat :: Nat -> Nat
1207+
#if WORD_SIZE_IN_BITS==32
1208+
revNat x1 = case ((x1 `shiftRL` 1) .&. 0x55555555) .|. ((x1 .&. 0x55555555) `shiftLL` 1) of
1209+
x2 -> case ((x2 `shiftRL` 2) .&. 0x33333333) .|. ((x2 .&. 0x33333333) `shiftLL` 2) of
1210+
x3 -> case ((x3 `shiftRL` 4) .&. 0x0F0F0F0F) .|. ((x3 .&. 0x0F0F0F0F) `shiftLL` 4) of
1211+
x4 -> case ((x4 `shiftRL` 8) .&. 0x00FF00FF) .|. ((x4 .&. 0x00FF00FF) `shiftLL` 8) of
1212+
x5 -> ( x5 `shiftRL` 16 ) .|. ( x5 `shiftLL` 16);
1213+
#else
1214+
revNat x1 = case ((x1 `shiftRL` 1) .&. 0x5555555555555555) .|. ((x1 .&. 0x5555555555555555) `shiftLL` 1) of
1215+
x2 -> case ((x2 `shiftRL` 2) .&. 0x3333333333333333) .|. ((x2 .&. 0x3333333333333333) `shiftLL` 2) of
1216+
x3 -> case ((x3 `shiftRL` 4) .&. 0x0F0F0F0F0F0F0F0F) .|. ((x3 .&. 0x0F0F0F0F0F0F0F0F) `shiftLL` 4) of
1217+
x4 -> case ((x4 `shiftRL` 8) .&. 0x00FF00FF00FF00FF) .|. ((x4 .&. 0x00FF00FF00FF00FF) `shiftLL` 8) of
1218+
x5 -> case ((x5 `shiftRL` 16) .&. 0x0000FFFF0000FFFF) .|. ((x5 .&. 0x0000FFFF0000FFFF) `shiftLL` 16) of
1219+
x6 -> ( x6 `shiftRL` 32 ) .|. ( x6 `shiftLL` 32);
1220+
#endif
1221+
1222+
lowestBitSet x = indexOfTheOnlyBit (lowestBitMask x)
1223+
1224+
highestBitSet x = indexOfTheOnlyBit (highestBitMask x)
1225+
1226+
foldlBits shift f z bm = go bm z
1227+
where go bm z | bm == 0 = z
1228+
| otherwise = case lowestBitMask bm of
1229+
bit -> bit `seq` case indexOfTheOnlyBit bit of
1230+
bi -> bi `seq` go (bm `xor` bit) ((f z) $! (shift+bi))
1231+
1232+
foldl'Bits shift f z bm = go bm z
1233+
where STRICT_2_OF_2(go)
1234+
go bm z | bm == 0 = z
1235+
| otherwise = case lowestBitMask bm of
1236+
bit -> bit `seq` case indexOfTheOnlyBit bit of
1237+
bi -> bi `seq` go (bm `xor` bit) ((f z) $! (shift+bi))
1238+
1239+
foldrBits shift f z bm = go (revNat bm) z
1240+
where go bm z | bm == 0 = z
1241+
| otherwise = case lowestBitMask bm of
1242+
bit -> bit `seq` case indexOfTheOnlyBit bit of
1243+
bi -> bi `seq` go (bm `xor` bit) ((f $! (shift+(WORD_SIZE_IN_BITS-1)-bi)) z)
1244+
1245+
foldr'Bits shift f z bm = go (revNat bm) z
1246+
where STRICT_2_OF_2(go)
1247+
go bm z | bm == 0 = z
1248+
| otherwise = case lowestBitMask bm of
1249+
bit -> bit `seq` case indexOfTheOnlyBit bit of
1250+
bi -> bi `seq` go (bm `xor` bit) ((f $! (shift+(WORD_SIZE_IN_BITS-1)-bi)) z)
1251+
1252+
#else
1253+
{----------------------------------------------------------------------
1254+
In general case we use logarithmic implementation of
1255+
lowestBitSet and highestBitSet, which works up to bit sizes of 64.
1256+
1257+
Folds are linear scans.
11441258
----------------------------------------------------------------------}
11451259

1146-
lowestBitSet :: Word -> Int
11471260
lowestBitSet n0 =
11481261
let (n1,b1) = if n0 .&. 0xFFFFFFFF /= 0 then (n0,0) else (n0 `shiftRL` 32, 32)
11491262
(n2,b2) = if n1 .&. 0xFFFF /= 0 then (n1,b1) else (n1 `shiftRL` 16, 16+b1)
11501263
(n3,b3) = if n2 .&. 0xFF /= 0 then (n2,b2) else (n2 `shiftRL` 8, 8+b2)
11511264
(n4,b4) = if n3 .&. 0xF /= 0 then (n3,b3) else (n3 `shiftRL` 4, 4+b3)
11521265
(n5,b5) = if n4 .&. 0x3 /= 0 then (n4,b4) else (n4 `shiftRL` 2, 2+b4)
11531266
b6 = if n5 .&. 0x1 /= 0 then b5 else 1+b5
1154-
in b6
1155-
{-# INLINE lowestBitSet #-}
1267+
in b6
11561268

1157-
highestBitSet :: Word -> Int
11581269
highestBitSet n0 =
11591270
let (n1,b1) = if n0 .&. 0xFFFFFFFF00000000 /= 0 then (n0 `shiftRL` 32, 32) else (n0,0)
11601271
(n2,b2) = if n1 .&. 0xFFFF0000 /= 0 then (n1 `shiftRL` 16, 16+b1) else (n1,b1)
11611272
(n3,b3) = if n2 .&. 0xFF00 /= 0 then (n2 `shiftRL` 8, 8+b2) else (n2,b2)
11621273
(n4,b4) = if n3 .&. 0xF0 /= 0 then (n3 `shiftRL` 4, 4+b3) else (n3,b3)
11631274
(n5,b5) = if n4 .&. 0xC /= 0 then (n4 `shiftRL` 2, 2+b4) else (n4,b4)
11641275
b6 = if n5 .&. 0x2 /= 0 then 1+b5 else b5
1165-
in b6
1166-
{-# INLINE highestBitSet #-}
1167-
1168-
1169-
{----------------------------------------------------------------------
1170-
Folds over bitmaps. These are crucial for good speed in toList, filter,
1171-
partition. Futher optimization is welcome.
1172-
----------------------------------------------------------------------}
1276+
in b6
11731277

1174-
foldlBits :: Int -> (a -> Int -> a) -> a -> Word -> a
1175-
foldlBits shift f x bm = let lb = lowestBitSet bm
1176-
in go (shift+lb) x (bm `shiftRL` lb)
1177-
where STRICT_2_OF_3(go)
1278+
foldlBits shift f z bm = let lb = lowestBitSet bm
1279+
in go (shift+lb) z (bm `shiftRL` lb)
1280+
where STRICT_1_OF_3(go)
11781281
go bi acc 0 = acc
11791282
go bi acc n | n `testBit` 0 = go (bi + 1) (f acc bi) (n `shiftRL` 1)
11801283
| otherwise = go (bi + 1) acc (n `shiftRL` 1)
11811284

1182-
foldl'Bits :: Int -> (a -> Int -> a) -> a -> Word -> a
1183-
foldl'Bits shift f x bm = let lb = lowestBitSet bm
1184-
in go (shift+lb) x (bm `shiftRL` lb)
1285+
foldl'Bits shift f z bm = let lb = lowestBitSet bm
1286+
in go (shift+lb) z (bm `shiftRL` lb)
11851287
where STRICT_1_OF_3(go)
11861288
STRICT_2_OF_3(go)
11871289
go bi acc 0 = acc
11881290
go bi acc n | n `testBit` 0 = go (bi + 1) (f acc bi) (n `shiftRL` 1)
11891291
| otherwise = go (bi + 1) acc (n `shiftRL` 1)
11901292

1191-
foldrBits :: Int -> (Int -> a -> a) -> a -> Word -> a
1192-
foldrBits shift f x bm = let lb = lowestBitSet bm
1293+
foldrBits shift f z bm = let lb = lowestBitSet bm
11931294
in go (shift+lb) (bm `shiftRL` lb)
11941295
where STRICT_1_OF_2(go)
1195-
go bi 0 = x
1296+
go bi 0 = z
11961297
go bi n | n `testBit` 0 = f bi (go (bi + 1) (n `shiftRL` 1))
11971298
| otherwise = go (bi + 1) (n `shiftRL` 1)
11981299

1199-
foldr'Bits :: Int -> (Int -> a -> a) -> a -> Word -> a
1200-
foldr'Bits shift f x bm = let lb = lowestBitSet bm
1300+
foldr'Bits shift f z bm = let lb = lowestBitSet bm
12011301
in go (shift+lb) (bm `shiftRL` lb)
1202-
where go bi 0 = x
1302+
where STRICT_1_OF_2(go)
1303+
go bi 0 = z
12031304
go bi n | n `testBit` 0 = f bi $! go (bi + 1) (n `shiftRL` 1)
12041305
| otherwise = go (bi + 1) (n `shiftRL` 1)
12051306

1307+
#endif
1308+
12061309
{----------------------------------------------------------------------
12071310
[bitcount] as posted by David F. Place to haskell-cafe on April 11, 2006,
12081311
based on the code on

0 commit comments

Comments
 (0)