Skip to content

Commit

Permalink
connection-manager: clean connection shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
coot committed Jun 8, 2021
1 parent db19141 commit c4dfb0f
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 25 deletions.
1 change: 1 addition & 0 deletions ouroboros-network-framework/demo/connection-manager.hs
Expand Up @@ -227,6 +227,7 @@ withBidirectionalConnectionManager snocket socket
cmAddressType = \_ -> Just IPv4Address,
cmSnocket = snocket,
cmTimeWaitTimeout = timeWaitTimeout,
cmOutboundIdleTimeout = protocolIdleTimeout,
connectionDataFlow = const Duplex,
cmPrunePolicy = simplePrunePolicy,
cmConnectionsLimits = AcceptedConnectionsLimit {
Expand Down
Expand Up @@ -18,6 +18,7 @@ module Ouroboros.Network.ConnectionManager.Core
, withConnectionManager
, defaultTimeWaitTimeout
, defaultProtocolIdleTimeout
, defaultResetTimeout

, ConnectionState (..)
, abstractState
Expand All @@ -30,6 +31,7 @@ import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow hiding (handle)
import Control.Monad.Class.MonadTimer
import Control.Monad.Class.MonadSTM.Strict
import qualified Control.Monad.Class.MonadSTM as LazySTM
import Control.Tracer (Tracer, traceWith, contramap)
import Data.Foldable (traverse_)
import Data.Functor (($>))
Expand All @@ -42,6 +44,8 @@ import GHC.Stack (CallStack, HasCallStack, callStack)
import Data.Map (Map)
import qualified Data.Map as Map

import Data.Monoid.Synchronisation

import Network.Mux.Types (MuxMode)
import Network.Mux.Trace (MuxTrace, WithMuxBearer (..))

Expand Down Expand Up @@ -100,6 +104,11 @@ data ConnectionManagerArguments handlerTrace socket peerAddr handle handleError
--
cmTimeWaitTimeout :: DiffTime,

-- | Inactivity timeout before the connection will be reset. It is the
-- timeout attached to the 'OutboundIdleState'.
--
cmOutboundIdleTimeout :: DiffTime,

-- | @version@ represents the tuple of @versionNumber@ and
-- @agreedOptions@.
--
Expand Down Expand Up @@ -146,12 +155,28 @@ data ConnectionState peerAddr handle handleError version m =

-- | Either @OutboundState Duplex@ or @OutobundState^\tau Duplex@.
| OutboundDupState !(ConnectionId peerAddr) !(Async m ()) !handle !TimeoutExpired

-- | Before connection is reset it is put in 'OutboundIdleState' for the
-- duration of 'cmOutboundIdleTimeout'.
--
| OutboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow
| InboundIdleState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow
| InboundState !(ConnectionId peerAddr) !(Async m ()) !handle !DataFlow
| DuplexState !(ConnectionId peerAddr) !(Async m ()) !handle
| TerminatingState !(ConnectionId peerAddr) !(Async m ()) !(Maybe handleError)
| TerminatedState !(Maybe handleError)


-- | Return 'True' for states in which the connection was already closed.
--
connectionTerminated :: ConnectionState peerAddr handle handleError version m
-> Bool
connectionTerminated TerminatingState {} = True
connectionTerminated TerminatedState {} = True
connectionTerminated _ = False



-- | Perform counting from an 'AbstractState'
connectionStateToCounters
:: ConnectionState peerAddr handle handleError version m
Expand All @@ -170,6 +195,7 @@ connectionStateToCounters state =
<> duplexConn
<> outgoingConn

OutboundIdleState _ _ _ _ -> mempty
InboundIdleState _ _ _ Unidirectional -> prunableConn
<> uniConn
<> incomingConn
Expand Down Expand Up @@ -230,6 +256,14 @@ instance ( Show peerAddr
, " "
, show expired
]
show (OutboundIdleState connId connThread _handle df) =
concat [ "OutboundIdleState "
, show connId
, " "
, show (asyncThreadId (Proxy :: Proxy m) connThread)
, " "
, show df
]
show (InboundIdleState connId connThread _handle df) =
concat ([ "InboundIdleState "
, show connId
Expand Down Expand Up @@ -270,6 +304,7 @@ getConnThread ReservedOutboundState = Nothin
getConnThread (UnnegotiatedState _pr _connId connThread) = Just connThread
getConnThread (OutboundUniState _connId connThread _handle ) = Just connThread
getConnThread (OutboundDupState _connId connThread _handle _te) = Just connThread
getConnThread (OutboundIdleState _connId connThread _handle _df) = Just connThread
getConnThread (InboundIdleState _connId connThread _handle _df) = Just connThread
getConnThread (InboundState _connId connThread _handle _df) = Just connThread
getConnThread (DuplexState _connId connThread _handle) = Just connThread
Expand All @@ -286,6 +321,7 @@ getConnType ReservedOutboundState = Nothing
getConnType (UnnegotiatedState pr _connId _connThread) = Just (UnnegotiatedConn pr)
getConnType (OutboundUniState _connId _connThread _handle) = Just (NegotiatedConn Outbound Unidirectional)
getConnType (OutboundDupState _connId _connThread _handle _te) = Just (NegotiatedConn Outbound Duplex)
getConnType (OutboundIdleState _connId _connThread _handle df) = Just (OutboundIdleConn df)
getConnType (InboundIdleState _connId _connThread _handle df) = Just (InboundIdleConn df)
getConnType (InboundState _connId _connThread _handle df) = Just (NegotiatedConn Inbound df)
getConnType (DuplexState _connId _connThread _handle) = Just DuplexConn
Expand All @@ -304,6 +340,7 @@ abstractState = \s -> case s of
go (UnnegotiatedState pr _ _) = UnnegotiatedSt pr
go (OutboundUniState _ _ _) = OutboundUniSt
go (OutboundDupState _ _ _ te) = OutboundDupSt te
go (OutboundIdleState _ _ _ df) = OutboundIdleSt df
go (InboundIdleState _ _ _ df) = InboundIdleSt df
go (InboundState _ _ _ df) = InboundSt df
go DuplexState {} = DuplexSt
Expand All @@ -322,6 +359,9 @@ defaultTimeWaitTimeout = 60
defaultProtocolIdleTimeout :: DiffTime
defaultProtocolIdleTimeout = 5

defaultResetTimeout :: DiffTime
defaultResetTimeout = 5


-- | A wedge product
-- <https://hackage.haskell.org/package/smash/docs/Data-Wedge.html#t:Wedge>
Expand All @@ -345,6 +385,7 @@ data DemoteToColdLocal peerAddr handlerTrace handle handleError version m
--
= DemotedToColdLocal (ConnectionId peerAddr)
(Async m ())
(StrictTVar m (ConnectionState peerAddr handle handleError version m))
!(Transition (ConnectionState peerAddr handle handleError version m))

-- | Any @DemoteToCold@ transition which does not terminate the connection, i.e.
Expand Down Expand Up @@ -415,6 +456,7 @@ withConnectionManager ConnectionManagerArguments {
cmAddressType,
cmSnocket,
cmTimeWaitTimeout,
cmOutboundIdleTimeout,
connectionDataFlow,
cmPrunePolicy,
cmConnectionsLimits
Expand Down Expand Up @@ -581,6 +623,9 @@ withConnectionManager ConnectionManagerArguments {
OutboundDupState {} -> do
writeTVar connVar (TerminatedState Nothing)
return $ There connState
OutboundIdleState {} -> do
writeTVar connVar (TerminatedState Nothing)
return $ There connState
InboundIdleState {} -> do
writeTVar connVar (TerminatedState Nothing)
return $ There connState
Expand Down Expand Up @@ -792,6 +837,13 @@ withConnectionManager ConnectionManagerArguments {
OutboundDupState _connId _connThread _handle _expired ->
throwSTM (withCallStack (ImpossibleState peerAddr))

OutboundIdleState _ _ _ dataFlow' -> do
let connState' = InboundIdleState
connId connThread handle
dataFlow'
writeTVar connVar connState'
return (mkTransition connState connState')

InboundIdleState {} ->
throwSTM (withCallStack (ImpossibleState peerAddr))

Expand Down Expand Up @@ -886,6 +938,14 @@ withConnectionManager ConnectionManagerArguments {
, Nothing
, UnsupportedState OutboundUniSt )

-- unexpected state, this state is reachable only from outbound
-- states
OutboundIdleState _connId _connThread _handle _dataFlow ->
assert False $
return ( Nothing
, Nothing
, OperationSuccess CommitTr )

-- @
-- Commit^{dataFlow} : InboundIdleState dataFlow
-- → TerminatingState
Expand Down Expand Up @@ -993,6 +1053,13 @@ withConnectionManager ConnectionManagerArguments {
(ConnectionExists provenance peerAddr))
)

OutboundIdleState _connId _connThread _handle _dataFlow ->
let tr = abstractState (Known connState) in
return ( Just (Right (TrForbiddenOperation peerAddr tr))
, connVar
, Left (withCallStack (ForbiddenOperation peerAddr tr))
)

InboundIdleState connId _connThread _handle Unidirectional -> do
return ( Just (Right (TrForbiddenConnection connId))
, connVar
Expand Down Expand Up @@ -1227,6 +1294,10 @@ withConnectionManager ConnectionManagerArguments {
OutboundDupState {} ->
throwSTM (withCallStack (ConnectionExists provenance connId))

OutboundIdleState _connId _connThread _handle _dataFlow ->
let tr = abstractState (Known connState) in
throwSTM (withCallStack (ForbiddenOperation peerAddr tr))

InboundIdleState _connId connThread handle dataFlow@Duplex -> do
-- @
-- Awake^{Duplex}_{Local} : InboundIdleState Duplex
Expand Down Expand Up @@ -1333,26 +1404,28 @@ withConnectionManager ConnectionManagerArguments {
(TrForbiddenOperation peerAddr st)
st

OutboundUniState connId connThread _handle -> do
OutboundUniState connId connThread handle -> do
-- @
-- DemotedToCold^{Unidirectional}_{Local}
-- : OutboundState Unidirectional
-- → TerminatingState
-- @
let connState' = TerminatingState connId connThread Nothing
let connState' = OutboundIdleState connId connThread handle
Unidirectional
writeTVar connVar connState'
return (DemotedToColdLocal connId connThread
return (DemotedToColdLocal connId connThread connVar
(mkTransition connState connState'))

OutboundDupState connId connThread _handle Expired -> do
OutboundDupState connId connThread handle Expired -> do
-- @
-- DemotedToCold^{Duplex}_{Local}
-- : OutboundState Duplex
-- → InboundIdleState^\tau
-- @
let connState' = TerminatingState connId connThread Nothing
let connState' = OutboundIdleState connId connThread handle
Duplex
writeTVar connVar connState'
return (DemotedToColdLocal connId connThread
return (DemotedToColdLocal connId connThread connVar
(mkTransition connState connState'))

OutboundDupState connId connThread handle Ticking -> do
Expand All @@ -1365,6 +1438,9 @@ withConnectionManager ConnectionManagerArguments {
writeTVar connVar connState'
return (DemoteToColdLocalNoop (mkTransition connState connState'))

OutboundIdleState _connId _connThread _handleError _dataFlow ->
return (DemoteToColdLocalNoop (mkTransition connState connState))

InboundIdleState _connId _connThread _handle dataFlow ->
assert (dataFlow == Duplex) $
return (DemoteToColdLocalNoop (mkTransition connState connState))
Expand Down Expand Up @@ -1429,14 +1505,39 @@ withConnectionManager ConnectionManagerArguments {

traceCounters stateVar
case transition of
DemotedToColdLocal _connId connThread tr -> do
DemotedToColdLocal connId connThread connVar tr -> do
traceWith trTracer (TransitionTrace peerAddr tr)
cancel connThread
-- We relay on the `finally` handler of connection thread to:
--
-- - close the socket,
-- - set the state to 'TerminatedState'
return (OperationSuccess (abstractState (fromState tr)))
timeoutVar <- registerDelay cmOutboundIdleTimeout
r <- atomically $ runFirstToFinish $
FirstToFinish (do connState <- readTVar connVar
check (case connState of
OutboundIdleState {} -> False
_ -> True
)
return (Left connState)
)
<> FirstToFinish (do b <- LazySTM.readTVar timeoutVar
check b
Right <$> readTVar connVar
)
case r of
Right connState -> do
let connState' = TerminatingState connId connThread Nothing
atomically $ writeTVar connVar connState'
traceWith trTracer (TransitionTrace peerAddr
(mkTransition connState connState'))
-- We relay on the `finally` handler of connection thread to:
--
-- - close the socket,
-- - set the state to 'TerminatedState'
cancel connThread
return (OperationSuccess (abstractState $ Known connState'))

Left connState | connectionTerminated connState
->
return (OperationSuccess (abstractState $ Known connState))
Left connState ->
return (UnsupportedState (abstractState $ Known connState))

PruneConnections _connId pruneMap tr -> do
traceWith trTracer (TransitionTrace peerAddr tr)
Expand Down Expand Up @@ -1483,6 +1584,21 @@ withConnectionManager ConnectionManagerArguments {
let connState' = DuplexState connId connThread handle
writeTVar connVar connState'
return (OperationSuccess (mkTransition connState connState'))
-- @
-- Awake^{Duplex}_{Remote} : OutboundIdleState Duplex
-- → InboundState Duplex
-- @
OutboundIdleState connId connThread handle dataFlow@Duplex -> do
-- @
-- Awake^{Duplex}_{Remote} : OutboundIdleState Duplex
-- → InboundState Duplex
-- @
let connState' = InboundState connId connThread handle dataFlow
writeTVar connVar connState
return (OperationSuccess (mkTransition connState connState'))
OutboundIdleState _connId _connThread _handle
dataFlow@Unidirectional ->
return (UnsupportedState (OutboundIdleSt dataFlow))
InboundIdleState connId connThread handle dataFlow -> do
-- @
-- Awake^{dataFlow}_{Remote} : InboundIdleState Duplex
Expand Down Expand Up @@ -1534,6 +1650,10 @@ withConnectionManager ConnectionManagerArguments {
return (UnsupportedState OutboundUniSt)
OutboundDupState _connId _connThread _handle expired ->
return (UnsupportedState (OutboundDupSt expired))
-- one can only enter 'OutboundIdleState' if remote state is
-- already cold.
OutboundIdleState _connId _connThread _handle dataFlow ->
return (UnsupportedState (OutboundIdleSt dataFlow))
InboundIdleState _connId _connThread _handle dataFlow ->
return (UnsupportedState (InboundIdleSt dataFlow))

Expand Down
Expand Up @@ -218,6 +218,10 @@ data ConnectionType
--
= UnnegotiatedConn !Provenance

-- | An outbound idle connection.
--
| OutboundIdleConn !DataFlow

-- | An inbound idle connection.
--
| InboundIdleConn !DataFlow
Expand Down Expand Up @@ -634,12 +638,14 @@ data AbstractState
| InboundSt !DataFlow
| OutboundUniSt
| OutboundDupSt !TimeoutExpired
| OutboundIdleSt !DataFlow
| DuplexSt
| WaitRemoteIdleSt
| TerminatingSt
| TerminatedSt
deriving (Eq, Show, Typeable)


-- | Counters for tracing and analysis purposes
--
data ConnectionManagerCounters = ConnectionManagerCounters {
Expand Down
Expand Up @@ -463,6 +463,7 @@ randomPrunePolicy stateVar mp n = do
case connType of
UnnegotiatedConn Outbound -> True
UnnegotiatedConn Inbound -> False
OutboundIdleConn _ -> True
InboundIdleConn _ -> False
NegotiatedConn Outbound _ -> True
NegotiatedConn Inbound _ -> False
Expand Down

0 comments on commit c4dfb0f

Please sign in to comment.