Skip to content
Permalink
Browse files

Don't use `Serialise` in codecs

Instead take explicit encoders and decoders. This will be needed for Byron
since we don't _have_ those instances there (instead we have functions that
require additional context).

This also updates the Byron proxy.  This could mean that we don't need the
wrappers anymore, but I'll leave that to Alex.
  • Loading branch information...
edsko committed May 16, 2019
1 parent f178137 commit e0e991183b7a6efeb891240b5abd078a5c52ffb5
@@ -52,4 +52,4 @@ instance Serialise Point where
codec
:: (MonadST m)
=> Codec (ChainSync Block Point) CBOR.DeserialiseFailure m Lazy.ByteString
codec = codecChainSync
codec = codecChainSync encode encode decode decode
@@ -7,7 +7,7 @@ module Run (
runNode
) where

import Codec.Serialise (encode)
import Codec.Serialise (encode, decode)
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.STM
import Control.Monad
@@ -177,11 +177,11 @@ handleSimpleNode p CLI{..} (TopologyInfo myNodeId topologyFile) = do
where
direction = Upstream (producerNodeId :==>: myNodeId)
nodeCommsCS = NodeComms {
ncCodec = codecChainSync
ncCodec = codecChainSync encode encode decode decode
, ncWithChan = NamedPipe.withPipeChannel "chain-sync" direction
}
nodeCommsBF = NodeComms {
ncCodec = codecBlockFetch
ncCodec = codecBlockFetch encode encode decode decode
, ncWithChan = NamedPipe.withPipeChannel "block-fetch" direction
}

@@ -194,10 +194,10 @@ handleSimpleNode p CLI{..} (TopologyInfo myNodeId topologyFile) = do
where
direction = Downstream (myNodeId :==>: consumerNodeId)
nodeCommsCS = NodeComms {
ncCodec = codecChainSync
ncCodec = codecChainSync encode encode decode decode
, ncWithChan = NamedPipe.withPipeChannel "chain-sync" direction
}
nodeCommsBF = NodeComms {
ncCodec = codecBlockFetch
ncCodec = codecBlockFetch encode encode decode decode
, ncWithChan = NamedPipe.withPipeChannel "block-fetch" direction
}
@@ -17,7 +17,7 @@ import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import qualified Data.ByteString.Lazy as LBS
import Codec.Serialise (Serialise)
import Codec.Serialise (Serialise(..))

