Skip to content

Commit

Permalink
Add cookie to keep-alive protocol
Browse files Browse the repository at this point in the history
Add a random cookie to keep-alive messages so that a rouge responder
can't send a MsgKeepAliveResponse before it has read the MsgKeepAlive.
  • Loading branch information
karknu committed Aug 12, 2020
1 parent fd39a33 commit c7c455f
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 52 deletions.
Expand Up @@ -54,6 +54,7 @@ import qualified Data.Set as Set
import qualified Data.Typeable as Typeable
import Data.Void (Void)
import GHC.Stack
import System.Random (mkStdGen)

import qualified Ouroboros.Network.AnchoredFragment as AF
import Ouroboros.Network.BlockFetch (BlockFetchConfiguration (..))
Expand Down Expand Up @@ -911,6 +912,8 @@ runThreadNetwork systemTime ThreadNetworkArgs
(ledgerState <$>
ChainDB.getCurrentLedger chainDB)

let kaRng = case seed of
Seed s -> mkStdGen s
let nodeArgs = NodeArgs
{ tracers
, registry
Expand All @@ -922,6 +925,7 @@ runThreadNetwork systemTime ThreadNetworkArgs
, blockFetchSize = nodeBlockFetchSize
, maxTxCapacityOverride = NoMaxTxCapacityOverride
, mempoolCapacityOverride = NoMempoolCapacityBytesOverride
, keepAliveRng = kaRng
, miniProtocolParameters = MiniProtocolParameters {
chainSyncPipeliningHighMark = 4,
chainSyncPipeliningLowMark = 2,
Expand Down
Expand Up @@ -162,7 +162,7 @@ mkHandlers
-> NodeKernel m remotePeer localPeer blk
-> Handlers m remotePeer blk
mkHandlers
NodeArgs {miniProtocolParameters}
NodeArgs {keepAliveRng, miniProtocolParameters}
NodeKernel {getChainDB, getMempool, getTopLevelConfig, getTracers = tracers} =
Handlers {
hChainSyncClient =
Expand Down Expand Up @@ -197,7 +197,7 @@ mkHandlers
(getMempoolReader getMempool)
(getMempoolWriter getMempool)
version
, hKeepAliveClient = \_version -> keepAliveClient (Node.keepAliveClientTracer tracers)
, hKeepAliveClient = \_version -> keepAliveClient (Node.keepAliveClientTracer tracers) keepAliveRng
, hKeepAliveServer = \_version _peer -> keepAliveServer
}

Expand Down
4 changes: 3 additions & 1 deletion ouroboros-consensus/src/Ouroboros/Consensus/Node.hs
Expand Up @@ -40,7 +40,7 @@ import Control.Tracer (Tracer, contramap)
import Data.ByteString.Lazy (ByteString)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import System.Random (randomIO, randomRIO)
import System.Random (newStdGen, randomIO, randomRIO)

import Ouroboros.Network.BlockFetch (BlockFetchConfiguration (..))
import Ouroboros.Network.Diffusion
Expand Down Expand Up @@ -427,6 +427,7 @@ mkNodeArgs
mkNodeArgs registry cfg mInitBlockForging tracers btime chainDB = do
mBlockForging <- sequence mInitBlockForging
bfsalt <- randomIO -- Per-node specific value used by blockfetch when ranking peers.
keepAliveRng <- newStdGen
return NodeArgs
{ tracers
, registry
Expand All @@ -440,6 +441,7 @@ mkNodeArgs registry cfg mInitBlockForging tracers btime chainDB = do
, mempoolCapacityOverride = NoMempoolCapacityBytesOverride
, miniProtocolParameters = defaultMiniProtocolParameters
, blockFetchConfiguration = defaultBlockFetchConfiguration bfsalt
, keepAliveRng = keepAliveRng
}
where
defaultBlockFetchConfiguration :: Int -> BlockFetchConfiguration
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs
Expand Up @@ -29,6 +29,7 @@ import Data.Map.Strict (Map)
import Data.Maybe (isJust)
import Data.Proxy
import Data.Word (Word32)
import System.Random (StdGen)

import Control.Tracer

Expand Down Expand Up @@ -126,6 +127,7 @@ data NodeArgs m remotePeer localPeer blk = NodeArgs {
, mempoolCapacityOverride :: MempoolCapacityBytesOverride
, miniProtocolParameters :: MiniProtocolParameters
, blockFetchConfiguration :: BlockFetchConfiguration
, keepAliveRng :: StdGen
}

initNodeKernel
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-network/ouroboros-network.cabal
Expand Up @@ -149,6 +149,7 @@ library
network >=3.1 && <3.2,
psqueues >=0.2.3 && <0.3,
serialise >=0.2 && <0.3,
random,
stm >=2.4 && <2.6,
time >=1.6 && <1.10,

Expand Down Expand Up @@ -274,6 +275,7 @@ test-suite test-network
pipes,
process,
psqueues,
random,
serialise,
splitmix,
stm,
Expand Down
Expand Up @@ -13,7 +13,7 @@ direct KeepAliveServer { recvMsgDone }
(SendMsgDone mdone) =
(,) <$> recvMsgDone <*> mdone
direct KeepAliveServer { recvMsgKeepAlive }
(SendMsgKeepAlive mclient) = do
(SendMsgKeepAlive _cookie mclient) = do
server <- recvMsgKeepAlive
client <- mclient
direct server client
Expand Up @@ -4,6 +4,7 @@ module Ouroboros.Network.Protocol.KeepAlive.Examples where

import Ouroboros.Network.Protocol.KeepAlive.Server
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Type


-- | A client which applies a function whenever it receives
Expand All @@ -23,7 +24,7 @@ keepAliveClientApply f = go
= SendMsgDone (pure acc)

| otherwise
= SendMsgKeepAlive $
= SendMsgKeepAlive (Cookie $ fromIntegral n) $
pure $ go (f acc) (pred n)


Expand Down
Expand Up @@ -15,7 +15,6 @@ import Control.Monad.IOSim (runSimOrThrow)
import Control.Tracer (nullTracer)

import qualified Codec.CBOR.Read as CBOR
import Data.Functor.Identity (Identity (..))
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BL

Expand Down Expand Up @@ -67,7 +66,7 @@ tests = testGroup "Ouroboros.Network.Protocol.KeepAlive"

prop_direct :: (Int -> Int) -> NonNegative Int -> Property
prop_direct f (NonNegative n) =
runIdentity
runSimOrThrow
(direct
keepAliveServerCount
(keepAliveClientApply f 0 n))
Expand All @@ -82,7 +81,7 @@ prop_connect :: (Int -> Int)
-> NonNegative Int
-> Bool
prop_connect f (NonNegative n) =
case runIdentity
case runSimOrThrow
(connect
(keepAliveServerPeer keepAliveServerCount)
(keepAliveClientPeer $ keepAliveClientApply f 0 n))
Expand Down Expand Up @@ -130,19 +129,21 @@ prop_channel_IO f (NonNegative n) =
--

instance Arbitrary (AnyMessageAndAgency KeepAlive) where
arbitrary = oneof
[ pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgKeepAlive
, pure $ AnyMessageAndAgency (ServerAgency TokServer) MsgKeepAliveResponse
, pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgDone
]
arbitrary = do
c <- arbitrary
oneof
[ pure $ AnyMessageAndAgency (ClientAgency TokClient) (MsgKeepAlive $ Cookie c)
, pure $ AnyMessageAndAgency (ServerAgency TokServer) (MsgKeepAliveResponse $ Cookie c)
, pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgDone
]

instance Show (AnyMessageAndAgency KeepAlive) where
show (AnyMessageAndAgency _ msg) = show msg

instance Eq (AnyMessage KeepAlive) where
AnyMessage MsgKeepAlive == AnyMessage MsgKeepAlive = True
AnyMessage (MsgKeepAliveResponse) == AnyMessage (MsgKeepAliveResponse) = True
AnyMessage MsgDone == AnyMessage MsgDone = True
AnyMessage (MsgKeepAlive cookieA) == AnyMessage (MsgKeepAlive cookieB) = cookieA == cookieB
AnyMessage (MsgKeepAliveResponse cookieA) == AnyMessage (MsgKeepAliveResponse cookieB) = cookieA == cookieB
AnyMessage MsgDone == AnyMessage MsgDone = True
_ == _ = False

prop_codec :: AnyMessageAndAgency KeepAlive -> Bool
Expand Down
16 changes: 11 additions & 5 deletions ouroboros-network/src/Ouroboros/Network/KeepAlive.hs
Expand Up @@ -20,11 +20,13 @@ import Control.Monad.Class.MonadTimer
import Control.Tracer (Tracer, traceWith)
import Data.Maybe (fromJust)
import qualified Data.Map.Strict as M
import System.Random (StdGen, random)

import Ouroboros.Network.Mux (RunOrStop (..), ScheduledStop)
import Ouroboros.Network.DeltaQ
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Server
import Ouroboros.Network.Protocol.KeepAlive.Type


newtype KeepAliveInterval = KeepAliveInterval { keepAliveInterval :: DiffTime }
Expand All @@ -44,13 +46,15 @@ keepAliveClient
, Ord peer
)
=> Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ScheduledStop m
-> peer
-> (StrictTVar m (M.Map peer PeerGSV))
-> KeepAliveInterval
-> KeepAliveClient m ()
keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInterval } =
SendMsgKeepAlive (go Nothing)
keepAliveClient tracer inRng shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInterval } =
let (cookie, rng) = random inRng in
SendMsgKeepAlive (Cookie cookie) (go rng Nothing)
where
payloadSize = 2

