Skip to content

Commit

Permalink
server-test: use DataFlowProtocolData
Browse files Browse the repository at this point in the history
This patch providess UnversionedProtocol which allows to negotiate
'DataFlow'.
  • Loading branch information
bolt12 authored and coot committed Oct 14, 2021
1 parent d2a0f0d commit 22a238a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
Expand Up @@ -8,6 +8,9 @@ module Ouroboros.Network.Protocol.Handshake.Unversioned
, unversionedHandshakeCodec
, unversionedProtocolDataCodec
, unversionedProtocol
, DataFlowProtocolData (..)
, dataFlowProtocolDataCodec
, dataFlowProtocol
) where

import Control.Monad.Class.MonadST
Expand All @@ -22,6 +25,7 @@ import Data.ByteString.Lazy (ByteString)
import Network.TypedProtocol.Codec

import Ouroboros.Network.CodecCBORTerm
import Ouroboros.Network.ConnectionManager.Types (DataFlow (..))
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version
Expand All @@ -37,7 +41,6 @@ data UnversionedProtocol = UnversionedProtocol
data UnversionedProtocolData = UnversionedProtocolData
deriving (Eq, Show)


instance Acceptable UnversionedProtocolData where
acceptableVersion UnversionedProtocolData
UnversionedProtocolData = Accept UnversionedProtocolData
Expand All @@ -55,7 +58,6 @@ unversionedProtocolDataCodec = cborTermVersionDataCodec
decodeTerm CBOR.TNull = Right UnversionedProtocolData
decodeTerm t = Left $ T.pack $ "unexpected term: " ++ show t


-- | Make a 'Versions' for an unversioned protocol. Only use this for
-- tests and demos where proper versioning is excessive.
--
Expand All @@ -66,6 +68,36 @@ unversionedProtocol :: app
unversionedProtocol = simpleSingletonVersions UnversionedProtocol UnversionedProtocolData


-- | Alternative for 'UnversionedProtocolData' which contains 'DataFlow'.
--
newtype DataFlowProtocolData =
DataFlowProtocolData { getProtocolDataFlow :: DataFlow }
deriving (Eq, Show)

instance Acceptable DataFlowProtocolData where
acceptableVersion (DataFlowProtocolData local) (DataFlowProtocolData remote) =
Accept (DataFlowProtocolData $ local `min` remote)

dataFlowProtocolDataCodec :: UnversionedProtocol -> CodecCBORTerm Text DataFlowProtocolData
dataFlowProtocolDataCodec _ = CodecCBORTerm {encodeTerm, decodeTerm}
where
encodeTerm :: DataFlowProtocolData -> CBOR.Term
encodeTerm (DataFlowProtocolData Unidirectional) = CBOR.TBool False
encodeTerm (DataFlowProtocolData Duplex) = CBOR.TBool True

decodeTerm :: CBOR.Term -> Either Text DataFlowProtocolData
decodeTerm (CBOR.TBool False) = Right (DataFlowProtocolData Unidirectional)
decodeTerm (CBOR.TBool True) = Right (DataFlowProtocolData Duplex)
decodeTerm t = Left $ T.pack $ "unexpected term: " ++ show t

dataFlowProtocol :: DataFlow
-> app
-> Versions UnversionedProtocol
DataFlowProtocolData
app
dataFlowProtocol dataFlow =
simpleSingletonVersions UnversionedProtocol (DataFlowProtocolData dataFlow)

