Skip to content

Commit

Permalink
Replace StrictSVar where possible by StrictMVar with NoThunks checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisdral committed May 31, 2023
1 parent 65fc73a commit 5a6d730
Show file tree
Hide file tree
Showing 20 changed files with 129 additions and 92 deletions.
Expand Up @@ -10,6 +10,7 @@ import Cardano.Tools.DBAnalyser.HasAnalysis
import Cardano.Tools.DBAnalyser.Types
import Codec.CBOR.Decoding (Decoder)
import Codec.Serialise (Serialise (decode))
import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Monad.Except (runExceptT)
import Control.Tracer (Tracer (..), nullTracer)
import qualified Debug.Trace as Debug
Expand Down Expand Up @@ -52,7 +53,7 @@ analyse ::
-> IO (Maybe AnalysisResult)
analyse DBAnalyserConfig{analysis, confLimit, dbDir, selectDB, validation, verbose} args =
withRegistry $ \registry -> do
lock <- newSVar ()
lock <- newMVar ()
chainDBTracer <- mkTracer lock verbose
analysisTracer <- mkTracer lock True
ProtocolInfo { pInfoInitLedger = genesisLedger, pInfoConfig = cfg } <-
Expand Down Expand Up @@ -128,7 +129,7 @@ analyse DBAnalyserConfig{analysis, confLimit, dbDir, selectDB, validation, verbo
hPutStrLn stderr $ concat ["[", show diff, "] ", show ev]
hFlush stderr
where
withLock = bracket_ (takeSVar lock) (putSVar lock ())
withLock = bracket_ (takeMVar lock) (putMVar lock ())

immValidationPolicy = case (analysis, validation) of
(_, Just ValidateAllBlocks) -> ImmutableDB.ValidateAllChunks
Expand Down
Expand Up @@ -9,6 +9,7 @@ module Cardano.Tools.DBTruncater.Run (truncate) where
import Cardano.Slotting.Slot (WithOrigin (..))
import Cardano.Tools.DBAnalyser.HasAnalysis
import Cardano.Tools.DBTruncater.Types
import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Monad
import Control.Tracer
import Data.Functor.Identity
Expand Down Expand Up @@ -103,15 +104,15 @@ findNewTip target iter =
IteratorResult item -> do
if acceptable item then go (Just item) else pure acc

mkLock :: MonadSTM m => m (StrictSVar m ())
mkLock = newSVar ()
mkLock :: MonadMVar m => m (StrictMVar m ())
mkLock = newMVar ()

mkTracer :: Show a => StrictSVar IO () -> Bool -> IO (Tracer IO a)
mkTracer :: Show a => StrictMVar IO () -> Bool -> IO (Tracer IO a)
mkTracer _ False = pure mempty
mkTracer lock True = do
startTime <- getMonotonicTime
pure $ Tracer $ \ev -> do
bracket_ (takeSVar lock) (putSVar lock ()) $ do
bracket_ (takeMVar lock) (putMVar lock ()) $ do
traceTime <- getMonotonicTime
let diff = diffTime traceTime startTime
hPutStrLn stderr $ concat ["[", show diff, "] ", show ev]
Expand Down
Expand Up @@ -28,6 +28,7 @@ import qualified Cardano.Crypto.KES as Relative (Period)
import Cardano.Ledger.Crypto (Crypto)
import qualified Cardano.Ledger.Keys as SL
import qualified Cardano.Protocol.TPraos.OCert as Absolute (KESPeriod (..))
import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Data.Word (Word64)
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack)
Expand Down Expand Up @@ -173,13 +174,13 @@ mkHotKey ::
-> Word64 -- ^ Max KES evolutions
-> m (HotKey c m)
mkHotKey initKey startPeriod@(Absolute.KESPeriod start) maxKESEvolutions = do
varKESState <- newSVar initKESState
varKESState <- newMVar initKESState
return HotKey {
evolve = evolveKey varKESState
, getInfo = kesStateInfo <$> readSVar varKESState
, isPoisoned = kesKeyIsPoisoned . kesStateKey <$> readSVar varKESState
, getInfo = kesStateInfo <$> readMVar varKESState
, isPoisoned = kesKeyIsPoisoned . kesStateKey <$> readMVar varKESState
, sign_ = \toSign -> do
KESState { kesStateInfo, kesStateKey } <- readSVar varKESState
KESState { kesStateInfo, kesStateKey } <- readMVar varKESState
case kesStateKey of
KESKeyPoisoned -> error "trying to sign with a poisoned key"
KESKey key -> do
Expand Down Expand Up @@ -217,8 +218,8 @@ mkHotKey initKey startPeriod@(Absolute.KESPeriod start) maxKESEvolutions = do
-- When the key is poisoned, we always return 'UpdateFailed'.
evolveKey ::
forall m c. (Crypto c, IOLike m)
=> StrictSVar m (KESState c) -> Absolute.KESPeriod -> m KESEvolutionInfo
evolveKey varKESState targetPeriod = modifySVar varKESState $ \kesState -> do
=> StrictMVar m (KESState c) -> Absolute.KESPeriod -> m KESEvolutionInfo
evolveKey varKESState targetPeriod = modifyMVar varKESState $ \kesState -> do
let info = kesStateInfo kesState
-- We mask the evolution process because if we got interrupted after
-- calling 'forgetSignKeyKES', which destructively updates the current
Expand Down
1 change: 1 addition & 0 deletions ouroboros-consensus/ouroboros-consensus.cabal
Expand Up @@ -539,6 +539,7 @@ test-suite consensus-test
, random
, serialise
, si-timers
, strict-mvar
, tasty
, tasty-hunit
, tasty-quickcheck
Expand Down
Expand Up @@ -24,6 +24,7 @@ module Test.Util.LogicalClock (
, tickTracer
) where

import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Monad
import Control.Tracer (Tracer, contramapM)
import Data.Time (NominalDiffTime)
Expand Down Expand Up @@ -152,7 +153,7 @@ newWithDelay :: (IOLike m, HasCallStack)
-> m (LogicalClock m)
newWithDelay registry (NumTicks numTicks) tickLen = do
current <- newTVarIO 0
done <- newEmptySVar ()
done <- newEmptyMVar
_thread <- forkThread registry "ticker" $ do
-- Tick 0 is the first tick, so increment @numTicks - 1@ times
replicateM_ (fromIntegral numTicks - 1) $ do
Expand All @@ -163,11 +164,11 @@ newWithDelay registry (NumTicks numTicks) tickLen = do
-- Give tests that need to do some final processing on the last
-- tick a chance to do that before we indicate completion.
threadDelay (nominalDelay tickLen)
putSVar done ()
putMVar done ()

return LogicalClock {
getCurrentTick = Tick <$> readTVar current
, waitUntilDone = readSVar done
, waitUntilDone = readMVar done
, mockSystemTime = BTime.SystemTime {
BTime.systemTimeCurrent = do
tick <- atomically $ readTVar current
Expand Down
@@ -1,9 +1,13 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE NamedFieldPuns #-}
-- TODO: remove ScopedTypeVariables
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module Test.Util.Orphans.NoThunks () where


import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Monad.IOSim
import Control.Monad.ST.Lazy
import Control.Monad.ST.Unsafe (unsafeSTToIO)
Expand All @@ -16,3 +20,10 @@ instance NoThunks a => NoThunks (StrictSVar (IOSim s) a) where
wNoThunks ctxt StrictSVar { tvar } = do
a <- unsafeSTToIO $ lazyToStrictST $ inspectTVar (Proxy :: Proxy (IOSim s)) tvar
noThunks ctxt a

-- TODO: we need to be able to inspect the value inside the mvar a la MonadInspectSTM.
instance NoThunks a => NoThunks (StrictMVar (IOSim s) a) where
showTypeOf _ = "StrictMVar IOSim"
wNoThunks ctxt _v = do
a <- undefined :: IO a -- TODO
noThunks ctxt a
Expand Up @@ -11,6 +11,7 @@ module Ouroboros.Consensus.Mock.Node.Praos (

import Cardano.Crypto.KES
import Cardano.Crypto.VRF
import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Data.Bifunctor (second)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
Expand Down Expand Up @@ -96,16 +97,17 @@ praosBlockForging ::
-> HotKey PraosMockCrypto
-> m (BlockForging m MockPraosBlock)
praosBlockForging cid initHotKey = do
varHotKey <- newSVar initHotKey
varHotKey <- newMVar initHotKey
return $ BlockForging {
forgeLabel = "praosBlockForging"
, canBeLeader = cid
, updateForgeState = \_ sno _ -> updateSVar varHotKey $
second forgeStateUpdateInfoFromUpdateInfo
, updateForgeState = \_ sno _ -> modifyMVar varHotKey $
pure
. second forgeStateUpdateInfoFromUpdateInfo
. evolveKey sno
, checkCanForge = \_ _ _ _ _ -> return ()
, forgeBlock = \cfg bno sno tickedLedgerSt txs isLeader -> do
hotKey <- readSVar varHotKey
hotKey <- readMVar varHotKey
return $
forgeSimple
(forgePraosExt hotKey)
Expand Down
Expand Up @@ -213,7 +213,7 @@ deriving instance PraosCrypto c => Show (HotKey c)
newtype HotKeyEvolutionError = HotKeyEvolutionError Period
deriving (Show)

-- | To be used in conjunction with, e.g., 'updateSVar'.
-- | To be used in conjunction with, e.g., 'modifyMVar'.
--
-- NOTE: when the key's period is after the target period, we shouldn't use
-- it, but we currently do. In real TPraos we check this in
Expand Down
Expand Up @@ -35,6 +35,7 @@ module Ouroboros.Consensus.Storage.ChainDB.Impl (
, openDBInternal
) where

import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Tracer
Expand Down Expand Up @@ -170,7 +171,7 @@ openDBInternal args launchBgTasks = runWithTempRegistry $ do
varFollowers <- newTVarIO Map.empty
varNextIteratorKey <- newTVarIO (IteratorKey 0)
varNextFollowerKey <- newTVarIO (FollowerKey 0)
varCopyLock <- newSVar ()
varCopyLock <- newMVar ()
varKillBgThreads <- newTVarIO $ return ()
blocksToAdd <- newBlocksToAdd (Args.cdbBlocksToAddSize args)

Expand Down
Expand Up @@ -37,6 +37,7 @@ module Ouroboros.Consensus.Storage.ChainDB.Impl.Background (
, addBlockRunner
) where

import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Exception (assert)
import Control.Monad (forM_, forever, void)
import Control.Tracer
Expand Down Expand Up @@ -197,8 +198,8 @@ copyToImmutableDB CDB{..} = withCopyLock $ do

withCopyLock :: forall a. HasCallStack => m a -> m a
withCopyLock = bracket_
(fmap mustBeUnlocked $ tryTakeSVar cdbCopyLock)
(putSVar cdbCopyLock ())
(fmap mustBeUnlocked $ tryTakeMVar cdbCopyLock)
(putMVar cdbCopyLock ())

mustBeUnlocked :: forall b. HasCallStack => Maybe b -> b
mustBeUnlocked = fromMaybe
Expand Down
Expand Up @@ -60,6 +60,7 @@ module Ouroboros.Consensus.Storage.ChainDB.Impl.Types (
, TraceValidationEvent (..)
) where

import Control.Concurrent.Class.MonadMVar.Strict.NoThunks
import Control.Tracer
import Data.Map.Strict (Map)
import Data.Maybe.Strict (StrictMaybe (..))
Expand Down Expand Up @@ -225,7 +226,7 @@ data ChainDbEnv m blk = CDB
-- not when hashes are garbage-collected from the map.
, cdbNextIteratorKey :: !(StrictTVar m IteratorKey)
, cdbNextFollowerKey :: !(StrictTVar m FollowerKey)
, cdbCopyLock :: !(StrictSVar m ())
, cdbCopyLock :: !(StrictMVar m ())
-- ^ Lock used to ensure that 'copyToImmutableDB' is not executed more than
-- once concurrently.
--
Expand Down

0 comments on commit 5a6d730

Please sign in to comment.