Skip to content

Commit

Permalink
Add cookie to keep-alive protocol
Browse files Browse the repository at this point in the history
Add a random cookie to keep-alive messages so that a rouge responder
can't send a MsgKeepAliveResponse before it has read the MsgKeepAlive.
  • Loading branch information
karknu committed Aug 11, 2020
1 parent c743526 commit 92c32a5
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 53 deletions.
Expand Up @@ -54,6 +54,7 @@ import qualified Data.Set as Set
import qualified Data.Typeable as Typeable
import Data.Void (Void)
import GHC.Stack
import System.Random (mkStdGen)

import qualified Ouroboros.Network.AnchoredFragment as AF
import Ouroboros.Network.BlockFetch (BlockFetchConfiguration (..))
Expand Down Expand Up @@ -911,6 +912,8 @@ runThreadNetwork systemTime ThreadNetworkArgs
(ledgerState <$>
ChainDB.getCurrentLedger chainDB)

let kaRng = case seed of
Seed s -> mkStdGen s
let nodeArgs = NodeArgs
{ tracers
, registry
Expand All @@ -922,6 +925,7 @@ runThreadNetwork systemTime ThreadNetworkArgs
, blockFetchSize = nodeBlockFetchSize
, maxTxCapacityOverride = NoMaxTxCapacityOverride
, mempoolCapacityOverride = NoMempoolCapacityBytesOverride
, keepAliveRng = kaRng
, miniProtocolParameters = MiniProtocolParameters {
chainSyncPipeliningHighMark = 4,
chainSyncPipeliningLowMark = 2,
Expand Down
Expand Up @@ -162,7 +162,7 @@ mkHandlers
-> NodeKernel m remotePeer localPeer blk
-> Handlers m remotePeer blk
mkHandlers
NodeArgs {miniProtocolParameters}
NodeArgs {keepAliveRng, miniProtocolParameters}
NodeKernel {getChainDB, getMempool, getTopLevelConfig, getTracers = tracers} =
Handlers {
hChainSyncClient =
Expand Down Expand Up @@ -197,7 +197,7 @@ mkHandlers
(getMempoolReader getMempool)
(getMempoolWriter getMempool)
version
, hKeepAliveClient = \_version -> keepAliveClient (Node.keepAliveClientTracer tracers)
, hKeepAliveClient = \_version -> keepAliveClient (Node.keepAliveClientTracer tracers) keepAliveRng
, hKeepAliveServer = \_version _peer -> keepAliveServer
}

Expand Down
9 changes: 7 additions & 2 deletions ouroboros-consensus/src/Ouroboros/Consensus/Node.hs
Expand Up @@ -40,7 +40,7 @@ import Control.Tracer (Tracer, contramap)
import Data.ByteString.Lazy (ByteString)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import System.Random (randomIO, randomRIO)
import System.Random (StdGen, newStdGen, randomIO, randomRIO)

import Ouroboros.Network.BlockFetch (BlockFetchConfiguration (..))
import Ouroboros.Network.Diffusion
Expand Down Expand Up @@ -206,6 +206,7 @@ run runargs@RunNodeArgs{..} =
(ledgerState <$>
ChainDB.getCurrentLedger chainDB)
bfsalt <- randomIO -- Per-node specific value used by blockfetch when ranking peers.
keepAliveRng <- newStdGen
nodeArgs <- nodeArgsEnforceInvariants . rnCustomiseNodeArgs <$>
mkNodeArgs
registry
Expand All @@ -215,9 +216,11 @@ run runargs@RunNodeArgs{..} =
btime
chainDB
bfsalt
keepAliveRng
nodeKernel <- initNodeKernel nodeArgs
rnNodeKernelHook registry nodeKernel


let ntnApps = mkNodeToNodeApps nodeArgs nodeKernel
ntcApps = mkNodeToClientApps nodeArgs nodeKernel
diffusionApplications = mkDiffusionApplications
Expand Down Expand Up @@ -426,8 +429,9 @@ mkNodeArgs
-> BlockchainTime IO
-> ChainDB IO blk
-> Int
-> StdGen
-> IO (NodeArgs IO RemoteConnectionId LocalConnectionId blk)
mkNodeArgs registry cfg mInitBlockForging tracers btime chainDB bfsalt = do
mkNodeArgs registry cfg mInitBlockForging tracers btime chainDB bfsalt keepAliveRng = do
mBlockForging <- sequence mInitBlockForging
return NodeArgs
{ tracers
Expand All @@ -442,6 +446,7 @@ mkNodeArgs registry cfg mInitBlockForging tracers btime chainDB bfsalt = do
, mempoolCapacityOverride = NoMempoolCapacityBytesOverride
, miniProtocolParameters = defaultMiniProtocolParameters
, blockFetchConfiguration = defaultBlockFetchConfiguration
, keepAliveRng = keepAliveRng
}
where
defaultBlockFetchConfiguration :: BlockFetchConfiguration
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-consensus/src/Ouroboros/Consensus/NodeKernel.hs
Expand Up @@ -29,6 +29,7 @@ import Data.Map.Strict (Map)
import Data.Maybe (isJust)
import Data.Proxy
import Data.Word (Word32)
import System.Random (StdGen)

import Control.Tracer

Expand Down Expand Up @@ -126,6 +127,7 @@ data NodeArgs m remotePeer localPeer blk = NodeArgs {
, mempoolCapacityOverride :: MempoolCapacityBytesOverride
, miniProtocolParameters :: MiniProtocolParameters
, blockFetchConfiguration :: BlockFetchConfiguration
, keepAliveRng :: StdGen
}

initNodeKernel
Expand Down
2 changes: 2 additions & 0 deletions ouroboros-network/ouroboros-network.cabal
Expand Up @@ -149,6 +149,7 @@ library
network >=3.1 && <3.2,
psqueues >=0.2.3 && <0.3,
serialise >=0.2 && <0.3,
random,
stm >=2.4 && <2.6,
time >=1.6 && <1.10,

Expand Down Expand Up @@ -274,6 +275,7 @@ test-suite test-network
pipes,
process,
psqueues,
random,
serialise,
splitmix,
stm,
Expand Down
Expand Up @@ -13,7 +13,7 @@ direct KeepAliveServer { recvMsgDone }
(SendMsgDone mdone) =
(,) <$> recvMsgDone <*> mdone
direct KeepAliveServer { recvMsgKeepAlive }
(SendMsgKeepAlive mclient) = do
(SendMsgKeepAlive _cookie mclient) = do
server <- recvMsgKeepAlive
client <- mclient
direct server client
Expand Up @@ -4,6 +4,7 @@ module Ouroboros.Network.Protocol.KeepAlive.Examples where

import Ouroboros.Network.Protocol.KeepAlive.Server
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Type


-- | A client which applies a function whenever it receives
Expand All @@ -23,7 +24,7 @@ keepAliveClientApply f = go
= SendMsgDone (pure acc)

| otherwise
= SendMsgKeepAlive $
= SendMsgKeepAlive (Cookie $ fromIntegral n) $
pure $ go (f acc) (pred n)


Expand Down
Expand Up @@ -15,7 +15,6 @@ import Control.Monad.IOSim (runSimOrThrow)
import Control.Tracer (nullTracer)

import qualified Codec.CBOR.Read as CBOR
import Data.Functor.Identity (Identity (..))
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BL

Expand Down Expand Up @@ -67,7 +66,7 @@ tests = testGroup "Ouroboros.Network.Protocol.KeepAlive"

prop_direct :: (Int -> Int) -> NonNegative Int -> Property
prop_direct f (NonNegative n) =
runIdentity
runSimOrThrow
(direct
keepAliveServerCount
(keepAliveClientApply f 0 n))
Expand All @@ -82,7 +81,7 @@ prop_connect :: (Int -> Int)
-> NonNegative Int
-> Bool
prop_connect f (NonNegative n) =
case runIdentity
case runSimOrThrow
(connect
(keepAliveServerPeer keepAliveServerCount)
(keepAliveClientPeer $ keepAliveClientApply f 0 n))
Expand Down Expand Up @@ -130,19 +129,21 @@ prop_channel_IO f (NonNegative n) =
--

instance Arbitrary (AnyMessageAndAgency KeepAlive) where
arbitrary = oneof
[ pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgKeepAlive
, pure $ AnyMessageAndAgency (ServerAgency TokServer) MsgKeepAliveResponse
, pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgDone
]
arbitrary = do
c <- arbitrary
oneof
[ pure $ AnyMessageAndAgency (ClientAgency TokClient) (MsgKeepAlive $ Cookie c)
, pure $ AnyMessageAndAgency (ServerAgency TokServer) (MsgKeepAliveResponse $ Cookie c)
, pure $ AnyMessageAndAgency (ClientAgency TokClient) MsgDone
]

instance Show (AnyMessageAndAgency KeepAlive) where
show (AnyMessageAndAgency _ msg) = show msg

instance Eq (AnyMessage KeepAlive) where
AnyMessage MsgKeepAlive == AnyMessage MsgKeepAlive = True
AnyMessage (MsgKeepAliveResponse) == AnyMessage (MsgKeepAliveResponse) = True
AnyMessage MsgDone == AnyMessage MsgDone = True
AnyMessage (MsgKeepAlive cookieA) == AnyMessage (MsgKeepAlive cookieB) = cookieA == cookieB
AnyMessage (MsgKeepAliveResponse cookieA) == AnyMessage (MsgKeepAliveResponse cookieB) = cookieA == cookieB
AnyMessage MsgDone == AnyMessage MsgDone = True
_ == _ = False

prop_codec :: AnyMessageAndAgency KeepAlive -> Bool
Expand Down
16 changes: 11 additions & 5 deletions ouroboros-network/src/Ouroboros/Network/KeepAlive.hs
Expand Up @@ -20,11 +20,13 @@ import Control.Monad.Class.MonadTimer
import Control.Tracer (Tracer, traceWith)
import Data.Maybe (fromJust)
import qualified Data.Map.Strict as M
import System.Random (StdGen, random)

import Ouroboros.Network.Mux (RunOrStop (..), ScheduledStop)
import Ouroboros.Network.DeltaQ
import Ouroboros.Network.Protocol.KeepAlive.Client
import Ouroboros.Network.Protocol.KeepAlive.Server
import Ouroboros.Network.Protocol.KeepAlive.Type


newtype KeepAliveInterval = KeepAliveInterval { keepAliveInterval :: DiffTime }
Expand All @@ -44,13 +46,15 @@ keepAliveClient
, Ord peer
)
=> Tracer m (TraceKeepAliveClient peer)
-> StdGen
-> ScheduledStop m
-> peer
-> (StrictTVar m (M.Map peer PeerGSV))
-> KeepAliveInterval
-> KeepAliveClient m ()
keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInterval } =
SendMsgKeepAlive (go Nothing)
keepAliveClient tracer inRng shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInterval } =
let (cookie, rng) = random inRng in
SendMsgKeepAlive (Cookie cookie) (go rng Nothing)
where
payloadSize = 2

