Skip to content

Commit

Permalink
Create persistence handle inside of network code
Browse files Browse the repository at this point in the history
  • Loading branch information
v0d1ch committed Oct 9, 2023
1 parent b60365d commit 8442b26
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 65 deletions.
5 changes: 1 addition & 4 deletions hydra-node/exe/hydra-node/Main.hs
Expand Up @@ -103,10 +103,7 @@ main = do
withAPIServer apiHost apiPort party apiPersistence (contramap APIServer tracer) chain pparams (putEvent . ClientEvent) $ \server -> do

-- Network
msgPersistence <- createPersistenceIncremental $ persistenceDir <> "/network-messages"
ackPersistence <- createPersistence $ persistenceDir <> "/acks"

withNetwork tracer msgPersistence ackPersistence (connectionMessages server) signingKey otherParties host port peers nodeId putNetworkEvent $ \hn -> do
withNetwork tracer persistenceDir (connectionMessages server) signingKey otherParties host port peers nodeId putNetworkEvent $ \hn -> do
-- Main loop
runHydraNode (contramap Node tracer) $
HydraNode
Expand Down
59 changes: 46 additions & 13 deletions hydra-node/src/Hydra/Network/Reliability.hs
Expand Up @@ -74,6 +74,7 @@ import Data.Vector (
fromList,
generate,
length,
replicate,
zipWith,
(!?),
)
Expand All @@ -82,6 +83,7 @@ import Hydra.Network (Network (..), NetworkComponent)
import Hydra.Network.Authenticate (Authenticated (..))
import Hydra.Network.Heartbeat (Heartbeat (..), isPing)
import Hydra.Party (Party)
import Hydra.Persistence (Persistence (..), PersistenceIncremental (..))
import Test.QuickCheck (getPositive, listOf)

data ReliableMsg msg = ReliableMsg
Expand Down Expand Up @@ -128,13 +130,43 @@ instance Arbitrary ReliabilityLog where
-- | Handle for all persistence operations in the Reliability network layer.
-- This handle takes care of storing and retreiving vector clock and all
-- messages.
data MessagePersistence m msg =
MessagePersistence
{ loadAcks :: m (Vector Int)
, saveAcks :: Vector Int -> m ()
, loadMessages :: m [Heartbeat msg]
, appendMessage :: Heartbeat msg -> m ()
}
data MessagePersistence m msg = MessagePersistence
{ loadAcks :: m (Vector Int)
, saveAcks :: Vector Int -> m ()
, loadMessages :: m [Heartbeat msg]
, appendMessage :: Heartbeat msg -> m ()
}

-- | Create 'MessagePersistence' out of 'PersistenceIncremental' and
-- 'Persistence' handles. This handle loads and saves acks (vector clock data)
-- and can load and append network messages.
-- On start we construct empty ack vector from all parties in case nothing
-- is stored on disk.
-- NOTE: This handle is returned in the underlying context just for the sake of
-- convenience.
mkMessagePersistence ::
(MonadThrow m, FromJSON msg, ToJSON msg) =>
Vector a ->
m (PersistenceIncremental (Heartbeat msg) m) ->
m (Persistence (Vector Int) m) ->
m (MessagePersistence m msg)
mkMessagePersistence allParties msgPersistence' ackPersistence' = do
msgPersistence <- msgPersistence'
ackPersistence <- ackPersistence'
pure
MessagePersistence
{ loadAcks = do
macks <- load ackPersistence
case macks of
Nothing -> pure $ replicate (length allParties) 0
Just acks -> pure acks
, saveAcks = \acks -> do
save ackPersistence acks
, loadMessages = do
loadAll msgPersistence
, appendMessage = \msg -> do
append msgPersistence msg
}

