Skip to content

Commit

Permalink
handshake refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Oct 27, 2020
1 parent 20a0a41 commit 9d5832c
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ tryHandshake doHandshake = do

-- | Common arguments for both 'Handshake' client & server.
--
data HandshakeArguments connectionId vNumber extra m application agreedOptions = HandshakeArguments {
data HandshakeArguments connectionId vNumber vData m application = HandshakeArguments {
-- | 'Handshake' tracer
--
haHandshakeTracer :: Tracer m (WithMuxBearer connectionId
Expand All @@ -88,10 +88,10 @@ data HandshakeArguments connectionId vNumber extra m application agreedOptions =
-- | A codec for protocol parameters.
--
haVersionDataCodec
:: VersionDataCodec extra CBOR.Term vNumber agreedOptions,
:: VersionDataCodec CBOR.Term vNumber vData,

-- | versioned application aggreed upon with the 'Handshake' protocol.
haVersions :: Versions vNumber extra application
haVersions :: Versions vNumber vData application
}


Expand All @@ -108,10 +108,10 @@ runHandshakeClient
)
=> MuxBearer m
-> connectionId
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber extra m application agreedOptions
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either (HandshakeException (HandshakeClientProtocolError vNumber))
(application, agreedOptions))
(application, vNumber, vData))
runHandshakeClient bearer
connectionId
acceptVersion
Expand Down Expand Up @@ -145,11 +145,11 @@ runHandshakeServer
)
=> MuxBearer m
-> connectionId
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber extra m application agreedOptions
-> (vData -> vData -> Accept vData)
-> HandshakeArguments connectionId vNumber vData m application
-> m (Either
(HandshakeException (RefuseReason vNumber))
(application, agreedOptions))
(application, vNumber, vData))
runHandshakeServer bearer
connectionId
acceptVersion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ import Ouroboros.Network.Protocol.Handshake.Version
--
handshakeClientPeer
:: Ord vNumber
=> VersionDataCodec extra CBOR.Term vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra r
=> VersionDataCodec CBOR.Term vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer (Handshake vNumber CBOR.Term)
AsClient StPropose m
(Either
(HandshakeClientProtocolError vNumber)
(r, agreedOptions))
handshakeClientPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions} acceptVersion versions =
(r, vNumber, vData))
handshakeClientPeer VersionDataCodec {encodeData, decodeData} acceptVersion versions =
-- send known versions
Yield (ClientAgency TokPropose) (MsgProposeVersions $ encodeVersions encodeData versions) $

Expand All @@ -50,27 +50,29 @@ handshakeClientPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions}
MsgAcceptVersion vNumber vParams ->
case vNumber `Map.lookup` getVersions versions of
Nothing -> Done TokDone (Left $ NotRecognisedVersion vNumber)
Just (Sigma vData version) ->
case decodeData (versionExtra version) vParams of
Just (Version app vData) ->
case decodeData vParams of

Left err ->
Done TokDone (Left (HandshakeError $ HandshakeDecodeError vNumber err))

