Skip to content

Commit

Permalink
Add handshake protocol definition
Browse files Browse the repository at this point in the history
  • Loading branch information
jhbertra committed Feb 8, 2023
1 parent 60176ec commit 32a33f5
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 0 deletions.
4 changes: 4 additions & 0 deletions marlowe-protocols/marlowe-protocols.cabal
Expand Up @@ -55,6 +55,10 @@ library
Network.Protocol.ChainSeek.Server
Network.Protocol.ChainSeek.Types
Network.Protocol.ChainSeek.TH
Network.Protocol.Handshake.Client
Network.Protocol.Handshake.Codec
Network.Protocol.Handshake.Server
Network.Protocol.Handshake.Types
Network.Protocol.Job.Client
Network.Protocol.Job.Codec
Network.Protocol.Job.Server
Expand Down
2 changes: 2 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Codec.hs
Expand Up @@ -8,6 +8,8 @@ module Network.Protocol.Codec
, GetMessage
, PutMessage
, binaryCodec
, decodeGet
, encodePut
) where

import Control.Exception (Exception)
Expand Down
56 changes: 56 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Handshake/Client.hs
@@ -0,0 +1,56 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}

-- | A generic client for the handshake protocol. Includes a function for
-- interpreting a client as a typed-protocols peer that can be executed with a
-- driver and a codec.

module Network.Protocol.Handshake.Client
where

import Network.Protocol.Handshake.Types
import Network.TypedProtocol

-- | A generic client for the handshake protocol.
data HandshakeClient h client m a = HandshakeClient
{ handshake :: m h
, recvMsgReject :: m a
, recvMsgAccept :: m (client m a)
}
deriving Functor

hoistHandshakeClient
:: Functor m
=> (forall x. (forall y. m y -> n y) -> client m x -> client n x)
-> (forall x. m x -> n x)
-> HandshakeClient h client m a
-> HandshakeClient h client n a
hoistHandshakeClient hoistClient f HandshakeClient{..} = HandshakeClient
{ handshake = f handshake
, recvMsgReject = f recvMsgReject
, recvMsgAccept = f $ hoistClient f <$> recvMsgAccept
}

handshakeClientPeer
:: forall client m h ps st a
. Functor m
=> (forall x. client m x -> Peer ps 'AsClient st m x)
-> HandshakeClient h client m a
-> Peer (Handshake h ps) 'AsClient ('StInit st) m a
handshakeClientPeer clientPeer HandshakeClient{..} =
Effect $ peerInit <$> handshake
where
peerInit :: h -> Peer (Handshake h ps) 'AsClient ('StInit st) m a
peerInit h =
Yield (ClientAgency TokInit) (MsgHandshake h) $
Await (ServerAgency TokHandshake) \case
MsgReject -> Effect $ Done TokDone <$> recvMsgReject
MsgAccept -> Effect $ liftPeer . clientPeer <$> recvMsgAccept

liftPeer :: forall st'. Peer ps 'AsClient st' m a -> Peer (Handshake h ps) 'AsClient ('StLift st') m a
liftPeer = \case
Effect m -> Effect $ liftPeer <$> m
Done tok a -> Done (TokLiftNobody tok) a
Yield (ClientAgency tok) msg next -> Yield (ClientAgency $ TokLiftClient tok) (MsgLift msg) $ liftPeer next
Await (ServerAgency tok) next -> Await (ServerAgency $ TokLiftServer tok) \(MsgLift msg) -> liftPeer $ next msg
93 changes: 93 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Handshake/Codec.hs
@@ -0,0 +1,93 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}

module Network.Protocol.Handshake.Codec
where

import Control.Monad ((<=<))
import Data.Binary
import Data.Binary.Put (runPut)
import Data.ByteString.Lazy (toStrict)
import qualified Data.ByteString.Lazy as LBS
import Data.Foldable (Foldable(fold))
import Network.Protocol.Codec (DeserializeError(..), decodeGet)
import Network.Protocol.Handshake.Types
import Network.TypedProtocol.Codec