import Control.Monad.Class.MonadSTM
import Control.Monad.Class.MonadST
@@ -190,7 +190,7 @@ runFetchClient :: (MonadCatch m, MonadAsync m, MonadST m, Ord peerid,
-> m a
runFetchClient tracer registry peerid channel client =
bracketFetchClient registry peerid $ \stateVars ->
runPipelinedPeer 10 tracer codecBlockFetch channel (client stateVars)
runPipelinedPeer 10 tracer (codecBlockFetch encode encode decode decode) channel (client stateVars)

runFetchServer :: (MonadThrow m, MonadST m,
Serialise header,
@@ -201,7 +201,7 @@ runFetchServer :: (MonadThrow m, MonadST m,
-> BlockFetchServer header block m a
-> m a
runFetchServer tracer channel server =
runPeer tracer codecBlockFetch channel (blockFetchServerPeer server)
runPeer tracer (codecBlockFetch encode encode decode decode) channel (blockFetchServerPeer server)

runFetchClientAndServerAsync
:: (MonadCatch m, MonadAsync m, MonadST m, Ord peerid,
@@ -341,4 +341,3 @@ mkTestFetchedBlockHeap points = do
getTestFetchedBlocks = readTVar v,
addTestFetchedBlock = \p _b -> atomically (modifyTVar' v (Set.insert p))
}

@@ -19,40 +19,48 @@ import Data.ByteString.Lazy (ByteString)
import qualified Codec.CBOR.Encoding as CBOR (Encoding, encodeListLen, encodeWord)
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Decoding as CBOR (Decoder, decodeListLen, decodeWord)
import Codec.Serialise.Class (Serialise)
import qualified Codec.Serialise.Class as CBOR

import Network.TypedProtocol.Codec
import Ouroboros.Network.Codec

import Ouroboros.Network.Block (HeaderHash)
import Ouroboros.Network.Block (HeaderHash, Point)
import qualified Ouroboros.Network.Block as Block
import Ouroboros.Network.Protocol.BlockFetch.Type

codecBlockFetch
:: forall header body m.
( Monad m
, MonadST m
, Serialise header
, Serialise body
, Serialise (HeaderHash header)
)
=> Codec (BlockFetch header body) CBOR.DeserialiseFailure m ByteString
codecBlockFetch = mkCodecCborLazyBS encode decode
=> (body -> CBOR.Encoding)
-> (HeaderHash header -> CBOR.Encoding)
-> (forall s. CBOR.Decoder s body)
-> (forall s. CBOR.Decoder s (HeaderHash header))
-> Codec (BlockFetch header body) CBOR.DeserialiseFailure m ByteString
codecBlockFetch encodeBody encodeHeaderHash
decodeBody decodeHeaderHash =
mkCodecCborLazyBS encode decode
where
encodePoint' :: Point header -> CBOR.Encoding
encodePoint' = Block.encodePoint $ Block.encodeChainHash encodeHeaderHash

decodePoint' :: forall s. CBOR.Decoder s (Point header)
decodePoint' = Block.decodePoint $ Block.decodeChainHash decodeHeaderHash

encode :: forall (pr :: PeerRole) st st'.
PeerHasAgency pr st
-> Message (BlockFetch header body) st st'
-> CBOR.Encoding
encode (ClientAgency TokIdle) (MsgRequestRange (ChainRange from to)) =
CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> CBOR.encode from <> CBOR.encode to
CBOR.encodeListLen 2 <> CBOR.encodeWord 0 <> encodePoint' from <> encodePoint' to
encode (ClientAgency TokIdle) MsgClientDone =
CBOR.encodeListLen 1 <> CBOR.encodeWord 1
encode (ServerAgency TokBusy) MsgStartBatch =
CBOR.encodeListLen 1 <> CBOR.encodeWord 2
encode (ServerAgency TokBusy) MsgNoBlocks =
CBOR.encodeListLen 1 <> CBOR.encodeWord 3
encode (ServerAgency TokStreaming) (MsgBlock body) =
CBOR.encodeListLen 2 <> CBOR.encodeWord 4 <> CBOR.encode body
CBOR.encodeListLen 2 <> CBOR.encodeWord 4 <> encodeBody body
encode (ServerAgency TokStreaming) MsgBatchDone =
CBOR.encodeListLen 1 <> CBOR.encodeWord 5

@@ -64,13 +72,13 @@ codecBlockFetch = mkCodecCborLazyBS encode decode
key <- CBOR.decodeWord
case (stok, key) of
(ClientAgency TokIdle, 0) -> do
from <- CBOR.decode
to <- CBOR.decode
from <- decodePoint'
to <- decodePoint'
return $ SomeMessage $ MsgRequestRange (ChainRange from to)
(ClientAgency TokIdle, 1) -> return $ SomeMessage MsgClientDone
(ServerAgency TokBusy, 2) -> return $ SomeMessage MsgStartBatch
(ServerAgency TokBusy, 3) -> return $ SomeMessage MsgNoBlocks
(ServerAgency TokStreaming, 4) -> SomeMessage . MsgBlock <$> CBOR.decode
(ServerAgency TokStreaming, 4) -> SomeMessage . MsgBlock <$> decodeBody
(ServerAgency TokStreaming, 5) -> return $ SomeMessage MsgBatchDone

-- TODO proper exceptions
@@ -8,6 +8,7 @@ module Ouroboros.Network.Protocol.BlockFetch.Test (tests) where

import Control.Monad.ST (runST)
import Data.ByteString.Lazy (ByteString)
import qualified Codec.Serialise as S

import Control.Monad.IOSim (runSimOrThrow)
import Control.Monad.Class.MonadST (MonadST)
@@ -124,7 +125,7 @@ prop_direct (TestChainAndPoints chain points) =

-- | Run a pipelined block-fetch client with a server, without going via 'Peer'.
--
--
--
--
prop_directPipelined1 :: TestChainAndPoints -> Bool
prop_directPipelined1 (TestChainAndPoints chain points) =
@@ -267,7 +268,7 @@ prop_channel :: (MonadAsync m, MonadCatch m, MonadST m)
prop_channel createChannels chain points = do
(bodies, ()) <-
runConnectedPeers
createChannels nullTracer codecBlockFetch
createChannels nullTracer (codecBlockFetch S.encode S.encode S.decode S.decode)
(blockFetchClientPeer (testClient chain points))
(blockFetchServerPeer (testServer chain))
return $ reverse bodies === concat (receivedBlockBodies chain points)
@@ -327,19 +328,19 @@ prop_codec_BlockFetch
:: AnyMessageAndAgency (BlockFetch BlockHeader BlockBody)
-> Bool
prop_codec_BlockFetch msg =
runST (prop_codecM codecBlockFetch msg)
runST (prop_codecM (codecBlockFetch S.encode S.encode S.decode S.decode) msg)

prop_codec_splits2_BlockFetch
:: AnyMessageAndAgency (BlockFetch BlockHeader BlockBody)
-> Bool
prop_codec_splits2_BlockFetch msg =
runST (prop_codec_splitsM splits2 codecBlockFetch msg)
runST (prop_codec_splitsM splits2 (codecBlockFetch S.encode S.encode S.decode S.decode) msg)

prop_codec_splits3_BlockFetch
:: AnyMessageAndAgency (BlockFetch BlockHeader BlockBody)
-> Bool
prop_codec_splits3_BlockFetch msg =
runST (prop_codec_splitsM splits3 codecBlockFetch msg)
runST (prop_codec_splitsM splits3 (codecBlockFetch S.encode S.encode S.decode S.decode) msg)


--
@@ -1,9 +1,10 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}

module Ouroboros.Network.Protocol.ChainSync.Codec
( codecChainSync
@@ -22,19 +23,33 @@ import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Decoding as CBOR (Decoder)

import Data.ByteString.Lazy (ByteString)
import Codec.CBOR.Encoding (encodeListLen, encodeWord)
import Codec.CBOR.Decoding (decodeListLen, decodeWord)
import Codec.Serialise.Class (Serialise)
import qualified Codec.Serialise.Class as CBOR

import Codec.CBOR.Encoding (
Encoding
, encodeBreak
, encodeListLen
, encodeListLenIndef
, encodeWord
)
import Codec.CBOR.Decoding (
Decoder
, decodeListLen
, decodeListLenOrIndef
, decodeSequenceLenIndef
, decodeSequenceLenN
, decodeWord
)

-- | The main CBOR 'Codec' for the 'ChainSync' protocol.
--
codecChainSync :: forall header point m.
(MonadST m, Serialise header, Serialise point)
=> Codec (ChainSync header point)
MonadST m
=> (header -> Encoding)
-> (point -> Encoding)
-> (forall s. Decoder s header)
-> (forall s. Decoder s point)
-> Codec (ChainSync header point)
CBOR.DeserialiseFailure m ByteString
codecChainSync =
codecChainSync encodeHeader encodePoint decodeHeader decodePoint =
mkCodecCborLazyBS encode decode
where
encode :: forall (pr :: PeerRole) (st :: ChainSync header point) (st' :: ChainSync header point).
@@ -49,19 +64,19 @@ codecChainSync =
encodeListLen 1 <> encodeWord 1

encode (ServerAgency TokNext{}) (MsgRollForward h p) =
encodeListLen 3 <> encodeWord 2 <> CBOR.encode h <> CBOR.encode p
encodeListLen 3 <> encodeWord 2 <> encodeHeader h <> encodePoint p

encode (ServerAgency TokNext{}) (MsgRollBackward p1 p2) =
encodeListLen 3 <> encodeWord 3 <> CBOR.encode p1 <> CBOR.encode p2
encodeListLen 3 <> encodeWord 3 <> encodePoint p1 <> encodePoint p2

encode (ClientAgency TokIdle) (MsgFindIntersect ps) =
encodeListLen 2 <> encodeWord 4 <> CBOR.encode ps
encodeListLen 2 <> encodeWord 4 <> encodeList encodePoint ps

encode (ServerAgency TokIntersect) (MsgIntersectImproved p1 p2) =
encodeListLen 3 <> encodeWord 5 <> CBOR.encode p1 <> CBOR.encode p2
encodeListLen 3 <> encodeWord 5 <> encodePoint p1 <> encodePoint p2

encode (ServerAgency TokIntersect) (MsgIntersectUnchanged p) =
encodeListLen 2 <> encodeWord 6 <> CBOR.encode p
encodeListLen 2 <> encodeWord 6 <> encodePoint p

encode (ClientAgency TokIdle) MsgDone =
encodeListLen 1 <> encodeWord 7
@@ -80,26 +95,26 @@ codecChainSync =
return (SomeMessage MsgAwaitReply)

(2, 3, ServerAgency (TokNext _)) -> do
h <- CBOR.decode
p <- CBOR.decode
h <- decodeHeader
p <- decodePoint
return (SomeMessage (MsgRollForward h p))

(3, 3, ServerAgency (TokNext _)) -> do
p1 <- CBOR.decode
p2 <- CBOR.decode
p1 <- decodePoint
p2 <- decodePoint
return (SomeMessage (MsgRollBackward p1 p2))

(4, 2, ClientAgency TokIdle) -> do
ps <- CBOR.decode
ps <- decodeList decodePoint
return (SomeMessage (MsgFindIntersect ps))

(5, 3, ServerAgency TokIntersect) -> do
p1 <- CBOR.decode
p2 <- CBOR.decode
p1 <- decodePoint
p2 <- decodePoint
return (SomeMessage (MsgIntersectImproved p1 p2))

(6, 2, ServerAgency TokIntersect) -> do
p <- CBOR.decode
p <- decodePoint
return (SomeMessage (MsgIntersectUnchanged p))

(7, 1, ClientAgency TokIdle) ->
@@ -147,3 +162,21 @@ codecChainSyncId = Codec encode decode
(ClientAgency TokIdle, Just (AnyMessage MsgDone)) -> return (DecodeDone (SomeMessage MsgDone) Nothing)

(_, _) -> return $ DecodeFail (CodecFailure "codecChainSync: no matching message")

{-------------------------------------------------------------------------------
Auxiliary
This is adapted from 'defaultEncodeList' and 'defaultDecodeList' from
@serialise@; they should relaly be exported.
-------------------------------------------------------------------------------}

encodeList :: (a -> Encoding) -> [a] -> Encoding
encodeList _ [] = encodeListLen 0
encodeList e xs = encodeListLenIndef <> Prelude.foldr (\x r -> e x <> r) encodeBreak xs

decodeList :: Decoder s a -> Decoder s [a]
decodeList d = do
mn <- decodeListLenOrIndef
case mn of
Nothing -> decodeSequenceLenIndef (flip (:)) [] reverse d
Just n -> decodeSequenceLenN (flip (:)) [] reverse n d
Oops, something went wrong.

0 comments on commit e0e9911

Please sign in to comment.
You can’t perform that action at this time.