Expand All @@ -66,8 +70,8 @@ keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInt
then return Run
else retry

go :: Maybe Time -> m (KeepAliveClient m ())
go startTime_m = do
go :: StdGen -> Maybe Time -> m (KeepAliveClient m ())
go rng startTime_m = do
endTime <- getMonotonicTime
case startTime_m of
Just startTime -> do
Expand All @@ -91,7 +95,9 @@ keepAliveClient tracer shouldStopSTM peer dqCtx KeepAliveInterval { keepAliveInt
decision <- atomically (decisionSTM delayVar)
now <- getMonotonicTime
case decision of
Run -> pure (SendMsgKeepAlive $ go $ Just now)
Run ->
let (cookie, rng') = random rng in
pure (SendMsgKeepAlive (Cookie cookie) $ go rng' $ Just now)
Stop -> pure (SendMsgDone (pure ()))


Expand Down
Expand Up @@ -6,13 +6,15 @@ module Ouroboros.Network.Protocol.KeepAlive.Client (
keepAliveClientPeer
) where

import Control.Monad.Class.MonadThrow
import Network.TypedProtocol.Core
import Ouroboros.Network.Protocol.KeepAlive.Type


data KeepAliveClient m a where
SendMsgKeepAlive
:: (m (KeepAliveClient m a))
:: Cookie
-> (m (KeepAliveClient m a))
-> KeepAliveClient m a

SendMsgDone
Expand All @@ -24,15 +26,16 @@ data KeepAliveClient m a where
-- 'KeepAlive' protocol.
--
keepAliveClientPeer
:: Functor m
:: MonadThrow m
=> KeepAliveClient m a
-> Peer KeepAlive AsClient StClient m a

keepAliveClientPeer (SendMsgDone mresult) =
Yield (ClientAgency TokClient) MsgDone $
Effect (Done TokDone <$> mresult)

keepAliveClientPeer (SendMsgKeepAlive next) =
Yield (ClientAgency TokClient) MsgKeepAlive $
Await (ServerAgency TokServer) $ \MsgKeepAliveResponse ->
Effect $ keepAliveClientPeer <$> next
keepAliveClientPeer (SendMsgKeepAlive cookieReq next) =
Yield (ClientAgency TokClient) (MsgKeepAlive cookieReq) $
Await (ServerAgency TokServer) $ \(MsgKeepAliveResponse cookieRsp) ->
if cookieReq == cookieRsp then Effect $ keepAliveClientPeer <$> next
else Effect $ throwM $ KeepAliveCookieMissmatch cookieReq cookieRsp
Expand Up @@ -18,9 +18,9 @@ import Control.Monad.Class.MonadTime (DiffTime)

import Data.ByteString.Lazy (ByteString)

import qualified Codec.CBOR.Encoding as CBOR (Encoding, encodeWord)
import qualified Codec.CBOR.Encoding as CBOR (Encoding, encodeWord, encodeWord16)
import qualified Codec.CBOR.Read as CBOR
import qualified Codec.CBOR.Decoding as CBOR (Decoder, decodeWord)
import qualified Codec.CBOR.Decoding as CBOR (Decoder, decodeWord, decodeWord16)

import Network.TypedProtocol.Core

Expand All @@ -40,18 +40,22 @@ codecKeepAlive = mkCodecCborLazyBS encodeMsg decodeMsg
PeerHasAgency pr st
-> Message KeepAlive st st'
-> CBOR.Encoding
encodeMsg (ClientAgency TokClient) MsgKeepAlive = CBOR.encodeWord 0
encodeMsg (ServerAgency TokServer) MsgKeepAliveResponse = CBOR.encodeWord 1
encodeMsg (ClientAgency TokClient) MsgDone = CBOR.encodeWord 2
encodeMsg (ClientAgency TokClient) (MsgKeepAlive (Cookie c)) = CBOR.encodeWord 0 <> CBOR.encodeWord16 c
encodeMsg (ServerAgency TokServer) (MsgKeepAliveResponse (Cookie c)) = CBOR.encodeWord 1 <> CBOR.encodeWord16 c
encodeMsg (ClientAgency TokClient) MsgDone = CBOR.encodeWord 2

decodeMsg :: forall (pr :: PeerRole) s (st :: KeepAlive).
PeerHasAgency pr st
-> CBOR.Decoder s (SomeMessage st)
decodeMsg stok = do
key <- CBOR.decodeWord
case (stok, key) of
(ClientAgency TokClient, 0) -> pure (SomeMessage MsgKeepAlive)
(ServerAgency TokServer, 1) -> pure (SomeMessage MsgKeepAliveResponse)
(ClientAgency TokClient, 0) -> do
cookie <- CBOR.decodeWord16
return (SomeMessage $ MsgKeepAlive $ Cookie cookie)
(ServerAgency TokServer, 1) -> do
cookie <- CBOR.decodeWord16
return (SomeMessage $ MsgKeepAliveResponse $ Cookie cookie)
(ClientAgency TokClient, 2) -> pure (SomeMessage MsgDone)

(ClientAgency TokClient, _) ->
Expand Down Expand Up @@ -97,9 +101,9 @@ codecKeepAliveId = Codec encodeMsg decodeMsg
CodecFailure m (SomeMessage st))
decodeMsg stok = return $ DecodePartial $ \bytes -> return $
case (stok, bytes) of
(ClientAgency TokClient, Just (AnyMessage msg@(MsgKeepAlive)))
(ClientAgency TokClient, Just (AnyMessage msg@(MsgKeepAlive {})))
-> DecodeDone (SomeMessage msg) Nothing
(ServerAgency TokServer, Just (AnyMessage msg@(MsgKeepAliveResponse)))
(ServerAgency TokServer, Just (AnyMessage msg@(MsgKeepAliveResponse {})))
-> DecodeDone (SomeMessage msg) Nothing
(ClientAgency TokClient, Just (AnyMessage msg@(MsgDone)))
-> DecodeDone (SomeMessage msg) Nothing
Expand Down
Expand Up @@ -27,10 +27,10 @@ keepAliveServerPeer KeepAliveServer { recvMsgKeepAlive, recvMsgDone } =
case msg of
MsgDone -> Effect $ Done TokDone <$> recvMsgDone

MsgKeepAlive ->
MsgKeepAlive cookie ->
Effect $
fmap (\server ->
Yield (ServerAgency TokServer)
MsgKeepAliveResponse
(MsgKeepAliveResponse cookie)
(keepAliveServerPeer server))
recvMsgKeepAlive

0 comments on commit 92c32a5

Please sign in to comment.