-- | Middleware function to handle message counters tracking and resending logic.
--
Expand All @@ -148,24 +180,25 @@ withReliability ::
(MonadThrow (STM m), MonadThrow m, MonadAsync m) =>
-- | Tracer for logging messages.
Tracer m ReliabilityLog ->
MessagePersistence m msg ->
-- | Our persistence handle
m (MessagePersistence m msg) ->
-- | Our own party identifier.
Party ->
-- | All parties' identifiers.
Vector Party ->
-- | Underlying network component providing consuming and sending channels.
NetworkComponent m (Authenticated (ReliableMsg (Heartbeat msg))) (ReliableMsg (Heartbeat msg)) a ->
NetworkComponent m (Authenticated (Heartbeat msg)) (Heartbeat msg) a
withReliability tracer msgPersistence@MessagePersistence{loadAcks, saveAcks} me allParties withRawNetwork callback action = do
withReliability tracer msgPersistence' me allParties withRawNetwork callback action = do
msgPersistence <- msgPersistence'
resendQ <- newTQueueIO
let ourIndex = fromMaybe (error "This cannot happen because we constructed the list with our party inside.") (findPartyIndex me)
let resend = writeTQueue resendQ
withRawNetwork (reliableCallback resend ourIndex) $ \network@Network{broadcast} -> do
withRawNetwork (reliableCallback msgPersistence resend ourIndex) $ \network@Network{broadcast} -> do
withAsync (forever $ atomically (readTQueue resendQ) >>= broadcast) $ \_ ->
reliableBroadcast ourIndex msgPersistence network
where

