Skip to content

Commit

Permalink
typed-protocols: provide ProtocolState and PeerHasAgency
Browse files Browse the repository at this point in the history
'SingProtocolState' is a 'ProtocolState' singleton. The type of
'ProtocolState' kind provides information about current state and its
agency.

'SingPeerHasAgency' is a 'PeerHasAgency' singleton.  The kinds
'PeerHasAgency' and 'ProtocolState' are isomorphic.  Like
'SingProtocolState' it gives access to the protocol state and agency,
but it allows to match only for 'ClientAgency' or 'ServerAgency', while
'SingProtocolState' allows to also match for 'NobodyAgency'.  Limiting
the possible pattern matches to only active agencies, is, for example,
useful in a 'Codec', where we know that the protocol must be in a state
whose agency is either on the client or server side.

This patch also modifies 'Yield', 'YieldPipelined' and 'Message'
constructors constraints.   This prepares for future changes where we
will need to deduce that a sent message changes the agency.
  • Loading branch information
coot committed Sep 26, 2021
1 parent fe9b32a commit 2032b29
Show file tree
Hide file tree
Showing 16 changed files with 307 additions and 99 deletions.
49 changes: 31 additions & 18 deletions ouroboros-network-framework/test/Test/Ouroboros/Network/Driver.hs
Expand Up @@ -12,8 +12,12 @@

module Test.Ouroboros.Network.Driver (tests) where

import Data.List (intercalate)

import Network.TypedProtocol.Core
import Network.TypedProtocol.Codec
import Network.TypedProtocol.Peer.Client (Client)
import Network.TypedProtocol.Peer.Server (Server)

import Ouroboros.Network.Channel
import Ouroboros.Network.Driver
Expand All @@ -32,7 +36,9 @@ import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import Control.Monad.Class.MonadThrow
import Control.Monad.IOSim (runSimOrThrow)
import Control.Monad.Class.MonadSay
import Control.Monad.IOSim
import Control.Exception (throw)
import Control.Tracer

import Test.Ouroboros.Network.Orphans ()
Expand Down Expand Up @@ -63,10 +69,10 @@ byteLimitsReqResp
-> ProtocolSizeLimits (ReqResp req resp) String
byteLimitsReqResp limit = ProtocolSizeLimits stateToLimit (fromIntegral . length)
where
stateToLimit :: forall (pr :: PeerRole) (st :: ReqResp req resp).
PeerHasAgency pr st -> Word
stateToLimit (ClientAgency TokIdle) = limit
stateToLimit (ServerAgency TokBusy) = limit
stateToLimit :: forall (st :: ReqResp req resp).
SingPeerHasAgency st -> Word
stateToLimit (SingClientHasAgency SingIdle) = limit
stateToLimit (SingServerHasAgency SingBusy) = limit


serverTimeout :: DiffTime
Expand All @@ -76,19 +82,19 @@ serverTimeout = 0.2 -- 200 ms
timeLimitsReqResp :: forall req resp. ProtocolTimeLimits (ReqResp req resp)
timeLimitsReqResp = ProtocolTimeLimits stateToLimit
where
stateToLimit :: forall (pr :: PeerRole) (st :: ReqResp req resp).
PeerHasAgency pr st -> Maybe DiffTime
stateToLimit (ClientAgency TokIdle) = Just serverTimeout
stateToLimit (ServerAgency TokBusy) = Just serverTimeout
stateToLimit :: forall (st :: ReqResp req resp).
SingPeerHasAgency st -> Maybe DiffTime
stateToLimit (SingClientHasAgency SingIdle) = Just serverTimeout
stateToLimit (SingServerHasAgency SingBusy) = Just serverTimeout

-- Unlimited Time
timeUnLimitsReqResp :: forall req resp. ProtocolTimeLimits (ReqResp req resp)
timeUnLimitsReqResp = ProtocolTimeLimits stateToLimit
where
stateToLimit :: forall (pr :: PeerRole) (st :: ReqResp req resp).
PeerHasAgency pr st -> Maybe DiffTime
stateToLimit (ClientAgency TokIdle) = Nothing
stateToLimit (ServerAgency TokBusy) = Nothing
stateToLimit :: forall (st :: ReqResp req resp).
SingPeerHasAgency st -> Maybe DiffTime
stateToLimit (SingClientHasAgency SingIdle) = Nothing
stateToLimit (SingServerHasAgency SingBusy) = Nothing


