Skip to content

Commit

Permalink
server-test: use DataFlowProtocolData
Browse files Browse the repository at this point in the history
This patch providess UnversionedProtocol which allows to negotiate
'DataFlow'.
  • Loading branch information
bolt12 committed Sep 27, 2021
1 parent 73e3f91 commit 4908810
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ module Ouroboros.Network.Protocol.Handshake.Unversioned
, unversionedHandshakeCodec
, unversionedProtocolDataCodec
, unversionedProtocol
, DataFlowProtocolData (..)
, dataFlowProtocolDataCodec
, dataFlowProtocol
) where

import Control.Monad.Class.MonadST
Expand All @@ -22,6 +25,7 @@ import Data.ByteString.Lazy (ByteString)
import Network.TypedProtocol.Codec

import Ouroboros.Network.CodecCBORTerm
import Ouroboros.Network.ConnectionManager.Types (DataFlow (..))
import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version
Expand All @@ -37,7 +41,6 @@ data UnversionedProtocol = UnversionedProtocol
data UnversionedProtocolData = UnversionedProtocolData
deriving (Eq, Show)


instance Acceptable UnversionedProtocolData where
acceptableVersion UnversionedProtocolData
UnversionedProtocolData = Accept UnversionedProtocolData
Expand All @@ -55,7 +58,6 @@ unversionedProtocolDataCodec = cborTermVersionDataCodec
decodeTerm CBOR.TNull = Right UnversionedProtocolData
decodeTerm t = Left $ T.pack $ "unexpected term: " ++ show t


-- | Make a 'Versions' for an unversioned protocol. Only use this for
-- tests and demos where proper versioning is excessive.
--
Expand All @@ -66,6 +68,36 @@ unversionedProtocol :: app
unversionedProtocol = simpleSingletonVersions UnversionedProtocol UnversionedProtocolData


-- | Alternative for 'UnversionedProtocolData' which contains 'DataFlow'.
--
newtype DataFlowProtocolData =
DataFlowProtocolData { getProtocolDataFlow :: DataFlow }
deriving (Eq, Show)

instance Acceptable DataFlowProtocolData where
acceptableVersion (DataFlowProtocolData local) (DataFlowProtocolData remote) =
Accept (DataFlowProtocolData $ local `min` remote)

dataFlowProtocolDataCodec :: UnversionedProtocol -> CodecCBORTerm Text DataFlowProtocolData
dataFlowProtocolDataCodec _ = CodecCBORTerm {encodeTerm, decodeTerm}
where
encodeTerm :: DataFlowProtocolData -> CBOR.Term
encodeTerm (DataFlowProtocolData Unidirectional) = CBOR.TBool False
encodeTerm (DataFlowProtocolData Duplex) = CBOR.TBool True

decodeTerm :: CBOR.Term -> Either Text DataFlowProtocolData
decodeTerm (CBOR.TBool False) = Right (DataFlowProtocolData Unidirectional)
decodeTerm (CBOR.TBool True) = Right (DataFlowProtocolData Duplex)
decodeTerm t = Left $ T.pack $ "unexpected term: " ++ show t

dataFlowProtocol :: DataFlow
-> app
-> Versions UnversionedProtocol
DataFlowProtocolData
app
dataFlowProtocol dataFlow =
simpleSingletonVersions UnversionedProtocol (DataFlowProtocolData dataFlow)

-- | 'Handshake' codec used in various tests.
--
unversionedHandshakeCodec :: MonadST m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ import Ouroboros.Network.Testing.Utils (genDelayWithPrecision)
import Simulation.Network.Snocket

import Test.Ouroboros.Network.Orphans () -- ShowProxy ReqResp instance
import Test.Simulation.Network.Snocket (NonFailingBearerInfoScript(..), AbsBearerInfo, toBearerInfo)
import Test.Simulation.Network.Snocket hiding (tests)
import Test.Ouroboros.Network.ConnectionManager (verifyAbstractTransition)

tests :: TestTree
Expand Down Expand Up @@ -332,7 +332,7 @@ withInitiatorOnlyConnectionManager name timeouts cmTrTracer snocket localAddr
cmIPv6Address = Nothing,
cmAddressType = \_ -> Just IPv4Address,
cmSnocket = snocket,
connectionDataFlow = const Unidirectional,
connectionDataFlow = getProtocolDataFlow . snd,
cmPrunePolicy = simplePrunePolicy,
cmConnectionsLimits = AcceptedConnectionsLimit {
acceptedConnectionsHardLimit = maxBound,
Expand All @@ -350,7 +350,7 @@ withInitiatorOnlyConnectionManager name timeouts cmTrTracer snocket localAddr
-- TraceSendRecv
haHandshakeTracer = (name,) `contramap` nullTracer,
haHandshakeCodec = unversionedHandshakeCodec,
haVersionDataCodec = unversionedProtocolDataCodec,
haVersionDataCodec = cborTermVersionDataCodec dataFlowProtocolDataCodec,
haAcceptVersion = acceptableVersion,
haTimeLimits = handshakeTimeLimits
}
Expand Down Expand Up @@ -515,7 +515,7 @@ withBidirectionalConnectionManager name timeouts
cmSnocket = snocket,
cmTimeWaitTimeout = tTimeWaitTimeout timeouts,
cmOutboundIdleTimeout = tOutboundIdleTimeout timeouts,
connectionDataFlow = const Duplex,
connectionDataFlow = getProtocolDataFlow . snd,
cmPrunePolicy = simplePrunePolicy,
cmConnectionsLimits = AcceptedConnectionsLimit {
acceptedConnectionsHardLimit = maxBound,
Expand All @@ -531,7 +531,7 @@ withBidirectionalConnectionManager name timeouts
-- TraceSendRecv
haHandshakeTracer = WithName name `contramap` nullTracer,
haHandshakeCodec = unversionedHandshakeCodec,
haVersionDataCodec = unversionedProtocolDataCodec,
haVersionDataCodec = cborTermVersionDataCodec dataFlowProtocolDataCodec,
haAcceptVersion = acceptableVersion,
haTimeLimits = handshakeTimeLimits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ module Test.Simulation.Network.Snocket
( tests
, BearerInfoScript(..)
, NonFailingBearerInfoScript(..)
, AbsBearerInfo
, AbsDelay (..)
, AbsSpeed (..)
, AbsSDUSize (..)
, AbsAttenuation (..)
, AbsBearerInfo (..)
, toBearerInfo
) where

Expand Down

0 comments on commit 4908810

Please sign in to comment.