Skip to content

Commit

Permalink
handshake: resiliant for simultanous open
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Sep 16, 2021
1 parent 9aab52b commit bb28ec3
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 62 deletions.
@@ -1,11 +1,13 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Ouroboros.Network.Protocol.Handshake.Client
( handshakeClientPeer
, acceptOrRefuse
) where

import Data.Map (Map)
Expand All @@ -26,7 +28,8 @@ import Ouroboros.Network.Protocol.Handshake.Version
-- TODO: GADT encoding of the client (@Handshake.Client@ module).
--
handshakeClientPeer
:: Ord vNumber
:: ( Ord vNumber
)
=> VersionDataCodec CBOR.Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
Expand All @@ -35,11 +38,18 @@ handshakeClientPeer
(Either
(HandshakeProtocolError vNumber)
(r, vNumber, vData))
handshakeClientPeer VersionDataCodec {encodeData, decodeData} acceptVersion versions =
handshakeClientPeer codec@VersionDataCodec {encodeData, decodeData}
acceptVersion versions =
-- send known versions
Yield (ClientAgency TokPropose) (MsgProposeVersions $ encodeVersions encodeData versions) $

Await (ServerAgency TokConfirm) $ \msg -> case msg of
MsgProposeVersions' vMap ->
-- simultanous open; 'accept' will choose version (the greatest common
-- version), and check if we can accept received version data.
Done TokDone $ case acceptOrRefuse codec acceptVersion versions vMap of
Right r -> Right r
Left vReason -> Left (HandshakeError vReason)

-- the server refused common highest version
MsgRefuse vReason ->
Expand Down Expand Up @@ -76,3 +86,37 @@ encodeVersions encoder (Versions vs) = go `Map.mapWithKey` vs
where
go :: vNumber -> Version vData r -> vParams
go vNumber Version {versionData} = encoder vNumber versionData


acceptOrRefuse
:: forall vParams vNumber vData r.
Ord vNumber
=> VersionDataCodec vParams vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Map vNumber vParams
-- ^ proposed versions received either with `MsgProposeVersions` or
-- `MsgProposeVersions'`
-> Either (RefuseReason vNumber) (r, vNumber, vData)
acceptOrRefuse VersionDataCodec {decodeData}
acceptVersion versions versionMap =
case lookupGreatestCommonKey versionMap (getVersions versions) of
Nothing ->
Left $ VersionMismatch (Map.keys $ getVersions versions) []

Just (vNumber, (vParams, Version app vData)) ->
case decodeData vNumber vParams of
Left err ->
Left (HandshakeDecodeError vNumber err)

Right vData' ->
case acceptVersion vData vData' of
Accept agreedData ->
Right (app agreedData, vNumber, agreedData)

Refuse err ->
Left (Refused vNumber err)


lookupGreatestCommonKey :: Ord k => Map k a -> Map k b -> Maybe (k, (a, b))
lookupGreatestCommonKey l r = Map.lookupMax $ Map.intersectionWith (,) l r
Expand Up @@ -131,15 +131,16 @@ codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg
-> CBOR.Encoding

