Skip to content

Commit

Permalink
Merge #2691
Browse files Browse the repository at this point in the history
2691: Handshake Improvements. r=karknu a=karknu

Change acceptableVersion so that it also returns the "acceptable" data
agreed upon by the client and server.
Update the client so that it checks that response is acceptable.
Change runApplication to take the result of the version negotiation
instead of a local and remote version of protocol options.

Co-authored-by: Karl Knutsson <karl.knutsson@iohk.io>
  • Loading branch information
iohk-bors[bot] and karknu committed Oct 20, 2020
2 parents 495e023 + db6d28c commit 9ec4dfe
Show file tree
Hide file tree
Showing 18 changed files with 93 additions and 86 deletions.
2 changes: 2 additions & 0 deletions ouroboros-network-framework/demo/ping-pong.hs
Expand Up @@ -111,6 +111,7 @@ clientPingPong pipelined =
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol app)
Nothing
defaultLocalSocketAddr
Expand Down Expand Up @@ -206,6 +207,7 @@ clientPingPong2 =
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol app)
Nothing
defaultLocalSocketAddr
Expand Down
Expand Up @@ -108,11 +108,13 @@ runHandshakeClient
)
=> MuxBearer m
-> connectionId
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber extra m application agreedOptions
-> m (Either (HandshakeException (HandshakeClientProtocolError vNumber))
(application, agreedOptions))
runHandshakeClient bearer
connectionId
acceptVersion
HandshakeArguments {
haHandshakeTracer,
haHandshakeCodec,
Expand All @@ -127,7 +129,7 @@ runHandshakeClient bearer
byteLimitsHandshake
timeLimitsHandshake
(fromChannel (muxBearerAsChannel bearer handshakeProtocolNum InitiatorDir))
(handshakeClientPeer haVersionDataCodec haVersions))
(handshakeClientPeer haVersionDataCodec acceptVersion haVersions))


-- | Run server side of the 'Handshake' protocol.
Expand All @@ -143,7 +145,7 @@ runHandshakeServer
)
=> MuxBearer m
-> connectionId
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber extra m application agreedOptions
-> m (Either
(HandshakeException (RefuseReason vNumber))
Expand Down
Expand Up @@ -28,13 +28,14 @@ import Ouroboros.Network.Protocol.Handshake.Version
handshakeClientPeer
:: Ord vNumber
=> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra r
-> Peer (Handshake vNumber CBOR.Term)
AsClient StPropose m
(Either
(HandshakeClientProtocolError vNumber)
(r, agreedOptions))
handshakeClientPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions} versions =
handshakeClientPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions} acceptVersion versions =
-- send known versions
Yield (ClientAgency TokPropose) (MsgProposeVersions $ encodeVersions encodeData versions) $

Expand All @@ -56,11 +57,13 @@ handshakeClientPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions}
Done TokDone (Left (HandshakeError $ HandshakeDecodeError vNumber err))

Right vData' ->
-- TODO: we should check that we agree on received @vData'@,
-- this might be less trivial than testing for equality.
Done TokDone $ Right $ ( runApplication (versionApplication version) vData vData'
, getAgreedOptions (versionExtra version) vNumber vData'
)
case acceptVersion (versionExtra version) vData vData' of
Accept agreedData ->
Done TokDone $ Right $ ( runApplication (versionApplication version) agreedData
, getAgreedOptions (versionExtra version) vNumber agreedData
)
Refuse err ->
Done TokDone (Left (InvalidServerSelection vNumber err))

