Skip to content

Commit

Permalink
MonadAsync: added various interfaces
Browse files Browse the repository at this point in the history
* `asyncBound`
* `asyncOn`
* `asyncOnWithUnmask`
* `withAsyncBound`
* `withAsyncOn`
* `withAsyncWithUnmask`
* `withAsyncOnWithUnmask`
* `compareAsyncs`
  • Loading branch information
coot committed Sep 23, 2022
1 parent 3efb144 commit 74dc225
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 2 deletions.
86 changes: 84 additions & 2 deletions io-classes/src/Control/Monad/Class/MonadAsync.hs
Expand Up @@ -53,15 +53,19 @@ class ( MonadSTM m
, MonadThread m
) => MonadAsync m where

{-# MINIMAL async, asyncThreadId, cancel, cancelWith, asyncWithUnmask,
waitCatchSTM, pollSTM #-}
{-# MINIMAL async, asyncBound, asyncOn, asyncThreadId, cancel, cancelWith,
asyncWithUnmask, asyncOnWithUnmask, waitCatchSTM, pollSTM #-}

-- | An asynchronous action
type Async m = (async :: Type -> Type) | async -> m

async :: m a -> m (Async m a)
asyncBound :: m a -> m (Async m a)
asyncOn :: Int -> m a -> m (Async m a)
asyncThreadId :: Async m a -> ThreadId m
withAsync :: m a -> (Async m a -> m b) -> m b
withAsyncBound :: m a -> (Async m a -> m b) -> m b
withAsyncOn :: Int -> m a -> (Async m a -> m b) -> m b

waitSTM :: Async m a -> STM m a
pollSTM :: Async m a -> STM m (Maybe (Either SomeException a))
Expand Down Expand Up @@ -144,9 +148,23 @@ class ( MonadSTM m
concurrently_ :: m a -> m b -> m ()

asyncWithUnmask :: ((forall b . m b -> m b) -> m a) -> m (Async m a)
asyncOnWithUnmask :: Int -> ((forall b . m b -> m b) -> m a) -> m (Async m a)
withAsyncWithUnmask :: ((forall c. m c -> m c) -> m a) -> (Async m a -> m b) -> m b
withAsyncOnWithUnmask :: Int -> ((forall c. m c -> m c) -> m a) -> (Async m a -> m b) -> m b

compareAsyncs :: Async m a -> Async m b -> Ordering

-- default implementations
default withAsync :: MonadMask m => m a -> (Async m a -> m b) -> m b
default withAsyncBound:: MonadMask m => m a -> (Async m a -> m b) -> m b
default withAsyncOn :: MonadMask m => Int -> m a -> (Async m a -> m b) -> m b
default withAsyncWithUnmask
:: MonadMask m => ((forall c. m c -> m c) -> m a)
-> (Async m a -> m b) -> m b
default withAsyncOnWithUnmask
:: MonadMask m => Int
-> ((forall c. m c -> m c) -> m a)
-> (Async m a -> m b) -> m b
default uninterruptibleCancel
:: MonadMask m => Async m a -> m ()
default waitAnyCancel :: MonadThrow m => [Async m a] -> m (Async m a, a)
Expand All @@ -157,12 +175,35 @@ class ( MonadSTM m
default waitEitherCatchCancel :: MonadThrow m => Async m a -> Async m b
-> m (Either (Either SomeException a)
(Either SomeException b))
default compareAsyncs :: Ord (ThreadId m)
=> Async m a -> Async m b -> Ordering

withAsync action inner = mask $ \restore -> do
a <- async (restore action)
restore (inner a)
`finally` uninterruptibleCancel a

withAsyncBound action inner = mask $ \restore -> do
a <- asyncBound (restore action)
restore (inner a)
`finally` uninterruptibleCancel a

withAsyncOn n action inner = mask $ \restore -> do
a <- asyncOn n (restore action)
restore (inner a)
`finally` uninterruptibleCancel a


withAsyncWithUnmask action inner = mask $ \restore -> do
a <- asyncWithUnmask action
restore (inner a)
`finally` uninterruptibleCancel a

withAsyncOnWithUnmask n action inner = mask $ \restore -> do
a <- asyncOnWithUnmask n action
restore (inner a)
`finally` uninterruptibleCancel a

wait = atomically . waitSTM
poll = atomically . pollSTM
waitCatch = atomically . waitCatchSTM
Expand Down Expand Up @@ -202,6 +243,8 @@ class ( MonadSTM m

concurrently_ left right = void $ concurrently left right

compareAsyncs a b = asyncThreadId a `compare` asyncThreadId b

-- | Similar to 'Async.Concurrently' but which works for any 'MonadAsync'
-- instance.
--
Expand Down Expand Up @@ -265,8 +308,12 @@ instance MonadAsync IO where
type Async IO = Async.Async

async = Async.async
asyncBound = Async.asyncBound
asyncOn = Async.asyncOn
asyncThreadId = Async.asyncThreadId
withAsync = Async.withAsync
withAsyncBound = Async.withAsyncBound
withAsyncOn = Async.withAsyncOn

waitSTM = Async.waitSTM
pollSTM = Async.pollSTM
Expand Down Expand Up @@ -303,6 +350,11 @@ instance MonadAsync IO where
concurrently_ = Async.concurrently_

asyncWithUnmask = Async.asyncWithUnmask
asyncOnWithUnmask = Async.asyncOnWithUnmask
withAsyncWithUnmask = Async.withAsyncWithUnmask
withAsyncOnWithUnmask = Async.withAsyncOnWithUnmask

compareAsyncs = Async.compareAsyncs


--
Expand Down Expand Up @@ -410,15 +462,45 @@ instance ( MonadAsync m
asyncThreadId (WrappedAsync a) = asyncThreadId a

async (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> async (ma r)
asyncBound (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> asyncBound (ma r)
asyncOn n (ReaderT ma) = ReaderT $ \r -> WrappedAsync <$> asyncOn n (ma r)
withAsync (ReaderT ma) f = ReaderT $ \r -> withAsync (ma r)
$ \a -> runReaderT (f (WrappedAsync a)) r
withAsyncBound (ReaderT ma) f = ReaderT $ \r -> withAsyncBound (ma r)
$ \a -> runReaderT (f (WrappedAsync a)) r
withAsyncOn n (ReaderT ma) f = ReaderT $ \r -> withAsyncOn n (ma r)
$ \a -> runReaderT (f (WrappedAsync a)) r

asyncWithUnmask f = ReaderT $ \r -> fmap WrappedAsync
$ asyncWithUnmask
$ \unmask -> runReaderT (f (liftF unmask)) r
where
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF g (ReaderT r) = ReaderT (g . r)

asyncOnWithUnmask n f = ReaderT $ \r -> fmap WrappedAsync
$ asyncOnWithUnmask n
$ \unmask -> runReaderT (f (liftF unmask)) r
where
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF g (ReaderT r) = ReaderT (g . r)

withAsyncWithUnmask action f =
ReaderT $ \r -> withAsyncWithUnmask (\unmask -> case action (liftF unmask) of
ReaderT ma -> ma r)
$ \a -> runReaderT (f (WrappedAsync a)) r
where
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF g (ReaderT r) = ReaderT (g . r)

withAsyncOnWithUnmask n action f =
ReaderT $ \r -> withAsyncOnWithUnmask n (\unmask -> case action (liftF unmask) of
ReaderT ma -> ma r)
$ \a -> runReaderT (f (WrappedAsync a)) r
where
liftF :: (m a -> m a) -> ReaderT r m a -> ReaderT r m a
liftF g (ReaderT r) = ReaderT (g . r)

waitCatchSTM = WrappedSTM . waitCatchSTM . unWrapAsync
pollSTM = WrappedSTM . pollSTM . unWrapAsync

Expand Down
3 changes: 3 additions & 0 deletions io-classes/src/Control/Monad/Class/MonadFork.hs
Expand Up @@ -32,6 +32,7 @@ class (Monad m, Eq (ThreadId m),
class MonadThread m => MonadFork m where

forkIO :: m () -> m (ThreadId m)
forkOn :: Int -> m () -> m (ThreadId m)
forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
throwTo :: Exception e => ThreadId m -> e -> m ()

Expand All @@ -56,6 +57,7 @@ instance MonadThread IO where

instance MonadFork IO where
forkIO = IO.forkIO
forkOn = IO.forkOn
forkIOWithUnmask = IO.forkIOWithUnmask
throwTo = IO.throwTo
killThread = IO.killThread
Expand All @@ -68,6 +70,7 @@ instance MonadThread m => MonadThread (ReaderT r m) where

instance MonadFork m => MonadFork (ReaderT e m) where
forkIO (ReaderT f) = ReaderT $ \e -> forkIO (f e)
forkOn n (ReaderT f) = ReaderT $ \e -> forkOn n (f e)
forkIOWithUnmask k = ReaderT $ \e -> forkIOWithUnmask $ \restore ->
let restore' :: ReaderT e m a -> ReaderT e m a
restore' (ReaderT f) = ReaderT $ restore . f
Expand Down
5 changes: 5 additions & 0 deletions io-sim/src/Control/Monad/IOSim/Types.hs
Expand Up @@ -369,6 +369,7 @@ instance MonadThread (IOSim s) where

instance MonadFork (IOSim s) where
forkIO task = IOSim $ oneShot $ \k -> Fork task k
forkOn _ task = IOSim $ oneShot $ \k -> Fork task k
forkIOWithUnmask f = forkIO (f unblock)
throwTo tid e = IOSim $ oneShot $ \k -> ThrowTo (toException e) tid (k ())
yield = IOSim $ oneShot $ \k -> YieldSim (k ())
Expand Down Expand Up @@ -470,6 +471,9 @@ instance MonadAsync (IOSim s) where
MonadSTM.labelTMVarIO var ("async-" ++ show tid)
return (Async tid (MonadSTM.readTMVar var))

asyncOn _ = async
asyncBound = async

asyncThreadId (Async tid _) = tid

waitCatchSTM (Async _ w) = w
Expand All @@ -479,6 +483,7 @@ instance MonadAsync (IOSim s) where
cancelWith a@(Async tid _) e = throwTo tid e <* waitCatch a

asyncWithUnmask k = async (k unblock)
asyncOnWithUnmask _ k = async (k unblock)

instance MonadST (IOSim s) where
withLiftST f = f liftST
Expand Down

0 comments on commit 74dc225

Please sign in to comment.