Expand All @@ -66,8 +70,8 @@ keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInt
then return Run
else retry

go :: Maybe Time -> m (KeepAliveClient m ())
go startTime_m = do
go :: StdGen -> Maybe Time -> m (KeepAliveClient m ())
go rng startTime_m = do
endTime <- getMonotonicTime
case startTime_m of
Just startTime -> do
Expand All @@ -91,7 +95,9 @@ keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInt
decision <- atomically (decisionSTM delayVar)
now <- getMonotonicTime
case decision of
Run -> pure (SendMsgKeepAlive $ go $ Just now)
Run ->
let (cookie, rng') = random rng in
pure (SendMsgKeepAlive (Cookie cookie) $ go rng' $ Just now)
Stop -> pure (SendMsgDone (pure ()))


Expand Down
Expand Up @@ -6,13 +6,15 @@ module Ouroboros.Network.Protocol.KeepAlive.Client (
keepAliveClientPeer
) where

import Control.Monad.Class.MonadThrow
import Network.TypedProtocol.Core
import Ouroboros.Network.Protocol.KeepAlive.Type


data KeepAliveClient m a where
SendMsgKeepAlive
:: (m (KeepAliveClient m a))
:: Cookie
-> (m (KeepAliveClient m a))
-> KeepAliveClient m a

SendMsgDone
Expand All @@ -24,15 +26,16 @@ data KeepAliveClient m a where
-- 'KeepAlive' protocol.
--
keepAliveClientPeer
:: Functor m
:: MonadThrow m
=> KeepAliveClient m a
-> Peer KeepAlive AsClient StClient m a

keepAliveClientPeer (SendMsgDone mresult) =
Yield (ClientAgency TokClient) MsgDone $
Effect (Done TokDone <$> mresult)

keepAliveClientPeer (SendMsgKeepAlive next) =
Yield (ClientAgency TokClient) MsgKeepAlive $
Await (ServerAgency TokServer) $ \MsgKeepAliveResponse ->
Effect $ keepAliveClientPeer <$> next
keepAliveClientPeer (SendMsgKeepAlive cookieReq next) =
Yield (ClientAgency TokClient) (MsgKeepAlive cookieReq) $
Await (ServerAgency TokServer) $ \(MsgKeepAliveResponse cookieRsp) ->
if cookieReq == cookieRsp then Effect $ keepAliveClientPeer <$> next
else Effect $ throwM $ KeepAliveCookieMissmatch cookieReq cookieRsp
Expand Up @@ -18,9 +18,9 @@ import Control.Monad.Class.MonadTime (DiffTime)

import Data.ByteString.Lazy (ByteString)

import qualified Codec.CBOR.Encoding as CBOR (Encoding, encodeWord)
import qualified Codec.CBOR.Encoding as CBOR (Encoding, encodeWord, encodeWord16)
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Decoding as CBOR (Decoder, decodeWord)
import qualified Codec.CBOR.Decoding as CBOR (Decoder, decodeWord, decodeWord16)

import Network.TypedProtocol.Core

Expand All @@ -40,18 +40,22 @@ codecKeepAlive = mkCodecCborLazyBS encodeMsg decodeMsg
PeerHasAgency pr st
-> Message KeepAlive st st'
-> CBOR.Encoding
encodeMsg (ClientAgency TokClient) MsgKeepAlive = CBOR.encodeWord 0
encodeMsg (ServerAgency TokServer) MsgKeepAliveResponse = CBOR.encodeWord 1
encodeMsg (ClientAgency TokClient) MsgDone = CBOR.encodeWord 2
encodeMsg (ClientAgency TokClient) (MsgKeepAlive (Cookie c)) = CBOR.encodeWord 0 <> CBOR.encodeWord16 c
encodeMsg (ServerAgency TokServer) (MsgKeepAliveResponse (Cookie c)) = CBOR.encodeWord 1 <> CBOR.encodeWord16 c
encodeMsg (ClientAgency TokClient) MsgDone = CBOR.encodeWord 2

decodeMsg :: forall (pr :: PeerRole) s (st :: KeepAlive).
PeerHasAgency pr st
-> CBOR.Decoder s (SomeMessage st)
decodeMsg stok = do
key <- CBOR.decodeWord
case (stok, key) of
(ClientAgency TokClient, 0) -> pure (SomeMessage MsgKeepAlive)
(ServerAgency TokServer, 1) -> pure (SomeMessage MsgKeepAliveResponse)
(ClientAgency TokClient, 0) -> do
cookie <- CBOR.decodeWord16
return (SomeMessage $ MsgKeepAlive $ Cookie cookie)
(ServerAgency TokServer, 1) -> do
cookie <- CBOR.decodeWord16
return (SomeMessage $ MsgKeepAliveResponse $ Cookie cookie)
(ClientAgency TokClient, 2) -> pure (SomeMessage MsgDone)

(ClientAgency TokClient, _) ->
Expand Down Expand Up @@ -97,9 +101,9 @@ codecKeepAliveId = Codec encodeMsg decodeMsg
CodecFailure m (SomeMessage st))
decodeMsg stok = return $ DecodePartial $ \bytes -> return $
case (stok, bytes) of
(ClientAgency TokClient, Just (AnyMessage msg@(MsgKeepAlive)))
(ClientAgency TokClient, Just (AnyMessage msg@(MsgKeepAlive {})))
-> DecodeDone (SomeMessage msg) Nothing
(ServerAgency TokServer, Just (AnyMessage msg@(MsgKeepAliveResponse)))
(ServerAgency TokServer, Just (AnyMessage msg@(MsgKeepAliveResponse {})))
-> DecodeDone (SomeMessage msg) Nothing
(ClientAgency TokClient, Just (AnyMessage msg@(MsgDone)))
-> DecodeDone (SomeMessage msg) Nothing
Expand Down
Expand Up @@ -27,10 +27,10 @@ keepAliveServerPeer KeepAliveServer { recvMsgKeepAlive, recvMsgDone } =
case msg of
MsgDone -> Effect $ Done TokDone <$> recvMsgDone

MsgKeepAlive ->
MsgKeepAlive cookie ->
Effect $
fmap (\server ->
Yield (ServerAgency TokServer)
MsgKeepAliveResponse
(MsgKeepAliveResponse cookie)
(keepAliveServerPeer server))
recvMsgKeepAlive
21 changes: 16 additions & 5 deletions ouroboros-network/src/Ouroboros/Network/Protocol/KeepAlive/Type.hs
Expand Up @@ -23,9 +23,18 @@
--
module Ouroboros.Network.Protocol.KeepAlive.Type where

import Control.Monad.Class.MonadThrow (Exception)
import Data.Word (Word16)
import Network.TypedProtocol.Core
import Ouroboros.Network.Util.ShowProxy (ShowProxy (..))

-- | A 16bit value used to match responses to requests.
newtype Cookie = Cookie {unCookie :: Word16 } deriving (Eq, Show)

data KeepAliveProtocolFailure =
KeepAliveCookieMissmatch Cookie Cookie deriving (Eq, Show)

instance Exception KeepAliveProtocolFailure

-- | A kind to identify our protocol, and the types of the states in the state
-- transition diagram of the protocol.
Expand Down Expand Up @@ -57,12 +66,14 @@ instance Protocol KeepAlive where
-- | Send a keep alive message.
--
MsgKeepAlive
:: Message KeepAlive StClient StServer
:: Cookie
-> Message KeepAlive StClient StServer

-- | Keep alive response.
--
MsgKeepAliveResponse
:: Message KeepAlive StServer StClient
:: Cookie
-> Message KeepAlive StServer StClient

-- | The client side terminating message of the protocol.
--
Expand All @@ -84,9 +95,9 @@ instance Protocol KeepAlive where


instance Show (Message KeepAlive from to) where
show MsgKeepAlive = "MsgKeepAlive"
show MsgKeepAliveResponse = "MsgKeepAliveResponse"
show MsgDone = "MsgDone"
show (MsgKeepAlive cookie) = "MsgKeepAlive " ++ show cookie
show (MsgKeepAliveResponse cookie) = "MsgKeepAliveResponse " ++ show cookie
show MsgDone = "MsgDone"

instance Show (ClientHasAgency (st :: KeepAlive)) where
show TokClient = "TokClient"
Expand Down

0 comments on commit c7c455f

Please sign in to comment.