encodeVersions
:: forall vNumber extra r vParams.
Expand Down
Expand Up @@ -26,7 +26,7 @@ import Ouroboros.Network.Protocol.Handshake.Version
handshakeServerPeer
:: Ord vNumber
=> VersionDataCodec extra vParams vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra r
-> Peer (Handshake vNumber vParams)
AsServer StPropose m
Expand Down Expand Up @@ -61,12 +61,12 @@ handshakeServerPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions}
-- We agree on the version; send back the agreed version
-- number @vNumber@ and encoded data associated with our
-- version.
Accept ->
Accept agreedData ->
Yield (ServerAgency TokConfirm)
(MsgAcceptVersion vNumber (encodeData (versionExtra version) vData))
(MsgAcceptVersion vNumber (encodeData (versionExtra version) agreedData))
(Done TokDone $ Right $
( runApplication (versionApplication version) vData vData'
, getAgreedOptions (versionExtra version) vNumber vData'
( runApplication (versionApplication version) agreedData
, getAgreedOptions (versionExtra version) vNumber agreedData
))

-- We disagree on the version.
Expand Down
Expand Up @@ -127,6 +127,7 @@ instance Show (ServerHasAgency (st :: Handshake vNumber vParams)) where
data HandshakeClientProtocolError vNumber
= HandshakeError (RefuseReason vNumber)
| NotRecognisedVersion vNumber
| InvalidServerSelection vNumber Text
deriving (Eq, Show)

instance (Typeable vNumber, Show vNumber)
Expand Down
Expand Up @@ -40,7 +40,7 @@ data UnversionedProtocolData = UnversionedProtocolData

instance Acceptable UnversionedProtocolData where
acceptableVersion UnversionedProtocolData
UnversionedProtocolData = Accept
UnversionedProtocolData = Accept UnversionedProtocolData


unversionedProtocolDataCodec :: CodecCBORTerm Text UnversionedProtocolData
Expand Down
Expand Up @@ -16,7 +16,6 @@ module Ouroboros.Network.Protocol.Handshake.Version
, Acceptable (..)
, Dict (..)
, DictVersion (..)
, pickVersions
, VersionMismatch (..)

-- * Simple or no versioning
Expand All @@ -29,7 +28,7 @@ import Data.Foldable (toList)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Text (Text)
import Data.Typeable ((:~:) (Refl), Typeable, eqT)
import Data.Typeable (Typeable)
import GHC.Stack (HasCallStack)

import Ouroboros.Network.CodecCBORTerm
Expand Down Expand Up @@ -68,7 +67,7 @@ newtype Versions vNum extra r = Versions
instance Functor (Versions vNum extra) where
fmap f (Versions vs) = Versions $ Map.map fmapSigma vs
where
fmapSigma (Sigma t (Version (Application app) extra)) = Sigma t (Version (Application $ \x y -> f (app x y)) extra)
fmapSigma (Sigma t (Version (Application app) extra)) = Sigma t (Version (Application $ \x -> f (app x)) extra)

data Sigma f where
Sigma :: !t -> !(f t) -> Sigma f
Expand Down Expand Up @@ -96,17 +95,17 @@ combineVersions = foldMapVersions id
-- |
-- A @'Maybe'@ like type which better explains its purpose.
--
data Accept
= Accept
data Accept vData
= Accept vData
| Refuse !Text
deriving (Eq, Show)

class Acceptable v where
acceptableVersion :: v -> v -> Accept
acceptableVersion :: v -> v -> Accept v

-- | Takes a pair of version data: local then remote.
newtype Application r vData = Application
{ runApplication :: vData -> vData -> r
{ runApplication :: vData -> r
}

data Version extra r vData = Version
Expand Down Expand Up @@ -135,40 +134,6 @@ data DictVersion vNumber agreedOptions vData where
-- ^ agreed vData
-> DictVersion vNumber agreedOptions vData

-- | Pick the version with the highest version number (by `Ord vNum`) common
-- in both maps.
--
-- This is a useful guide for comparison with a version negotiation scheme for
-- use in production between different processes. If the `Versions` maps
-- used by each process are given to `pickVersions`, it should come up with
-- the same result as the production version negotiation.
--
-- It is _assumed_ that if the maps agree on a key, then the existential
-- types in the `Sigma` value at the key are also equal.
--
-- So, the issue here is that they may not have the same version data type.
-- This becomes a non-issue on the network because the decoder/encoder
-- basically fills the role of a safe dynamic type cast.
pickVersions
:: ( Ord vNum )
=> (forall vData . extra vData -> Dict Typeable vData)
-> Versions vNum extra r
-> Versions vNum extra r
-> Either (VersionMismatch vNum) (r, r)
pickVersions isTypeable lversions rversions = case Map.toDescList commonVersions of
[] -> Left NoCommonVersion
(vNum, (Sigma (ldata :: ldata) lversion, Sigma (rdata :: rdata) rversion)) : _ ->
case (isTypeable (versionExtra lversion), isTypeable (versionExtra rversion)) of
(Dict, Dict) -> case eqT :: Maybe (ldata :~: rdata) of
Nothing -> Left $ InconsistentVersion vNum
Just Refl ->
let lapp = versionApplication lversion
rapp = versionApplication rversion
in Right (runApplication lapp ldata rdata, runApplication rapp rdata rdata)
where
commonVersions = getVersions lversions `intersect` getVersions rversions
intersect = Map.intersectionWith (,)

--
-- Simple version negotation
--
Expand All @@ -184,4 +149,4 @@ simpleSingletonVersions
simpleSingletonVersions vNum vData extra r =
Versions
$ Map.singleton vNum
(Sigma vData (Version (Application $ \_ _ -> r) extra))
(Sigma vData (Version (Application $ \_ -> r) extra))
21 changes: 13 additions & 8 deletions ouroboros-network-framework/src/Ouroboros/Network/Socket.hs
Expand Up @@ -195,14 +195,15 @@ connectToNode
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> NetworkConnectTracers addr vNumber
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (OuroborosApplication appType addr BL.ByteString IO a b)
-- ^ application to run over the connection
-> Maybe addr
-- ^ local address; the created socket will bind to it
-> addr
-- ^ remote address
-> IO ()
connectToNode sn handshakeCodec versionDataCodec tracers versions localAddr remoteAddr =
connectToNode sn handshakeCodec versionDataCodec tracers acceptVersion versions localAddr remoteAddr =
bracket
(Snocket.openToConnect sn remoteAddr)
(Snocket.close sn)
Expand All @@ -211,7 +212,7 @@ connectToNode sn handshakeCodec versionDataCodec tracers versions localAddr remo
Just addr -> Snocket.bind sn sd addr
Nothing -> return ()
Snocket.connect sn sd remoteAddr
connectToNode' sn handshakeCodec versionDataCodec tracers versions sd
connectToNode' sn handshakeCodec versionDataCodec tracers acceptVersion versions sd
)

-- |
Expand All @@ -233,11 +234,12 @@ connectToNode'
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> NetworkConnectTracers addr vNumber
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (OuroborosApplication appType addr BL.ByteString IO a b)
-- ^ application to run over the connection
-> fd
-> IO ()
connectToNode' sn handshakeCodec versionDataCodec NetworkConnectTracers {nctMuxTracer, nctHandshakeTracer } versions sd = do
connectToNode' sn handshakeCodec versionDataCodec NetworkConnectTracers {nctMuxTracer, nctHandshakeTracer } acceptVersion versions sd = do
connectionId <- ConnectionId <$> Snocket.getLocalAddr sn sd <*> Snocket.getRemoteAddr sn sd
muxTracer <- initDeltaQTracer' $ Mx.WithMuxBearer connectionId `contramap` nctMuxTracer
ts_start <- getMonotonicTime
Expand All @@ -246,6 +248,7 @@ connectToNode' sn handshakeCodec versionDataCodec NetworkConnectTracers {nctMuxT
runHandshakeClient
(Snocket.toBearer sn sduHandshakeTimeout muxTracer sd)
connectionId
acceptVersion
-- TODO: push 'HandshakeArguments' up the call stack.
HandshakeArguments {
haHandshakeTracer = nctHandshakeTracer,
Expand Down Expand Up @@ -283,16 +286,18 @@ connectToNodeSocket
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> NetworkConnectTracers Socket.SockAddr vNumber
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (OuroborosApplication appType Socket.SockAddr BL.ByteString IO a b)
-- ^ application to run over the connection
-> Socket.Socket
-> IO ()
connectToNodeSocket iocp handshakeCodec versionDataCodec tracers versions sd =
connectToNodeSocket iocp handshakeCodec versionDataCodec tracers acceptVersion versions sd =
connectToNode'
(Snocket.socketSnocket iocp)
handshakeCodec
versionDataCodec
tracers
acceptVersion
versions
sd

Expand Down Expand Up @@ -347,7 +352,7 @@ beginConnection
-> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) (TraceSendRecv (Handshake vNumber CBOR.Term)))
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> (Time -> addr -> st -> STM.STM (AcceptConnection st vNumber extra addr IO BL.ByteString))
-- ^ either accept or reject a connection.
-> Server.BeginConnection addr fd st ()
Expand Down Expand Up @@ -507,7 +512,7 @@ runServerThread
-> AcceptedConnectionsLimit
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (SomeResponderApplication addr BL.ByteString IO b)
-> ErrorPolicies
-> IO Void
Expand Down Expand Up @@ -588,7 +593,7 @@ withServerNode
-> addr
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (SomeResponderApplication addr BL.ByteString IO b)
-- ^ The mux application that will be run on each incoming connection from
-- a given address. Note that if @'MuxClientAndServerApplication'@ is
Expand Down Expand Up @@ -654,7 +659,7 @@ withServerNode'
-> fd
-> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString
-> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept)
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra (SomeResponderApplication addr BL.ByteString IO b)
-- ^ The mux application that will be run on each incoming connection from
-- a given address. Note that if @'MuxClientAndServerApplication'@ is
Expand Down
Expand Up @@ -255,6 +255,7 @@ prop_socket_send_recv initiatorAddr responderAddr f xs =
unversionedHandshakeCodec
cborTermVersionDataCodec
(NetworkConnectTracers activeMuxTracer nullTracer)
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol initiatorApp)
(Just initiatorAddr)
responderAddr
Expand Down Expand Up @@ -489,6 +490,7 @@ prop_socket_client_connect_error _ xs =
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol app)
(Just $ Socket.addrAddress clientAddr)
(Socket.addrAddress serverAddr)
Expand Down
Expand Up @@ -612,6 +612,7 @@ prop_send_recv f xs _first = ioProperty $ withIOManager $ \iocp -> do
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol initiatorApp))