Right vData' ->
case acceptVersion (versionExtra version) vData vData' of
case acceptVersion vData vData' of
Accept agreedData ->
Done TokDone $ Right $ ( runApplication (versionApplication version) agreedData
, getAgreedOptions (versionExtra version) vNumber agreedData
Done TokDone $ Right $ ( runApplication app agreedData
, vNumber
, agreedData
)
Refuse err ->
Done TokDone (Left (InvalidServerSelection vNumber err))


encodeVersions
:: forall vNumber extra r vParams.
(forall vData. extra vData -> vData -> vParams)
-> Versions vNumber extra r
:: forall vNumber r vParams vData.
(vData -> vParams)
-> Versions vNumber vData r
-> Map vNumber vParams
encodeVersions encoder (Versions vs) = go <$> vs
where
go :: Sigma (Version extra r) -> vParams
go (Sigma vData Version {versionExtra}) = encoder versionExtra vData
go :: Version vData r -> vParams
go Version {versionData} = encoder versionData
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import Ouroboros.Network.CodecCBORTerm
import Ouroboros.Network.Driver.Limits

import Ouroboros.Network.Protocol.Handshake.Type
import Ouroboros.Network.Protocol.Handshake.Version

-- | Codec for version data ('vData' in code) exchanged by the handshake
-- protocol.
Expand All @@ -51,24 +50,21 @@ import Ouroboros.Network.Protocol.Handshake.Version
-- is instatiated to 'NodeToNodeVersionData' in "Ouroboros.Network.NodeToNode"
-- or to '()' in "Ouroboros.Network.NodeToClient".
--
data VersionDataCodec extra bytes vNumber agreedOptions = VersionDataCodec {
encodeData :: forall vData. extra vData -> vData -> bytes,
data VersionDataCodec bytes vNumber vData = VersionDataCodec {
encodeData :: vData -> bytes,
-- ^ encoder of 'vData' which has access to 'extra vData' which can bring
-- extra instances into the scope (by means of pattern matching on a GADT).
decodeData :: forall vData. extra vData -> bytes -> Either Text vData,
decodeData :: bytes -> Either Text vData
-- ^ decoder of 'vData'.
getAgreedOptions :: forall vData. extra vData -> vNumber -> vData -> agreedOptions
-- ^ map negotiated 'vData' into version independent representation
-- 'agreedOptions'.
}

-- TODO: remove this from top level API, this is the only way we encode or
-- decode version data.
cborTermVersionDataCodec :: VersionDataCodec (DictVersion vNumber agreedOptions) CBOR.Term vNumber agreedOptions
cborTermVersionDataCodec = VersionDataCodec {
encodeData = \(DictVersion codec _) -> encodeTerm codec,
decodeData = \(DictVersion codec _) -> decodeTerm codec,
getAgreedOptions = \(DictVersion _ f) -> f
cborTermVersionDataCodec :: CodecCBORTerm Text vData
-> VersionDataCodec CBOR.Term vNumber vData
cborTermVersionDataCodec codec = VersionDataCodec {
encodeData = encodeTerm codec,
decodeData = decodeTerm codec
}

-- |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import Ouroboros.Network.Protocol.Handshake.Version
--
handshakeServerPeer
:: Ord vNumber
=> VersionDataCodec extra vParams vNumber agreedOptions
-> (forall vData. extra vData -> vData -> vData -> Accept vData)
-> Versions vNumber extra r
=> VersionDataCodec vParams vNumber vData
-> (vData -> vData -> Accept vData)
-> Versions vNumber vData r
-> Peer (Handshake vNumber vParams)
AsServer StPropose m
(Either (RefuseReason vNumber) (r, agreedOptions))
handshakeServerPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions} acceptVersion versions =
(Either (RefuseReason vNumber) (r, vNumber, vData))
handshakeServerPeer VersionDataCodec {encodeData, decodeData} acceptVersion versions =
-- await for versions proposed by a client
Await (ClientAgency TokPropose) $ \msg -> case msg of

Expand All @@ -48,25 +48,26 @@ handshakeServerPeer VersionDataCodec {encodeData, decodeData, getAgreedOptions}

vNumber:_ ->
case (getVersions versions Map.! vNumber, vMap Map.! vNumber) of
(Sigma vData version, vParams) -> case decodeData (versionExtra version) vParams of
(Version app vData, vParams) -> case decodeData vParams of
Left err ->
let vReason = HandshakeDecodeError vNumber err
in Yield (ServerAgency TokConfirm)
(MsgRefuse vReason)
(Done TokDone $ Left vReason)

Right vData' ->
case acceptVersion (versionExtra version) vData vData' of
case acceptVersion vData vData' of

-- We agree on the version; send back the agreed version
-- number @vNumber@ and encoded data associated with our
-- version.
Accept agreedData ->
Yield (ServerAgency TokConfirm)
(MsgAcceptVersion vNumber (encodeData (versionExtra version) agreedData))
(MsgAcceptVersion vNumber (encodeData agreedData))
(Done TokDone $ Right $
( runApplication (versionApplication version) agreedData
, getAgreedOptions (versionExtra version) vNumber agreedData
( runApplication app agreedData
, vNumber
, agreedData
))

-- We disagree on the version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ unversionedProtocolDataCodec = CodecCBORTerm {encodeTerm, decodeTerm}
-- | Make a 'Versions' for an unversioned protocol. Only use this for
-- tests and demos where proper versioning is excessive.
--
unversionedProtocol :: app -> Versions UnversionedProtocol (DictVersion UnversionedProtocol UnversionedProtocolData) app
unversionedProtocol =
simpleSingletonVersions UnversionedProtocol UnversionedProtocolData
(DictVersion unversionedProtocolDataCodec (\_ _ -> UnversionedProtocolData))
unversionedProtocol :: app
-> Versions UnversionedProtocol
UnversionedProtocolData
app
unversionedProtocol = simpleSingletonVersions UnversionedProtocol UnversionedProtocolData


-- | 'Handshake' codec used in various tests.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
Expand All @@ -11,7 +12,6 @@ module Ouroboros.Network.Protocol.Handshake.Version
( Versions (..)
, Application (..)
, Version (..)
, Sigma (..)
, Accept (..)
, Acceptable (..)
, Dict (..)
Expand Down Expand Up @@ -59,18 +59,13 @@ import Ouroboros.Network.CodecCBORTerm
-- > ]
-- >
--
newtype Versions vNum extra r = Versions
{ getVersions :: Map vNum (Sigma (Version extra r))
newtype Versions vNum vData r = Versions
{ getVersions :: Map vNum (Version vData r)
}
deriving (Semigroup)

instance Functor (Versions vNum extra) where
fmap f (Versions vs) = Versions $ Map.map fmapSigma vs
where
fmapSigma (Sigma t (Version (Application app) extra)) = Sigma t (Version (Application $ \x -> f (app x)) extra)

data Sigma f where
Sigma :: !t -> !(f t) -> Sigma f
fmap f (Versions vs) = Versions $ Map.map (fmap f) vs


-- | Useful for folding multiple 'Versions'.
Expand Down Expand Up @@ -104,14 +99,17 @@ class Acceptable v where
acceptableVersion :: v -> v -> Accept v

-- | Takes a pair of version data: local then remote.
newtype Application r vData = Application
newtype Application vData r = Application
{ runApplication :: vData -> r
}
deriving Functor


data Version extra r vData = Version
{ versionApplication :: Application r vData
, versionExtra :: extra vData
data Version vData r = Version
{ versionApplication :: Application vData r
, versionData :: vData
}
deriving Functor

data VersionMismatch vNum where
NoCommonVersion :: VersionMismatch vNum
Expand All @@ -124,15 +122,13 @@ data Dict constraint thing where
-- 'hanshakeParams' is instatiated in either "Ouroboros.Network.NodeToNode" or
-- "Ouroboros.Network.NodeToClient" to 'HandshakeParams'.
--
data DictVersion vNumber agreedOptions vData where
data DictVersion vNumber vData where
DictVersion :: ( Typeable vData
, Acceptable vData
, Show vData
)
=> CodecCBORTerm Text vData
-> (vNumber -> vData -> agreedOptions)
-- ^ agreed vData
-> DictVersion vNumber agreedOptions vData
-> DictVersion vNumber vData

--
-- Simple version negotation
Expand All @@ -143,10 +139,9 @@ data DictVersion vNumber agreedOptions vData where
simpleSingletonVersions
:: vNum
-> vData
-> extra vData
-> r
-> Versions vNum extra r
simpleSingletonVersions vNum vData extra r =
-> Versions vNum vData r
simpleSingletonVersions vNum vData r =
Versions
$ Map.singleton vNum
(Sigma vData (Version (Application $ \_ -> r) extra))
(Version (Application (\_ -> r)) vData)
Loading

0 comments on commit 9d5832c

Please sign in to comment.