Skip to content

Commit

Permalink
Fix Snocket bug in the accept/connect
Browse files Browse the repository at this point in the history
Basically, if I connect  to someone and someone connects to me, before the connect returns (and before the remote accept returns as well) the local accept can return first masking itself as the remote one because we have no way to distinguish directions.
  • Loading branch information
bolt12 committed Oct 26, 2021
1 parent cb23efe commit 29dcdfe
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions ouroboros-network-framework/src/Simulation/Network/Snocket.hs
Expand Up @@ -95,7 +95,7 @@ stepScriptSTM scriptVar = do
return x


data Connection m = Connection
data Connection m addr = Connection
{ -- | Attenuated channels of a connection.
--
connChannelLocal :: !(AttenuatedChannel m)
Expand All @@ -109,6 +109,10 @@ data Connection m = Connection
-- open.
--
, connState :: !OpenState

-- | Provider of this Connection, so one can know its origin and decide
-- accordingly when accepting/connecting a connection.
, connProvider :: !addr
}


Expand All @@ -125,7 +129,7 @@ data OpenState
deriving (Eq, Show)


dualConnection :: Connection m -> Connection m
dualConnection :: Connection m addr -> Connection m addr
dualConnection conn@Connection { connChannelLocal, connChannelRemote } =
conn { connChannelLocal = connChannelRemote
, connChannelRemote = connChannelLocal
Expand All @@ -142,7 +146,7 @@ mkConnection :: ( MonadLabelledSTM m
(SnocketTrace m (TestAddress addr)))
-> BearerInfo
-> ConnectionId (TestAddress addr)
-> STM m (Connection m)
-> STM m (Connection m (TestAddress addr))
mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } = do
(channelLocal, channelRemote) <-
newConnectedAttenuatedChannelPair
Expand All @@ -169,6 +173,7 @@ mkConnection tr bearerInfo connId@ConnectionId { localAddress, remoteAddress } =
channelRemote
(biSDUSize bearerInfo)
HalfOpened
localAddress


-- | Connection id independent of who provisioned the connection. 'NormalisedId'
Expand Down Expand Up @@ -202,7 +207,7 @@ data NetworkState m addr = NetworkState {

-- | Registry of active connections.
--
nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m)),
nsConnections :: StrictTVar m (Map (NormalisedId addr) (Connection m addr)),

-- | Get an unused ephemeral address.
--
Expand Down Expand Up @@ -440,15 +445,15 @@ data FD_ m addr
-- assigned to it.
--
| FDConnecting !(ConnectionId addr)
!(Connection m)
!(Connection m addr)

-- | 'FD_' for snockets in connected state.
--
-- 'FDConnected' is created by either 'connect' or 'accept'.
| FDConnected
!(ConnectionId addr)
-- ^ local and remote addresses
!(Connection m)
!(Connection m addr)
-- ^ connection

-- | 'FD_' of a closed file descriptor; we keep 'ConnectionId' just for
Expand Down Expand Up @@ -963,9 +968,14 @@ mkSnocket state tr = Snocket { getLocalAddr
let connId = ConnectionId localAddress (cwiAddress cwi)

case Map.lookup (normaliseId connId) connMap of
Nothing -> return True
Just (Connection _ _ _ HalfOpened) -> return True
_ -> return False
Nothing ->
return True
Just (Connection _ _ _ HalfOpened provider) ->
return ( provider /= localAddress
|| localAddress == cwiAddress cwi
)
_ ->
return False

accept_ = Accept $ \unmask -> do
bracketOnError
Expand Down Expand Up @@ -1046,6 +1056,7 @@ mkSnocket state tr = Snocket { getLocalAddr
, connChannelRemote = channelRemote
, connSDUSize = sduSize
, connState = Established
, connProvider = remoteAddress
})

traceWith tr (WithAddr (Just (localAddress connId)) Nothing
Expand Down

0 comments on commit 29dcdfe

Please sign in to comment.