diff --git a/Control/Concurrent/Async/Pool/Internal.hs b/Control/Concurrent/Async/Pool/Internal.hs index a859f98..252411b 100644 --- a/Control/Concurrent/Async/Pool/Internal.hs +++ b/Control/Concurrent/Async/Pool/Internal.hs @@ -10,7 +10,7 @@ import Control.Concurrent (ThreadId) import qualified Control.Concurrent.Async as Async (withAsync) import Control.Concurrent.Async.Pool.Async import Control.Concurrent.STM -import Control.Exception (SomeException, throwIO, finally) +import Control.Exception (SomeException, throwIO, finally, bracket_) import Control.Monad hiding (forM, forM_) import Control.Monad.Base import Control.Monad.IO.Class (MonadIO(..)) @@ -160,6 +160,9 @@ asyncAfterAll p parents t = atomically $ do asyncAfter :: TaskGroup -> Async b -> IO a -> IO (Async a) asyncAfter p parent = asyncAfterAll p [taskHandle parent] +extraWorkerWhileBlocked :: TaskGroup -> IO a -> IO a +extraWorkerWhileBlocked p = bracket_ (atomically $ modifyTVar' (avail p) (+ 1)) (atomically $ modifyTVar' (avail p) ((-) 1)) + -- | Helper function used by several of the variants of 'mapTasks' below. mapTasksWorker :: Traversable t => TaskGroup @@ -169,7 +172,7 @@ mapTasksWorker :: Traversable t -> IO (t c) mapTasksWorker p fs f g = do hs <- forM fs $ atomically . asyncUsing p rawForkIO - f $ forM hs g + extraWorkerWhileBlocked p $ f $ forM hs g -- | Execute a group of tasks within the given task group, returning the -- results in order. The order of execution is random, but the results are