Skip to content

Commit

Permalink
Add a test for testing early termination and fix implementations to p…
Browse files Browse the repository at this point in the history
…ass.
  • Loading branch information
merijn committed Sep 18, 2018
1 parent 5551e8e commit 5047f01
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
26 changes: 24 additions & 2 deletions broadcast-chan-tests/BroadcastChan/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (wait, withAsync)
import Control.Concurrent.MVar
import Control.Concurrent.STM
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Exception (Exception, try)
import Control.Exception (Exception, throwIO, try)
import Data.Bifunctor (second)
import Data.List (sort)
import Data.Map (Map)
Expand Down Expand Up @@ -222,6 +223,26 @@ outputTest seqSink parSink threads inputs label =
parTest :: Handle -> IO r
parTest hndl = parSink inputs (doPrint hndl) threads

data TestException = TestException deriving (Eq, Show, Typeable)
instance Exception TestException

exceptionTests
:: (Eq r, Show r)
=> ([Int] -> (Int -> IO Int) -> IO r)
-> (Handler IO Int -> [Int] -> (Int -> IO Int) -> Int -> IO r)
-> TestTree
exceptionTests _seqImpl parImpl = testGroup "exceptions" $
[ testCase "termination" . expect TestException . void $
withSystemTempFile "terminate.out" $ \_ hndl ->
parImpl (Simple Terminate) inputs (dropEven hndl) 4
]
where
inputs = [1..100]

dropEven hnd n
| even n = throwIO TestException
| otherwise = doPrint hnd n

newtype SlowTests = SlowTests Bool
deriving (Eq, Ord, Typeable)

Expand Down Expand Up @@ -258,7 +279,7 @@ genStreamTests
genStreamTests name f g = askOption $ \(SlowTests slow) ->
withResource (newTVarIO M.empty) (const $ return ()) $ \getCache ->
let
testTree = growTree (Just "/") testGroup
testTree = growTree (Just ".") testGroup
params
| slow = simpleParam "threads" [1,2,5]
. derivedParam (enumFromTo 0) "inputs" [600]
Expand All @@ -269,6 +290,7 @@ genStreamTests name f g = askOption $ \(SlowTests slow) ->
in testGroup name
[ testTree "output" (outputTest f (g term)) params
, testTree "speedup" (speedupTest getCache f (g term)) $ params . pause
, exceptionTests f g
]
where
term = Simple Terminate
Expand Down
36 changes: 21 additions & 15 deletions broadcast-chan/BroadcastChan/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@ module BroadcastChan.Extra
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<*))
#endif
import Control.Concurrent
(ThreadId, forkFinally, killThread, mkWeakThreadId, myThreadId)
import Control.Concurrent (ThreadId, forkFinally, mkWeakThreadId, myThreadId)
import Control.Concurrent.MVar
import Control.Concurrent.QSem
import Control.Concurrent.QSemN
import Control.Exception
(Exception(..), SomeException(..), catch, mask_, throwIO, throwTo)
import Control.Exception (Exception(..), SomeException(..))
import qualified Control.Exception as Exc
import Control.Monad ((>=>), replicateM, void)
import Control.Monad.IO.Unlift (MonadIO(..))
import Data.Typeable (Typeable)
Expand All @@ -51,7 +50,7 @@ import BroadcastChan.Internal
unsafeWriteBChan :: MonadIO m => BroadcastChan In a -> a -> m ()
unsafeWriteBChan (BChan writeVar) val = liftIO $ do
new_hole <- newEmptyMVar
mask_ $ do
Exc.mask_ $ do
old_hole <- takeMVar writeVar
-- old_hole is only full if the channel was previously closed
item <- tryTakeMVar old_hole
Expand Down Expand Up @@ -126,12 +125,10 @@ parallelCore hndl threads f = liftIO $ do
simpleHandler val exc act = case act of
Drop -> return ()
Retry -> unsafeWriteBChan inChanIn val
Terminate -> do
throwTo originTid exc
myThreadId >>= killThread
Terminate -> Exc.throwIO exc

handler :: a -> SomeException -> IO ()
handler _ exc | Just Shutdown <- fromException exc = throwIO exc
handler _ exc | Just Shutdown <- fromException exc = Exc.throwIO exc
handler val exc = case hndl of
Simple a -> simpleHandler val exc a
Handle h -> h val exc >>= simpleHandler val exc
Expand All @@ -142,17 +139,26 @@ parallelCore hndl threads f = liftIO $ do
case x of
Nothing -> signalQSemN endSem 1
Just a -> do
f a `catch` handler a
f a `Exc.catch` handler a
processInput

allocate :: IO [Weak ThreadId]
allocate = liftIO $ do
tids <- replicateM threads $
forkFinally processInput (\_ -> signalQSemN shutdownSem 1)
tids <- replicateM threads . forkFinally processInput $ \exit -> do
signalQSemN shutdownSem 1
case exit of
Left exc
| Just Shutdown <- fromException exc -> return ()
| otherwise ->
Exc.throwTo originTid exc `Exc.catch` shutdownHandler
Right () -> return ()

mapM mkWeakThreadId tids
where
shutdownHandler Shutdown = return ()

cleanup :: [Weak ThreadId] -> IO ()
cleanup threadIds = liftIO $ do
cleanup threadIds = liftIO . Exc.uninterruptibleMask_ $ do
mapM_ killWeakThread threadIds
waitQSemN shutdownSem threads

Expand All @@ -168,7 +174,7 @@ parallelCore hndl threads f = liftIO $ do
tid <- deRefWeak wTid
case tid of
Nothing -> return ()
Just t -> throwTo t Shutdown
Just t -> Exc.throwTo t Shutdown

-- | Sets up parallel processing.
--
Expand Down Expand Up @@ -231,7 +237,7 @@ runParallel yielder hndl threads work pipe = do

action :: n r
action = do
result <- pipe process queueAndYield
result <- pipe (liftIO . bufferValue) queueAndYield
wait
closeBChan outChanIn
finish result
Expand Down

0 comments on commit 5047f01

Please sign in to comment.