Skip to content

Commit

Permalink
connection-manager: cancel threads using throwTo
Browse files Browse the repository at this point in the history
Also fix connection manager simulation.  The simulation environment
should not kill the connection thread with an asynchronous exception
when the connection handler throws an exception.
  • Loading branch information
coot committed Dec 3, 2021
1 parent 6f86a28 commit 2703d72
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
Expand Up @@ -27,7 +27,7 @@ module Ouroboros.Network.ConnectionManager.Core

import Control.Exception (assert)
import Control.Monad (forM_, guard, when)
import Control.Monad.Class.MonadFork
import Control.Monad.Class.MonadFork (MonadFork, ThreadId, throwTo)
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow hiding (handle)
import Control.Monad.Class.MonadTimer
Expand Down Expand Up @@ -517,6 +517,8 @@ withConnectionManager
:: forall (muxMode :: MuxMode) peerAddr socket handlerTrace handle handleError version m a.
( Monad m
, MonadLabelledSTM m
-- 'MonadFork' is only to get access to 'throwTo'
, MonadFork m
, MonadAsync m
, MonadEvaluate m
, MonadMask m
Expand Down Expand Up @@ -658,7 +660,7 @@ withConnectionManager ConnectionManagerArguments {
-- with the thread. We put each connection in 'TerminatedState' to
-- try that none of the connection threads will enter
-- 'TerminatingState' (and thus delay shutdown for 'tcp_WAIT_TIME'
-- seconds) when receiving the 'AsyncCancelled' exception. However,
-- seconds) when receiving the 'ThreadKilled' exception. However,
-- we can have a race between the finally handler and the `cleanup`
-- callback. If the finally block loses the race, the AsyncCancelled
-- received should interrupt the threadDelay.
Expand All @@ -675,7 +677,10 @@ withConnectionManager ConnectionManagerArguments {

when shouldTrace $
traceWith trTracer tr
traverse_ cancel (getConnThread connState)
-- using 'cancel' here, since we want to block until connection
-- handler thread terminates.
traverse_ (\thread -> cancel thread)
(getConnThread connState)
)
state
where
Expand Down Expand Up @@ -925,7 +930,14 @@ withConnectionManager ConnectionManagerArguments {
traceWith tracer (TrPruneConnections (Map.keysSet pruneMap)
numberToPrune
(Map.keysSet choiceMap))
forM_ pruneMap $ \(_, connThread', _) -> cancel connThread'
-- we don't block until the thread terminates, delivering the
-- async exception is enough (although in this case, there's no
-- difference, since we put the connection in 'TerminatedState'
-- which avoids the 'cmTimeWaitTimeout').
forM_ pruneMap $ \(_, connThread', _) ->
throwTo (asyncThreadId (Proxy :: Proxy m)
connThread')
AsyncCancelled
)

includeInboundConnectionImpl
Expand Down Expand Up @@ -1321,7 +1333,9 @@ withConnectionManager ConnectionManagerArguments {
traverse_ (traceWith trTracer . TransitionTrace peerAddr) mbTransition
traceCounters stateVar

traverse_ cancel mbThread
-- 'throwTo' avoids blocking until 'cmTimeWaitTimeout' expires.
traverse_ (flip throwTo AsyncCancelled . asyncThreadId (Proxy :: Proxy m))
mbThread

whenJust mbAssertion $ \tr -> do
traceWith tracer tr
Expand Down Expand Up @@ -2039,7 +2053,9 @@ withConnectionManager ConnectionManagerArguments {
--
-- - close the socket,
-- - set the state to 'TerminatedState'
cancel connThread
-- - 'throwTo' avoids blocking until 'cmTimeWaitTimeout' expires.
throwTo (asyncThreadId (Proxy :: Proxy m) connThread)
AsyncCancelled
return (OperationSuccess (abstractState $ Known connState'))

Left connState | connectionTerminated connState
Expand Down
Expand Up @@ -1993,7 +1993,7 @@ prop_connectionManagerSimulation (SkewedBool bindToLocalAddress) scheduleMap =

Right (Just (Disconnected {})) -> pure ()

Right (Just (Connected _ _ handle)) -> do
Right (Just (Connected _ _ _)) -> do
threadDelay (either id id (seActiveDelay conn))
-- if this outbound connection is not
-- executed within inbound connection,
Expand All @@ -2003,7 +2003,7 @@ prop_connectionManagerSimulation (SkewedBool bindToLocalAddress) scheduleMap =
-- 'unregisterOutboundConnection' can
-- block.
case seActiveDelay conn of
Left _ -> killThread (hThreadId handle)
Left _ -> pure ()
Right _ -> do
when ( not (siReused (seExtra conn))
&& seDataFlow conn == Duplex ) $
Expand Down

0 comments on commit 2703d72

Please sign in to comment.