-- | 'Handshake' codec used in various tests.
--
unversionedHandshakeCodec :: MonadST m
Expand Down
Expand Up @@ -86,7 +86,8 @@ import qualified Ouroboros.Network.InboundGovernor.ControlChannel as Server
import Ouroboros.Network.Mux
import Ouroboros.Network.MuxMode
import Ouroboros.Network.Protocol.Handshake
import Ouroboros.Network.Protocol.Handshake.Codec ( noTimeLimitsHandshake
import Ouroboros.Network.Protocol.Handshake.Codec ( cborTermVersionDataCodec
, noTimeLimitsHandshake
, timeLimitsHandshake)
import Ouroboros.Network.Protocol.Handshake.Type (Handshake)
import Ouroboros.Network.Protocol.Handshake.Unversioned
Expand All @@ -102,7 +103,7 @@ import Simulation.Network.Snocket

import Ouroboros.Network.Testing.Utils (genDelayWithPrecision)
import Test.Ouroboros.Network.Orphans () -- ShowProxy ReqResp instance
import Test.Simulation.Network.Snocket (NonFailingBearerInfoScript(..), AbsBearerInfo, toBearerInfo)
import Test.Simulation.Network.Snocket hiding (tests)
import Test.Ouroboros.Network.ConnectionManager (verifyAbstractTransition)

tests :: TestTree
Expand Down Expand Up @@ -332,7 +333,7 @@ withInitiatorOnlyConnectionManager name timeouts cmTrTracer snocket localAddr
cmIPv6Address = Nothing,
cmAddressType = \_ -> Just IPv4Address,
cmSnocket = snocket,
connectionDataFlow = const Unidirectional,
connectionDataFlow = getProtocolDataFlow . snd,
cmPrunePolicy = simplePrunePolicy,
cmConnectionsLimits = AcceptedConnectionsLimit {
acceptedConnectionsHardLimit = maxBound,
Expand All @@ -350,11 +351,11 @@ withInitiatorOnlyConnectionManager name timeouts cmTrTracer snocket localAddr
-- TraceSendRecv
haHandshakeTracer = (name,) `contramap` nullTracer,
haHandshakeCodec = unversionedHandshakeCodec,
haVersionDataCodec = unversionedProtocolDataCodec,
haVersionDataCodec = cborTermVersionDataCodec dataFlowProtocolDataCodec,
haAcceptVersion = acceptableVersion,
haTimeLimits = handshakeTimeLimits
}
(unversionedProtocol clientApplication)
(dataFlowProtocol Unidirectional clientApplication)
(mainThreadId, debugMuxErrorRethrowPolicy
<> debugMuxRuntimeErrorRethrowPolicy
<> debugIOErrorRethrowPolicy
Expand Down Expand Up @@ -516,7 +517,7 @@ withBidirectionalConnectionManager name timeouts
cmSnocket = snocket,
cmTimeWaitTimeout = tTimeWaitTimeout timeouts,
cmOutboundIdleTimeout = tOutboundIdleTimeout timeouts,
connectionDataFlow = const Duplex,
connectionDataFlow = getProtocolDataFlow . snd,
cmPrunePolicy = simplePrunePolicy,
cmConnectionsLimits = AcceptedConnectionsLimit {
acceptedConnectionsHardLimit = maxBound,
Expand All @@ -532,11 +533,11 @@ withBidirectionalConnectionManager name timeouts
-- TraceSendRecv
haHandshakeTracer = WithName name `contramap` nullTracer,
haHandshakeCodec = unversionedHandshakeCodec,
haVersionDataCodec = unversionedProtocolDataCodec,
haVersionDataCodec = cborTermVersionDataCodec dataFlowProtocolDataCodec,
haAcceptVersion = acceptableVersion,
haTimeLimits = handshakeTimeLimits
}
(unversionedProtocol serverApplication)
(dataFlowProtocol Duplex serverApplication)
(mainThreadId, debugMuxErrorRethrowPolicy
<> debugMuxRuntimeErrorRethrowPolicy
<> debugIOErrorRethrowPolicy
Expand Down
Expand Up @@ -16,7 +16,11 @@ module Test.Simulation.Network.Snocket
( tests
, BearerInfoScript(..)
, NonFailingBearerInfoScript(..)
, AbsBearerInfo
, AbsDelay (..)
, AbsSpeed (..)
, AbsSDUSize (..)
, AbsAttenuation (..)
, AbsBearerInfo (..)
, toBearerInfo
) where

Expand Down

0 comments on commit 22a238a

Please sign in to comment.