Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
marlowe-protocols/src/Network/Protocol/Handshake/Client.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE GADTs #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
|
||
-- | A generic client for the handshake protocol. Includes a function for | ||
-- interpreting a client as a typed-protocols peer that can be executed with a | ||
-- driver and a codec. | ||
|
||
module Network.Protocol.Handshake.Client | ||
where | ||
|
||
import Network.Protocol.Handshake.Types | ||
import Network.TypedProtocol | ||
|
||
-- | A generic client for the handshake protocol. | ||
data HandshakeClient h client m a = HandshakeClient | ||
{ handshake :: m h | ||
, recvMsgReject :: m a | ||
, recvMsgAccept :: m (client m a) | ||
} | ||
deriving Functor | ||
|
||
hoistHandshakeClient | ||
:: Functor m | ||
=> (forall x. (forall y. m y -> n y) -> client m x -> client n x) | ||
-> (forall x. m x -> n x) | ||
-> HandshakeClient h client m a | ||
-> HandshakeClient h client n a | ||
hoistHandshakeClient hoistClient f HandshakeClient{..} = HandshakeClient | ||
{ handshake = f handshake | ||
, recvMsgReject = f recvMsgReject | ||
, recvMsgAccept = f $ hoistClient f <$> recvMsgAccept | ||
} | ||
|
||
handshakeClientPeer | ||
:: forall client m h ps st a | ||
. Functor m | ||
=> (forall x. client m x -> Peer ps 'AsClient st m x) | ||
-> HandshakeClient h client m a | ||
-> Peer (Handshake h ps) 'AsClient ('StInit st) m a | ||
handshakeClientPeer clientPeer HandshakeClient{..} = | ||
Effect $ peerInit <$> handshake | ||
where | ||
peerInit :: h -> Peer (Handshake h ps) 'AsClient ('StInit st) m a | ||
peerInit h = | ||
Yield (ClientAgency TokInit) (MsgHandshake h) $ | ||
Await (ServerAgency TokHandshake) \case | ||
MsgReject -> Effect $ Done TokDone <$> recvMsgReject | ||
MsgAccept -> Effect $ liftPeer . clientPeer <$> recvMsgAccept | ||
|
||
liftPeer :: forall st'. Peer ps 'AsClient st' m a -> Peer (Handshake h ps) 'AsClient ('StLift st') m a | ||
liftPeer = \case | ||
Effect m -> Effect $ liftPeer <$> m | ||
Done tok a -> Done (TokLiftNobody tok) a | ||
Yield (ClientAgency tok) msg next -> Yield (ClientAgency $ TokLiftClient tok) (MsgLift msg) $ liftPeer next | ||
Await (ServerAgency tok) next -> Await (ServerAgency $ TokLiftServer tok) \(MsgLift msg) -> liftPeer $ next msg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE ExplicitNamespaces #-} | ||
{-# LANGUAGE GADTs #-} | ||
{-# LANGUAGE KindSignatures #-} | ||
{-# LANGUAGE PolyKinds #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
|
||
module Network.Protocol.Handshake.Codec | ||
where | ||
|
||
import Control.Monad ((<=<)) | ||
import Data.Binary | ||
import Data.Binary.Put (runPut) | ||
import Data.ByteString.Lazy (toStrict) | ||
import qualified Data.ByteString.Lazy as LBS | ||
import Data.Foldable (Foldable(fold)) | ||
import Network.Protocol.Codec (DeserializeError(..), decodeGet) | ||
import Network.Protocol.Handshake.Types | ||
import Network.TypedProtocol.Codec | ||
|
||
codecHandshake | ||
:: forall ps h m | ||
. (Monad m, Binary h) | ||
=> Codec ps DeserializeError m LBS.ByteString | ||
-> Codec (Handshake h ps) DeserializeError m LBS.ByteString | ||
codecHandshake (Codec encodeMsg decodeMsg) = Codec | ||
( \case | ||
ClientAgency TokInit -> \case | ||
MsgHandshake h -> runPut do | ||
putWord8 0x00 | ||
put h | ||
ClientAgency (TokLiftClient tok) -> \case | ||
MsgLift msg -> fold | ||
[ runPut $ putWord8 0x3 | ||
, encodeMsg (ClientAgency tok) msg | ||
] | ||
|
||
ServerAgency TokHandshake -> runPut . \case | ||
MsgAccept -> putWord8 0x01 | ||
MsgReject -> putWord8 0x02 | ||
|
||
ServerAgency (TokLiftServer tok) -> \case | ||
MsgLift msg -> fold | ||
[ runPut $ putWord8 0x3 | ||
, encodeMsg (ServerAgency tok) msg | ||
] | ||
) | ||
(\tok -> decodeGet getWord8 >>= handleTag tok) | ||
where | ||
handleTag | ||
:: PeerHasAgency pr (st :: Handshake h ps) | ||
-> DecodeStep LBS.ByteString DeserializeError m Word8 | ||
-> m (DecodeStep LBS.ByteString DeserializeError m (SomeMessage st)) | ||
handleTag tok = \case | ||
DecodeFail f -> pure $ DecodeFail f | ||
DecodePartial next -> pure $ DecodePartial $ handleTag tok <=< next | ||
DecodeDone tag mUnconsumed -> case tag of | ||
0x00 -> case tok of | ||
ClientAgency TokInit -> handleH <$> decodeGet get | ||
_ -> failInvalidTag "Invalid protocol state for MsgHandshake" mUnconsumed | ||
0x01 -> case tok of | ||
ServerAgency TokHandshake -> pure $ DecodeDone (SomeMessage MsgAccept) mUnconsumed | ||
_ -> failInvalidTag "Invalid protocol state for MsgAccept" mUnconsumed | ||
0x02 -> case tok of | ||
ServerAgency TokHandshake -> pure $ DecodeDone (SomeMessage MsgAccept) mUnconsumed | ||
_ -> failInvalidTag "Invalid protocol state for MsgAccept" mUnconsumed | ||
0x03 -> case tok of | ||
ClientAgency (TokLiftClient tok') -> handleMsg <$> decodeMsg (ClientAgency tok') | ||
ServerAgency (TokLiftServer tok') -> handleMsg <$> decodeMsg (ServerAgency tok') | ||
_ -> failInvalidTag "Invalid protocol state for MsgLift" mUnconsumed | ||
_ -> failInvalidTag ("Invalid msg tag " <> show tag) mUnconsumed | ||
where | ||
failInvalidTag message mUnconsumed = pure $ DecodeFail DeserializeError | ||
{ message | ||
, offset = 0 | ||
, unconsumedInput = foldMap toStrict mUnconsumed | ||
} | ||
|
||
handleH | ||
:: DecodeStep LBS.ByteString DeserializeError m h | ||
-> DecodeStep LBS.ByteString DeserializeError m (SomeMessage ('StInit st' :: Handshake h ps)) | ||
handleH = \case | ||
DecodeFail f -> DecodeFail f | ||
DecodePartial next -> DecodePartial $ fmap handleH <$> next | ||
DecodeDone h mUnconsumed -> DecodeDone (SomeMessage $ MsgHandshake h) mUnconsumed | ||
|
||
handleMsg | ||
:: DecodeStep LBS.ByteString DeserializeError m (SomeMessage st') | ||
-> DecodeStep LBS.ByteString DeserializeError m (SomeMessage ('StLift st')) | ||
handleMsg = \case | ||
DecodeFail f -> DecodeFail f | ||
DecodePartial next -> DecodePartial $ fmap handleMsg <$> next | ||
DecodeDone (SomeMessage msg) mUnconsumed -> DecodeDone (SomeMessage $ MsgLift msg) mUnconsumed |
56 changes: 56 additions & 0 deletions
56
marlowe-protocols/src/Network/Protocol/Handshake/Server.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE GADTs #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
|
||
-- | A generic server for the handshake protocol. Includes a function for | ||
-- interpreting a server as a typed-protocols peer that can be executed with a | ||
-- driver and a codec. | ||
|
||
module Network.Protocol.Handshake.Server | ||
where | ||
|
||
import Data.Bifunctor (Bifunctor(bimap)) | ||
import Data.Functor ((<&>)) | ||
import Network.Protocol.Handshake.Types | ||
import Network.TypedProtocol | ||
|
||
-- | A generic server for the handshake protocol. | ||
newtype HandshakeServer h server m a = HandshakeServer | ||
{ recvMsgHandshake :: h -> m (Either a (server m a)) | ||
} | ||
|
||
instance (Functor m, Functor (server m)) => Functor (HandshakeServer h server m) where | ||
fmap f HandshakeServer{..} = HandshakeServer | ||
{ recvMsgHandshake = fmap (bimap f $ fmap f) . recvMsgHandshake | ||
} | ||
|
||
hoistHandshakeServer | ||
:: Functor m | ||
=> (forall x. (forall y. m y -> n y) -> server m x -> server n x) | ||
-> (forall x. m x -> n x) | ||
-> HandshakeServer h server m a | ||
-> HandshakeServer h server n a | ||
hoistHandshakeServer hoistServer f HandshakeServer{..} = HandshakeServer | ||
{ recvMsgHandshake = f . (fmap . fmap) (hoistServer f) . recvMsgHandshake | ||
} | ||
|
||
handshakeServerPeer | ||
:: forall client m h ps st a | ||
. Functor m | ||
=> (forall x. client m x -> Peer ps 'AsServer st m x) | ||
-> HandshakeServer h client m a | ||
-> Peer (Handshake h ps) 'AsServer ('StInit st) m a | ||
handshakeServerPeer serverPeer HandshakeServer{..} = | ||
Await (ClientAgency TokInit) \case | ||
MsgHandshake h -> Effect $ recvMsgHandshake h <&> \case | ||
Left a -> | ||
Yield (ServerAgency TokHandshake) MsgReject $ Done TokDone a | ||
Right server -> | ||
Yield (ServerAgency TokHandshake) MsgAccept $ liftPeer $ serverPeer server | ||
where | ||
liftPeer :: forall st'. Peer ps 'AsServer st' m a -> Peer (Handshake h ps) 'AsServer ('StLift st') m a | ||
liftPeer = \case | ||
Effect m -> Effect $ liftPeer <$> m | ||
Done tok a -> Done (TokLiftNobody tok) a | ||
Yield (ServerAgency tok) msg next -> Yield (ServerAgency $ TokLiftServer tok) (MsgLift msg) $ liftPeer next | ||
Await (ClientAgency tok) next -> Await (ClientAgency $ TokLiftClient tok) \(MsgLift msg) -> liftPeer $ next msg |
154 changes: 154 additions & 0 deletions
154
marlowe-protocols/src/Network/Protocol/Handshake/Types.hs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
{-# LANGUAGE DataKinds #-} | ||
{-# LANGUAGE EmptyCase #-} | ||
{-# LANGUAGE GADTs #-} | ||
{-# LANGUAGE KindSignatures #-} | ||
{-# LANGUAGE PolyKinds #-} | ||
{-# LANGUAGE TypeFamilies #-} | ||
{-# LANGUAGE TypeOperators #-} | ||
|
||
-- | The type of the handshake protocol. | ||
-- | ||
-- The job protocol is used to establish a connection between two peers, and | ||
-- gives them a chance to confirm that they speak the same protocol. | ||
|
||
module Network.Protocol.Handshake.Types | ||
where | ||
|
||
import Data.Aeson (ToJSON, Value(..), object, (.=)) | ||
import GHC.Show (showSpace) | ||
import Network.Protocol.Codec.Spec (ArbitraryMessage(..), MessageEq(..), ShowProtocol(..)) | ||
import Network.Protocol.Driver (MessageToJSON(..)) | ||
import Network.TypedProtocol | ||
import Network.TypedProtocol.Codec (AnyMessageAndAgency(..)) | ||
import Test.QuickCheck (Arbitrary(arbitrary), oneof, shrink) | ||
|
||
data Handshake h ps where | ||
StInit :: ps -> Handshake h ps | ||
StHandshake :: ps -> Handshake h ps | ||
StLift :: ps -> Handshake h ps | ||
StDone :: Handshake h ps | ||
|
||
instance Protocol ps => Protocol (Handshake h ps) where | ||
data Message (Handshake h ps) st st' where | ||
MsgHandshake :: h -> Message (Handshake h ps) | ||
('StInit st) | ||
('StHandshake st) | ||
MsgAccept :: Message (Handshake h ps) | ||
('StHandshake st) | ||
('StLift st) | ||
MsgReject :: Message (Handshake h ps) | ||
('StHandshake st) | ||
'StDone | ||
MsgLift :: Message ps st st' -> Message (Handshake h ps) | ||
('StLift st) | ||
('StLift st') | ||
|
||
data ClientHasAgency st where | ||
TokInit :: ClientHasAgency ('StInit st') | ||
TokLiftClient :: ClientHasAgency st' -> ClientHasAgency ('StLift st') | ||
|
||
data ServerHasAgency st where | ||
TokHandshake :: ServerHasAgency ('StHandshake st') | ||
TokLiftServer :: ServerHasAgency st' -> ServerHasAgency ('StLift st') | ||
|
||
data NobodyHasAgency st where | ||
TokDone :: NobodyHasAgency 'StDone | ||
TokLiftNobody :: NobodyHasAgency st' -> NobodyHasAgency ('StLift st') | ||
|
||
exclusionLemma_ClientAndServerHaveAgency = \case | ||
TokInit -> \case | ||
TokLiftClient c_tok -> \case | ||
TokLiftServer s_tok -> exclusionLemma_ClientAndServerHaveAgency c_tok s_tok | ||
|
||
exclusionLemma_NobodyAndClientHaveAgency = \case | ||
TokDone -> \case | ||
TokLiftNobody n_tok -> \case | ||
TokLiftClient c_tok -> exclusionLemma_NobodyAndClientHaveAgency n_tok c_tok | ||
|
||
exclusionLemma_NobodyAndServerHaveAgency = \case | ||
TokDone -> \case | ||
TokLiftNobody n_tok -> \case | ||
TokLiftServer s_tok -> exclusionLemma_NobodyAndServerHaveAgency n_tok s_tok | ||
|
||
instance (ArbitraryMessage ps, Arbitrary h) => ArbitraryMessage (Handshake h ps) where | ||
arbitraryMessage = oneof | ||
[ AnyMessageAndAgency (ClientAgency TokInit) . MsgHandshake <$> arbitrary | ||
, pure $ AnyMessageAndAgency (ServerAgency TokHandshake) MsgReject | ||
, pure $ AnyMessageAndAgency (ServerAgency TokHandshake) MsgAccept | ||
, do | ||
AnyMessageAndAgency tok msg <- arbitraryMessage @ps | ||
pure case tok of | ||
ClientAgency tok' -> AnyMessageAndAgency (ClientAgency $ TokLiftClient tok') (MsgLift msg) | ||
ServerAgency tok' -> AnyMessageAndAgency (ServerAgency $ TokLiftServer tok') (MsgLift msg) | ||
] | ||
shrinkMessage = \case | ||
ClientAgency TokInit -> \case | ||
MsgHandshake h -> MsgHandshake <$> shrink h | ||
ClientAgency (TokLiftClient tok) -> \case | ||
MsgLift msg -> MsgLift <$> shrinkMessage (ClientAgency tok) msg | ||
ServerAgency TokHandshake -> const [] | ||
ServerAgency (TokLiftServer tok) -> \case | ||
MsgLift msg -> MsgLift <$> shrinkMessage (ServerAgency tok) msg | ||
|
||
instance (MessageEq ps, Eq h) => MessageEq (Handshake h ps) where | ||
messageEq (AnyMessageAndAgency tok1 msg1) (AnyMessageAndAgency tok2 msg2)= case (tok1, tok2) of | ||
(ClientAgency TokInit, ClientAgency TokInit) -> case (msg1, msg2) of | ||
(MsgHandshake h, MsgHandshake h') -> h == h' | ||
(ClientAgency TokInit, _) -> False | ||
(ClientAgency (TokLiftClient tok1'), ClientAgency (TokLiftClient tok2')) -> case (msg1, msg2) of | ||
(MsgLift msg1', MsgLift msg2') -> | ||
messageEq (AnyMessageAndAgency (ClientAgency tok1') msg1') (AnyMessageAndAgency (ClientAgency tok2') msg2') | ||
(ClientAgency (TokLiftClient _), _) -> False | ||
(ServerAgency TokHandshake, ServerAgency TokHandshake) -> case (msg1, msg2) of | ||
(MsgAccept, MsgAccept) -> True | ||
(MsgAccept, MsgReject) -> False | ||
(MsgReject, MsgAccept) -> False | ||
(MsgReject, MsgReject) -> True | ||
(ServerAgency TokHandshake, _) -> False | ||
(ServerAgency (TokLiftServer tok1'), ServerAgency (TokLiftServer tok2')) -> case (msg1, msg2) of | ||
(MsgLift msg1', MsgLift msg2') -> | ||
messageEq (AnyMessageAndAgency (ServerAgency tok1') msg1') (AnyMessageAndAgency (ServerAgency tok2') msg2') | ||
(ServerAgency (TokLiftServer _), _) -> False | ||
|
||
instance (ShowProtocol ps, Show h) => ShowProtocol (Handshake h ps) where | ||
showsPrecMessage p tok = \case | ||
MsgHandshake h -> showParen (p >= 11) | ||
( showString "MsgHandshake" | ||
. showSpace | ||
. showsPrec 11 h | ||
) | ||
MsgAccept -> showString "MsgAccept" | ||
MsgReject -> showString "MsgReject" | ||
MsgLift msg -> showParen (p >= 11) | ||
( showString "MsgLift" | ||
. showSpace | ||
. case tok of | ||
ClientAgency (TokLiftClient tok') -> showsPrecMessage 11 (ClientAgency tok') msg | ||
ServerAgency (TokLiftServer tok') -> showsPrecMessage 11 (ServerAgency tok') msg | ||
) | ||
showsPrecServerHasAgency p = \case | ||
TokHandshake -> showString "TokHandshake" | ||
TokLiftServer tok -> showParen (p >= 11) | ||
( showString "TokLiftServer" | ||
. showSpace | ||
. showsPrecServerHasAgency 11 tok | ||
) | ||
showsPrecClientHasAgency p = \case | ||
TokInit -> showString "TokInit" | ||
TokLiftClient tok -> showParen (p >= 11) | ||
( showString "TokLiftClient" | ||
. showSpace | ||
. showsPrecClientHasAgency 11 tok | ||
) | ||
|
||
instance (MessageToJSON ps, ToJSON h) => MessageToJSON (Handshake h ps) where | ||
messageToJSON = \case | ||
ClientAgency TokInit -> \case | ||
MsgHandshake h -> object [ "handshake" .= h ] | ||
ClientAgency (TokLiftClient tok) -> \case | ||
MsgLift msg -> object [ "lift" .= messageToJSON (ClientAgency tok) msg ] | ||
ServerAgency TokHandshake -> String . \case | ||
MsgAccept -> "accept" | ||
MsgReject -> "reject" | ||
ServerAgency (TokLiftServer tok) -> \case | ||
MsgLift msg -> object [ "lift" .= messageToJSON (ServerAgency tok) msg ] |