Skip to content

Commit

Permalink
Implement ChainDB to fix header-body split
Browse files Browse the repository at this point in the history
  • Loading branch information
bolt12 committed Mar 27, 2023
1 parent 5235b3f commit 5f0b937
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 46 deletions.
1 change: 1 addition & 0 deletions ouroboros-network/ouroboros-network.cabal
Expand Up @@ -152,6 +152,7 @@ test-suite test
Test.Ouroboros.Network.Diffusion.Node
Test.Ouroboros.Network.Diffusion.Node.NodeKernel
Test.Ouroboros.Network.Diffusion.Node.MiniProtocols
Test.Ouroboros.Network.Diffusion.Node.ChainDB
Test.Ouroboros.Network.Diffusion.Policies
Test.Ouroboros.Network.BlockFetch
Test.Ouroboros.Network.KeepAlive
Expand Down
23 changes: 11 additions & 12 deletions ouroboros-network/test/Test/Ouroboros/Network/Diffusion/Node.hs
Expand Up @@ -65,7 +65,7 @@ import Ouroboros.Network.Mock.ConcreteBlock (Block (..),
import Ouroboros.Network.Mock.ProducerState (ChainProducerState (..))

import qualified Ouroboros.Network.AnchoredFragment as AF
import Ouroboros.Network.Block (MaxSlotNo (..), Point,
import Ouroboros.Network.Block (MaxSlotNo (..),
maxSlotNoFromWithOrigin, pointSlot)
import Ouroboros.Network.BlockFetch
import Ouroboros.Network.ConnectionManager.Types (DataFlow (..))
Expand Down Expand Up @@ -102,6 +102,8 @@ import Ouroboros.Network.Testing.Data.Script (Script (..))

import Simulation.Network.Snocket (AddressType (..), FD)

import Test.Ouroboros.Network.Diffusion.Node.ChainDB (addBlock,
getBlockPointSet)
import qualified Test.Ouroboros.Network.Diffusion.Node.MiniProtocols as Node
import qualified Test.Ouroboros.Network.Diffusion.Node.NodeKernel as Node
import Test.Ouroboros.Network.Diffusion.Node.NodeKernel
Expand Down Expand Up @@ -265,12 +267,10 @@ run blockGeneratorArgs limits ni na tracersExtra tracerBlockFetch =
blockFetch :: NodeKernel BlockHeader Block m
-> m Void
blockFetch nodeKernel = do
blockHeapVar <- LazySTM.newTVarIO Set.empty

blockFetchLogic
nullTracer
tracerBlockFetch
(blockFetchPolicy blockHeapVar nodeKernel)
(blockFetchPolicy nodeKernel)
(nkFetchClientRegistry nodeKernel)
(BlockFetchConfiguration {
bfcMaxConcurrencyBulkSync = 1,
Expand All @@ -280,25 +280,24 @@ run blockGeneratorArgs limits ni na tracersExtra tracerBlockFetch =
bfcSalt = 0
})

blockFetchPolicy :: LazySTM.TVar m (Set (Point Block))
-> NodeKernel BlockHeader Block m
blockFetchPolicy :: NodeKernel BlockHeader Block m
-> BlockFetchConsensusInterface NtNAddr BlockHeader Block m
blockFetchPolicy blockHeapVar nodeKernel =
blockFetchPolicy nodeKernel =
BlockFetchConsensusInterface {
readCandidateChains = readTVar (nkClientChains nodeKernel)
>>= traverse (readTVar
>=> (return . toAnchoredFragmentHeader)),
readCurrentChain = readTVar (nkChainProducerState nodeKernel)
>>= (return . toAnchoredFragmentHeader . chainState),
readFetchMode = return FetchModeBulkSync,
readFetchedBlocks = flip Set.member <$> LazySTM.readTVar blockHeapVar,
readFetchedBlocks = flip Set.member <$> getBlockPointSet (nkChainDB nodeKernel),
readFetchedMaxSlotNo = foldl' max NoMaxSlotNo .
map (maxSlotNoFromWithOrigin . pointSlot) .
Set.elems <$>
LazySTM.readTVar blockHeapVar,
mkAddFetchedBlock = \_enablePipelining -> do
pure $ \p _b ->
atomically (LazySTM.modifyTVar' blockHeapVar (Set.insert p)),
getBlockPointSet (nkChainDB nodeKernel),
mkAddFetchedBlock = \_enablePipelining ->
pure $ \_p b ->
atomically (addBlock b (nkChainDB nodeKernel)),

plausibleCandidateChain,
compareCandidateChains,
Expand Down
@@ -0,0 +1,107 @@
module Test.Ouroboros.Network.Diffusion.Node.ChainDB
( ChainDB (..)
, SelectChain (..)
, newChainDB
, addBlock
, getBlockPointSet
) where

import Control.Concurrent.Class.MonadSTM (MonadSTM (..))
import Data.Coerce (coerce)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Ouroboros.Network.AnchoredFragment (Point)
import Ouroboros.Network.Block (ChainHash (..), HasFullHeader,
HasHeader, blockHash, blockPoint, blockPrevHash)
import Ouroboros.Network.Mock.Chain (Chain (..), selectChain)
import qualified Ouroboros.Network.Mock.Chain as Chain

-- | ChainDB is an in memory store of all fetched (downloaded) blocks.
--
-- This type holds an index mapping previous hashes to their blocks (i.e. if a
-- block "A" has prevHash "H" then the entry "H -> [A]" exists in the map) and
-- the current version of the longest chain.
--
-- Used to simulate real world ChainDB, it offers the invariant that
-- 'cdbLongestChainVar' is always the longest known chain of downloaded blocks.
-- Whenever a node generates a new block it gets added here, and whenever it gets
-- a block via block fetch it gets added here as well. Everytime 'addBlock' is
-- called the possibly new longest chain gets computed, since the API is atomic
-- we can guarantee that in each moment ChainDB has the current longest chain.
--
-- This type is used in diffusion simulation.
--
data ChainDB block m = ChainDB { cdbIndexVar :: TVar m (Map (ChainHash block) [block]),
cdbLongestChainVar :: TVar m (Chain block)
}

-- | Constructs a new ChainDB, the index has only 1 value which is the
-- 'GenesisHash' but this hash does not map to any block.
--
newChainDB :: MonadSTM m => m (ChainDB block m)
newChainDB = do
indexVar <- newTVarIO (Map.singleton GenesisHash [])
longestChain <- newTVarIO Genesis
return (ChainDB indexVar longestChain)

-- | Adds a block to ChainDB.
--
-- This function also recomputes the longest chain with the new block
-- information.
--
addBlock :: (MonadSTM m, HasFullHeader block)
=> block -> ChainDB block m -> STM m ()
addBlock block chainDB@(ChainDB indexVar lchainVar) = do
modifyTVar' indexVar $ \index ->
case Map.lookup (blockPrevHash block) index of
Nothing -> Map.insertWith (++) GenesisHash [block] index
Just _ -> Map.insertWith (++) (blockPrevHash block) [block] index
longestChain <- getLongestChain chainDB
writeTVar lchainVar longestChain

-- | Constructs the block Point set of all downloaded blocks
--
getBlockPointSet :: (MonadSTM m, HasHeader block)
=> ChainDB block m -> STM m (Set (Point block))
getBlockPointSet (ChainDB indexVar _) = do
index <- readTVar indexVar
return (foldMap (Set.fromList . map blockPoint) index)

-- | Computes the longest chain from Genesis
--
getLongestChain :: (HasHeader block, MonadSTM m)
=> ChainDB block m
-> STM m (Chain block)
getLongestChain (ChainDB indexVar _) = do
index <- readTVar indexVar
return (go Nothing Genesis index)
where
go :: HasHeader block
=> Maybe block
-> Chain block
-> Map (ChainHash block) [block]
-> Chain block
go mbblock chain m =
let hash = maybe GenesisHash (BlockHash . blockHash) mbblock
in case Map.lookup hash m of
Nothing -> maybe Genesis (`Chain.addBlock` chain) mbblock
Just blocks ->
let longestChain = getSelectedChain
$ foldMap (\b -> SelectChain
$ go (Just b) chain m)
blocks
in maybe longestChain (`Chain.addBlock` longestChain) mbblock

-- | Chain selection as a 'Monoid'.
--
newtype SelectChain block = SelectChain { getSelectedChain :: Chain block }

instance HasHeader block => Semigroup (SelectChain block) where
(<>) = (coerce :: ( Chain block -> Chain block -> Chain block)
-> SelectChain block -> SelectChain block -> SelectChain block)
selectChain

instance HasHeader block => Monoid (SelectChain block) where
mempty = SelectChain Genesis
Expand Up @@ -28,14 +28,14 @@ module Test.Ouroboros.Network.Diffusion.Node.NodeKernel

import GHC.Generics (Generic)

import qualified Control.Concurrent.Class.MonadSTM as LazySTM
import Control.Concurrent.Class.MonadSTM.Strict
import Control.Monad (replicateM, when)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Monad.Class.MonadTime
import Control.Monad.Class.MonadTimer
import qualified Data.ByteString.Char8 as BSC
import Data.Coerce (coerce)
import Data.Hashable (Hashable)
import Data.IP (IP (..), toIPv4, toIPv6)
import qualified Data.IP as IP
Expand All @@ -52,14 +52,13 @@ import Data.Monoid.Synchronisation
import Network.Socket (PortNumber)

import Ouroboros.Network.AnchoredFragment (Anchor (..))
import Ouroboros.Network.Block (HasHeader, SlotNo)
import Ouroboros.Network.Block (HasFullHeader, SlotNo)
import qualified Ouroboros.Network.Block as Block
import Ouroboros.Network.BlockFetch
import Ouroboros.Network.NodeToNode.Version (DiffusionMode (..))
import Ouroboros.Network.Protocol.Handshake.Unversioned
import Ouroboros.Network.Snocket (TestAddress (..))

import Ouroboros.Network.Mock.Chain (Chain)
import qualified Ouroboros.Network.Mock.Chain as Chain
import Ouroboros.Network.Mock.ConcreteBlock (Block)
import qualified Ouroboros.Network.Mock.ConcreteBlock as ConcreteBlock
Expand All @@ -70,6 +69,9 @@ import Simulation.Network.Snocket (AddressType (..),

import Test.Ouroboros.Network.Orphans ()

import Ouroboros.Network.Mock.Chain (Chain (..))
import qualified Test.Ouroboros.Network.Diffusion.Node.ChainDB as ChainDB
import Test.Ouroboros.Network.Diffusion.Node.ChainDB (ChainDB (..))
import Test.QuickCheck (Arbitrary (..), choose, chooseInt, frequency,
oneof)

Expand Down Expand Up @@ -175,14 +177,17 @@ data NodeKernel header block m = NodeKernel {
nkChainProducerState
:: StrictTVar m (ChainProducerState block),

nkFetchClientRegistry :: FetchClientRegistry NtNAddr header block m
nkFetchClientRegistry :: FetchClientRegistry NtNAddr header block m,

nkChainDB :: ChainDB block m
}

newNodeKernel :: MonadSTM m => m (NodeKernel header block m)
newNodeKernel = NodeKernel
<$> newTVarIO Map.empty
<*> newTVarIO (ChainProducerState Chain.Genesis Map.empty 0)
<*> newFetchClientRegistry
<*> ChainDB.newChainDB

-- | Register a new upstream chain-sync client.
--
Expand All @@ -191,7 +196,7 @@ registerClientChains :: MonadSTM m
-> NtNAddr
-> m (StrictTVar m (Chain block))
registerClientChains NodeKernel { nkClientChains } peerAddr = atomically $ do
chainVar <- newTVar Chain.Genesis
chainVar <- newTVar Genesis
modifyTVar nkClientChains (Map.insert peerAddr chainVar)
return chainVar

Expand Down Expand Up @@ -241,20 +246,6 @@ withSlotTime slotDuration k = do
atomically $ writeTVar slotVar next
go (succ next)


-- | Chain selection as a 'Monoid'.
--
newtype SelectChain block = SelectChain { getSelectedChain :: Chain block }

instance HasHeader block => Semigroup (SelectChain block) where
(<>) = (coerce :: ( Chain block -> Chain block -> Chain block)
-> SelectChain block -> SelectChain block -> SelectChain block)
Chain.selectChain

instance HasHeader block => Monoid (SelectChain block) where
mempty = SelectChain Chain.Genesis


-- | Node kernel erros.
--
data NodeKernelError = UnexpectedSlot !SlotNo !SlotNo
Expand All @@ -272,7 +263,7 @@ withNodeKernelThread
, MonadTimer m
, MonadThrow m
, MonadThrow (STM m)
, HasHeader block
, HasFullHeader block
)
=> BlockGeneratorArgs block seed
-> (NodeKernel header block m -> Async m Void -> m a)
Expand All @@ -288,7 +279,7 @@ withNodeKernelThread BlockGeneratorArgs { bgaSlotDuration, bgaBlockGenerator, bg
blockProducerThread :: NodeKernel header block m
-> (SlotNo -> STM m SlotNo)
-> m Void
blockProducerThread NodeKernel { nkClientChains, nkChainProducerState }
blockProducerThread NodeKernel { nkChainProducerState, nkChainDB }
waitForSlot
= loop (Block.SlotNo 1) bgaSeed
where
Expand Down Expand Up @@ -317,9 +308,21 @@ withNodeKernelThread BlockGeneratorArgs { bgaSlotDuration, bgaBlockGenerator, bg
(Just block, seed')
| Block.blockPoint block
>= Chain.headPoint chainState
-> let chainState' = Chain.addBlock block chainState in
-> do
-- Forged a block add it to our ChainDB this will
-- make the new block available for computing
-- longestChain
ChainDB.addBlock block nkChainDB

-- Get possibly new longest chain
longestChain <-
LazySTM.readTVar (cdbLongestChainVar nkChainDB)

-- Switch to it and update our current state so we
-- can serve other nodes through block fetch.
let cps' = switchFork longestChain cps
writeTVar nkChainProducerState
cps { chainState = chainState' }
cps' { chainState = longestChain }
>> return (succ nextSlot, seed')
(_, seed')
-> return (succ nextSlot, seed')
Expand All @@ -329,15 +332,24 @@ withNodeKernelThread BlockGeneratorArgs { bgaSlotDuration, bgaBlockGenerator, bg
-- chain selection
--
<> FirstToFinish
( do chains <- readTVar nkClientChains
>>= traverse readTVar
cps <- readTVar nkChainProducerState
let candidateChain = getSelectedChain
$ foldMap SelectChain chains
<> SelectChain (chainState cps)
cps' = switchFork candidateChain cps
check $ Chain.headPoint (chainState cps)
/= Chain.headPoint candidateChain
( do
-- Get our current chain
cps@ChainProducerState { chainState } <-
readTVar nkChainProducerState
-- Get what ChainDB sees as the longest chain
longestChain <-
LazySTM.readTVar (cdbLongestChainVar nkChainDB)

-- Only update the chain if it's different than our current
-- one, else retry
check $ Chain.headPoint chainState
/= Chain.headPoint longestChain

-- If it's different, switch to it and update our current
-- state so we can serve other nodes through block fetch.
let cps' = switchFork longestChain cps
writeTVar nkChainProducerState
cps' { chainState = longestChain }
writeTVar nkChainProducerState cps'
-- do not update 'nextSlot'; This stm branch might run
-- multiple times within the current slot.
Expand Down
16 changes: 14 additions & 2 deletions ouroboros-network/test/Test/Ouroboros/Network/Testnet.hs
Expand Up @@ -70,7 +70,7 @@ import Test.QuickCheck
import Test.Tasty
import Test.Tasty.QuickCheck (testProperty)

import Ouroboros.Network.BlockFetch (TraceFetchClientState)
import Ouroboros.Network.BlockFetch (TraceFetchClientState (..))
import Ouroboros.Network.Mock.ConcreteBlock (BlockHeader)
import Ouroboros.Network.NodeToNode (DiffusionMode (..))
import TestLib.ConnectionManager (abstractStateIsFinalTransition,
Expand Down Expand Up @@ -349,12 +349,24 @@ prop_fetch_client_state_trace_coverage defaultBearerInfo diffScript =
. traceEvents
$ runSimTrace sim

transitionsSeenNames = map show events
transitionsSeenNames = map traceFetchClientStateMap events

-- TODO: Add checkCoverage here
in tabulate "fetch client state trace" transitionsSeenNames
True
where
traceFetchClientStateMap :: TraceFetchClientState BlockHeader
-> String
traceFetchClientStateMap AddedFetchRequest{} = "AddedFetchRequest"
traceFetchClientStateMap AcknowledgedFetchRequest{} =
"AcknowledgedFetchRequest"
traceFetchClientStateMap SendFetchRequest{} = "SendFetchRequest"
traceFetchClientStateMap StartedFetchBatch{} = "StartedFetchBatch"
traceFetchClientStateMap CompletedBlockFetch{} = "CompletedBlockFetch"
traceFetchClientStateMap CompletedFetchBatch{} = "CompletedFetchBatch"
traceFetchClientStateMap RejectedFetchBatch{} = "RejectedFetchBatch"
traceFetchClientStateMap (ClientTerminating n) = "ClientTerminating "
++ show n

-- | Unit test which covers issue #4177
--
Expand Down

0 comments on commit 5f0b937

Please sign in to comment.