Skip to content

Commit

Permalink
Improve parallelization of iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
lehins committed Aug 7, 2022
1 parent 33df063 commit f96d02a
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 116 deletions.
64 changes: 50 additions & 14 deletions massiv-bench/bench/Iter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Data.Typeable

baseline :: Sz1 -> IO ()
baseline (Sz sz) = loopA_ 0 (< sz) (+ 1) $ \i -> i `seq` pure ()
{-# NOINLINE baseline #-}


main :: IO ()
main = do
Expand Down Expand Up @@ -51,6 +51,14 @@ main = do
, iterFullBenchPar sz4
, iterFullBenchPar sz5
]
, bgroup
"Stride"
[ iterStrideBenchPar (Stride 3) sz1
, iterStrideBenchPar (Stride 3) sz2
, iterStrideBenchPar (Stride 3) sz3
, iterStrideBenchPar (Stride 3) sz4
, iterStrideBenchPar (Stride 3) sz5
]
]
]

Expand All @@ -59,26 +67,36 @@ iterFullBench :: Index ix => Sz ix -> Benchmark
iterFullBench !sz =
bgroup
(show (typeOf sz))
[ bench "RowMajor" $ whnfIO $
stToIO $ iterTargetFullST_ defRowMajor trivialScheduler_ 0 sz seqAction
, bench "iterA_" $ whnfIO $
[ bench "iterA_" $ whnfIO $
iterA_ zeroIndex (unSz sz) oneIndex (<) (\ix -> ix `seq` pure ())
, bench "iterFullA_" $ whnfIO $
, bench "iterFullA_ (RowMajor)" $ whnfIO $
iterFullA_ defRowMajor zeroIndex sz (\ix -> ix `seq` pure ())
, bench "RowMajorLinear" $ whnfIO $
, bench "iterTargetFullST_ (RowMajor)" $ whnfIO $
stToIO $ iterTargetFullST_ defRowMajor trivialScheduler_ 0 sz seqAction
, bench "iterFullST (RowMajor)" $ whnfIO $
stToIO $ iterFullST defRowMajor trivialScheduler_ zeroIndex sz () noopSplit seqAction
, bench "iterFullA_ (RowMajorLinear)" $ whnfIO $
iterFullA_ defRowMajorLinear zeroIndex sz (\ix -> ix `seq` pure ())
, bench "iterTargetFullST_ (RowMajorLinear)" $ whnfIO $
stToIO $ iterTargetFullST_ defRowMajorLinear trivialScheduler_ 0 sz seqAction
, bench "iterFullST (RowMajorLinear)" $ whnfIO $
stToIO $ iterFullST defRowMajorLinear trivialScheduler_ zeroIndex sz () noopSplit seqAction
]

noopSplit :: () -> ST s ((), ())
noopSplit () = pure ((), ())

iterStrideBench :: (Num ix, Index ix) => Stride ix -> Sz ix -> Benchmark
iterStrideBench !stride !sz =
bgroup
(show (typeOf sz))
[ bench "RowMajor" $
[ bench "iterM" $
whnfIO $
() <$ iterM zeroIndex (unSz sz * s) s (<) (0 :: Int) (\ix i -> (i + 1) <$ seqAction i ix)
, bench "iterTargetFullWithStrideST_ (RowMajor)" $
whnfIO $
stToIO $ iterTargetFullWithStrideST_ defRowMajor trivialScheduler_ 0 sz stride seqAction
, bench "iterM" $
whnfIO $ () <$ iterM zeroIndex (unSz sz * s) s (<) 0 (\ix i -> (i + 1) <$ seqAction i ix)
, bench "RowMajorLinear" $
, bench "iterTargetFullWithStrideST_ (RowMajorLinear)" $
whnfIO $
stToIO $ iterTargetFullWithStrideST_ defRowMajorLinear trivialScheduler_ 0 sz stride seqAction
]
Expand All @@ -90,16 +108,34 @@ iterFullBenchPar :: Index ix => Sz ix -> Benchmark
iterFullBenchPar sz =
bgroup
(show (typeOf sz))
[ bench "RowMajor" $ whnfIO $
[ bench "iterTargetFullST_ (RowMajor)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterTargetFullST_ defRowMajor scheduler 0 sz seqAction
, bench "RowMajorLinear" $ whnfIO $
, bench "iterFullST (RowMajor)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterFullST defRowMajor scheduler zeroIndex sz () noopSplit seqAction
, bench "iterTargetFullST_ (RowMajorLinear)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterTargetFullST_ defRowMajorLinear scheduler 0 sz seqAction
, bench "iterFullST (RowMajorLinear)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterFullST defRowMajorLinear scheduler zeroIndex sz () noopSplit seqAction
]

iterStrideBenchPar :: (Num ix, Index ix) => Stride ix -> Sz ix -> Benchmark
iterStrideBenchPar stride sz =
bgroup
(show (typeOf sz))
[ bench "iterTargetFullWithStrideST_ (RowMajor)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterTargetFullWithStrideST_ defRowMajor scheduler 0 sz stride seqAction
, bench "iterTargetFullWithStrideST_ (RowMajorLinear)" $ whnfIO $
withMassivScheduler_ Par $ \ scheduler ->
stToIO $ iterTargetFullWithStrideST_ defRowMajorLinear scheduler 0 sz stride seqAction
]

seqAction :: Monad m => Int -> ix -> m ()
seqAction i ix = i `seq` ix `seq` pure ()
seqAction :: Monad m => a -> b -> m ()
seqAction a b = a `seq` b `seq` pure ()
{-# INLINE seqAction #-}


Expand Down
1 change: 0 additions & 1 deletion massiv/src/Data/Massiv/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ module Data.Massiv.Core
, initWorkerStates
, scheduleWork
, scheduleWork_
, withMassivScheduler_
, module Data.Massiv.Core.Index
-- * Numeric
, FoldNumeric
Expand Down
18 changes: 1 addition & 17 deletions massiv/src/Data/Massiv/Core/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ module Data.Massiv.Core.Common
, numWorkers
, scheduleWork
, scheduleWork_
, withMassivScheduler_
, WorkerStates
, unsafeRead
, unsafeWrite
Expand Down Expand Up @@ -99,9 +98,7 @@ import Control.Monad.IO.Unlift (MonadIO(liftIO), MonadUnliftIO(..))
import Control.Monad.Primitive
import Control.Monad.ST
import Control.Scheduler (Comp(..), Scheduler, WorkerStates, numWorkers,
scheduleWork, scheduleWork_, trivialScheduler_,
withScheduler_)
import Control.Scheduler.Global
scheduleWork, scheduleWork_, trivialScheduler_)
import GHC.Exts (IsList)
import Data.Massiv.Core.Exception
import Data.Massiv.Core.Index
Expand Down Expand Up @@ -667,19 +664,6 @@ unsafeDefaultLinearShrink marr sz = do
pure marr'
{-# INLINE unsafeDefaultLinearShrink #-}


-- | Selects an optimal scheduler for the supplied strategy, but it works only in `IO`
--
-- @since 1.0.0
withMassivScheduler_ :: Comp -> (Scheduler RealWorld () -> IO ()) -> IO ()
withMassivScheduler_ comp f =
case comp of
Par -> withGlobalScheduler_ globalScheduler f
Seq -> f trivialScheduler_
_ -> withScheduler_ comp f
{-# INLINE withMassivScheduler_ #-}


-- | Read an array element
--
-- @since 0.1.0
Expand Down
125 changes: 58 additions & 67 deletions massiv/src/Data/Massiv/Core/Index/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -650,14 +650,14 @@ class ( Eq ix
{-# INLINE iterM #-}

iterRowMajorST :: Int -- ^ Scheduler multiplying factor. Must be positive
-> Scheduler s a -- ^ Scheduler to use
-> ix -- ^ Start index
-> ix -- ^ Stride
-> Sz ix -- ^ Size
-> a -- ^ Initial accumulator
-> (a -> ST s (a, a)) -- ^ Function that splits accumulator for each scheduled job.
-> (ix -> a -> ST s a) -- ^ Action
-> ST s ()
-> Scheduler s a -- ^ Scheduler to use
-> ix -- ^ Start index
-> ix -- ^ Stride
-> Sz ix -- ^ Size
-> a -- ^ Initial accumulator
-> (a -> ST s (a, a)) -- ^ Function that splits accumulator for each scheduled job.
-> (ix -> a -> ST s a) -- ^ Action
-> ST s ()
default iterRowMajorST :: Index (Lower ix)
=> Int
-> Scheduler s a
Expand All @@ -669,26 +669,20 @@ class ( Eq ix
-> (ix -> a -> ST s a)
-> ST s ()
iterRowMajorST !fact scheduler ixStart ixStride sz initAcc splitAcc f = do
let !(!n, !nL) = unconsDim (unSz sz)
!(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
!nw = numWorkers scheduler
if nw == 1 || n == 0
then
scheduleWork scheduler $ iterM ixStart (unSz sz) ixStride (<) initAcc f
else
if fact > 1 && n < nw * fact
then
let newFact = 1 + (fact `quot` n)
szL = SafeSz nL
in loopA_ start (< start + n * stride) (+ stride) $ \j ->
iterRowMajorST newFact scheduler ixL sL szL initAcc splitAcc $ \jl ->
f (consDim j jl)
else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ _ _ chunkStartAdj chunkStopAdj acc ->
loopM chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \j a ->
iterM ixL nL sL (<) a $ \jl -> f (consDim j jl)
let !(SafeSz n, szL@(SafeSz nL)) = unconsSz sz
when (n > 0) $ do
let !(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
if fact > 1 && n < numWorkers scheduler * fact
then do
let !newFact = 1 + (fact `quot` n)
loopA_ start (< start + n * stride) (+ stride) $ \j ->
iterRowMajorST newFact scheduler ixL sL szL initAcc splitAcc (f . consDim j)
else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ _ _ chunkStartAdj chunkStopAdj acc ->
loopM chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \j a ->
iterM ixL nL sL (<) a (f . consDim j)

-- Initial implementation:
-- let !(!n, !nL) = unconsDim (unSz sz)
Expand Down Expand Up @@ -827,28 +821,23 @@ class ( Eq ix
-> (Ix1 -> ix -> a -> ST s a)
-> ST s ()
iterTargetRowMajorAccST !iAcc !fact scheduler iStart sz ixStart ixStride initAcc splitAcc f = do
let !(SafeSz n, !nL) = unconsSz sz
!(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
!iAccL = iAcc * n
!nw = numWorkers scheduler
if nw == 1 || n == 0
then
scheduleWork scheduler $
iterTargetRowMajorAccM iAcc iStart sz ixStart ixStride initAcc f
else
-- FIXME: slower for manifest arrays, but faster for D (resized)
if fact > 1 && n < nw * fact
then
let newFact = 1 + (fact `quot` n)
in iloopA_ iAccL start (< start + n * stride) (+ stride) $ \k j ->
iterTargetRowMajorAccST k newFact scheduler iStart nL ixL sL initAcc splitAcc $
\i jl -> f i (consDim j jl)
else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ chunkStart _ chunkStartAdj chunkStopAdj acc ->
iloopM (iAccL + chunkStart) chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \k j a ->
iterTargetRowMajorAccM k iStart nL ixL sL a $ \i jl -> f i (consDim j jl)
let !(SafeSz n, nL) = unconsSz sz
when (n > 0) $ do
let !(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
!iAccL = iAcc * n
-- FIXME: slower for manifest arrays, but faster for D (resized)
if fact > 1 && n < numWorkers scheduler * fact
then do
let newFact = 1 + (fact `quot` n)
iloopA_ iAccL start (< start + n * stride) (+ stride) $ \k j ->
iterTargetRowMajorAccST k newFact scheduler iStart nL ixL sL initAcc splitAcc $
\i -> f i . consDim j
else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ chunkStart _ chunkStartAdj chunkStopAdj acc ->
iloopM (iAccL + chunkStart) chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \k j a ->
iterTargetRowMajorAccM k iStart nL ixL sL a $ \i -> f i . consDim j

-- Initial implementation:
-- !n = n' * stride
Expand Down Expand Up @@ -897,23 +886,25 @@ class ( Eq ix
-> (Ix1 -> ix -> a -> ST s a)
-> ST s ()
iterTargetRowMajorAccST_ !iAcc !fact scheduler iStart sz ixStart ixStride initAcc splitAcc f = do
let !(SafeSz n, !nL) = unconsSz sz
!(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
!iAccL = iAcc * n
-- TODO: Attempt to optimize:
-- if n < numWorkers scheduler
-- then
-- void $ iloopM iAccL start (< n * stride) (+ stride) initAcc $ \k j acc -> do
-- (accCur, accNext) <- splitAcc acc
-- iterTargetRowMajorAccST_ k 1 scheduler iStart nL ixL sL accCur splitAcc $ \i jl ->
-- f i (consDim j jl)
-- pure accNext
-- else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ chunkStart _ chunkStartAdj chunkStopAdj acc ->
void $ iloopM (iAccL + chunkStart) chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \k j a ->
iterTargetRowMajorAccM k iStart nL ixL sL a $ \i jl -> f i (consDim j jl)
let !(SafeSz n, szL) = unconsSz sz
when (n > 0) $ do
let !(!start, !ixL) = unconsDim ixStart
!(!stride, !sL) = unconsDim ixStride
!iAccL = iAcc * n
if fact > 1 && n < numWorkers scheduler * fact
then do
let !newFact = 1 + (fact `quot` n)
void $ iloopM iAccL start (< n * stride) (+ stride) initAcc $ \k j acc -> do
(accCur, accNext) <- splitAcc acc
iterTargetRowMajorAccST_ k newFact scheduler iStart szL ixL sL accCur splitAcc $ \i ->
f i . consDim j
pure accNext
else
splitWorkWithFactorST fact scheduler start stride n initAcc splitAcc $
\ chunkStart _ chunkStartAdj chunkStopAdj acc ->
void $
iloopM (iAccL + chunkStart) chunkStartAdj (< chunkStopAdj) (+ stride) acc $ \k j a ->
iterTargetRowMajorAccM k iStart szL ixL sL a $ \i -> f i . consDim j

-- Initial implementation:
-- let !(SafeSz n', !nL) = unconsSz szEnd
Expand Down
53 changes: 36 additions & 17 deletions massiv/src/Data/Massiv/Core/Loop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,21 @@ module Data.Massiv.Core.Loop
, stepStartAdjust
-- * Experimental
, splitWorkWithFactorST
, scheduleMassivWork
, withMassivScheduler_
) where

import Data.Functor.Identity
import Control.Monad
import Control.Monad.IO.Unlift
import Control.Monad (void, when)
import Control.Monad.IO.Unlift (MonadUnliftIO(..))
import Control.Monad.Primitive
import Control.Monad.ST
import Control.Scheduler
import Control.Monad.ST (ST)
import Control.Scheduler (Comp(..), Scheduler, SchedulerWS,
numWorkers, scheduleWork, scheduleWorkState_,
scheduleWork_, trivialScheduler_, unwrapSchedulerWS,
withScheduler_)
import Control.Scheduler.Global (globalScheduler, withGlobalScheduler_)
import Data.Coerce
import Data.Functor.Identity

-- | Efficient loop with an accumulator
--
Expand Down Expand Up @@ -328,29 +334,19 @@ splitWorkWithFactorST fact scheduler start step totalLength initAcc splitAcc' ac
accSlack <-
loopM 0 (< slackStart) (+ chunkLength) initAcc $ \ !chunkStart !acc -> do
(accCur, accNext) <- splitAcc acc
scheduleWork' scheduler $ do
scheduleMassivWork scheduler $ do
let !chunkStartAdj = start + chunkStart * step
!chunkStopAdj = chunkStartAdj + chunkLength * step
action chunkStart chunkLength chunkStartAdj chunkStopAdj accCur
pure accNext
let !slackLength = totalLength - slackStart
when (slackLength > 0) $
scheduleWork' scheduler $ do
scheduleMassivWork scheduler $ do
let !slackStartAdj = start + slackStart * step
!slackStopAdj = slackStartAdj + slackLength * step
action slackStart slackLength slackStartAdj slackStopAdj accSlack
{-# INLINE splitWorkWithFactorST #-}

scheduleWork' :: PrimBase m => Scheduler (PrimState m) a -> m a -> m ()
scheduleWork' = scheduleWork
{-# INLINE[0] scheduleWork' #-}

{-# RULES
"scheduleWork/scheduleWork_/ST" forall (scheduler :: Scheduler s ()) (action :: ST s ()) . scheduleWork' scheduler action = scheduleWork_ scheduler action
"scheduleWork/scheduleWork_/IO" forall (scheduler :: Scheduler RealWorld ()) (action :: IO ()) . scheduleWork' scheduler action = scheduleWork_ scheduler action
#-}


-- | Linear iterator that supports multiplying factor
--
-- @since 1.0.2
Expand Down Expand Up @@ -432,3 +428,26 @@ splitNumChunks splitAcc' fact nw totalLength =
stepStartAdjust :: Int -> Int -> Int
stepStartAdjust step ix = ix + ((step - (ix `mod` step)) `mod` step)
{-# INLINE stepStartAdjust #-}


-- | Internal version of a `scheduleWork` that will be replaced by
-- `scheduleWork_` by the compiler whenever action produces `()`
scheduleMassivWork :: PrimBase m => Scheduler (PrimState m) a -> m a -> m ()
scheduleMassivWork = scheduleWork
{-# INLINE[0] scheduleMassivWork #-}

{-# RULES
"scheduleWork/scheduleWork_/ST" forall (scheduler :: Scheduler s ()) (action :: ST s ()) . scheduleMassivWork scheduler action = scheduleWork_ scheduler action
"scheduleWork/scheduleWork_/IO" forall (scheduler :: Scheduler RealWorld ()) (action :: IO ()) . scheduleMassivWork scheduler action = scheduleWork_ scheduler action
#-}

-- | Selects an optimal scheduler for the supplied strategy, but it works only in `IO`
--
-- @since 1.0.0
withMassivScheduler_ :: Comp -> (Scheduler RealWorld () -> IO ()) -> IO ()
withMassivScheduler_ comp f =
case comp of
Par -> withGlobalScheduler_ globalScheduler f
Seq -> f trivialScheduler_
_ -> withScheduler_ comp f
{-# INLINE withMassivScheduler_ #-}

0 comments on commit f96d02a

Please sign in to comment.