codecHandshake
:: forall ps h m
. (Monad m, Binary h)
=> Codec ps DeserializeError m LBS.ByteString
-> Codec (Handshake h ps) DeserializeError m LBS.ByteString
codecHandshake (Codec encodeMsg decodeMsg) = Codec
( \case
ClientAgency TokInit -> \case
MsgHandshake h -> runPut do
putWord8 0x00
put h
ClientAgency (TokLiftClient tok) -> \case
MsgLift msg -> fold
[ runPut $ putWord8 0x3
, encodeMsg (ClientAgency tok) msg
]

ServerAgency TokHandshake -> runPut . \case
MsgAccept -> putWord8 0x01
MsgReject -> putWord8 0x02

ServerAgency (TokLiftServer tok) -> \case
MsgLift msg -> fold
[ runPut $ putWord8 0x3
, encodeMsg (ServerAgency tok) msg
]
)
(\tok -> decodeGet getWord8 >>= handleTag tok)
where
handleTag
:: PeerHasAgency pr (st :: Handshake h ps)
-> DecodeStep LBS.ByteString DeserializeError m Word8
-> m (DecodeStep LBS.ByteString DeserializeError m (SomeMessage st))
handleTag tok = \case
DecodeFail f -> pure $ DecodeFail f
DecodePartial next -> pure $ DecodePartial $ handleTag tok <=< next
DecodeDone tag mUnconsumed -> case tag of
0x00 -> case tok of
ClientAgency TokInit -> handleH <$> decodeGet get
_ -> failInvalidTag "Invalid protocol state for MsgHandshake" mUnconsumed
0x01 -> case tok of
ServerAgency TokHandshake -> pure $ DecodeDone (SomeMessage MsgAccept) mUnconsumed
_ -> failInvalidTag "Invalid protocol state for MsgAccept" mUnconsumed
0x02 -> case tok of
ServerAgency TokHandshake -> pure $ DecodeDone (SomeMessage MsgAccept) mUnconsumed
_ -> failInvalidTag "Invalid protocol state for MsgAccept" mUnconsumed
0x03 -> case tok of
ClientAgency (TokLiftClient tok') -> handleMsg <$> decodeMsg (ClientAgency tok')
ServerAgency (TokLiftServer tok') -> handleMsg <$> decodeMsg (ServerAgency tok')
_ -> failInvalidTag "Invalid protocol state for MsgLift" mUnconsumed
_ -> failInvalidTag ("Invalid msg tag " <> show tag) mUnconsumed
where
failInvalidTag message mUnconsumed = pure $ DecodeFail DeserializeError
{ message
, offset = 0
, unconsumedInput = foldMap toStrict mUnconsumed
}

handleH
:: DecodeStep LBS.ByteString DeserializeError m h
-> DecodeStep LBS.ByteString DeserializeError m (SomeMessage ('StInit st' :: Handshake h ps))
handleH = \case
DecodeFail f -> DecodeFail f
DecodePartial next -> DecodePartial $ fmap handleH <$> next
DecodeDone h mUnconsumed -> DecodeDone (SomeMessage $ MsgHandshake h) mUnconsumed

handleMsg
:: DecodeStep LBS.ByteString DeserializeError m (SomeMessage st')
-> DecodeStep LBS.ByteString DeserializeError m (SomeMessage ('StLift st'))
handleMsg = \case
DecodeFail f -> DecodeFail f
DecodePartial next -> DecodePartial $ fmap handleMsg <$> next
DecodeDone (SomeMessage msg) mUnconsumed -> DecodeDone (SomeMessage $ MsgLift msg) mUnconsumed
56 changes: 56 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Handshake/Server.hs
@@ -0,0 +1,56 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}

-- | A generic server for the handshake protocol. Includes a function for
-- interpreting a server as a typed-protocols peer that can be executed with a
-- driver and a codec.

module Network.Protocol.Handshake.Server
where

import Data.Bifunctor (Bifunctor(bimap))
import Data.Functor ((<&>))
import Network.Protocol.Handshake.Types
import Network.TypedProtocol

-- | A generic server for the handshake protocol.
newtype HandshakeServer h server m a = HandshakeServer
{ recvMsgHandshake :: h -> m (Either a (server m a))
}

instance (Functor m, Functor (server m)) => Functor (HandshakeServer h server m) where
fmap f HandshakeServer{..} = HandshakeServer
{ recvMsgHandshake = fmap (bimap f $ fmap f) . recvMsgHandshake
}

hoistHandshakeServer
:: Functor m
=> (forall x. (forall y. m y -> n y) -> server m x -> server n x)
-> (forall x. m x -> n x)
-> HandshakeServer h server m a
-> HandshakeServer h server n a
hoistHandshakeServer hoistServer f HandshakeServer{..} = HandshakeServer
{ recvMsgHandshake = f . (fmap . fmap) (hoistServer f) . recvMsgHandshake
}

