Skip to content

Commit

Permalink
Add runConnectedPeers driver
Browse files Browse the repository at this point in the history
And introduce a 'concurrently' utility like from async to make writing
this easier. This will be replaced with the proper version once we have
MonadAsync.

Also use it in runPipelinedPeer.
  • Loading branch information
dcoutts committed Feb 17, 2019
1 parent 946f3bd commit 5a51b32
Showing 1 changed file with 89 additions and 20 deletions.
109 changes: 89 additions & 20 deletions typed-protocols/src/Network/TypedProtocol/Driver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ module Network.TypedProtocol.Driver (
-- * Pipelined peers
runPipelinedPeer,

-- * Connected peers
runConnectedPeers,

-- * Driver utilities
-- | This may be useful if you want to write your own driver.
runDecoderWithChannel,
Expand All @@ -33,6 +36,7 @@ import Network.TypedProtocol.Channel
import Network.TypedProtocol.Codec

import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadThrow

import Numeric.Natural (Natural)

Expand Down Expand Up @@ -63,6 +67,10 @@ import Numeric.Natural (Natural)
--


--
-- Driver for normal peers
--

-- | Run a peer with the given channel via the given codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
Expand Down Expand Up @@ -96,22 +104,9 @@ runPeer Codec{encode, decode} channel@Channel{send} = go Nothing
Left failure -> return (Left failure)


-- | Run a codec incremental decoder 'DecodeStep' against a channel. It also
-- takes any extra input data and returns any unused trailing data.
--
runDecoderWithChannel :: Monad m
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m a
-> m (Either failure (a, Maybe bytes))

runDecoderWithChannel Channel{recv} = go
where
go _ (DecodeDone x trailing) = return (Right (x, trailing))
go _ (DecodeFail failure) = return (Left failure)
go Nothing (DecodePartial k) = recv >>= k >>= go Nothing
go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing

-- Driver for pipelined peers
--

-- | Run a pipelined peer with the given channel via the given codec.
--
Expand All @@ -122,7 +117,7 @@ runDecoderWithChannel Channel{recv} = go
--
runPipelinedPeer
:: forall ps (st :: ps) pr failure bytes m a.
MonadSTM m
(MonadSTM m, MonadCatch m)
=> Natural
-> Codec ps failure m bytes
-> Channel m bytes
Expand All @@ -131,10 +126,12 @@ runPipelinedPeer
runPipelinedPeer maxOutstanding codec channel (PeerPipelined peer) = do
receiveQueue <- atomically $ newTBQueue maxOutstanding
collectQueue <- atomically $ newTBQueue maxOutstanding
fork $ runPipelinedPeerReceiverQueue receiveQueue collectQueue
codec channel
runPipelinedPeerSender receiveQueue collectQueue
codec channel peer
((), x) <- runPipelinedPeerReceiverQueue receiveQueue collectQueue
codec channel
`concurrently`
runPipelinedPeerSender receiveQueue collectQueue
codec channel peer
return x
--TODO: manage the fork + exceptions here


Expand Down Expand Up @@ -223,3 +220,75 @@ runPipelinedPeerReceiver Codec{decode} channel = go
Right (SomeMessage msg, trailing') -> go trailing' (k msg)
Left _failure -> error "TODO: proper exceptions for runPipelinedPeer"


--
-- Utils
--

-- | Run a codec incremental decoder 'DecodeStep' against a channel. It also
-- takes any extra input data and returns any unused trailing data.
--
runDecoderWithChannel :: Monad m
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m a
-> m (Either failure (a, Maybe bytes))

runDecoderWithChannel Channel{recv} = go
where
go _ (DecodeDone x trailing) = return (Right (x, trailing))
go _ (DecodeFail failure) = return (Left failure)
go Nothing (DecodePartial k) = recv >>= k >>= go Nothing
go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing


-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.
--
-- This is useful for tests and quick experiments.
--
-- The first argument is expected to create two channels that are connected,
-- for example 'createConnectedChannels'.
--
runConnectedPeers :: (MonadSTM m, MonadCatch m)
=> m (Channel m bytes, Channel m bytes)
-> Codec ps failure m bytes
-> Peer ps AsClient st m a
-> Peer ps AsServer st m b
-> m (Either failure (a, b))
runConnectedPeers createChannels codec client server = do
(clientChannel, serverChannel) <- createChannels
results <- runPeer codec clientChannel client
`concurrently`
runPeer codec serverChannel server

case results of
(Left err, _) -> return (Left err)
(_, Left err) -> return (Left err)
(Right x, Right y) -> return (Right (x,y))


-- TODO: replace with version from MonadAsync when available
concurrently :: (MonadSTM m, MonadCatch m)
=> m a -> m b -> m (a, b)
concurrently actionA actionB = do

resAVar <- newEmptyTMVarM
resBVar <- newEmptyTMVarM

fork $ try actionA >>= \x -> atomically (putTMVar resAVar x)
fork $ try actionB >>= \x -> atomically (putTMVar resBVar x)

res <- atomically $ do
mresA <- tryReadTMVar resAVar
mresB <- tryReadTMVar resBVar
case (mresA, mresB) of
(Nothing, Nothing) -> retry
(Just (Left e), _) -> return (Left (e :: SomeException))
(_, Just (Left e)) -> return (Left (e :: SomeException))

(Just (Right _), Nothing) -> retry
(Nothing ,Just (Right _)) -> retry
(Just (Right a), Just (Right b)) -> return (Right (a, b))

either throwM return res

0 comments on commit 5a51b32

Please sign in to comment.