Skip to content

Commit

Permalink
net-sim: handle async exception in connect
Browse files Browse the repository at this point in the history
Connect has to blocking actions that can throw async exceptions:
* wait for connection delay
* wait for connection to be accepted

In the first case only async exceptions can be thrown, so we can
simplify the error handling with `onException`.

In the second case:

* if an async exception is caught when we are waiting for the connection
  to be accepted we simplify remove it from the known connection.
* on the other side, we let the `accept` skip over unknown connections.
  • Loading branch information
coot committed Oct 26, 2021
1 parent 91d9eca commit d409cef
Showing 1 changed file with 48 additions and 44 deletions.
92 changes: 48 additions & 44 deletions ouroboros-network-framework/src/Simulation/Network/Snocket.hs
Expand Up @@ -732,11 +732,9 @@ mkSnocket state tr = Snocket { getLocalAddr
(STBearerInfo bearerInfo))
-- connection delay
unmask (threadDelay (biConnectionDelay bearerInfo `min` connectTimeout))
-- Can receive not only AsyncException but also MuxError here!
`catch` \(e :: SomeException) -> do
atomically $ modifyTVar (nsConnections state)
(Map.delete (normaliseId connId))
throwIO e
`onException`
atomically (modifyTVar (nsConnections state)
(Map.delete $ normaliseId connId))

traceWith tr (WithAddr (Just (localAddress connId))
(Just remoteAddress)
Expand Down Expand Up @@ -810,37 +808,47 @@ mkSnocket state tr = Snocket { getLocalAddr
traceWith' fd (STDebug "connect: succesful open")
-- successful open

-- wait for a connection to be accepted
-- wait for a connection to be accepted; we can also be
-- interrupted by an asynchronous exception in which case we
-- just forget about the connection.
timeoutVar <-
registerDelay (connectTimeout - biConnectionDelay bearerInfo)
r <- unmask (atomically $ runFirstToFinish $
(FirstToFinish $ do
LazySTM.readTVar timeoutVar >>= check
return Nothing
)
<>
(FirstToFinish $ do
mbConn <- Map.lookup (normaliseId connId)
<$> readTVar (nsConnections state)
case mbConn of
-- it could happen that the 'accept' removes the
-- connection from the state; we treat this as an io
-- exception.
Nothing -> throwSTM $ connectIOError connId
$ "unknown connection: "
++ show (normaliseId connId)
Just Connection { connState } ->
Just <$> check (connState == ESTABLISHED))
)
`onException`
atomically (modifyTVar (nsConnections state)
(Map.delete (normaliseId connId)))
r <-
handleJust
(\e -> case fromException e of
Just SomeAsyncException {} -> Just e
Nothing -> Nothing)
(\e -> atomically $ modifyTVar (nsConnections state)
(Map.delete (normaliseId connId))
>> throwIO e)
$ unmask (atomically $ runFirstToFinish $
(FirstToFinish $ do
LazySTM.readTVar timeoutVar >>= check
modifyTVar (nsConnections state)
(Map.delete (normaliseId connId))
return Nothing
)
<>
(FirstToFinish $ do
mbConn <- Map.lookup (normaliseId connId)
<$> readTVar (nsConnections state)
case mbConn of
-- it could happen that the 'accept' removes the
-- connection from the state; we treat this as an io
-- exception.
Nothing -> do
modifyTVar (nsConnections state)
(Map.delete (normaliseId connId))
throwSTM $ connectIOError connId
$ "unknown connection: "
++ show (normaliseId connId)
Just Connection { connState } ->
Just <$> check (connState == ESTABLISHED))
)

case r of
Nothing -> do
traceWith' fd (STConnectTimeout WaitingToBeAccepted)
atomically $ modifyTVar (nsConnections state)
(Map.delete (normaliseId connId))
throwIO (connectIOError connId "connect timeout: when waiting for being accepted")
Just _ -> traceWith' fd (STConnected fd_' o)

Expand Down Expand Up @@ -966,16 +974,15 @@ mkSnocket state tr = Snocket { getLocalAddr
accept FD { fdVar } = pure accept_
where
-- non-blocking; return 'True' if a connection is in 'SYN_SENT' state
-- or if it was removed from simulation state.
synSentOrUnknown :: TestAddress addr
-> ChannelWithInfo m (TestAddress addr)
-> STM m Bool
synSentOrUnknown localAddress cwi = do
synSent :: TestAddress addr
-> ChannelWithInfo m (TestAddress addr)
-> STM m Bool
synSent localAddress cwi = do
connMap <- readTVar (nsConnections state)
let connId = ConnectionId localAddress (cwiAddress cwi)

case Map.lookup (normaliseId connId) connMap of
Nothing -> return True
Nothing -> return False
Just (Connection _ _ _ SYN_SENT) -> return True
_ -> return False

Expand Down Expand Up @@ -1004,14 +1011,11 @@ mkSnocket state tr = Snocket { getLocalAddr
)

FDListening localAddress queue -> do
-- We should not accept nor fail the 'accept' call
-- in the presence of a connection that is in
-- SYN_SENT state. So we take from the TBQueue
-- until we have found one that is __not__ in SYN_SENT
-- state.
cwi <- readTBQueueUntil
(synSentOrUnknown localAddress)
queue
-- We should not accept nor fail the 'accept' call in the
-- presence of a connection that is __not__ in SYN_SENT
-- state. So we take from the TBQueue until we have found
-- one that is SYN_SENT state.
cwi <- readTBQueueUntil (synSent localAddress) queue

let connId = ConnectionId localAddress (cwiAddress cwi)

Expand Down

0 comments on commit d409cef

Please sign in to comment.