handshakeServerPeer
:: forall client m h ps st a
. Functor m
=> (forall x. client m x -> Peer ps 'AsServer st m x)
-> HandshakeServer h client m a
-> Peer (Handshake h ps) 'AsServer ('StInit st) m a
handshakeServerPeer serverPeer HandshakeServer{..} =
Await (ClientAgency TokInit) \case
MsgHandshake h -> Effect $ recvMsgHandshake h <&> \case
Left a ->
Yield (ServerAgency TokHandshake) MsgReject $ Done TokDone a
Right server ->
Yield (ServerAgency TokHandshake) MsgAccept $ liftPeer $ serverPeer server
where
liftPeer :: forall st'. Peer ps 'AsServer st' m a -> Peer (Handshake h ps) 'AsServer ('StLift st') m a
liftPeer = \case
Effect m -> Effect $ liftPeer <$> m
Done tok a -> Done (TokLiftNobody tok) a
Yield (ServerAgency tok) msg next -> Yield (ServerAgency $ TokLiftServer tok) (MsgLift msg) $ liftPeer next
Await (ClientAgency tok) next -> Await (ClientAgency $ TokLiftClient tok) \(MsgLift msg) -> liftPeer $ next msg
154 changes: 154 additions & 0 deletions marlowe-protocols/src/Network/Protocol/Handshake/Types.hs
@@ -0,0 +1,154 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | The type of the handshake protocol.
--
-- The job protocol is used to establish a connection between two peers, and
-- gives them a chance to confirm that they speak the same protocol.

module Network.Protocol.Handshake.Types
where

import Data.Aeson (ToJSON, Value(..), object, (.=))
import GHC.Show (showSpace)
import Network.Protocol.Codec.Spec (ArbitraryMessage(..), MessageEq(..), ShowProtocol(..))
import Network.Protocol.Driver (MessageToJSON(..))
import Network.TypedProtocol
import Network.TypedProtocol.Codec (AnyMessageAndAgency(..))
import Test.QuickCheck (Arbitrary(arbitrary), oneof, shrink)

data Handshake h ps where
StInit :: ps -> Handshake h ps
StHandshake :: ps -> Handshake h ps
StLift :: ps -> Handshake h ps
StDone :: Handshake h ps