reliableBroadcast ourIndex MessagePersistence{appendMessage} Network{broadcast} =
reliableBroadcast ourIndex MessagePersistence{appendMessage, loadAcks, saveAcks} Network{broadcast} =
action $
Network
{ broadcast = \msg ->
Expand All @@ -187,7 +220,7 @@ withReliability tracer msgPersistence@MessagePersistence{loadAcks, saveAcks} me
saveAcks newAcks
pure newAcks

reliableCallback resend ourIndex (Authenticated (ReliableMsg acks msg) party) = do
reliableCallback msgPersistence@MessagePersistence{loadAcks, saveAcks} resend ourIndex (Authenticated (ReliableMsg acks msg) party) = do
if length acks /= length allParties
then pure ()
else do
Expand Down
36 changes: 13 additions & 23 deletions hydra-node/src/Hydra/Node/Network.hs
Expand Up @@ -66,16 +66,16 @@ module Hydra.Node.Network (withNetwork, withFlipHeartbeats) where
import Hydra.Prelude hiding (fromList, replicate)

import Control.Tracer (Tracer)
import Data.Vector (fromList, replicate)
import Hydra.Crypto (HydraKey, SigningKey)
import Hydra.Logging.Messages (HydraLog (..))
import Hydra.Network (Host (..), IP, NetworkComponent, NodeId, PortNumber)
import Hydra.Network.Authenticate (Authenticated (Authenticated), Signed, withAuthentication)
import Hydra.Network.Heartbeat (ConnectionMessages, Heartbeat (..), withHeartbeat)
import Hydra.Network.Ouroboros (TraceOuroborosNetwork, WithHost, withOuroborosNetwork)
import Hydra.Network.Reliability (ReliableMsg, withReliability, MessagePersistence (..))
import Hydra.Network.Reliability (MessagePersistence (..), ReliableMsg, mkMessagePersistence, withReliability)
import Hydra.Party (Party, deriveParty)
import Hydra.Persistence (PersistenceIncremental (..), Persistence (load), save)
import Data.Vector (Vector, replicate, fromList)
import Hydra.Persistence (Persistence (load), PersistenceIncremental (..), createPersistence, createPersistenceIncremental, save)

-- | An alias for logging messages output by network component.
-- The type is made complicated because the various subsystems use part of the tracer only.
Expand All @@ -85,10 +85,8 @@ withNetwork ::
(ToCBOR msg, ToJSON msg, FromJSON msg, FromCBOR msg) =>
-- | Tracer to use for logging messages.
Tracer IO (LogEntry tx msg) ->
-- | Persistence handle to store messages
PersistenceIncremental (Heartbeat msg) IO ->
-- | Persistence handle to store acks
Persistence (Vector Int) IO ->
-- | Persistence directory
FilePath ->
-- | Callback/observer for connectivity changes in peers.
ConnectionMessages IO ->
-- | This node's signing key. This is used to sign messages sent to peers.
Expand All @@ -105,27 +103,19 @@ withNetwork ::
NodeId ->
-- | Produces a `NetworkComponent` that can send `msg` and consumes `Authenticated` @msg@.
NetworkComponent IO (Authenticated msg) msg ()
withNetwork tracer msgPersistence ackPersistence connectionMessages signingKey otherParties host port peers nodeId = do
withNetwork tracer persistenceDir connectionMessages signingKey otherParties host port peers nodeId = do
let localhost = Host{hostname = show host, port}
me = deriveParty signingKey
-- construct sorted vector of parties including ourselves
allParties = fromList $ sort $ me : otherParties
messagePersistence =
MessagePersistence
{ loadAcks = do
macks <- load ackPersistence
case macks of
Nothing -> pure $ replicate (length (me : otherParties)) 0
Just acks -> pure acks
, saveAcks = save ackPersistence
, loadMessages = loadAll msgPersistence
, appendMessage = append msgPersistence
}
msgPersistence = createPersistenceIncremental $ persistenceDir <> "/network-messages"
ackPersistence = createPersistence $ persistenceDir <> "/acks"
messagePersistence = mkMessagePersistence allParties msgPersistence ackPersistence
withHeartbeat nodeId connectionMessages $
withFlipHeartbeats $
withReliability (contramap Reliability tracer) messagePersistence me allParties $
withAuthentication (contramap Authentication tracer) signingKey otherParties $
withOuroborosNetwork (contramap Network tracer) localhost peers
withFlipHeartbeats $
withReliability (contramap Reliability tracer) messagePersistence me allParties $
withAuthentication (contramap Authentication tracer) signingKey otherParties $
withOuroborosNetwork (contramap Network tracer) localhost peers

withFlipHeartbeats ::
NetworkComponent m (Authenticated (Heartbeat msg)) msg1 a ->
Expand Down
45 changes: 20 additions & 25 deletions hydra-node/test/Hydra/Network/ReliabilitySpec.hs
Expand Up @@ -2,7 +2,7 @@

module Hydra.Network.ReliabilitySpec where

import Hydra.Prelude hiding (empty, fromList, head, unlines, replicate)
import Hydra.Prelude hiding (empty, fromList, head, replicate, unlines)
import Test.Hydra.Prelude

import Control.Concurrent.Class.MonadSTM (
Expand All @@ -16,7 +16,7 @@ import Control.Concurrent.Class.MonadSTM (
import Control.Monad.IOSim (runSimOrThrow)
import Control.Tracer (Tracer (..), nullTracer)
import Data.Sequence.Strict ((|>))
import Data.Vector (Vector, empty, fromList, head, snoc, replicate)
import Data.Vector (Vector, empty, fromList, head, replicate, snoc)
import qualified Data.Vector as Vector
import Hydra.Network (Network (..))
import Hydra.Network.Authenticate (Authenticated (..))
Expand Down Expand Up @@ -99,9 +99,8 @@ spec = parallel $ do
prop "broadcast messages to the network assigning a sequential id" $ \(messages :: [String]) ->
let sentMsgs = runSimOrThrow $ do
sentMessages <- newTVarIO empty
messagePersistence <- mockMessagePersistence 1

withReliability nullTracer messagePersistence alice (fromList [alice]) (captureOutgoing sentMessages) noop $ \Network{broadcast} -> do
withReliability nullTracer (mockMessagePersistence 1) alice (fromList [alice]) (captureOutgoing sentMessages) noop $ \Network{broadcast} -> do
mapM_ (broadcast . Data "node-1") messages

fromList . Vector.toList <$> readTVarIO sentMessages
Expand All @@ -119,16 +118,14 @@ spec = parallel $ do
randomSeed <- newTVarIO $ mkStdGen seed
aliceToBob <- newTQueueIO
bobToAlice <- newTQueueIO
aliceMessagePersistence <- mockMessagePersistence 2
bobMessagePersistence <- mockMessagePersistence 2
let
-- this is a NetworkComponent that broadcasts authenticated messages
-- mediated through a read and a write TQueue but drops 0.2 % of them
aliceFailingNetwork = failingNetwork randomSeed alice (bobToAlice, aliceToBob)
bobFailingNetwork = failingNetwork randomSeed bob (aliceToBob, bobToAlice)

bobReliabilityStack = reliabilityStack aliceMessagePersistence bobFailingNetwork emittedTraces "bob" bob (fromList [alice, bob])
aliceReliabilityStack = reliabilityStack bobMessagePersistence aliceFailingNetwork emittedTraces "alice" alice (fromList [alice, bob])
bobReliabilityStack = reliabilityStack (mockMessagePersistence 2) bobFailingNetwork emittedTraces "bob" bob (fromList [alice, bob])
aliceReliabilityStack = reliabilityStack (mockMessagePersistence 2) aliceFailingNetwork emittedTraces "alice" alice (fromList [alice, bob])

runAlice = runPeer aliceReliabilityStack "alice" messagesReceivedByAlice messagesReceivedByBob aliceToBobMessages bobToAliceMessages
runBob = runPeer bobReliabilityStack "bob" messagesReceivedByBob messagesReceivedByAlice bobToAliceMessages aliceToBobMessages
Expand All @@ -150,10 +147,9 @@ spec = parallel $ do
it "broadcast updates counter from peers" $ do
let receivedMsgs = runSimOrThrow $ do
sentMessages <- newTVarIO empty
messagePersistence <- mockMessagePersistence 2
withReliability
nullTracer
messagePersistence
(mockMessagePersistence 2)
alice
(fromList [alice, bob])
( \incoming action -> do
Expand All @@ -170,21 +166,22 @@ spec = parallel $ do
receivedMsgs `shouldBe` [ReliableMsg (fromList [1, 1]) (Data "node-1" msg)]

it "appends messages to disk and can load them back" $ do
withTempDir "" $ \tmpDir -> do
withTempDir "network-messages-persistence" $ \tmpDir -> do
Persistence{load, save} <- createPersistence $ tmpDir <> "/acks"
PersistenceIncremental{loadAll, append} <- createPersistenceIncremental $ tmpDir <> "/network-messages"

let messagePersistence =
MessagePersistence
{ loadAcks = do
mloaded <- load
case mloaded of
Nothing -> pure $ replicate (length [alice, bob]) 0
Just acks -> pure acks
, saveAcks = save
, loadMessages = loadAll
, appendMessage = append
}
pure
MessagePersistence
{ loadAcks = do
mloaded <- load
case mloaded of
Nothing -> pure $ replicate (length [alice, bob]) 0
Just acks -> pure acks
, saveAcks = save
, loadMessages = loadAll
, appendMessage = append
}

receivedMsgs <- do
sentMessages <- newTVarIO empty
Expand Down Expand Up @@ -253,15 +250,14 @@ noop = const $ pure ()
aliceReceivesMessages :: [Authenticated (ReliableMsg (Heartbeat msg))] -> [Authenticated (Heartbeat msg)]
aliceReceivesMessages messages = runSimOrThrow $ do
receivedMessages <- newTVarIO empty
messagePersistence <- mockMessagePersistence 3
let baseNetwork incoming _ = mapM incoming messages

aliceReliabilityStack =
withReliability
nullTracer
messagePersistence
(mockMessagePersistence 3)
alice
(fromList [alice , bob, carol])
(fromList [alice, bob, carol])
baseNetwork

void $ aliceReliabilityStack (captureIncoming receivedMessages) $ \_action ->
Expand Down Expand Up @@ -302,4 +298,3 @@ mockMessagePersistence numberOfParties = do
, loadMessages = toList <$> readTVarIO messages
, appendMessage = \msg -> atomically $ modifyTVar' messages (|> msg)
}

0 comments on commit 8442b26

Please sign in to comment.