--
Expand Down Expand Up @@ -139,10 +145,10 @@ prop_runPeerWithLimits tracer limit reqPayloads = do
Nothing -> False

where
sendPeer :: Peer (ReqResp String ()) AsClient StIdle m [()]
sendPeer :: Client (ReqResp String ()) NonPipelined Empty StIdle m [()]
sendPeer = reqRespClientPeer $ reqRespClientMap $ map fst reqPayloads

recvPeer :: Peer (ReqResp String ()) AsServer StIdle m [DiffTime]
recvPeer :: Server (ReqResp String ()) NonPipelined Empty StIdle m [DiffTime]
recvPeer = reqRespServerPeer $ reqRespServerMapAccumL
(\a _ -> case a of
[] -> error "prop_runPeerWithLimits: empty list"
Expand All @@ -156,12 +162,14 @@ prop_runPeerWithLimits tracer limit reqPayloads = do
shouldFail :: [(String, DiffTime)] -> Maybe ShouldFail
shouldFail [] =
-- Check @MsgDone@ which is always sent
let msgDone = encode (codecReqResp @String @() @m) (ClientAgency TokIdle) MsgDone in
let msgDone = encode (codecReqResp @String @() @m)
MsgDone in
if length msgDone > fromIntegral limit
then Just ShouldExceededSizeLimit
else Nothing
shouldFail ((msg, delay):cmds) =
let msg' = encode (codecReqResp @String @() @m) (ClientAgency TokIdle) (MsgReq msg) in
let msg' = encode (codecReqResp @String @() @m)
(MsgReq msg) in
if length msg' > fromIntegral limit
then Just ShouldExceededSizeLimit
else if delay >= serverTimeout
Expand Down Expand Up @@ -213,7 +221,12 @@ prop_runPeerWithLimits_ST
-> Property
prop_runPeerWithLimits_ST (ReqRespPayloadWithLimit limit payload) =
tabulate "Limit Boundaries" (labelExamples limit payload) $
runSimOrThrow (prop_runPeerWithLimits nullTracer limit [payload])
let trace = runSimTrace (prop_runPeerWithLimits (Tracer (say . show)) limit [payload])
in counterexample (intercalate "\n" $ map show $ traceEvents trace)
$ case traceResult True trace of
Left e -> throw e
Right x -> x

where
labelExamples :: Word -> (String, DiffTime) -> [String]
labelExamples l (p,_) =
Expand Down
10 changes: 6 additions & 4 deletions typed-protocols-cborg/src/Network/TypedProtocol/Codec/CBOR.hs
@@ -1,4 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -46,11 +48,11 @@ mkCodecCborStrictBS
:: forall ps m. MonadST m

=> (forall (st :: ps) (st' :: ps).
SingI st
SingI (PeerHasAgency st)
=> Message ps st st' -> CBOR.Encoding)

-> (forall (st :: ps) s.
SingI st
SingI (PeerHasAgency st)
=> CBOR.Decoder s (SomeMessage st))

-> Codec ps DeserialiseFailure m BS.ByteString
Expand Down Expand Up @@ -100,11 +102,11 @@ mkCodecCborLazyBS
:: forall ps m. MonadST m

=> (forall (st :: ps) (st' :: ps).
SingI st
SingI (PeerHasAgency st)
=> Message ps st st' -> CBOR.Encoding)

-> (forall (st :: ps) s.
SingI st
SingI (PeerHasAgency st)
=> CBOR.Decoder s (SomeMessage st))

-> Codec ps CBOR.DeserialiseFailure m LBS.ByteString
Expand Down
Expand Up @@ -91,7 +91,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
Driver { sendMessage, recvMessage, tryRecvMessage, startDState = Nothing }
where
sendMessage :: forall (st :: ps) (st' :: ps).
SingI st
SingI (PeerHasAgency st)
=> (ReflRelativeAgency (StateAgency st)
WeHaveAgency
(Relative pr (StateAgency st)))
Expand All @@ -102,7 +102,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
traceWith tracer (TraceSendMsg (AnyMessage msg))

recvMessage :: forall (st :: ps).
SingI st
SingI (PeerHasAgency st)
=> (ReflRelativeAgency (StateAgency st)
TheyHaveAgency
(Relative pr (StateAgency st)))
Expand All @@ -123,7 +123,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
throwIO failure

tryRecvMessage :: forall (st :: ps).
SingI st
SingI (PeerHasAgency st)
=> (ReflRelativeAgency (StateAgency st)
TheyHaveAgency
(Relative pr (StateAgency st)))
Expand Down
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}
Expand All @@ -8,6 +9,7 @@ module Network.TypedProtocol.PingPong.Codec where

import Data.Singletons

import Network.TypedProtocol.Core
import Network.TypedProtocol.Codec
import Network.TypedProtocol.PingPong.Type

Expand All @@ -26,14 +28,17 @@ codecPingPong =
encode MsgPong = "pong\n"

decode :: forall (st :: PingPong).
SingI st
SingI (PeerHasAgency st)
=> m (DecodeStep String CodecFailure m (SomeMessage st))
decode =
decodeTerminatedFrame '\n' $ \str trailing ->
case (sing :: Sing st, str) of
(SingBusy, "pong") -> DecodeDone (SomeMessage MsgPong) trailing
(SingIdle, "ping") -> DecodeDone (SomeMessage MsgPing) trailing
(SingIdle, "done") -> DecodeDone (SomeMessage MsgDone) trailing
case (sing :: Sing (PeerHasAgency st), str) of
(SingServerHasAgency SingBusy, "pong") ->
DecodeDone (SomeMessage MsgPong) trailing
(SingClientHasAgency SingIdle, "ping") ->
DecodeDone (SomeMessage MsgPing) trailing
(SingClientHasAgency SingIdle, "done") ->
DecodeDone (SomeMessage MsgDone) trailing

(_ , _ ) -> DecodeFail failure
where failure = CodecFailure ("unexpected server message: " ++ str)
Expand Down
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -34,18 +35,15 @@ codecPingPong = mkCodecCborLazyBS encodeMsg decodeMsg
encodeMsg MsgDone = CBOR.encodeWord 2

decodeMsg :: forall s (st :: PingPong).
SingI st
SingI (PeerHasAgency st)
=> CBOR.Decoder s (SomeMessage st)
decodeMsg = do
key <- CBOR.decodeWord
case (sing :: Sing st, key) of
(SingIdle, 0) -> return $ SomeMessage MsgPing
(SingBusy, 1) -> return $ SomeMessage MsgPong
(SingIdle, 2) -> return $ SomeMessage MsgDone
case (sing :: Sing (PeerHasAgency st), key) of
(SingClientHasAgency SingIdle, 0) -> return $ SomeMessage MsgPing
(SingServerHasAgency SingBusy, 1) -> return $ SomeMessage MsgPong
(SingClientHasAgency SingIdle, 2) -> return $ SomeMessage MsgDone

-- TODO proper exceptions
(SingIdle, _) -> fail "codecPingPong.StIdle: unexpected key"
(SingBusy, _) -> fail "codecPingPong.StBusy: unexpected key"
(SingDone, _) -> fail "codecPingPong.StDone: unexpected key"


(SingClientHasAgency SingIdle, _) -> fail "codecPingPong.StIdle: unexpected key"
(SingServerHasAgency SingBusy, _) -> fail "codecPingPong.StBusy: unexpected key"
Expand Up @@ -10,6 +10,7 @@ module Network.TypedProtocol.ReqResp.Codec where

import Data.Singletons

import Network.TypedProtocol.Core
import Network.TypedProtocol.Codec
import Network.TypedProtocol.ReqResp.Type
import Network.TypedProtocol.PingPong.Codec (decodeTerminatedFrame)
Expand All @@ -33,17 +34,17 @@ codecReqResp =

decode :: forall req' resp' m'
(st :: ReqResp req' resp')
. (Monad m', SingI st, Read req', Read resp')
. (Monad m', SingI (PeerHasAgency st), Read req', Read resp')
=> m' (DecodeStep String CodecFailure m' (SomeMessage st))
decode =
decodeTerminatedFrame '\n' $ \str trailing ->
case (sing :: Sing st, break (==' ') str) of
(SingIdle, ("MsgReq", str'))
case (sing :: Sing (PeerHasAgency st), break (==' ') str) of
(SingClientHasAgency SingIdle, ("MsgReq", str'))
| Just resp <- readMaybe str'
-> DecodeDone (SomeMessage (MsgReq resp)) trailing
(SingIdle, ("MsgDone", ""))
(SingClientHasAgency SingIdle, ("MsgDone", ""))
-> DecodeDone (SomeMessage MsgDone) trailing
(SingBusy, ("MsgResp", str'))
(SingServerHasAgency SingBusy, ("MsgResp", str'))
| Just resp <- readMaybe str'
-> DecodeDone (SomeMessage (MsgResp resp)) trailing

Expand Down
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -42,18 +43,17 @@ codecReqResp = mkCodecCborLazyBS encodeMsg decodeMsg
CBOR.encodeListLen 1 <> CBOR.encodeWord 2

decodeMsg :: forall s (st :: ReqResp req resp).
SingI st
SingI (PeerHasAgency st)
=> CBOR.Decoder s (SomeMessage st)
decodeMsg = do
_ <- CBOR.decodeListLen
key <- CBOR.decodeWord
case (sing :: Sing st, key) of
(SingIdle, 0) -> SomeMessage . MsgReq <$> CBOR.decode
(SingBusy, 1) -> SomeMessage . MsgResp <$> CBOR.decode
(SingIdle, 2) -> return $ SomeMessage MsgDone
case (sing :: Sing (PeerHasAgency st), key) of
(SingClientHasAgency SingIdle, 0) -> SomeMessage . MsgReq <$> CBOR.decode
(SingServerHasAgency SingBusy, 1) -> SomeMessage . MsgResp <$> CBOR.decode
(SingClientHasAgency SingIdle, 2) -> return $ SomeMessage MsgDone

-- TODO proper exceptions
(SingIdle, _) -> fail "codecReqResp.StIdle: unexpected key"
(SingBusy, _) -> fail "codecReqResp.StBusy: unexpected key"
(SingDone, _) -> fail "codecReqResp.StBusy: unexpected key"
(SingClientHasAgency SingIdle, _) -> fail "codecReqResp.StIdle: unexpected key"
(SingServerHasAgency SingBusy, _) -> fail "codecReqResp.StBusy: unexpected key"

Expand Up @@ -189,8 +189,10 @@ prop_connect (NonNegative n) =
(pingPongClientPeer (pingPongClientCount n))
(pingPongServerPeer pingPongServerCount))

of ((), n', TerminalStates SingDone ReflNobodyAgency
SingDone ReflNobodyAgency) ->
of ((), n', TerminalStates (SingProtocolState SingDone)
ReflNobodyAgency
(SingProtocolState SingDone)
ReflNobodyAgency) ->
n == n'


Expand All @@ -211,8 +213,10 @@ connect_pipelined client cs =
(pingPongClientPeerPipelined client)
(pingPongServerPeer pingPongServerCount))

of (reqResps, n, TerminalStates SingDone ReflNobodyAgency
SingDone ReflNobodyAgency) ->
of (reqResps, n, TerminalStates (SingProtocolState SingDone)
ReflNobodyAgency
(SingProtocolState SingDone)
ReflNobodyAgency) ->
(n, reqResps)


Expand Down
Expand Up @@ -169,8 +169,10 @@ prop_connect f xs =
(reqRespClientPeer (reqRespClientMap xs))
(reqRespServerPeer (reqRespServerMapAccumL (\a -> pure . f a) 0)))

of (c, s, TerminalStates SingDone ReflNobodyAgency
SingDone ReflNobodyAgency) ->
of (c, s, TerminalStates (SingProtocolState SingDone)
ReflNobodyAgency
(SingProtocolState SingDone)
ReflNobodyAgency) ->
(s, c) == mapAccumL f 0 xs


Expand All @@ -182,8 +184,10 @@ prop_connectPipelined cs f xs =
(reqRespServerPeer (reqRespServerMapAccumL
(\a -> pure . f a) 0)))

of (c, s, TerminalStates SingDone ReflNobodyAgency
SingDone ReflNobodyAgency) ->
of (c, s, TerminalStates (SingProtocolState SingDone)
ReflNobodyAgency
(SingProtocolState SingDone)
ReflNobodyAgency) ->
(s, c) == mapAccumL f 0 xs


Expand Down

0 comments on commit 2032b29

Please sign in to comment.