instance Protocol ps => Protocol (Handshake h ps) where
data Message (Handshake h ps) st st' where
MsgHandshake :: h -> Message (Handshake h ps)
('StInit st)
('StHandshake st)
MsgAccept :: Message (Handshake h ps)
('StHandshake st)
('StLift st)
MsgReject :: Message (Handshake h ps)
('StHandshake st)
'StDone
MsgLift :: Message ps st st' -> Message (Handshake h ps)
('StLift st)
('StLift st')

data ClientHasAgency st where
TokInit :: ClientHasAgency ('StInit st')
TokLiftClient :: ClientHasAgency st' -> ClientHasAgency ('StLift st')

data ServerHasAgency st where
TokHandshake :: ServerHasAgency ('StHandshake st')
TokLiftServer :: ServerHasAgency st' -> ServerHasAgency ('StLift st')

data NobodyHasAgency st where
TokDone :: NobodyHasAgency 'StDone
TokLiftNobody :: NobodyHasAgency st' -> NobodyHasAgency ('StLift st')

exclusionLemma_ClientAndServerHaveAgency = \case
TokInit -> \case
TokLiftClient c_tok -> \case
TokLiftServer s_tok -> exclusionLemma_ClientAndServerHaveAgency c_tok s_tok

exclusionLemma_NobodyAndClientHaveAgency = \case
TokDone -> \case
TokLiftNobody n_tok -> \case
TokLiftClient c_tok -> exclusionLemma_NobodyAndClientHaveAgency n_tok c_tok

exclusionLemma_NobodyAndServerHaveAgency = \case
TokDone -> \case
TokLiftNobody n_tok -> \case
TokLiftServer s_tok -> exclusionLemma_NobodyAndServerHaveAgency n_tok s_tok

instance (ArbitraryMessage ps, Arbitrary h) => ArbitraryMessage (Handshake h ps) where
arbitraryMessage = oneof
[ AnyMessageAndAgency (ClientAgency TokInit) . MsgHandshake <$> arbitrary
, pure $ AnyMessageAndAgency (ServerAgency TokHandshake) MsgReject
, pure $ AnyMessageAndAgency (ServerAgency TokHandshake) MsgAccept
, do
AnyMessageAndAgency tok msg <- arbitraryMessage @ps
pure case tok of
ClientAgency tok' -> AnyMessageAndAgency (ClientAgency $ TokLiftClient tok') (MsgLift msg)
ServerAgency tok' -> AnyMessageAndAgency (ServerAgency $ TokLiftServer tok') (MsgLift msg)
]
shrinkMessage = \case
ClientAgency TokInit -> \case
MsgHandshake h -> MsgHandshake <$> shrink h
ClientAgency (TokLiftClient tok) -> \case
MsgLift msg -> MsgLift <$> shrinkMessage (ClientAgency tok) msg
ServerAgency TokHandshake -> const []
ServerAgency (TokLiftServer tok) -> \case
MsgLift msg -> MsgLift <$> shrinkMessage (ServerAgency tok) msg

instance (MessageEq ps, Eq h) => MessageEq (Handshake h ps) where
messageEq (AnyMessageAndAgency tok1 msg1) (AnyMessageAndAgency tok2 msg2)= case (tok1, tok2) of
(ClientAgency TokInit, ClientAgency TokInit) -> case (msg1, msg2) of
(MsgHandshake h, MsgHandshake h') -> h == h'
(ClientAgency TokInit, _) -> False
(ClientAgency (TokLiftClient tok1'), ClientAgency (TokLiftClient tok2')) -> case (msg1, msg2) of
(MsgLift msg1', MsgLift msg2') ->
messageEq (AnyMessageAndAgency (ClientAgency tok1') msg1') (AnyMessageAndAgency (ClientAgency tok2') msg2')
(ClientAgency (TokLiftClient _), _) -> False
(ServerAgency TokHandshake, ServerAgency TokHandshake) -> case (msg1, msg2) of
(MsgAccept, MsgAccept) -> True
(MsgAccept, MsgReject) -> False
(MsgReject, MsgAccept) -> False
(MsgReject, MsgReject) -> True
(ServerAgency TokHandshake, _) -> False
(ServerAgency (TokLiftServer tok1'), ServerAgency (TokLiftServer tok2')) -> case (msg1, msg2) of
(MsgLift msg1', MsgLift msg2') ->
messageEq (AnyMessageAndAgency (ServerAgency tok1') msg1') (AnyMessageAndAgency (ServerAgency tok2') msg2')
(ServerAgency (TokLiftServer _), _) -> False

instance (ShowProtocol ps, Show h) => ShowProtocol (Handshake h ps) where
showsPrecMessage p tok = \case
MsgHandshake h -> showParen (p >= 11)
( showString "MsgHandshake"
. showSpace
. showsPrec 11 h
)
MsgAccept -> showString "MsgAccept"
MsgReject -> showString "MsgReject"
MsgLift msg -> showParen (p >= 11)
( showString "MsgLift"
. showSpace
. case tok of
ClientAgency (TokLiftClient tok') -> showsPrecMessage 11 (ClientAgency tok') msg
ServerAgency (TokLiftServer tok') -> showsPrecMessage 11 (ServerAgency tok') msg
)
showsPrecServerHasAgency p = \case
TokHandshake -> showString "TokHandshake"
TokLiftServer tok -> showParen (p >= 11)
( showString "TokLiftServer"
. showSpace
. showsPrecServerHasAgency 11 tok
)
showsPrecClientHasAgency p = \case
TokInit -> showString "TokInit"
TokLiftClient tok -> showParen (p >= 11)
( showString "TokLiftClient"
. showSpace
. showsPrecClientHasAgency 11 tok
)

instance (MessageToJSON ps, ToJSON h) => MessageToJSON (Handshake h ps) where
messageToJSON = \case
ClientAgency TokInit -> \case
MsgHandshake h -> object [ "handshake" .= h ]
ClientAgency (TokLiftClient tok) -> \case
MsgLift msg -> object [ "lift" .= messageToJSON (ClientAgency tok) msg ]
ServerAgency TokHandshake -> String . \case
MsgAccept -> "accept"
MsgReject -> "reject"
ServerAgency (TokLiftServer tok) -> \case
MsgLift msg -> object [ "lift" .= messageToJSON (ServerAgency tok) msg ]

0 comments on commit 32a33f5

Please sign in to comment.