res <- atomically $ (,) <$> takeTMVar sv <*> takeTMVar cv
Expand Down Expand Up @@ -776,6 +777,7 @@ prop_send_recv_init_and_rsp f xs = ioProperty $ withIOManager $ \iocp -> do
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol (appX rrcfg)))

atomically $ (,) <$> takeTMVar (rrcServerVar rrcfg)
Expand Down Expand Up @@ -847,6 +849,7 @@ _demo = ioProperty $ withIOManager $ \iocp -> do
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(unversionedProtocol appReq))

threadDelay 130
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-network/demo/chain-sync.hs
Expand Up @@ -158,6 +158,7 @@ clientChainSync sockPaths = withIOManager $ \iocp ->
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(simpleSingletonVersions
UnversionedProtocol
UnversionedProtocolData
Expand Down Expand Up @@ -365,6 +366,7 @@ clientBlockFetch sockAddrs = withIOManager $ \iocp -> do
unversionedHandshakeCodec
cborTermVersionDataCodec
nullNetworkConnectTracers
(\DictVersion {} -> acceptableVersion)
(simpleSingletonVersions
UnversionedProtocol
UnversionedProtocolData
Expand Down
Expand Up @@ -35,9 +35,9 @@ pureHandshake isTypeable acceptVersion (Versions serverVersions) (Versions clien
(Dict, Dict) -> case (cast vData, cast vData') of
(Just d, Just d') ->
( if acceptVersion (versionExtra version) vData d'
then Just $ runApplication (versionApplication version) vData d'
then Just $ runApplication (versionApplication version) d'
else Nothing

, Just $ runApplication (versionApplication version') vData' d
, Just $ runApplication (versionApplication version') d
)
_ -> (Nothing, Nothing)

0 comments on commit 9ec4dfe

Please sign in to comment.