encodeMsg (ClientAgency TokPropose) (MsgProposeVersions vs) =
let vs' = Map.toAscList vs
in
CBOR.encodeListLen 2
<> CBOR.encodeWord 0
<> CBOR.encodeMapLen (fromIntegral $ length vs')
<> mconcat [ CBOR.encodeTerm (encodeTerm versionNumberCodec vNumber)
<> CBOR.encodeTerm vParams
| (vNumber, vParams) <- vs'
]
<> encodeVersions versionNumberCodec vs

-- Although `MsgProposeVersions'` shall not be sent, for testing purposes
-- it is useful to have an encoder for it.
encodeMsg (ServerAgency TokConfirm) (MsgProposeVersions' vs) =
CBOR.encodeListLen 2
<> CBOR.encodeWord 0
<> encodeVersions versionNumberCodec vs

encodeMsg (ServerAgency TokConfirm) (MsgAcceptVersion vNumber vParams) =
CBOR.encodeListLen 3
Expand All @@ -152,6 +153,7 @@ codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg
<> CBOR.encodeWord 2
<> encodeRefuseReason versionNumberCodec vReason


-- decode a map checking the assumption that
-- * keys are different
-- * keys are encoded in ascending order
Expand Down Expand Up @@ -186,6 +188,10 @@ codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg
l <- CBOR.decodeMapLen
vMap <- decodeMap l Nothing []
pure $ SomeMessage $ MsgProposeVersions vMap
(ServerAgency TokConfirm, 0, 2) -> do
l <- CBOR.decodeMapLen
vMap <- decodeMap l Nothing []
pure $ SomeMessage $ MsgProposeVersions' vMap
(ServerAgency TokConfirm, 1, 3) -> do
v <- decodeTerm versionNumberCodec <$> CBOR.decodeTerm
case v of
Expand All @@ -204,6 +210,21 @@ codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg
fail $ printf "codecHandshake (%s) unexpected key (%d, %d)" (show stok) key len


-- | Encode version map preserving the ascending order of keys.
--
encodeVersions :: CodecCBORTerm (failure, Maybe Int) vNumber
-> Map vNumber CBOR.Term
-> CBOR.Encoding
encodeVersions versionNumberCodec vs =
CBOR.encodeMapLen (fromIntegral (Map.size vs))
<> Map.foldMapWithKey
(\vNumber vParams ->
CBOR.encodeTerm (encodeTerm versionNumberCodec vNumber)
<> CBOR.encodeTerm vParams
)
vs


encodeRefuseReason :: CodecCBORTerm fail vNumber
-> RefuseReason vNumber
-> CBOR.Encoding
Expand Down
Expand Up @@ -8,71 +8,37 @@ module Ouroboros.Network.Protocol.Handshake.Server
( handshakeServerPeer
) where

import Data.Map (Map)
import qualified Data.Map as Map
import qualified Codec.CBOR.Term as CBOR

import Network.TypedProtocol.Core

import Ouroboros.Network.Protocol.Handshake.Codec
import Ouroboros.Network.Protocol.Handshake.Client (acceptOrRefuse)
import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version


-- | Server following the handshake protocol; it accepts highest version offered
-- by the peer that also belongs to the server @versions@.
--
-- TODO: GADT encoding of the server (@Handshake.Server@ module).
--
handshakeServerPeer
:: Ord vNumber
=> VersionDataCodec vParams vNumber vData
:: ( Ord vNumber
)
=> VersionDataCodec CBOR.Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer (Handshake vNumber vParams)
-> Peer (Handshake vNumber CBOR.Term)
AsServer StPropose m
(Either (HandshakeProtocolError vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec {encodeData, decodeData} acceptVersion versions =
-- await for versions proposed by a client
handshakeServerPeer codec@VersionDataCodec {encodeData} acceptVersion versions =
Await (ClientAgency TokPropose) $ \msg -> case msg of

MsgProposeVersions vMap ->
-- Compute intersection of local and remote versions.
case lookupGreatestCommonKey vMap (getVersions versions) of
Nothing ->
let vReason = VersionMismatch (Map.keys $ getVersions versions) []
in Yield (ServerAgency TokConfirm)
(MsgRefuse vReason)
(Done TokDone (Left $ HandshakeError vReason))

Just (vNumber, (vParams, Version app vData)) ->
case decodeData vNumber vParams of
Left err ->
let vReason = HandshakeDecodeError vNumber err
in Yield (ServerAgency TokConfirm)
(MsgRefuse vReason)
(Done TokDone $ Left $ HandshakeError vReason)

Right vData' ->
case acceptVersion vData vData' of

-- We agree on the version; send back the agreed version
-- number @vNumber@ and encoded data associated with our
-- version.
Accept agreedData ->
Yield (ServerAgency TokConfirm)
(MsgAcceptVersion vNumber (encodeData vNumber agreedData))
(Done TokDone $ Right $
( app agreedData
, vNumber
, agreedData
))

-- We disagree on the version.
Refuse err ->
let vReason = Refused vNumber err
in Yield (ServerAgency TokConfirm)
(MsgRefuse vReason)
(Done TokDone $ Left $ HandshakeError vReason)

lookupGreatestCommonKey :: Ord k => Map k a -> Map k b -> Maybe (k, (a, b))
lookupGreatestCommonKey l r = Map.lookupMax $ Map.intersectionWith (,) l r
MsgProposeVersions vMap ->
case acceptOrRefuse codec acceptVersion versions vMap of
(Right r@(_, vNumber, agreedData)) ->
Yield (ServerAgency TokConfirm)
(MsgAcceptVersion vNumber (encodeData vNumber agreedData))
(Done TokDone (Right r))
(Left vReason) ->
Yield (ServerAgency TokConfirm)
(MsgRefuse vReason)
(Done TokDone (Left (HandshakeError vReason)))
Expand Up @@ -20,6 +20,7 @@ module Ouroboros.Network.Protocol.Handshake.Type
, ClientHasAgency (..)
, ServerHasAgency (..)
, NobodyHasAgency (..)
-- $simultanous-open
, RefuseReason (..)
, HandshakeProtocolError (..)
)
Expand All @@ -34,7 +35,6 @@ import Data.Map (Map)
import Network.TypedProtocol.Core
import Ouroboros.Network.Util.ShowProxy (ShowProxy (..))


-- |
-- The handshake mini-protocol is used initially to agree the version and
-- associated parameters of the protocol to use for all subsequent
Expand Down Expand Up @@ -82,6 +82,16 @@ instance Protocol (Handshake vNumber vParams) where
:: Map vNumber vParams
-> Message (Handshake vNumber vParams) StPropose StConfirm

-- |
-- `MsgProposeVersions'` received as a response to 'MsgProposeVersions'.
-- It is not supported to explicitly send this message. It can only be
-- received as a copy of 'MsgProposeVersions' in a simultanous open
-- scenario.
--
MsgProposeVersions'
:: Map vNumber vParams
-> Message (Handshake vNumber vParams) StConfirm StDone

-- |
-- The remote end decides which version to use and sends chosen version.
-- The server is allowed to modify version parameters.
Expand Down Expand Up @@ -111,11 +121,21 @@ instance Protocol (Handshake vNumber vParams) where
exclusionLemma_NobodyAndClientHaveAgency TokDone tok = case tok of {}
exclusionLemma_NobodyAndServerHaveAgency TokDone tok = case tok of {}

-- $simultanous-open
--
-- On simultanous open both sides will send `MsgProposeVersions`, which will be
-- decoded as `MsgProposeVersions'` which is a terminal message of the
-- protocol. It is important to stress that in this case both sides will make
-- the choice which version and parameters to pick. Our algorithm for picking
-- version is symmetric, which ensures that both sides will endup with the same
-- choice. If one side decides to refuse the version it will close the
-- connection, without sending the reason to the other side.

deriving instance (Show vNumber, Show vParams)
=> Show (Message (Handshake vNumber vParams) from to)

instance Show (ClientHasAgency (st :: Handshake vNumber vParams)) where
show TokPropose = "TokPropose"
show TokPropose = "TokPropose"

instance Show (ServerHasAgency (st :: Handshake vNumber vParams)) where
show TokConfirm = "TokConfirm"
Expand Down
Expand Up @@ -547,6 +547,9 @@ instance Eq (AnyMessage (Handshake VersionNumber CBOR.Term)) where
AnyMessage (MsgProposeVersions vs) == AnyMessage (MsgProposeVersions vs')
= vs == vs'

AnyMessage (MsgProposeVersions' vs) == AnyMessage (MsgProposeVersions' vs')
= vs == vs'

AnyMessage (MsgAcceptVersion vNumber vParams) == AnyMessage (MsgAcceptVersion vNumber' vParams')
= vNumber == vNumber' && vParams == vParams'

Expand All @@ -563,6 +566,12 @@ instance Arbitrary (AnyMessageAndAgency (Handshake VersionNumber CBOR.Term)) whe
. getVersions
<$> genVersions

, AnyMessageAndAgency (ServerAgency TokConfirm)
. MsgProposeVersions'
. Map.mapWithKey (\v -> encodeTerm (dataCodecCBORTerm v) . versionData)
. getVersions
<$> genVersions

, AnyMessageAndAgency (ServerAgency TokConfirm)
. uncurry MsgAcceptVersion
<$> genValidVersion'
Expand Down

0 comments on commit bb28ec3

Please sign in to comment.